1- abstract type SymbolHash end
2-
3- function getsymbolhash (sym)
4- sym = unwrap (sym)
5- hasmetadata (sym, SymbolHash) ? getmetadata (sym, SymbolHash) : hash (sym)
6- end
7-
81struct BufferTemplate
92 type:: DataType
103 length:: Int
@@ -18,38 +11,38 @@ struct ParameterIndex{P, I}
1811 idx:: I
1912end
2013
21- const IndexMap = Dict{UInt, Tuple{Int, Int}}
14+ const ParamIndexMap = Dict{Union{Symbol, BasicSymbolic}, Tuple{Int, Int}}
15+ const UnknownIndexMap = Dict{Union{Symbol, BasicSymbolic}, Union{Int, UnitRange{Int}}}
2216
2317struct IndexCache
24- unknown_idx:: Dict{UInt, Union{Int, UnitRange{Int}}}
25- discrete_idx:: IndexMap
26- param_idx :: IndexMap
27- constant_idx:: IndexMap
28- dependent_idx:: IndexMap
29- nonnumeric_idx:: IndexMap
18+ unknown_idx:: UnknownIndexMap
19+ discrete_idx:: ParamIndexMap
20+ tunable_idx :: ParamIndexMap
21+ constant_idx:: ParamIndexMap
22+ dependent_idx:: ParamIndexMap
23+ nonnumeric_idx:: ParamIndexMap
3024 discrete_buffer_sizes:: Vector{BufferTemplate}
31- param_buffer_sizes :: Vector{BufferTemplate}
25+ tunable_buffer_sizes :: Vector{BufferTemplate}
3226 constant_buffer_sizes:: Vector{BufferTemplate}
3327 dependent_buffer_sizes:: Vector{BufferTemplate}
3428 nonnumeric_buffer_sizes:: Vector{BufferTemplate}
3529end
3630
3731function IndexCache (sys:: AbstractSystem )
3832 unks = solved_unknowns (sys)
39- unk_idxs = Dict {UInt, Union{Int, UnitRange{Int}}} ()
33+ unk_idxs = UnknownIndexMap ()
4034 let idx = 1
4135 for sym in unks
42- h = getsymbolhash (sym)
36+ usym = unwrap (sym)
4337 sym_idx = if Symbolics. isarraysymbolic (sym)
4438 idx: (idx + length (sym) - 1 )
4539 else
4640 idx
4741 end
48- unk_idxs[h ] = sym_idx
42+ unk_idxs[usym ] = sym_idx
4943
5044 if hasname (sym)
51- h = hash (getname (sym))
52- unk_idxs[h] = sym_idx
45+ unk_idxs[getname (usym)] = sym_idx
5346 end
5447 idx += length (sym)
5548 end
@@ -120,17 +113,15 @@ function IndexCache(sys::AbstractSystem)
120113 end
121114
122115 function get_buffer_sizes_and_idxs (buffers:: Dict{Any, Set{BasicSymbolic}} )
123- idxs = IndexMap ()
116+ idxs = ParamIndexMap ()
124117 buffer_sizes = BufferTemplate[]
125118 for (i, (T, buf)) in enumerate (buffers)
126119 for (j, p) in enumerate (buf)
127- h = getsymbolhash (p)
128- idxs[h] = (i, j)
129- h = getsymbolhash (default_toterm (p))
130- idxs[h] = (i, j)
120+ idxs[p] = (i, j)
121+ idxs[default_toterm (p)] = (i, j)
131122 if hasname (p)
132- h = hash ( getname (p))
133- idxs[h ] = (i, j)
123+ idxs[ getname (p)] = (i, j )
124+ idxs[getname ( default_toterm (p)) ] = (i, j)
134125 end
135126 end
136127 push! (buffer_sizes, BufferTemplate (T, length (buf)))
@@ -139,39 +130,87 @@ function IndexCache(sys::AbstractSystem)
139130 end
140131
141132 disc_idxs, discrete_buffer_sizes = get_buffer_sizes_and_idxs (disc_buffers)
142- param_idxs, param_buffer_sizes = get_buffer_sizes_and_idxs (tunable_buffers)
133+ tunable_idxs, tunable_buffer_sizes = get_buffer_sizes_and_idxs (tunable_buffers)
143134 const_idxs, const_buffer_sizes = get_buffer_sizes_and_idxs (constant_buffers)
144135 dependent_idxs, dependent_buffer_sizes = get_buffer_sizes_and_idxs (dependent_buffers)
145136 nonnumeric_idxs, nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs (nonnumeric_buffers)
146137
147138 return IndexCache (
148139 unk_idxs,
149140 disc_idxs,
150- param_idxs ,
141+ tunable_idxs ,
151142 const_idxs,
152143 dependent_idxs,
153144 nonnumeric_idxs,
154145 discrete_buffer_sizes,
155- param_buffer_sizes ,
146+ tunable_buffer_sizes ,
156147 const_buffer_sizes,
157148 dependent_buffer_sizes,
158149 nonnumeric_buffer_sizes
159150 )
160151end
161152
153+ function SymbolicIndexingInterface. is_variable (ic:: IndexCache , sym)
154+ return check_index_map (ic. unknown_idx, sym) != = nothing
155+ end
156+
157+ function SymbolicIndexingInterface. variable_index (ic:: IndexCache , sym)
158+ return check_index_map (ic. unknown_idx, sym)
159+ end
160+
161+ function SymbolicIndexingInterface. is_parameter (ic:: IndexCache , sym)
162+ return check_index_map (ic. tunable_idx, sym) != = nothing ||
163+ check_index_map (ic. discrete_idx, sym) != = nothing ||
164+ check_index_map (ic. constant_idx, sym) != = nothing ||
165+ check_index_map (ic. nonnumeric_idx, sym) != = nothing ||
166+ check_index_map (ic. dependent_idx, sym) != = nothing
167+ end
168+
169+ function SymbolicIndexingInterface. parameter_index (ic:: IndexCache , sym)
170+ return if (idx = check_index_map (ic. tunable_idx, sym)) != = nothing
171+ ParameterIndex (SciMLStructures. Tunable (), idx)
172+ elseif (idx = check_index_map (ic. discrete_idx, sym)) != = nothing
173+ ParameterIndex (SciMLStructures. Discrete (), idx)
174+ elseif (idx = check_index_map (ic. constant_idx, sym)) != = nothing
175+ ParameterIndex (SciMLStructures. Constants (), idx)
176+ elseif (idx = check_index_map (ic. nonnumeric_idx, sym)) != = nothing
177+ ParameterIndex (NONNUMERIC_PORTION, idx)
178+ elseif (idx = check_index_map (ic. dependent_idx, sym)) != = nothing
179+ ParameterIndex (DEPENDENT_PORTION, idx)
180+ else
181+ nothing
182+ end
183+ end
184+
185+ function check_index_map (idxmap, sym)
186+ if (idx = get (idxmap, sym, nothing )) != = nothing
187+ return idx
188+ elseif hasname (sym) && (idx = get (idxmap, getname (sym), nothing )) != = nothing
189+ return idx
190+ end
191+ dsym = default_toterm (sym)
192+ isequal (sym, dsym) && return nothing
193+ if (idx = get (idxmap, dsym, nothing )) != = nothing
194+ idx
195+ elseif hasname (dsym) && (idx = get (idxmap, getname (dsym), nothing )) != = nothing
196+ idx
197+ else
198+ nothing
199+ end
200+ end
201+
162202function ParameterIndex (ic:: IndexCache , p, sub_idx = ())
163203 p = unwrap (p)
164- h = p isa Symbol ? hash (p) : getsymbolhash (p)
165- return if haskey (ic. param_idx, h)
166- ParameterIndex (SciMLStructures. Tunable (), (ic. param_idx[h]. .. , sub_idx... ))
167- elseif haskey (ic. discrete_idx, h)
168- ParameterIndex (SciMLStructures. Discrete (), (ic. discrete_idx[h]. .. , sub_idx... ))
169- elseif haskey (ic. constant_idx, h)
170- ParameterIndex (SciMLStructures. Constants (), (ic. constant_idx[h]. .. , sub_idx... ))
171- elseif haskey (ic. dependent_idx, h)
172- ParameterIndex (DEPENDENT_PORTION, (ic. dependent_idx[h]. .. , sub_idx... ))
173- elseif haskey (ic. nonnumeric_idx, h)
174- ParameterIndex (NONNUMERIC_PORTION, (ic. nonnumeric_idx[h]. .. , sub_idx... ))
204+ return if haskey (ic. tunable_idx, p)
205+ ParameterIndex (SciMLStructures. Tunable (), (ic. tunable_idx[p]. .. , sub_idx... ))
206+ elseif haskey (ic. discrete_idx, p)
207+ ParameterIndex (SciMLStructures. Discrete (), (ic. discrete_idx[p]. .. , sub_idx... ))
208+ elseif haskey (ic. constant_idx, p)
209+ ParameterIndex (SciMLStructures. Constants (), (ic. constant_idx[p]. .. , sub_idx... ))
210+ elseif haskey (ic. dependent_idx, p)
211+ ParameterIndex (DEPENDENT_PORTION, (ic. dependent_idx[p]. .. , sub_idx... ))
212+ elseif haskey (ic. nonnumeric_idx, p)
213+ ParameterIndex (NONNUMERIC_PORTION, (ic. nonnumeric_idx[p]. .. , sub_idx... ))
175214 elseif istree (p) && operation (p) === getindex
176215 _p, sub_idx... = arguments (p)
177216 ParameterIndex (ic, _p, sub_idx)
182221
183222function discrete_linear_index (ic:: IndexCache , idx:: ParameterIndex )
184223 idx. portion isa SciMLStructures. Discrete || error (" Discrete variable index expected" )
185- ind = sum (temp. length for temp in ic. param_buffer_sizes ; init = 0 )
224+ ind = sum (temp. length for temp in ic. tunable_buffer_sizes ; init = 0 )
186225 ind += sum (
187226 temp. length for temp in Iterators. take (ic. discrete_buffer_sizes, idx. idx[1 ] - 1 );
188227 init = 0 )
202241
203242function reorder_parameters (ic:: IndexCache , ps; drop_missing = false )
204243 param_buf = Tuple (BasicSymbolic[unwrap (variable (:DEF )) for _ in 1 : (temp. length)]
205- for temp in ic. param_buffer_sizes )
244+ for temp in ic. tunable_buffer_sizes )
206245 disc_buf = Tuple (BasicSymbolic[unwrap (variable (:DEF )) for _ in 1 : (temp. length)]
207246 for temp in ic. discrete_buffer_sizes)
208247 const_buf = Tuple (BasicSymbolic[unwrap (variable (:DEF )) for _ in 1 : (temp. length)]
@@ -213,21 +252,20 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
213252 for temp in ic. nonnumeric_buffer_sizes)
214253
215254 for p in ps
216- h = getsymbolhash (p)
217- if haskey (ic. discrete_idx, h)
218- i, j = ic. discrete_idx[h]
255+ if haskey (ic. discrete_idx, p)
256+ i, j = ic. discrete_idx[p]
219257 disc_buf[i][j] = unwrap (p)
220- elseif haskey (ic. param_idx, h )
221- i, j = ic. param_idx[h ]
258+ elseif haskey (ic. tunable_idx, p )
259+ i, j = ic. tunable_idx[p ]
222260 param_buf[i][j] = unwrap (p)
223- elseif haskey (ic. constant_idx, h )
224- i, j = ic. constant_idx[h ]
261+ elseif haskey (ic. constant_idx, p )
262+ i, j = ic. constant_idx[p ]
225263 const_buf[i][j] = unwrap (p)
226- elseif haskey (ic. dependent_idx, h )
227- i, j = ic. dependent_idx[h ]
264+ elseif haskey (ic. dependent_idx, p )
265+ i, j = ic. dependent_idx[p ]
228266 dep_buf[i][j] = unwrap (p)
229- elseif haskey (ic. nonnumeric_idx, h )
230- i, j = ic. nonnumeric_idx[h ]
267+ elseif haskey (ic. nonnumeric_idx, p )
268+ i, j = ic. nonnumeric_idx[p ]
231269 nonnumeric_buf[i][j] = unwrap (p)
232270 else
233271 error (" Invalid parameter $p " )
0 commit comments