Skip to content

Commit c4cbac0

Browse files
authored
Update parameter_buffer.jl
1 parent bd09d73 commit c4cbac0

File tree

1 file changed

+27
-14
lines changed

1 file changed

+27
-14
lines changed

src/systems/parameter_buffer.jl

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -362,37 +362,50 @@ function Base.setindex!(p::MTKParameters, val, i)
362362
end
363363

364364
function Base.getindex(p::MTKParameters, pind::ParameterIndex)
365-
(;portion, idx) = pind, (i,j) = idx
365+
(;portion, idx) = pind
366+
if length(idx) > 2
367+
i, j, k... = idx
368+
indexer = (v) -> v[i][j][k...]
369+
else
370+
i, j = idx
371+
indexer = (v) -> v[i][j]
372+
end
366373
if portion isa SciMLStructures.Tunable
367-
p.tunable[i][j]
374+
indexer(p.tunable)
368375
elseif portion isa SciMLStructures.Discrete
369-
p.discrete[i][j]
376+
indexer(p.discrete)
370377
elseif portion isa SciMLStructures.Constants
371-
p.constant[i][j]
378+
indexer(p.constant)
372379
elseif portion === DEPENDENT_PORTION
373-
p.dependent[i][j]
380+
indexer(p.dependent)
374381
elseif portion === NONNUMERIC_PORTION
375-
p.nonnumeric[i][j]
382+
indexer(p.nonnumeric)
376383
else
377-
error("Unhandled portion $portion")
384+
error("Unhandled portion ", portion)
378385
end
379386
end
380387

381388
function Base.setindex!(p::MTKParameters, val, pind::ParameterIndex)
382389
(;portion, idx) = pind
383-
(i,j) = idx
390+
if length(idx) > 2
391+
i, j, k... = idx
392+
setindexer = (v) -> v[i][j][k...] = val
393+
else
394+
i, j = idx
395+
setindexer = (v) -> v[i][j] = val
396+
end
384397
if portion isa SciMLStructures.Tunable
385-
p.tunable[i][j] = val
398+
setindexer(p.tunable)
386399
elseif portion isa SciMLStructures.Discrete
387-
p.discrete[i][j] = val
400+
setindexer(p.discrete)
388401
elseif portion isa SciMLStructures.Constants
389-
p.constant[i][j] = val
402+
setindexer(p.constant)
390403
elseif portion === DEPENDENT_PORTION
391-
p.dependent[i][j] = val
404+
setindexer(p.dependent)
392405
elseif portion === NONNUMERIC_PORTION
393-
p.nonnumeric[i][j] = val
406+
setindexer(p.nonnumeric)
394407
else
395-
error("Unhandled portion $portion")
408+
error("Unhandled portion", portion)
396409
end
397410
end
398411

0 commit comments

Comments
 (0)