@@ -43,35 +43,62 @@ function MTKParameters(
4343 end
4444 defs = merge (defs, u0)
4545 defs = merge (Dict (eq. lhs => eq. rhs for eq in observed (sys)), defs)
46- p = merge (defs, p)
47- p = merge (Dict (unwrap (k) => v for (k, v) in p),
48- Dict (default_toterm (unwrap (k)) => v for (k, v) in p))
49- p = Dict (unwrap (k) => fixpoint_sub (v, p) for (k, v) in p)
50- for (sym, _) in p
51- if iscall (sym) && operation (sym) === getindex &&
52- first (arguments (sym)) in all_ps
53- error (" Scalarized parameter values ($sym ) are not supported. Instead of `[p[1] => 1.0, p[2] => 2.0]` use `[p => [1.0, 2.0]]`" )
46+ bigdefs = merge (defs, p)
47+ p = Dict ()
48+ missing_params = Set ()
49+ pdeps = has_parameter_dependencies (sys) ? parameter_dependencies (sys) : nothing
50+
51+ for sym in all_ps
52+ ttsym = default_toterm (sym)
53+ isarr = iscall (sym) && operation (sym) === getindex
54+ arrparent = isarr ? arguments (sym)[1 ] : nothing
55+ ttarrparent = isarr ? default_toterm (arrparent) : nothing
56+ pname = hasname (sym) ? getname (sym) : nothing
57+ ttpname = hasname (ttsym) ? getname (ttsym) : nothing
58+ p[sym] = p[ttsym] = if haskey (bigdefs, sym)
59+ bigdefs[sym]
60+ elseif haskey (bigdefs, ttsym)
61+ bigdefs[ttsym]
62+ elseif haskey (bigdefs, pname)
63+ isarr ? bigdefs[pname][arguments (sym)[2 : end ]. .. ] : bigdefs[pname]
64+ elseif haskey (bigdefs, ttpname)
65+ isarr ? bigdefs[ttpname][arguments (sym)[2 : end ]. .. ] : bigdefs[pname]
66+ elseif isarr && haskey (bigdefs, arrparent)
67+ bigdefs[arrparent][arguments (sym)[2 : end ]. .. ]
68+ elseif isarr && haskey (bigdefs, ttarrparent)
69+ bigdefs[ttarrparent][arguments (sym)[2 : end ]. .. ]
5470 end
71+ if get (p, sym, nothing ) === nothing
72+ push! (missing_params, sym)
73+ continue
74+ end
75+ # We may encounter the `ttsym` version first, add it to `missing_params`
76+ # then encounter the "normal" version of a parameter or vice versa
77+ # Remove the old one in `missing_params` just in case
78+ delete! (missing_params, sym)
79+ delete! (missing_params, ttsym)
5580 end
5681
57- missing_params = Set ()
58- for idxmap in (ic. tunable_idx, ic. discrete_idx, ic. constant_idx, ic. nonnumeric_idx)
59- for sym in keys (idxmap)
60- sym isa Symbol && continue
61- haskey (p, sym) && continue
62- hasname (sym) && haskey (p, getname (sym)) && continue
82+ if pdeps != = nothing
83+ for (sym, expr) in pdeps
84+ sym = unwrap (sym)
6385 ttsym = default_toterm (sym)
64- haskey (p, ttsym) && continue
65- hasname (ttsym) && haskey (p, getname (ttsym)) && continue
66-
67- iscall (sym) && operation (sym) === getindex && haskey (p, arguments (sym)[1 ]) &&
68- continue
69- push! (missing_params, sym)
86+ delete! (missing_params, sym)
87+ delete! (missing_params, ttsym)
88+ p[sym] = p[ttsym] = expr
7089 end
7190 end
7291
7392 isempty (missing_params) || throw (MissingParametersError (collect (missing_params)))
7493
94+ p = Dict (unwrap (k) => fixpoint_sub (v, bigdefs) for (k, v) in p)
95+ for (sym, _) in p
96+ if iscall (sym) && operation (sym) === getindex &&
97+ first (arguments (sym)) in all_ps
98+ error (" Scalarized parameter values ($sym ) are not supported. Instead of `[p[1] => 1.0, p[2] => 2.0]` use `[p => [1.0, 2.0]]`" )
99+ end
100+ end
101+
75102 tunable_buffer = Tuple (Vector {temp.type} (undef, temp. length)
76103 for temp in ic. tunable_buffer_sizes)
77104 disc_buffer = Tuple (Vector {temp.type} (undef, temp. length)
@@ -135,8 +162,7 @@ function MTKParameters(
135162 # Don't narrow nonnumeric types
136163 nonnumeric_buffer = nonnumeric_buffer
137164
138- if has_parameter_dependencies (sys) &&
139- (pdeps = parameter_dependencies (sys)) != = nothing
165+ if pdeps != = nothing
140166 pdeps = Dict (k => fixpoint_sub (v, pdeps) for (k, v) in pdeps)
141167 dep_exprs = ArrayPartition ((Any[missing for _ in 1 : length (v)] for v in dep_buffer). .. )
142168 for (sym, val) in pdeps
0 commit comments