Skip to content

Commit d20f68d

Browse files
author
Brad Carman
committed
working default Tuple type
1 parent f9b0fa9 commit d20f68d

File tree

5 files changed

+142
-56
lines changed

5 files changed

+142
-56
lines changed

src/parameters.jl

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -62,28 +62,50 @@ macro parameters(xs...)
6262
toparam) |> esc
6363
end
6464

65-
function split_parameters_by_type(ps)
66-
by = let set = Dict{Any, Int}(), counter = Ref(1)
65+
function find_types(array)
66+
by = let set = Dict{Any, Int}(), counter = Ref(0)
6767
x -> begin
68-
t = typeof(x)
68+
# t = typeof(x)
69+
6970
get!(set, typeof(x)) do
70-
if t == Float64
71-
1
72-
else
73-
counter[] += 1
74-
end
71+
# if t == Float64
72+
# 1
73+
# else
74+
counter[] += 1
75+
# end
7576
end
7677
end
7778
end
78-
idxs = by.(ps)
79-
split_idxs = [Int[]]
80-
for (i, idx) in enumerate(idxs)
81-
if idx > length(split_idxs)
82-
push!(split_idxs, Int[])
79+
return by.(array)
80+
end
81+
82+
83+
function split_parameters_by_type(ps)
84+
85+
if ps === SciMLBase.NullParameters()
86+
return Float64[],[] #use Float64 to avoid Any type warning
87+
else
88+
by = let set = Dict{Any, Int}(), counter = Ref(0)
89+
x -> begin
90+
get!(set, typeof(x)) do
91+
counter[] += 1
92+
end
93+
end
94+
end
95+
idxs = by.(ps)
96+
split_idxs = [Int[]]
97+
for (i, idx) in enumerate(idxs)
98+
if idx > length(split_idxs)
99+
push!(split_idxs, Int[])
100+
end
101+
push!(split_idxs[idx], i)
102+
end
103+
tighten_types = x -> identity.(x)
104+
split_ps = tighten_types.(Base.Fix1(getindex, ps).(split_idxs))
105+
if length(split_ps) == 1 #Tuple not needed, only 1 type
106+
return split_ps[1], split_idxs
107+
else
108+
return (split_ps...,), split_idxs
83109
end
84-
push!(split_idxs[idx], i)
85110
end
86-
tighten_types = x -> identity.(x)
87-
split_ps = tighten_types.(Base.Fix1(getindex, ps).(split_idxs))
88-
(split_ps...,), split_idxs
89111
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param
122122
nothing,
123123
isdde = false,
124124
has_difference = false,
125-
split_parameters = false,
126125
kwargs...)
127126
if isdde
128127
eqs = delay_to_function(sys)
@@ -152,12 +151,14 @@ function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = param
152151
build_function(rhss, ddvs, u, p, t; postprocess_fbody = pre,
153152
states = sol_states,
154153
kwargs...)
155-
elseif split_parameters
156-
build_function(rhss, u, p..., t; postprocess_fbody = pre, states = sol_states,
157-
kwargs...)
158154
else
159-
build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
155+
if p isa Tuple
156+
build_function(rhss, u, p..., t; postprocess_fbody = pre, states = sol_states,
157+
kwargs...)
158+
else
159+
build_function(rhss, u, p, t; postprocess_fbody = pre, states = sol_states,
160160
kwargs...)
161+
end
161162
end
162163
end
163164
end
@@ -329,15 +330,14 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
329330
checkbounds = false,
330331
sparsity = false,
331332
analytic = nothing,
332-
split_parameters = false,
333333
kwargs...) where {iip, specialize}
334334
f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression},
335-
expression_module = eval_module, checkbounds = checkbounds, split_parameters,
335+
expression_module = eval_module, checkbounds = checkbounds,
336336
kwargs...)
337337
f_oop, f_iip = eval_expression ?
338338
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) :
339339
f_gen
340-
if split_parameters
340+
if p isa Tuple
341341
g(u, p, t) = f_oop(u, p..., t)
342342
g(du, u, p, t) = f_iip(du, u, p..., t)
343343
f = g
@@ -696,9 +696,8 @@ Take dictionaries with initial conditions and parameters and convert them to num
696696
function get_u0_p(sys,
697697
u0map,
698698
parammap;
699-
use_union = false,
700-
tofloat = !use_union,
701-
split_parameters = false,
699+
use_union = true,
700+
tofloat = true,
702701
symbolic_u0 = false)
703702
eqs = equations(sys)
704703
dvs = states(sys)
@@ -711,7 +710,7 @@ function get_u0_p(sys,
711710
if symbolic_u0
712711
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)
713712
else
714-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = !split_parameters)
713+
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true)
715714
end
716715
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)
717716
p = p === nothing ? SciMLBase.NullParameters() : p
@@ -726,10 +725,10 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
726725
simplify = false,
727726
linenumbers = true, parallel = SerialForm(),
728727
eval_expression = true,
729-
use_union = false,
730-
tofloat = !use_union,
728+
use_union = true,
729+
tofloat = true,
731730
symbolic_u0 = false,
732-
split_parameters = false,
731+
# split_parameters = true,
733732
kwargs...)
734733
eqs = equations(sys)
735734
dvs = states(sys)
@@ -741,12 +740,15 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
741740
parammap;
742741
tofloat,
743742
use_union,
744-
symbolic_u0,
745-
split_parameters)
746-
if split_parameters
743+
symbolic_u0)
744+
745+
# if split_parameters
747746
p, split_idxs = split_parameters_by_type(p)
748-
ps = Base.Fix1(getindex, parameters(sys)).(split_idxs)
749-
end
747+
if p isa Tuple
748+
ps = Base.Fix1(getindex, parameters(sys)).(split_idxs)
749+
ps = (ps...,) #if p is Tuple, ps should be Tuple
750+
end
751+
# end
750752

751753
if implicit_dae && du0map !== nothing
752754
ddvs = map(Differential(iv), dvs)
@@ -763,7 +765,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
763765
f = constructor(sys, dvs, ps, u0; ddvs = ddvs, tgrad = tgrad, jac = jac,
764766
checkbounds = checkbounds, p = p,
765767
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
766-
sparse = sparse, eval_expression = eval_expression, split_parameters,
768+
sparse = sparse, eval_expression = eval_expression,
767769
kwargs...)
768770
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
769771
end

src/variables.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ applicable.
5858
"""
5959
function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,
6060
toterm = default_toterm, promotetoconcrete = nothing,
61-
tofloat = true, use_union = false)
61+
tofloat = true, use_union = true)
6262
varlist = collect(map(unwrap, varlist))
6363

6464
# Edge cases where one of the arguments is effectively empty.
@@ -75,9 +75,10 @@ function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,
7575
end
7676
end
7777

78-
T = typeof(varmap)
79-
# We respect the input type
80-
container_type = T <: Dict ? Array : T
78+
# T = typeof(varmap)
79+
# We respect the input type (feature removed, not needed with Tuple support)
80+
# container_type = T <: Union{Dict,Tuple} ? Array : T
81+
container_type = Array
8182

8283
vals = if eltype(varmap) <: Pair # `varmap` is a dict or an array of pairs
8384
varmap = todict(varmap)

test/odesystem.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -734,18 +734,23 @@ let
734734
u0map = [A => 1.0]
735735
pmap = (k1 => 1.0, k2 => 1)
736736
tspan = (0.0, 1.0)
737-
prob = ODEProblem(sys, u0map, tspan, pmap)
738-
@test prob.p === Tuple([(Dict(pmap))[k] for k in values(parameters(sys))])
737+
prob = ODEProblem(sys, u0map, tspan, pmap; tofloat=false)
738+
739+
@test prob.p == ([1], [1.0]) #Tuple([(Dict(pmap))[k] for k in values(parameters(sys))])
739740

740741
pmap = [k1 => 1, k2 => 1]
741742
tspan = (0.0, 1.0)
742743
prob = ODEProblem(sys, u0map, tspan, pmap)
743744
@test eltype(prob.p) === Float64
744-
745-
pmap = Pair{Any, Union{Int, Float64}}[k1 => 1, k2 => 1.0]
746-
tspan = (0.0, 1.0)
747-
prob = ODEProblem(sys, u0map, tspan, pmap, use_union = true)
748-
@test eltype(prob.p) === Union{Float64, Int}
745+
746+
prob = ODEProblem(sys, u0map, tspan, pmap; tofloat=false)
747+
@test eltype(prob.p) === Int
748+
749+
# No longer supported, Tuple used instead
750+
# pmap = Pair{Any, Union{Int, Float64}}[k1 => 1, k2 => 1.0]
751+
# tspan = (0.0, 1.0)
752+
# prob = ODEProblem(sys, u0map, tspan, pmap, use_union = true)
753+
# @test eltype(prob.p) === Union{Float64, Int}
749754
end
750755

751756
let

test/split_parameters.jl

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ using ModelingToolkit, Test
22
using ModelingToolkitStandardLibrary.Blocks
33
using OrdinaryDiffEq
44

5+
6+
7+
# ------------------------ Mixed Single Values and Vector
8+
59
dt = 4e-4
610
t_end = 10.0
711
time = 0:dt:t_end
@@ -10,8 +14,31 @@ x = @. time^2 + 1.0
1014
@parameters t
1115
D = Differential(t)
1216

17+
get_value(data, t, dt) = data[round(Int, t/dt+1)]
18+
@register_symbolic get_value(data, t, dt)
19+
20+
21+
function Sampled(; name, data=Float64[], dt=0.0)
22+
pars = @parameters begin
23+
data = data
24+
dt = dt
25+
end
26+
27+
vars = []
28+
systems = @named begin
29+
output = RealOutput()
30+
end
31+
32+
eqs = [
33+
output.u ~ get_value(data, t, dt)
34+
]
35+
36+
return ODESystem(eqs, t, vars, pars; name, systems,
37+
defaults = [output.u => data[1]])
38+
end
39+
1340
vars = @variables y(t)=1 dy(t)=0 ddy(t)=0
14-
@named src = SampledData(; data = Float64[], dt)
41+
@named src = Sampled(; data = Float64[], dt)
1542
@named int = Integrator()
1643

1744
eqs = [y ~ src.output.u
@@ -22,10 +49,39 @@ eqs = [y ~ src.output.u
2249
@named sys = ODESystem(eqs, t, vars, []; systems = [int, src])
2350
s = complete(sys)
2451
sys = structural_simplify(sys)
25-
26-
prob = ODEProblem(sys, [], (0.0, t_end), [s.src.data => x]; split_parameters = true)
52+
prob = ODEProblem(sys, [], (0.0, t_end), [s.src.data => x])
2753
@test prob.p isa Tuple{Vector{Float64}, Vector{Int}, Vector{Vector{Float64}}}
28-
@time sol = solve(prob, ImplicitEuler());
29-
prob2 = ODEProblem(sys, [], (0.0, t_end), [s.src.data => x]; to_float = false)
30-
@test prob2.p isa Vector{Union{Float64, Int64, Vector{Float64}}}
31-
@time sol2 = solve(prob2, ImplicitEuler());
54+
sol = solve(prob, ImplicitEuler());
55+
@test sol.retcode == ReturnCode.Success
56+
57+
58+
# ------------------------ Mixed Type Converted to float (default behavior)
59+
60+
vars = @variables y(t)=1 dy(t)=0 ddy(t)=0
61+
pars = @parameters a=1.0 b=2.0 c=3
62+
eqs = [
63+
D(y) ~ dy*a
64+
D(dy) ~ ddy*b
65+
ddy ~ sin(t)*c]
66+
67+
@named sys = ODESystem(eqs, t, vars, pars)
68+
sys = structural_simplify(sys)
69+
70+
tspan = (0.0, t_end)
71+
prob = ODEProblem(sys, [], tspan, [])
72+
73+
@test prob.p isa Vector{Float64}
74+
sol = solve(prob, ImplicitEuler());
75+
@test sol.retcode == ReturnCode.Success
76+
77+
78+
# ------------------------ Mixed Type Conserved
79+
80+
prob = ODEProblem(sys, [], tspan, []; tofloat=false)
81+
82+
@test prob.p isa Tuple{Vector{Float64}, Vector{Int64}}
83+
sol = solve(prob, ImplicitEuler());
84+
@test sol.retcode == ReturnCode.Success
85+
86+
87+

0 commit comments

Comments
 (0)