@@ -12,7 +12,9 @@ function generate_initializesystem(sys::ODESystem;
1212 algebraic_only = false ,
1313 check_units = true , check_defguess = false ,
1414 name = nameof (sys), kwargs... )
15- vars = unique ([unknowns (sys); getfield .((observed (sys)), :lhs )])
15+ trueobs = unhack_observed (observed (sys))
16+ @show trueobs
17+ vars = unique ([unknowns (sys); getfield .(trueobs, :lhs )])
1618 vars_set = Set (vars) # for efficient in-lookup
1719
1820 eqs = equations (sys)
@@ -24,7 +26,7 @@ function generate_initializesystem(sys::ODESystem;
2426 D = Differential (get_iv (sys))
2527 diffmap = merge (
2628 Dict (eq. lhs => eq. rhs for eq in eqs_diff),
27- Dict (D (eq. lhs) => D (eq. rhs) for eq in observed (sys) )
29+ Dict (D (eq. lhs) => D (eq. rhs) for eq in trueobs )
2830 )
2931
3032 # 1) process dummy derivatives and u0map into initialization system
@@ -166,15 +168,14 @@ function generate_initializesystem(sys::ODESystem;
166168 )
167169
168170 # 7) use observed equations for guesses of observed variables if not provided
169- obseqs = observed (sys)
170- for eq in obseqs
171+ for eq in trueobs
171172 haskey (defs, eq. lhs) && continue
172173 any (x -> isequal (default_toterm (x), eq. lhs), keys (defs)) && continue
173174
174175 defs[eq. lhs] = eq. rhs
175176 end
176177
177- eqs_ics = Symbolics. substitute .([eqs_ics; obseqs ], (paramsubs,))
178+ eqs_ics = Symbolics. substitute .([eqs_ics; trueobs ], (paramsubs,))
178179 vars = [vars; collect (values (paramsubs))]
179180 for k in keys (defs)
180181 defs[k] = substitute (defs[k], paramsubs)
@@ -324,3 +325,37 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
324325 return nothing , nothing , nothing , nothing
325326 end
326327end
328+
329+ """
330+ Counteracts the CSE/array variable hacks in `symbolics_tearing.jl` so it works with
331+ initialization.
332+ """
333+ function unhack_observed (eqs:: Vector{Equation} )
334+ subs = Dict ()
335+ tempvars = Set ()
336+ rm_idxs = Int[]
337+ for (i, eq) in enumerate (eqs)
338+ iscall (eq. rhs) || continue
339+ if operation (eq. rhs) == StructuralTransformations. change_origin
340+ push! (rm_idxs, i)
341+ continue
342+ end
343+ if operation (eq. rhs) == StructuralTransformations. getindex_wrapper
344+ var, idxs = arguments (eq. rhs)
345+ subs[eq. rhs] = var[idxs... ]
346+ push! (tempvars, var)
347+ end
348+ end
349+
350+ for (i, eq) in enumerate (eqs)
351+ if eq. lhs in tempvars
352+ subs[eq. lhs] = eq. rhs
353+ push! (rm_idxs, i)
354+ end
355+ end
356+
357+ eqs = eqs[setdiff (eachindex (eqs), rm_idxs)]
358+ return map (eqs) do eq
359+ fixpoint_sub (eq. lhs, subs) ~ fixpoint_sub (eq. rhs, subs)
360+ end
361+ end
0 commit comments