Skip to content

Commit 35a7659

Browse files
feat: propagate p_constructor to InitializationProblem
1 parent 3547062 commit 35a7659

File tree

2 files changed

+24
-23
lines changed

2 files changed

+24
-23
lines changed

src/systems/problem_utils.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -869,15 +869,17 @@ All other keyword arguments are forwarded to `InitializationProblem`.
869869
function maybe_build_initialization_problem(
870870
sys::AbstractSystem, op::AbstractDict, u0map, pmap, t, defs,
871871
guesses, missing_unknowns; implicit_dae = false, u0_constructor = identity,
872-
floatT = Float64, initialization_eqs = [], use_scc = true, kwargs...)
872+
p_constructor = identity, floatT = Float64, initialization_eqs = [],
873+
use_scc = true, kwargs...)
873874
guesses = merge(ModelingToolkit.guesses(sys), todict(guesses))
874875

875876
if t === nothing && is_time_dependent(sys)
876877
t = zero(floatT)
877878
end
878879

879880
initializeprob = ModelingToolkit.InitializationProblem{true, SciMLBase.FullSpecialize}(
880-
sys, t, u0map, pmap; guesses, initialization_eqs, use_scc, kwargs...)
881+
sys, t, u0map, pmap; guesses, initialization_eqs,
882+
use_scc, u0_constructor, p_constructor, kwargs...)
881883
if state_values(initializeprob) !== nothing
882884
initializeprob = remake(initializeprob; u0 = floatT.(state_values(initializeprob)))
883885
end
@@ -1109,7 +1111,7 @@ function process_SciMLProblem(
11091111
warn_cyclic_dependency, check_units = check_initialization_units,
11101112
circular_dependency_max_cycle_length, circular_dependency_max_cycles, use_scc,
11111113
force_time_independent = force_initialization_time_independent, algebraic_only, allow_incomplete,
1112-
u0_constructor, floatT)
1114+
u0_constructor, p_constructor, floatT)
11131115

11141116
kwargs = merge(kwargs, kws)
11151117
end

test/initial_values.jl

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using ModelingToolkit: t_nounits as t, D_nounits as D, get_u0
33
using OrdinaryDiffEq
44
using DataInterpolations
55
using StaticArrays
6-
using SymbolicIndexingInterface: getu
6+
using SymbolicIndexingInterface
77

88
@variables x(t)[1:3]=[1.0, 2.0, 3.0] y(t) z(t)[1:2]
99

@@ -337,29 +337,28 @@ end
337337
end
338338

339339
@testset "`p_constructor` keyword argument" begin
340-
@parameters σ ρ β
341-
@variables x(t) y(t) z(t)
342-
343-
eqs = [D(D(x)) ~ σ * (y - x),
344-
D(y) ~ x *- z) - y,
345-
D(z) ~ x * y - β * z]
346-
347-
@mtkbuild sys = ODESystem(eqs, t)
348-
349-
u0 = [D(x) => 2.0f0,
350-
x => 1.0f0,
351-
y => 0.0f0,
352-
z => 0.0f0]
340+
@parameters g = 1.0
341+
@variables x(t) y(t) [state_priority = 10, guess = 1.0] λ(t) [guess = 1.0]
342+
eqs = [D(D(x)) ~ λ * x
343+
D(D(y)) ~ λ * y - g
344+
x^2 + y^2 ~ 1]
345+
@mtkbuild pend = ODESystem(eqs, t)
353346

354-
p ==> 28.0f0,
355-
ρ => 10.0f0,
356-
β => 8.0f0 / 3]
347+
u0 = [x => 1.0, D(x) => 0.0]
357348
u0_constructor = p_constructor = vals -> SVector{length(vals)}(vals...)
358-
prob = ODEProblem(sys, u0, tspan, p; u0_constructor, p_constructor)
349+
tspan = (0.0, 5.0)
350+
prob = ODEProblem(pend, u0, tspan; u0_constructor, p_constructor)
351+
@test prob.u0 isa SVector
359352
@test prob.p.tunable isa SVector
360353
@test prob.p.initials isa SVector
354+
initdata = prob.f.initialization_data
355+
@test state_values(initdata.initializeprob) isa SVector
356+
@test parameter_values(initdata.initializeprob).tunable isa SVector
361357

362-
@mtkbuild sys=ODESystem(eqs, t) split=false
363-
prob = ODEProblem(sys, u0, tspan, p; u0_constructor, p_constructor)
358+
@mtkbuild pend=ODESystem(eqs, t) split=false
359+
prob = ODEProblem(pend, u0, tspan; u0_constructor, p_constructor)
364360
@test prob.p isa SVector
361+
initdata = prob.f.initialization_data
362+
@test state_values(initdata.initializeprob) isa SVector
363+
@test parameter_values(initdata.initializeprob) isa SVector
365364
end

0 commit comments

Comments
 (0)