@@ -89,30 +89,43 @@ function linearization_function(sys::AbstractSystem, inputs,
8989 t0 = current_time (prob)
9090 inputvals = [p[idx] for idx in input_idxs]
9191
92- uf_fun = let fun = prob. f
93- function uff (du, u, p, t)
94- SciMLBase. UJacobianWrapper (fun, t, p)(du, u)
95- end
96- end
97- uf_jac = PreparedJacobian {true} (uf_fun, similar (prob. u0), autodiff, prob. u0, DI. Constant (p), DI. Constant (t0))
98- # observed function is a `GeneratedFunctionWrapper` with iip component
99- h_jac = PreparedJacobian {true} (h, similar (prob. u0, size (outputs)), autodiff, prob. u0, DI. Constant (p), DI. Constant (t0))
100- pf_fun = let fun = prob. f, setter = setp_oop (sys, input_idxs)
101- function pff (du, input, u, p, t)
102- p = setter (p, input)
103- SciMLBase. ParamJacobianWrapper (fun, t, u)(du, p)
104- end
105- end
106- pf_jac = PreparedJacobian {true} (pf_fun, similar (prob. u0), autodiff, inputvals, DI. Constant (prob. u0), DI. Constant (p), DI. Constant (t0))
10792 hp_fun = let fun = h, setter = setp_oop (sys, input_idxs)
10893 function hpf (du, input, u, p, t)
10994 p = setter (p, input)
11095 fun (du, u, p, t)
11196 return du
11297 end
11398 end
114- hp_jac = PreparedJacobian {true} (hp_fun, similar (prob. u0, size (outputs)), autodiff, inputvals, DI. Constant (prob. u0), DI. Constant (p), DI. Constant (t0))
115-
99+ if u0 === nothing
100+ uf_jac = h_jac = pf_jac = nothing
101+ T = p isa MTKParameters ? eltype (p. tunable) : eltype (p)
102+ hp_jac = PreparedJacobian {true} (
103+ hp_fun, zeros (T, size (outputs)), autodiff, inputvals,
104+ DI. Constant (prob. u0), DI. Constant (p), DI. Constant (t0))
105+ else
106+ uf_fun = let fun = prob. f
107+ function uff (du, u, p, t)
108+ SciMLBase. UJacobianWrapper (fun, t, p)(du, u)
109+ end
110+ end
111+ uf_jac = PreparedJacobian {true} (
112+ uf_fun, similar (prob. u0), autodiff, prob. u0, DI. Constant (p), DI. Constant (t0))
113+ # observed function is a `GeneratedFunctionWrapper` with iip component
114+ h_jac = PreparedJacobian {true} (h, similar (prob. u0, size (outputs)), autodiff,
115+ prob. u0, DI. Constant (p), DI. Constant (t0))
116+ pf_fun = let fun = prob. f, setter = setp_oop (ssimilarys, input_idxs)
117+ function pff (du, input, u, p, t)
118+ p = setter (p, input)
119+ SciMLBase. ParamJacobianWrapper (fun, t, u)(du, p)
120+ end
121+ end
122+ pf_jac = PreparedJacobian {true} (pf_fun, similar (prob. u0), autodiff, inputvals,
123+ DI. Constant (prob. u0), DI. Constant (p), DI. Constant (t0))
124+ hp_jac = PreparedJacobian {true} (
125+ hp_fun, similar (prob. u0, size (outputs)), autodiff, inputvals,
126+ DI. Constant (prob. u0), DI. Constant (p), DI. Constant (t0))
127+ end
128+
116129 lin_fun = LinearizationFunction (
117130 diff_idxs, alge_idxs, input_idxs, length (unknowns (sys)),
118131 prob, h, u0 === nothing ? nothing : similar (u0), uf_jac, h_jac, pf_jac,
@@ -151,12 +164,14 @@ end
151164
152165function PreparedJacobian {true} (f, buf, autodiff, args... )
153166 prep = DI. prepare_jacobian (f, buf, autodiff, args... )
154- return PreparedJacobian {true, typeof(prep), typeof(f), typeof(buf), typeof(autodiff)} (prep, f, buf, autodiff)
167+ return PreparedJacobian {true, typeof(prep), typeof(f), typeof(buf), typeof(autodiff)} (
168+ prep, f, buf, autodiff)
155169end
156170
157171function PreparedJacobian {false} (f, autodiff, args... )
158172 prep = DI. prepare_jacobian (f, autodiff, args... )
159- return PreparedJacobian {true, typeof(prep), typeof(f), Nothing, typeof(autodiff)} (prep, f, nothing )
173+ return PreparedJacobian {true, typeof(prep), typeof(f), Nothing, typeof(autodiff)} (
174+ prep, f, nothing )
160175end
161176
162177function (pj:: PreparedJacobian{true} )(args... )
@@ -279,14 +294,16 @@ function (linfun::LinearizationFunction)(u, p, t)
279294 end
280295 fg_xz = linfun. uf_jac (u, DI. Constant (p), DI. Constant (t))
281296 h_xz = linfun. h_jac (u, DI. Constant (p), DI. Constant (t))
282- fg_u = linfun. pf_jac ([p[idx] for idx in linfun. input_idxs], DI. Constant (u), DI. Constant (p), DI. Constant (t))
297+ fg_u = linfun. pf_jac ([p[idx] for idx in linfun. input_idxs],
298+ DI. Constant (u), DI. Constant (p), DI. Constant (t))
283299 else
284300 linfun. num_states == 0 ||
285301 error (" Number of unknown variables (0) does not match the number of input unknowns ($(length (u)) )" )
286302 fg_xz = zeros (0 , 0 )
287303 h_xz = fg_u = zeros (0 , length (linfun. input_idxs))
288304 end
289- h_u = linfun. hp_jac ([p[idx] for idx in linfun. input_idxs], DI. Constant (u), DI. Constant (p), DI. Constant (t))
305+ h_u = linfun. hp_jac ([p[idx] for idx in linfun. input_idxs],
306+ DI. Constant (u), DI. Constant (p), DI. Constant (t))
290307 (f_x = fg_xz[linfun. diff_idxs, linfun. diff_idxs],
291308 f_z = fg_xz[linfun. diff_idxs, linfun. alge_idxs],
292309 g_x = fg_xz[linfun. alge_idxs, linfun. diff_idxs],
0 commit comments