@@ -6,10 +6,8 @@ using UnPack
66using NaNMath
77const MTK = ModelingToolkit
88
9- # NaNMath
109for ff in [acos, log1p, acosh, log2, asin, tan, atanh, cos, log, sin, log10, sqrt]
1110 f = nameof (ff)
12- # These need to be defined so that JuMP can trace through functions built by Symbolics
1311 @eval NaNMath.$ f (x:: CasadiSymbolicObject ) = Base.$ f (x)
1412end
1513
@@ -76,78 +74,47 @@ function MTK.CasADiDynamicOptProblem(sys::System, u0map, tspan, pmap;
7674 dt = nothing ,
7775 steps = nothing ,
7876 guesses = Dict (), kwargs... )
79- MTK. warn_overdetermined (sys, u0map)
80- _u0map = has_alg_eqs (sys) ? MTK. to_varmap (u0map, unknowns (sys)) :
81- merge (Dict (u0map), Dict (guesses))
82- pmap = MTK. to_varmap (pmap, parameters (sys))
83- f, u0, p = MTK. process_SciMLProblem (ODEInputFunction, sys, merge (_u0map, pmap);
84- t = tspan != = nothing ? tspan[1 ] : tspan, output_type = MX, kwargs... )
85-
86- pmap = MTK. recursive_unwrap (MTK. AnyDict (pmap))
87- MTK. evaluate_varmap! (pmap, keys (pmap))
88- steps, is_free_t = MTK. process_tspan (tspan, dt, steps)
89- model = init_model (sys, tspan, steps, u0map, pmap, u0; is_free_t)
90-
91- CasADiDynamicOptProblem (f, u0, tspan, p, model, kwargs... )
77+ process_DynamicOptProblem (CasADiDynamicOptProblem, CasADiModel, sys, u0map, tspan, pmap; dt, steps, guesses, kwargs... )
9278end
9379
9480MTK. generate_internal_model (:: Type{CasADiModel} ) = CasADi. opti ()
95- MTK. generate_state_variable (model, u0, ns, nt)
96- MTK. generate_input_variable (model, c0, nc, nt) = 1
97- MTK. generate_timescale (model, dims) = 1
9881
99- function init_model (sys, tspan, steps, u0map, pmap, u0; is_free_t = false )
100- ctrls = MTK. unbound_inputs (sys)
101- states = unknowns (sys)
102- opti = CasADi. Opti ()
82+ function MTK. generate_state_variable (model:: Opti , u0, ns, nt, tsteps)
83+ U = CasADi. variable! (model, ns, nt)
84+ set_initial! (opti, U, DM (repeat (u0, 1 , steps)))
85+ MXLinearInterpolation (U, tsteps, tsteps[2 ] - tsteps[1 ])
86+ end
10387
88+ function MTK. generate_input_variable (model:: Opti , c0, nc, nt, tsteps)
89+ V = CasADi. variable! (model, nc, nt)
90+ ! isempty (c0) && set_initial! (opti, V, DM (repeat (c0, 1 , steps)))
91+ MXLinearInterpolation (V, tsteps, tsteps[2 ] - tsteps[1 ])
92+ end
93+
94+ function MTK. generate_timescale (model:: Opti , guess, is_free_t)
10495 if is_free_t
105- (ts_sym, te_sym) = tspan
106- MTK. symbolic_type (ts_sym) != = MTK. NotSymbolic () &&
107- error (" Free initial time problems are not currently supported in CasADiDynamicOptProblem." )
108- tₛ = variable! (opti)
109- set_initial! (opti, tₛ, pmap[te_sym])
110- subject_to! (opti, tₛ >= ts_sym)
111- hasbounds (te_sym) && begin
112- lo, hi = getbounds (te_sym)
113- subject_to! (opti, tₛ >= lo)
114- subject_to! (opti, tₛ >= hi)
115- end
116- pmap[te_sym] = tₛ
117- tsteps = LinRange (0 , 1 , steps)
96+ tₛ = variable! (model)
97+ set_initial! (model, tₛ, guess)
98+ subject_to! (model, tₛ >= 0 )
99+ tₛ
118100 else
119- tₛ = MX (1 )
120- tsteps = LinRange (tspan[1 ], tspan[2 ], steps)
101+ MX (1 )
121102 end
103+ end
122104
123- U = CasADi. variable! (opti, length (states), steps)
124- V = CasADi. variable! (opti, length (ctrls), steps)
125- set_initial! (opti, U, DM (repeat (u0, 1 , steps)))
126- c0 = MTK. value .([pmap[c] for c in ctrls])
127- ! isempty (c0) && set_initial! (opti, V, DM (repeat (c0, 1 , steps)))
128-
129- U_interp = MXLinearInterpolation (U, tsteps, tsteps[2 ] - tsteps[1 ])
130- V_interp = MXLinearInterpolation (V, tsteps, tsteps[2 ] - tsteps[1 ])
131- for (i, ct) in enumerate (ctrls)
132- pmap[ct] = V[i, :]
105+ function MTK. add_constraint! (model:: CasADiModel , expr)
106+ @unpack opti = model
107+ if cons isa Equation
108+ subject_to! (opti, expr. lhs - expr. rhs == 0 )
109+ elseif cons. relational_op === Symbolics. geq
110+ subject_to! (opti, expr. lhs - expr. rhs ≥ 0 )
111+ else
112+ subject_to! (opti, expr. lhs - expr. rhs ≤ 0 )
133113 end
134-
135- model = CasADiModel (opti, U_interp, V_interp, tₛ)
136-
137- set_casadi_bounds! (model, sys, pmap)
138- add_cost_function! (model, sys, tspan, pmap; is_free_t)
139- add_user_constraints! (model, sys, tspan, pmap; is_free_t)
140-
141- stidxmap = Dict ([v => i for (i, v) in enumerate (states)])
142- u0map = Dict ([MTK. default_toterm (MTK. value (k)) => v for (k, v) in u0map])
143- u0_idxs = has_alg_eqs (sys) ? collect (1 : length (states)) :
144- [stidxmap[MTK. default_toterm (k)] for (k, v) in u0map]
145- add_initial_constraints! (model, u0, u0_idxs)
146-
147- model
148114end
115+ MTK. set_objective! (model:: CasADiModel , expr) = minimize! (model. opti, MX (expr))
149116
150- function set_casadi_bounds ! (model, sys, pmap)
117+ function MTK . set_variable_bounds ! (model, sys, pmap, tf )
151118 @unpack opti, U, V = model
152119 for (i, u) in enumerate (unknowns (sys))
153120 if MTK. hasbounds (u)
@@ -163,36 +130,56 @@ function set_casadi_bounds!(model, sys, pmap)
163130 subject_to! (opti, V. u[i, :] <= Symbolics. fast_substitute (hi, pmap))
164131 end
165132 end
133+ if MTK. symbolic_type (tf) === MTK. ScalarSymbolic () && hasbounds (tf)
134+ lo, hi = MTK. getbounds (tf)
135+ subject_to! (opti, model. tₛ >= lo)
136+ subject_to! (opti, model. tₛ <= hi)
137+ end
166138end
167139
168- function add_initial_constraints! (model:: CasADiModel , u0, u0_idxs)
140+ function MTK . add_initial_constraints! (model:: CasADiModel , u0, u0_idxs)
169141 @unpack opti, U = model
170142 for i in u0_idxs
171143 subject_to! (opti, U. u[i, 1 ] == u0[i])
172144 end
173145end
174146
175- function add_user_constraints! (model:: CasADiModel , sys, tspan, pmap; is_free_t)
147+ function MTK. substitute_model_vars (
148+ model:: CasADiModel , sys, pmap, exprs; auxmap:: Dict = Dict (), is_free_t)
176149 @unpack opti, U, V, tₛ = model
177-
178150 iv = MTK. get_iv (sys)
179- jconstraints = MTK. get_constraints (sys)
180- (isnothing (jconstraints) || isempty (jconstraints)) && return nothing
181-
182- stidxmap = Dict ([v => i for (i, v) in enumerate (unknowns (sys))])
183- ctidxmap = Dict ([v => i for (i, v) in enumerate (MTK. unbound_inputs (sys))])
184- cons_dvs, cons_ps = MTK. process_constraint_system (
185- jconstraints, Set (unknowns (sys)), parameters (sys), iv; validate = false )
186-
187- auxmap = Dict ([u => MTK. default_toterm (MTK. value (u)) for u in cons_dvs])
188- jconstraints = substitute_casadi_vars (model, sys, pmap, jconstraints; is_free_t, auxmap)
189- # Manually substitute fixed-t variables
190- for (i, cons) in enumerate (jconstraints)
191- consvars = MTK. vars (cons)
192- for st in consvars
151+ sts = unknowns (sys)
152+ cts = MTK. unbound_inputs (sys)
153+
154+ x_ops = [MTK. operation (MTK. unwrap (st)) for st in sts]
155+ c_ops = [MTK. operation (MTK. unwrap (ct)) for ct in cts]
156+
157+ exprs = map (c -> Symbolics. fast_substitute (c, auxmap), exprs)
158+ exprs = map (c -> Symbolics. fast_substitute (c, Dict (pmap)), exprs)
159+ # tf means different things in different contexts; a [tf] in a cost function
160+ # should be tₛ, while a x(tf) should translate to x[1]
161+ if is_free_t
162+ free_t_map = Dict ([[x (tₛ) => U. u[i, end ] for (i, x) in enumerate (x_ops)];
163+ [c (tₛ) => V. u[i, end ] for (i, c) in enumerate (c_ops)]])
164+ exprs = map (c -> Symbolics. fast_substitute (c, free_t_map), exprs)
165+ end
166+
167+ exprs = substitute_fixed_t_vars (exprs)
168+
169+ # for variables like x(t)
170+ whole_interval_map = Dict ([[v => U. u[i, :] for (i, v) in enumerate (sts)];
171+ [v => V. u[i, :] for (i, v) in enumerate (cts)]])
172+ exprs = map (c -> Symbolics. fast_substitute (c, whole_interval_map), exprs)
173+ exprs
174+ end
175+
176+ function substitute_fixed_t_vars (exprs)
177+ for i in 1 : length (exprs)
178+ subvars = MTK. vars (exprs[i])
179+ for st in subvars
193180 MTK. iscall (st) || continue
194- x = MTK . operation (st)
195- t = only (MTK . arguments (st))
181+ x = operation (st)
182+ t = only (arguments (st))
196183 MTK. symbolic_type (t) === MTK. NotSymbolic () || continue
197184 if haskey (stidxmap, x (iv))
198185 idx = stidxmap[x (iv)]
@@ -201,52 +188,19 @@ function add_user_constraints!(model::CasADiModel, sys, tspan, pmap; is_free_t)
201188 idx = ctidxmap[x (iv)]
202189 cv = V
203190 end
204- cons = Symbolics. substitute (cons, Dict (x (t) => cv (t)[idx]))
205- end
206-
207- if cons isa Equation
208- subject_to! (opti, cons. lhs - cons. rhs == 0 )
209- elseif cons. relational_op === Symbolics. geq
210- subject_to! (opti, cons. lhs - cons. rhs ≥ 0 )
211- else
212- subject_to! (opti, cons. lhs - cons. rhs ≤ 0 )
191+ exprs[i] = Symbolics. fast_substitute (exprs[i], Dict (x (t) => cv (t)[idx]))
213192 end
193+ jcosts = Symbolics. substitute (jcosts, Dict (x (t) => cv (t)[idx]))
214194 end
215195end
216196
217- function add_cost_function! (model:: CasADiModel , sys, tspan, pmap; is_free_t)
218- @unpack opti, U, V, tₛ = model
219- jcosts = cost (sys)
220- if Symbolics. _iszero (jcosts)
221- minimize! (opti, MX (0 ))
222- return
223- end
224-
225- iv = MTK. get_iv (sys)
226- stidxmap = Dict ([v => i for (i, v) in enumerate (unknowns (sys))])
227- ctidxmap = Dict ([v => i for (i, v) in enumerate (MTK. unbound_inputs (sys))])
228-
229- jcosts = substitute_casadi_vars (model, sys, pmap, [jcosts]; is_free_t)[1 ]
230- # Substitute fixed-time variables.
231- costvars = MTK. vars (jcosts)
232- for st in costvars
233- MTK. iscall (st) || continue
234- x = operation (st)
235- t = only (arguments (st))
236- MTK. symbolic_type (t) === MTK. NotSymbolic () || continue
237- if haskey (stidxmap, x (iv))
238- idx = stidxmap[x (iv)]
239- cv = U
240- else
241- idx = ctidxmap[x (iv)]
242- cv = V
243- end
244- jcosts = Symbolics. substitute (jcosts, Dict (x (t) => cv (t)[idx]))
245- end
197+ MTK. substitute_differentials (model:: CasADiModel , exprs, args... ) = exprs
246198
199+ function MTK. substitute_integral (model:: CasADiModel , exprs)
200+ @unpack U, opti = model
247201 dt = U. t[2 ] - U. t[1 ]
248202 intmap = Dict ()
249- for int in MTK. collect_applied_operators (jcosts , Symbolics. Integral)
203+ for int in MTK. collect_applied_operators (exprs , Symbolics. Integral)
250204 op = MTK. operation (int)
251205 arg = only (arguments (MTK. value (int)))
252206 lo, hi = (op. domain. domain. left, op. domain. domain. right)
@@ -255,39 +209,11 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
255209 # Approximate integral as sum.
256210 intmap[int] = dt * tₛ * sum (arg)
257211 end
258- jcosts = Symbolics. substitute (jcosts, intmap)
259- jcosts = MTK. value (jcosts)
260- minimize! (opti, MX (jcosts))
261- end
262-
263- function substitute_casadi_vars (
264- model:: CasADiModel , sys, pmap, exprs; auxmap:: Dict = Dict (), is_free_t)
265- @unpack opti, U, V, tₛ = model
266- iv = MTK. get_iv (sys)
267- sts = unknowns (sys)
268- cts = MTK. unbound_inputs (sys)
269-
270- x_ops = [MTK. operation (MTK. unwrap (st)) for st in sts]
271- c_ops = [MTK. operation (MTK. unwrap (ct)) for ct in cts]
272-
273- exprs = map (c -> Symbolics. fast_substitute (c, auxmap), exprs)
274- exprs = map (c -> Symbolics. fast_substitute (c, Dict (pmap)), exprs)
275- # tf means different things in different contexts; a [tf] in a cost function
276- # should be tₛ, while a x(tf) should translate to x[1]
277- if is_free_t
278- free_t_map = Dict ([[x (tₛ) => U. u[i, end ] for (i, x) in enumerate (x_ops)];
279- [c (tₛ) => V. u[i, end ] for (i, c) in enumerate (c_ops)]])
280- exprs = map (c -> Symbolics. fast_substitute (c, free_t_map), exprs)
281- end
282-
283- # for variables like x(t)
284- whole_interval_map = Dict ([[v => U. u[i, :] for (i, v) in enumerate (sts)];
285- [v => V. u[i, :] for (i, v) in enumerate (cts)]])
286- exprs = map (c -> Symbolics. fast_substitute (c, whole_interval_map), exprs)
287- exprs
212+ exprs = map (c -> Symbolics. substitute (c, intmap), exprs)
213+ exprs = MTK. value .(exprs)
288214end
289215
290- function add_solve_constraints (prob, tableau)
216+ function add_solve_constraints! (prob, tableau)
291217 @unpack A, α, c = tableau
292218 @unpack model, f, p = prob
293219 @unpack opti, U, V, tₛ = model
@@ -332,57 +258,29 @@ function add_solve_constraints(prob, tableau)
332258 solver_opti
333259end
334260
335- """
336- solve(prob::CasADiDynamicOptProblem, casadi_solver, ode_solver; plugin_options, solver_options, silent)
337-
338- `plugin_options` and `solver_options` get propagated to the Opti object in CasADi.
339-
340- NOTE: the solver should be passed in as a string to CasADi. "ipopt"
341- """
342- function DiffEqBase. solve (
343- prob:: CasADiDynamicOptProblem , solver:: Union{String, Symbol} = " ipopt" ,
344- tableau_getter = MTK. constructDefault; plugin_options:: Dict = Dict (),
345- solver_options:: Dict = Dict (), silent = false )
346- @unpack model, u0, p, tspan, f = prob
347- tableau = tableau_getter ()
348- @unpack opti, U, V, tₛ = model
349-
261+ function MTK. prepare_solver ()
350262 opti = add_solve_constraints (prob, tableau)
351- silent && (solver_options[" print_level" ] = 0 )
352263 solver! (opti, " $solver " , plugin_options, solver_options)
264+ end
265+ function MTK. get_U_values ()
266+ U_vals = value_getter (U. u)
267+ size (U_vals, 2 ) == 1 && (U_vals = U_vals' )
268+ U_vals = [[U_vals[i, j] for i in 1 : size (U_vals, 1 )] for j in 1 : length (ts)]
269+ end
270+ function MTK. get_V_values ()
271+ end
272+ function MTK. get_t_values ()
273+ ts = value_getter (tₛ) * U. t
274+ end
353275
354- failed = false
355- value_getter = nothing
356- sol = nothing
276+ function MTK. optimize_model! ()
357277 try
358278 sol = CasADi. solve! (opti)
359279 value_getter = x -> CasADi. value (sol, x)
360280 catch ErrorException
361281 value_getter = x -> CasADi. debug_value (opti, x)
362282 failed = true
363283 end
364-
365- ts = value_getter (tₛ) * U. t
366- U_vals = value_getter (U. u)
367- size (U_vals, 2 ) == 1 && (U_vals = U_vals' )
368- U_vals = [[U_vals[i, j] for i in 1 : size (U_vals, 1 )] for j in 1 : length (ts)]
369- ode_sol = DiffEqBase. build_solution (prob, tableau_getter, ts, U_vals)
370-
371- input_sol = nothing
372- if prod (size (V. u)) != 0
373- V_vals = value_getter (V. u)
374- size (V_vals, 2 ) == 1 && (V_vals = V_vals' )
375- V_vals = [[V_vals[i, j] for i in 1 : size (V_vals, 1 )] for j in 1 : length (ts)]
376- input_sol = DiffEqBase. build_solution (prob, tableau_getter, ts, V_vals)
377- end
378-
379- if failed
380- ode_sol = SciMLBase. solution_new_retcode (
381- ode_sol, SciMLBase. ReturnCode. ConvergenceFailure)
382- ! isnothing (input_sol) && (input_sol = SciMLBase. solution_new_retcode (
383- input_sol, SciMLBase. ReturnCode. ConvergenceFailure))
384- end
385-
386- DynamicOptSolution (model, ode_sol, input_sol)
387284end
285+ MTK. successful_solve () = true
388286end
0 commit comments