Skip to content

Commit c1e1523

Browse files
feat: better handle observed variables, constants in SCCNonlinearProblem
1 parent e79dc90 commit c1e1523

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -543,17 +543,22 @@ function (cw::CacheWriter)(p, sols)
543543
cw.fn(p.caches[1], sols, p...)
544544
end
545545

546-
function CacheWriter(sys::AbstractSystem, exprs, solsyms;
546+
function CacheWriter(sys::AbstractSystem, exprs, solsyms, obseqs::Vector{Equation};
547547
eval_expression = false, eval_module = @__MODULE__)
548548
ps = parameters(sys)
549549
rps = reorder_parameters(sys, ps)
550+
obs_assigns = [eq.lhs eq.rhs for eq in obseqs]
551+
cmap, cs = get_cmap(sys)
552+
cmap_assigns = [eq.lhs eq.rhs for eq in cmap]
550553
fn = Func(
551554
[:out, DestructuredArgs(DestructuredArgs.(solsyms)),
552555
DestructuredArgs.(rps)...],
553556
[],
554557
SetArray(true, :out, exprs)
555-
) |> wrap_parameter_dependencies(sys, false)[2] |>
556-
wrap_array_vars(sys, exprs; dvs = nothing, inputs = [])[2] |> toexpr
558+
) |> wrap_assignments(false, obs_assigns)[2] |>
559+
wrap_parameter_dependencies(sys, false)[2] |>
560+
wrap_array_vars(sys, exprs; dvs = nothing, inputs = [])[2] |>
561+
wrap_assignments(false, cmap_assigns)[2] |> toexpr
557562
return CacheWriter(eval_or_rgf(fn; eval_expression, eval_module))
558563
end
559564

@@ -612,7 +617,8 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
612617
var_eq_matching, var_sccs = StructuralTransformations.algebraic_variables_scc(ts)
613618

614619
if length(var_sccs) == 1
615-
return NonlinearProblem{iip}(sys, u0map, parammap; eval_expression, eval_module, kwargs...)
620+
return NonlinearProblem{iip}(
621+
sys, u0map, parammap; eval_expression, eval_module, kwargs...)
616622
end
617623

618624
condensed_graph = MatchedCondensationGraph(
@@ -664,7 +670,7 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
664670

665671
# cached variables and their corresponding expressions
666672
cachevars = Any[obs[i].lhs for i in prevobsidxs]
667-
cacheexprs = Any[obs[i].rhs for i in prevobsidxs]
673+
cacheexprs = Any[obs[i].lhs for i in prevobsidxs]
668674
for (k, v) in state
669675
push!(cachevars, unwrap(v))
670676
push!(cacheexprs, unwrap(k))
@@ -676,7 +682,8 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
676682
else
677683
solsyms = getindex.((dvs,), view(var_sccs, 1:(i - 1)))
678684
push!(explicitfuns,
679-
CacheWriter(sys, cacheexprs, solsyms; eval_expression, eval_module))
685+
CacheWriter(sys, cacheexprs, solsyms, obs[prevobsidxs];
686+
eval_expression, eval_module))
680687
end
681688
f = SCCNonlinearFunction{iip}(
682689
sys, _eqs, _dvs, _obs, (cachevars,); eval_expression, eval_module, kwargs...)

0 commit comments

Comments
 (0)