@@ -132,7 +132,8 @@ function MTKParameters(
132132 tunable_buffer = narrow_buffer_type .(tunable_buffer)
133133 disc_buffer = narrow_buffer_type .(disc_buffer)
134134 const_buffer = narrow_buffer_type .(const_buffer)
135- nonnumeric_buffer = narrow_buffer_type .(nonnumeric_buffer)
135+ # Don't narrow nonnumeric types
136+ nonnumeric_buffer = nonnumeric_buffer
136137
137138 if has_parameter_dependencies (sys) &&
138139 (pdeps = get_parameter_dependencies (sys)) != = nothing
@@ -308,22 +309,31 @@ end
308309
309310function SymbolicIndexingInterface. set_parameter! (
310311 p:: MTKParameters , val, idx:: ParameterIndex )
311- @unpack portion, idx = idx
312+ @unpack portion, idx, validate_size = idx
312313 i, j, k... = idx
313314 if portion isa SciMLStructures. Tunable
314315 if isempty (k)
316+ if validate_size && size (val) != = size (p. tunable[i][j])
317+ throw (InvalidParameterSizeException (size (p. tunable[i][j]), size (val)))
318+ end
315319 p. tunable[i][j] = val
316320 else
317321 p. tunable[i][j][k... ] = val
318322 end
319323 elseif portion isa SciMLStructures. Discrete
320324 if isempty (k)
325+ if validate_size && size (val) != = size (p. discrete[i][j])
326+ throw (InvalidParameterSizeException (size (p. discrete[i][j]), size (val)))
327+ end
321328 p. discrete[i][j] = val
322329 else
323330 p. discrete[i][j][k... ] = val
324331 end
325332 elseif portion isa SciMLStructures. Constants
326333 if isempty (k)
334+ if validate_size && size (val) != = size (p. constant[i][j])
335+ throw (InvalidParameterSizeException (size (p. constant[i][j]), size (val)))
336+ end
327337 p. constant[i][j] = val
328338 else
329339 p. constant[i][j][k... ] = val
@@ -392,14 +402,73 @@ function narrow_buffer_type_and_fallback_undefs(oldbuf::Vector, newbuf::Vector)
392402 isassigned (newbuf, i) || continue
393403 type = promote_type (type, typeof (newbuf[i]))
394404 end
405+ if type == Union{}
406+ type = eltype (oldbuf)
407+ end
395408 for i in eachindex (newbuf)
396409 isassigned (newbuf, i) && continue
397410 newbuf[i] = convert (type, oldbuf[i])
398411 end
399412 return convert (Vector{type}, newbuf)
400413end
401414
402- function SymbolicIndexingInterface. remake_buffer (sys, oldbuf:: MTKParameters , vals:: Dict )
415+ function validate_parameter_type (ic:: IndexCache , p, index, val)
416+ p = unwrap (p)
417+ if p isa Symbol
418+ p = get (ic. symbol_to_variable, p, nothing )
419+ if p === nothing
420+ @warn " No matching variable found for `Symbol` $p , skipping type validation."
421+ return nothing
422+ end
423+ end
424+ (; portion) = index
425+ # Nonnumeric parameters have to match the type
426+ if portion === NONNUMERIC_PORTION
427+ stype = symtype (p)
428+ val isa stype && return nothing
429+ throw (ParameterTypeException (:validate_parameter_type , p, stype, val))
430+ end
431+ stype = symtype (p)
432+ # Array parameters need array values...
433+ if stype <: AbstractArray && ! isa (val, AbstractArray)
434+ throw (ParameterTypeException (:validate_parameter_type , p, stype, val))
435+ end
436+ # ... and must match sizes
437+ if stype <: AbstractArray && Symbolics. shape (p) != = Symbolics. Unknown () &&
438+ size (val) != size (p)
439+ throw (InvalidParameterSizeException (p, val))
440+ end
441+ # Early exit
442+ val isa stype && return nothing
443+ if stype <: AbstractArray
444+ # Arrays need handling when eltype is `Real` (accept any real array)
445+ etype = eltype (stype)
446+ if etype <: Real
447+ etype = Real
448+ end
449+ # This is for duals and other complicated number types
450+ etype = SciMLBase. parameterless_type (etype)
451+ eltype (val) <: etype || throw (ParameterTypeException (
452+ :validate_parameter_type , p, AbstractArray{etype}, val))
453+ else
454+ # Real check
455+ if stype <: Real
456+ stype = Real
457+ end
458+ stype = SciMLBase. parameterless_type (stype)
459+ val isa stype ||
460+ throw (ParameterTypeException (:validate_parameter_type , p, stype, val))
461+ end
462+ end
463+
464+ function indp_to_system (indp)
465+ while hasmethod (symbolic_container, Tuple{typeof (indp)})
466+ indp = symbolic_container (indp)
467+ end
468+ return indp
469+ end
470+
471+ function SymbolicIndexingInterface. remake_buffer (indp, oldbuf:: MTKParameters , vals:: Dict )
403472 newbuf = @set oldbuf. tunable = Tuple (Vector {Any} (undef, length (buf))
404473 for buf in oldbuf. tunable)
405474 @set! newbuf. discrete = Tuple (Vector {Any} (undef, length (buf))
@@ -409,9 +478,15 @@ function SymbolicIndexingInterface.remake_buffer(sys, oldbuf::MTKParameters, val
409478 @set! newbuf. nonnumeric = Tuple (Vector {Any} (undef, length (buf))
410479 for buf in newbuf. nonnumeric)
411480
481+ # If the parameter buffer is an `MTKParameters` object, `indp` must eventually drill
482+ # down to an `AbstractSystem` using `symbolic_container`. We leverage this to get
483+ # the index cache.
484+ ic = get_index_cache (indp_to_system (indp))
412485 for (p, val) in vals
486+ idx = parameter_index (indp, p)
487+ validate_parameter_type (ic, p, idx, val)
413488 _set_parameter_unchecked! (
414- newbuf, val, parameter_index (sys, p) ; update_dependent = false )
489+ newbuf, val, idx ; update_dependent = false )
415490 end
416491
417492 @set! newbuf. tunable = narrow_buffer_type_and_fallback_undefs .(
@@ -588,3 +663,15 @@ function Base.showerror(io::IO, e::MissingParametersError)
588663 println (io, MISSING_PARAMETERS_MESSAGE)
589664 println (io, e. vars)
590665end
666+
667+ function InvalidParameterSizeException (param, val)
668+ DimensionMismatch (" InvalidParameterSizeException: For parameter $(param) expected value of size $(size (param)) . Received value $(val) of size $(size (val)) ." )
669+ end
670+
671+ function InvalidParameterSizeException (param:: Tuple , val:: Tuple )
672+ DimensionMismatch (" InvalidParameterSizeException: Expected value of size $(param) . Received value of size $(val) ." )
673+ end
674+
675+ function ParameterTypeException (func, param, expected, val)
676+ TypeError (func, " Parameter $param " , expected, val)
677+ end
0 commit comments