@@ -3,29 +3,44 @@ $(TYPEDSIGNATURES)
33
44Generate `ODESystem`, dependent variables, and parameters from an `ODEProblem`.
55"""
6- function modelingtoolkitize (prob:: DiffEqBase.ODEProblem ; kwargs... )
6+ function modelingtoolkitize (
7+ prob:: DiffEqBase.ODEProblem ; u_names = nothing , p_names = nothing , kwargs... )
78 prob. f isa DiffEqBase. AbstractParameterizedFunction &&
89 return prob. f. sys
9- @parameters t
10-
10+ t = t_nounits
1111 p = prob. p
1212 has_p = ! (p isa Union{DiffEqBase. NullParameters, Nothing})
1313
14- _vars = define_vars (prob. u0, t)
14+ if u_names != = nothing
15+ varnames_length_check (prob. u0, u_names; is_unknowns = true )
16+ _vars = [_defvar (name)(t) for name in u_names]
17+ elseif SciMLBase. has_sys (prob. f)
18+ varnames = getname .(variable_symbols (prob. f. sys))
19+ varidxs = variable_index .((prob. f. sys,), varnames)
20+ invpermute! (varnames, varidxs)
21+ _vars = [_defvar (name)(t) for name in varnames]
22+ else
23+ _vars = define_vars (prob. u0, t)
24+ end
1525
1626 vars = prob. u0 isa Number ? _vars : ArrayInterface. restructure (prob. u0, _vars)
1727 params = if has_p
18- _params = define_params (p)
28+ if p_names === nothing && SciMLBase. has_sys (prob. f)
29+ p_names = Dict (parameter_index (prob. f. sys, sym) => sym
30+ for sym in parameter_symbols (prob. f. sys))
31+ end
32+ _params = define_params (p, p_names)
1933 p isa Number ? _params[1 ] :
20- (p isa Tuple || p isa NamedTuple || p isa AbstractDict ? _params :
34+ (p isa Tuple || p isa NamedTuple || p isa AbstractDict || p isa MTKParameters ?
35+ _params :
2136 ArrayInterface. restructure (p, _params))
2237 else
2338 []
2439 end
2540
2641 var_set = Set (vars)
2742
28- D = Differential (t)
43+ D = D_nounits
2944 mm = prob. f. mass_matrix
3045
3146 if mm === I
@@ -70,6 +85,8 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem; kwargs...)
7085 default_p = if has_p
7186 if prob. p isa AbstractDict
7287 Dict (v => prob. p[k] for (k, v) in pairs (_params))
88+ elseif prob. p isa MTKParameters
89+ Dict (params .=> reduce (vcat, prob. p))
7390 else
7491 Dict (params .=> vec (collect (prob. p)))
7592 end
@@ -125,44 +142,96 @@ function Base.showerror(io::IO, e::ModelingtoolkitizeParametersNotSupportedError
125142 println (io, e. type)
126143end
127144
128- function define_params (p)
145+ function varnames_length_check (vars, names; is_unknowns = false )
146+ if length (names) != length (vars)
147+ throw (ArgumentError ("""
148+ Number of $(is_unknowns ? " unknowns" : " parameters" ) ($(length (vars)) ) \
149+ does not match number of names ($(length (names)) ).
150+ """ ))
151+ end
152+ end
153+
154+ function define_params (p, _ = nothing )
129155 throw (ModelingtoolkitizeParametersNotSupportedError (typeof (p)))
130156end
131157
132- function define_params (p:: AbstractArray )
133- [toparam (variable (:α , i)) for i in eachindex (p)]
158+ function define_params (p:: AbstractArray , names = nothing )
159+ if names === nothing
160+ [toparam (variable (:α , i)) for i in eachindex (p)]
161+ else
162+ varnames_length_check (p, names)
163+ [toparam (variable (names[i])) for i in eachindex (p)]
164+ end
134165end
135166
136- function define_params (p:: Number )
137- [toparam (variable (:α ))]
167+ function define_params (p:: Number , names = nothing )
168+ if names === nothing
169+ [toparam (variable (:α ))]
170+ elseif names isa Union{AbstractArray, AbstractDict}
171+ varnames_length_check (p, names)
172+ [toparam (variable (names[i])) for i in eachindex (p)]
173+ else
174+ [toparam (variable (names))]
175+ end
138176end
139177
140- function define_params (p:: AbstractDict )
141- OrderedDict (k => toparam (variable (:α , i)) for (i, k) in zip (1 : length (p), keys (p)))
178+ function define_params (p:: AbstractDict , names = nothing )
179+ if names === nothing
180+ OrderedDict (k => toparam (variable (:α , i)) for (i, k) in zip (1 : length (p), keys (p)))
181+ else
182+ varnames_length_check (p, names)
183+ OrderedDict (k => toparam (variable (names[k])) for k in keys (p))
184+ end
142185end
143186
144- function define_params (p:: Union{SLArray, LArray} )
145- [toparam (variable (x)) for x in LabelledArrays. symnames (typeof (p))]
187+ function define_params (p:: Union{SLArray, LArray} , names = nothing )
188+ if names === nothing
189+ [toparam (variable (x)) for x in LabelledArrays. symnames (typeof (p))]
190+ else
191+ varnames_length_check (p, names)
192+ [toparam (variable (names[i])) for i in eachindex (p)]
193+ end
146194end
147195
148- function define_params (p:: Tuple )
149- tuple ((toparam (variable (:α , i)) for i in eachindex (p)). .. )
196+ function define_params (p:: Tuple , names = nothing )
197+ if names === nothing
198+ tuple ((toparam (variable (:α , i)) for i in eachindex (p)). .. )
199+ else
200+ varnames_length_check (p, names)
201+ tuple ((toparam (variable (names[i])) for i in eachindex (p)). .. )
202+ end
150203end
151204
152- function define_params (p:: NamedTuple )
153- NamedTuple (x => toparam (variable (x)) for x in keys (p))
205+ function define_params (p:: NamedTuple , names = nothing )
206+ if names === nothing
207+ NamedTuple (x => toparam (variable (x)) for x in keys (p))
208+ else
209+ varnames_length_check (p, names)
210+ NamedTuple (x => toparam (variable (names[x])) for x in keys (p))
211+ end
154212end
155213
156- function define_params (p:: MTKParameters )
157- bufs = (p... ,)
158- i = 1
159- ps = []
160- for buf in bufs
161- for _ in buf
162- push! (ps, toparam (variable (:α , i)))
214+ function define_params (p:: MTKParameters , names = nothing )
215+ if names === nothing
216+ bufs = (p... ,)
217+ i = 1
218+ ps = []
219+ for buf in bufs
220+ for _ in buf
221+ push! (
222+ ps,
223+ if names === nothing
224+ toparam (variable (:α , i))
225+ else
226+ toparam (variable (names[i]))
227+ end
228+ )
229+ end
163230 end
231+ return identity .(ps)
232+ else
233+ return collect (values (names))
164234 end
165- return identity .(ps)
166235end
167236
168237"""
0 commit comments