@@ -357,7 +357,9 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
357357 analytic = nothing ,
358358 split_idxs = nothing ,
359359 initializeprob = nothing ,
360+ update_initializeprob! = nothing ,
360361 initializeprobmap = nothing ,
362+ initializeprobpmap = nothing ,
361363 kwargs... ) where {iip, specialize}
362364 if ! iscomplete (sys)
363365 error (" A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunction`" )
@@ -459,7 +461,9 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
459461 sparsity = sparsity ? jacobian_sparsity (sys) : nothing ,
460462 analytic = analytic,
461463 initializeprob = initializeprob,
462- initializeprobmap = initializeprobmap)
464+ update_initializeprob! = update_initializeprob!,
465+ initializeprobmap = initializeprobmap,
466+ initializeprobpmap = initializeprobpmap)
463467end
464468
465469"""
@@ -789,6 +793,45 @@ function get_u0(
789793 return u0, defs
790794end
791795
796+ struct GetUpdatedMTKParameters{G, S}
797+ # `getu` functor which gets parameters that are unknowns during initialization
798+ getpunknowns:: G
799+ # `setu` functor which returns a modified MTKParameters using those parameters
800+ setpunknowns:: S
801+ end
802+
803+ function (f:: GetUpdatedMTKParameters )(prob, initializesol)
804+ mtkp = copy (parameter_values (prob))
805+ f. setpunknowns (mtkp, f. getpunknowns (initializesol))
806+ mtkp
807+ end
808+
809+ struct UpdateInitializeprob{G, S}
810+ # `getu` functor which gets all values from prob
811+ getvals:: G
812+ # `setu` functor which updates initializeprob with values
813+ setvals:: S
814+ end
815+
816+ function (f:: UpdateInitializeprob )(initializeprob, prob)
817+ f. setvals (initializeprob, f. getvals (prob))
818+ end
819+
820+ function get_temporary_value (p)
821+ stype = symtype (unwrap (p))
822+ return if stype == Real
823+ zero (Float64)
824+ elseif stype <: AbstractArray{Real}
825+ zeros (Float64, size (p))
826+ elseif stype <: Real
827+ zero (stype)
828+ elseif stype <: AbstractArray
829+ zeros (eltype (stype), size (p))
830+ else
831+ error (" Nonnumeric parameter $p with symtype $stype cannot be solved for during initialization" )
832+ end
833+ end
834+
792835function process_DEProblem (constructor, sys:: AbstractODESystem , u0map, parammap;
793836 implicit_dae = false , du0map = nothing ,
794837 version = nothing , tgrad = false ,
@@ -829,18 +872,38 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
829872 end
830873
831874 if eltype (parammap) <: Pair
832- parammap = Dict (unwrap (k) => v for (k, v) in todict ( parammap) )
875+ parammap = Dict {Any, Any} (unwrap (k) => v for (k, v) in parammap)
833876 elseif parammap isa AbstractArray
834877 if isempty (parammap)
835878 parammap = SciMLBase. NullParameters ()
836879 else
837- parammap = Dict (unwrap .(parameters (sys)) .=> parammap)
880+ parammap = Dict {Any, Any} (unwrap .(parameters (sys)) .=> parammap)
838881 end
839882 end
840-
883+ defs = defaults (sys)
884+ if has_guesses (sys)
885+ guesses = merge (
886+ ModelingToolkit. guesses (sys), isempty (guesses) ? Dict () : todict (guesses))
887+ solvablepars = [p
888+ for p in parameters (sys)
889+ if is_parameter_solvable (p, parammap, defs, guesses)]
890+
891+ pvarmap = if parammap === nothing || parammap == SciMLBase. NullParameters () || ! (eltype (parammap) <: Pair ) && isempty (parammap)
892+ defs
893+ else
894+ merge (defs, todict (parammap))
895+ end
896+ setparobserved = filter (keys (pvarmap)) do var
897+ has_parameter_dependency_with_lhs (sys, var)
898+ end
899+ else
900+ solvablepars = ()
901+ setparobserved = ()
902+ end
841903 # ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first
842904 if sys isa ODESystem && build_initializeprob &&
843- (((implicit_dae || ! isempty (missingvars) || ! isempty (setobserved)) &&
905+ (((implicit_dae || ! isempty (missingvars) || ! isempty (solvablepars) ||
906+ ! isempty (setobserved) || ! isempty (setparobserved)) &&
844907 ModelingToolkit. get_tearing_state (sys) != = nothing ) ||
845908 ! isempty (initialization_equations (sys))) && t != = nothing
846909 if eltype (u0map) <: Number
@@ -854,14 +917,32 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
854917 sys, t, u0map, parammap; guesses, warn_initialize_determined,
855918 initialization_eqs, eval_expression, eval_module, fully_determined, check_units)
856919 initializeprobmap = getu (initializeprob, unknowns (sys))
920+ punknowns = [p
921+ for p in all_variable_symbols (initializeprob) if is_parameter (sys, p)]
922+ getpunknowns = getu (initializeprob, punknowns)
923+ setpunknowns = setp (sys, punknowns)
924+ initializeprobpmap = GetUpdatedMTKParameters (getpunknowns, setpunknowns)
925+ reqd_syms = parameter_symbols (initializeprob)
926+ update_initializeprob! = UpdateInitializeprob (
927+ getu (sys, reqd_syms), setu (initializeprob, reqd_syms))
857928
858929 zerovars = Dict (setdiff (unknowns (sys), keys (defaults (sys))) .=> 0.0 )
930+ if parammap isa SciMLBase. NullParameters
931+ parammap = Dict ()
932+ end
933+ for p in punknowns
934+ p = unwrap (p)
935+ stype = symtype (p)
936+ parammap[p] = get_temporary_value (p)
937+ end
859938 trueinit = collect (merge (zerovars, eltype (u0map) <: Pair ? todict (u0map) : u0map))
860939 u0map isa StaticArraysCore. StaticArray &&
861940 (trueinit = SVector {length(trueinit)} (trueinit))
862941 else
863942 initializeprob = nothing
943+ update_initializeprob! = nothing
864944 initializeprobmap = nothing
945+ initializeprobpmap = nothing
865946 trueinit = u0map
866947 end
867948
@@ -909,7 +990,9 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
909990 sparse = sparse, eval_expression = eval_expression,
910991 eval_module = eval_module,
911992 initializeprob = initializeprob,
993+ update_initializeprob! = update_initializeprob!,
912994 initializeprobmap = initializeprobmap,
995+ initializeprobpmap = initializeprobpmap,
913996 kwargs... )
914997 implicit_dae ? (f, du0, u0, p) : (f, u0, p)
915998end
@@ -1471,10 +1554,12 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
14711554 isys = get_initializesystem (sys; initialization_eqs, check_units)
14721555 elseif isempty (u0map) && get_initializesystem (sys) === nothing
14731556 isys = structural_simplify (
1474- generate_initializesystem (sys; initialization_eqs, check_units); fully_determined)
1557+ generate_initializesystem (
1558+ sys; initialization_eqs, check_units, pmap = parammap); fully_determined)
14751559 else
14761560 isys = structural_simplify (
1477- generate_initializesystem (sys; u0map, initialization_eqs, check_units); fully_determined)
1561+ generate_initializesystem (
1562+ sys; u0map, initialization_eqs, check_units, pmap = parammap); fully_determined)
14781563 end
14791564
14801565 uninit = setdiff (unknowns (sys), [unknowns (isys); getfield .(observed (isys), :lhs )])
@@ -1498,14 +1583,15 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
14981583 parammap = parammap isa DiffEqBase. NullParameters || isempty (parammap) ?
14991584 [get_iv (sys) => t] :
15001585 merge (todict (parammap), Dict (get_iv (sys) => t))
1586+ parammap = Dict (k => v for (k, v) in parammap if v != = missing )
15011587 if isempty (u0map)
15021588 u0map = Dict ()
15031589 end
15041590 if isempty (guesses)
15051591 guesses = Dict ()
15061592 end
15071593
1508- u0map = merge (todict (guesses), todict (u0map))
1594+ u0map = merge (ModelingToolkit . guesses (sys), todict (guesses), todict (u0map))
15091595 if neqs == nunknown
15101596 NonlinearProblem (isys, u0map, parammap; kwargs... )
15111597 else
0 commit comments