Skip to content

Commit 5639bd1

Browse files
feat: pre-compute observed equations of previous SCCs
1 parent 04c0cf8 commit 5639bd1

File tree

1 file changed

+34
-19
lines changed

1 file changed

+34
-19
lines changed

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -560,18 +560,11 @@ end
560560
struct SCCNonlinearFunction{iip} end
561561

562562
function SCCNonlinearFunction{iip}(
563-
sys::NonlinearSystem, vscc, escc, cachesyms; eval_expression = false,
563+
sys::NonlinearSystem, _eqs, _dvs, _obs, cachesyms; eval_expression = false,
564564
eval_module = @__MODULE__, kwargs...) where {iip}
565-
dvs = unknowns(sys)
566565
ps = parameters(sys)
567566
rps = reorder_parameters(sys, ps)
568-
eqs = equations(sys)
569-
obs = observed(sys)
570567

571-
_dvs = dvs[vscc]
572-
_eqs = eqs[escc]
573-
obsidxs = observed_equations_used_by(sys, _eqs)
574-
_obs = obs[obsidxs]
575568
obs_assignments = [eq.lhs eq.rhs for eq in _obs]
576569

577570
cmap, cs = get_cmap(sys)
@@ -621,24 +614,46 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
621614

622615
_, u0, p = process_SciMLProblem(
623616
EmptySciMLFunction, sys, u0map, parammap; eval_expression, eval_module, kwargs...)
624-
p = rebuild_with_caches(p, BufferTemplate(eltype(u0), length(u0)))
625617

626-
subprobs = []
627618
explicitfuns = []
619+
nlfuns = []
620+
prevobsidxs = Int[]
621+
cachevars = []
622+
cacheexprs = []
628623
for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs))
629-
oldvars = dvs[reduce(vcat, view(var_sccs, 1:(i - 1)); init = Int[])]
630-
if isempty(oldvars)
631-
push!(explicitfuns, (_...) -> nothing)
624+
# subset unknowns and equations
625+
_dvs = dvs[vscc]
626+
_eqs = eqs[escc]
627+
# get observed equations required by this SCC
628+
obsidxs = observed_equations_used_by(sys, _eqs)
629+
# the ones used by previous SCCs can be precomputed into the cache
630+
setdiff!(obsidxs, prevobsidxs)
631+
_obs = obs[obsidxs]
632+
633+
if isempty(cachevars)
634+
push!(explicitfuns, Returns(nothing))
632635
else
633636
solsyms = getindex.((dvs,), view(var_sccs, 1:(i - 1)))
634637
push!(explicitfuns,
635-
CacheWriter(sys, oldvars, solsyms; eval_expression, eval_module))
638+
CacheWriter(sys, cacheexprs, solsyms; eval_expression, eval_module))
639+
end
640+
f = SCCNonlinearFunction{iip}(
641+
sys, _eqs, _dvs, _obs, (cachevars,); eval_expression, eval_module, kwargs...)
642+
push!(nlfuns, f)
643+
append!(cachevars, _dvs)
644+
append!(cacheexprs, _dvs)
645+
for i in obsidxs
646+
push!(cachevars, obs[i].lhs)
647+
push!(cacheexprs, obs[i].rhs)
636648
end
637-
prob = NonlinearProblem(
638-
SCCNonlinearFunction{iip}(
639-
sys, vscc, escc, (oldvars,); eval_expression, eval_module, kwargs...),
640-
u0[vscc],
641-
p)
649+
append!(prevobsidxs, obsidxs)
650+
end
651+
652+
p = rebuild_with_caches(p, BufferTemplate(eltype(u0), length(cachevars)))
653+
654+
subprobs = []
655+
for (f, vscc) in zip(nlfuns, var_sccs)
656+
prob = NonlinearProblem(f, u0[vscc], p)
642657
push!(subprobs, prob)
643658
end
644659

0 commit comments

Comments
 (0)