@@ -51,6 +51,10 @@ struct ODESystem <: AbstractODESystem
5151 observed:: Vector{Equation}
5252 """ System of constraints that must be satisfied by the solution to the system."""
5353 constraintsystem:: Union{Nothing, ConstraintsSystem}
54+ """ A set of expressions defining the costs of the system for optimal control."""
55+ costs:: Vector
56+ """ Takes the cost vector and returns a scalar for optimization."""
57+ coalesce:: Function
5458 """
5559 Time-derivative matrix. Note: this field will not be defined until
5660 [`calculate_tgrad`](@ref) is called on the system.
@@ -338,7 +342,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
338342 metadata, gui_metadata, is_dde, tstops, checks = checks)
339343end
340344
341- function ODESystem (eqs, iv; constraints = Equation[], kwargs... )
345+ function ODESystem (eqs, iv; constraints = Equation[], costs = Equation[], kwargs... )
342346 diffvars, allunknowns, ps, eqs = process_equations (eqs, iv)
343347
344348 for eq in get (kwargs, :parameter_dependencies , Equation[])
@@ -384,6 +388,13 @@ function ODESystem(eqs, iv; constraints = Equation[], kwargs...)
384388 end
385389 end
386390
391+ if ! isempty (costs)
392+ coststs, costps = process_costs (costs, allunknowns, new_ps, iv)
393+ for p in costps
394+ ! in (p, new_ps) && push! (new_ps, p)
395+ end
396+ end
397+
387398 return ODESystem (eqs, iv, collect (Iterators. flatten ((diffvars, algevars, consvars))),
388399 collect (new_ps); constraintsystem, kwargs... )
389400end
@@ -734,22 +745,52 @@ function Base.show(io::IO, mime::MIME"text/plain", sys::ODESystem; hint = true,
734745 return nothing
735746end
736747
737- # Validate that all the variables in the BVP constraints are well-formed states or parameters.
738- # - Callable/delay variables (e.g. of the form x(0.6) should be unknowns of the system (and have one arg, etc.)
739- # - Callable/delay parameters should be parameters of the system (and have one arg, etc.)
748+ """
749+ Build the constraint system for the ODESystem.
750+ """
740751function process_constraint_system (
741752 constraints:: Vector{Equation} , sts, ps, iv; consname = :cons )
742753 isempty (constraints) && return nothing
743754
744755 constraintsts = OrderedSet ()
745756 constraintps = OrderedSet ()
746-
747757 for cons in constraints
748758 collect_vars! (constraintsts, constraintps, cons, iv)
749759 end
750760
751761 # Validate the states.
752- for var in constraintsts
762+ validate_vars_and_find_ps! (coststs, costps, sts, iv)
763+
764+ ConstraintsSystem (
765+ constraints, collect (constraintsts), collect (constraintps); name = consname)
766+ end
767+
768+ """
769+ Process the costs for the constraint system.
770+ """
771+ function process_costs (costs:: Vector{Equation} , sts, ps, iv)
772+ coststs = OrderedSet ()
773+ costps = OrderedSet ()
774+ for cost in costs
775+ collect_vars! (coststs, costps, cost, iv)
776+ end
777+
778+ validate_vars_and_find_ps! (coststs, costps, sts, iv)
779+ end
780+
781+ """
782+ Validate that all the variables in an auxiliary system of the ODESystem (constraint or costs) are
783+ well-formed states or parameters.
784+ - Callable/delay variables (e.g. of the form x(0.6) should be unknowns of the system (and have one arg, etc.)
785+ - Callable/delay parameters should be parameters of the system
786+
787+ Return the set of additional parameters found in the system, e.g. in x(p) ~ 3 then p should be added as a
788+ parameter of the system.
789+ """
790+ function validate_vars_and_find_ps! (auxvars, auxps, sysvars, iv)
791+ sts = sysvars
792+
793+ for var in auxvars
753794 if ! iscall (var)
754795 occursin (iv, var) && (var ∈ sts ||
755796 throw (ArgumentError (" Time-dependent variable $var is not an unknown of the system." )))
@@ -764,13 +805,17 @@ function process_constraint_system(
764805 arg isa AbstractFloat ||
765806 throw (ArgumentError (" Invalid argument specified for variable $var . The argument of the variable should be either $iv , a parameter, or a value specifying the time that the constraint holds." ))
766807
767- isparameter (arg) && push! (constraintps , arg)
808+ isparameter (arg) && push! (auxps , arg)
768809 else
769810 var ∈ sts &&
770811 @warn " Variable $var has no argument. It will be interpreted as $var ($iv ), and the constraint will apply to the entire interval."
771812 end
772813 end
814+ end
773815
774- ConstraintsSystem (
775- constraints, collect (constraintsts), collect (constraintps); name = consname)
816+ function generate_cost_function (sys:: ODESystem )
817+ costs = get_costs (sys)
818+ coalesce = get_coalesce (sys)
819+ cost_fn = build_function_wrapper ()
820+ return (u, p, t) -> coalesce (cost_fn (u, p, t))
776821end
0 commit comments