22
33from pytensor import tensor as pt
44from pytensor .tensor .slinalg import block_diag
5- from scipy import linalg
65
76from pymc_extras .statespace .models .structural .core import Component
87from pymc_extras .statespace .models .structural .utils import _frequency_transition_block
@@ -190,22 +189,17 @@ def __init__(
190189 )
191190
192191 def make_symbolic_graph (self ) -> None :
193- if self .k_endog == 1 :
194- self .ssm ["design" , 0 , slice (0 , self .k_states , 2 )] = 1
195- self .ssm ["selection" , :, :] = np .eye (self .k_states )
196- init_state = self .make_and_register_variable (f"{ self .name } " , shape = (self .k_states ,))
197-
198- else :
199- Z = np .array ([1.0 , 0.0 ]).reshape ((1 , - 1 ))
200- design_matrix = linalg .block_diag (* [Z for _ in range (self .k_endog )])
201- self .ssm ["design" , :, :] = pt .as_tensor_variable (design_matrix )
192+ Z = np .array ([1.0 , 0.0 ]).reshape ((1 , - 1 ))
193+ design_matrix = block_diag (* [Z for _ in range (self .k_endog )])
194+ self .ssm ["design" , :, :] = pt .as_tensor_variable (design_matrix )
202195
203- R = np .eye (2 ) # 2x2 identity for each cycle component
204- selection_matrix = linalg .block_diag (* [R for _ in range (self .k_endog )])
205- self .ssm ["selection" , :, :] = pt .as_tensor_variable (selection_matrix )
206-
207- init_state = self .make_and_register_variable (f"{ self .name } " , shape = (self .k_endog , 2 ))
196+ R = np .eye (2 ) # 2x2 identity for each cycle component
197+ selection_matrix = block_diag (* [R for _ in range (self .k_endog )])
198+ self .ssm ["selection" , :, :] = pt .as_tensor_variable (selection_matrix )
208199
200+ init_state = self .make_and_register_variable (
201+ f"{ self .name } " , shape = (self .k_endog , 2 ) if self .k_endog > 1 else (self .k_states ,)
202+ )
209203 self .ssm ["initial_state" , :] = init_state .ravel ()
210204
211205 if self .estimate_cycle_length :
@@ -219,11 +213,8 @@ def make_symbolic_graph(self) -> None:
219213 rho = 1
220214
221215 T = rho * _frequency_transition_block (lamb , j = 1 )
222- if self .k_endog == 1 :
223- self .ssm ["transition" , :, :] = T
224- else :
225- transition = block_diag (* [T for _ in range (self .k_endog )])
226- self .ssm ["transition" ] = pt .specify_shape (transition , (self .k_states , self .k_states ))
216+ transition = block_diag (* [T for _ in range (self .k_endog )])
217+ self .ssm ["transition" ] = pt .specify_shape (transition , (self .k_states , self .k_states ))
227218
228219 if self .innovations :
229220 if self .k_endog == 1 :
@@ -239,13 +230,11 @@ def make_symbolic_graph(self) -> None:
239230 self .ssm ["state_cov" ] = pt .specify_shape (state_cov , (self .k_states , self .k_states ))
240231
241232 def populate_component_properties (self ):
242- if self .k_endog == 1 :
243- self .state_names = [f"{ self .name } _{ f } " for f in ["Cos" , "Sin" ]]
244- else :
245- # For multivariate cycles, create state names for each observed state
246- self .state_names = []
247- for var_name in self .observed_state_names :
248- self .state_names .extend ([f"{ self .name } _{ var_name } _{ f } " for f in ["Cos" , "Sin" ]])
233+ self .state_names = [
234+ f"{ self .name } _{ f } [{ var_name } ]" if self .k_endog > 1 else f"{ self .name } _{ f } "
235+ for var_name in self .observed_state_names
236+ for f in ["Cos" , "Sin" ]
237+ ]
249238
250239 self .param_names = [f"{ self .name } " ]
251240
@@ -276,17 +265,17 @@ def populate_component_properties(self):
276265 if self .estimate_cycle_length :
277266 self .param_names += [f"{ self .name } _length" ]
278267 self .param_info [f"{ self .name } _length" ] = {
279- "shape" : (),
268+ "shape" : () if self . k_endog == 1 else ( self . k_endog ,) ,
280269 "constraints" : "Positive, non-zero" ,
281- "dims" : None ,
270+ "dims" : None if self . k_endog == 1 else f" { self . name } _endog" ,
282271 }
283272
284273 if self .dampen :
285274 self .param_names += [f"{ self .name } _dampening_factor" ]
286275 self .param_info [f"{ self .name } _dampening_factor" ] = {
287- "shape" : (),
276+ "shape" : () if self . k_endog == 1 else ( self . k_endog ,) ,
288277 "constraints" : "0 < x ≤ 1" ,
289- "dims" : None ,
278+ "dims" : None if self . k_endog == 1 else f" { self . name } _endog" ,
290279 }
291280
292281 if self .innovations :
0 commit comments