Skip to content

Commit d328a60

Browse files
add OptimizationProblem building
1 parent a443583 commit d328a60

File tree

6 files changed

+33
-5
lines changed

6 files changed

+33
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ TreeViews = "a2a6695c-b41b-5b7d-aed9-dbfdeacea5d7"
2222
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
2323

2424
[compat]
25-
DiffEqBase = "6"
25+
DiffEqBase = "6.28"
2626
DiffRules = "0.1, 1.0"
2727
DocStringExtensions = "0.7, 0.8"
2828
GeneralizedGenerated = "0.1.4, 0.2"

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ export simplify_constants
122122
export Operation, Expression, Variable
123123
export calculate_jacobian, generate_jacobian, generate_function
124124
export calculate_tgrad, generate_tgrad
125+
export calculate_gradient, generate_gradient
125126
export calculate_hessian, generate_hessian
126127
export calculate_massmatrix, generate_diffusion_function
127128
export independent_variable, states, parameters, equations

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@ end
7474

7575
function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem,u0map,tspan,
7676
parammap=DiffEqBase.NullParameters();
77-
version = nothing, jac = false,
78-
kwargs...) where iip
77+
jac = false,kwargs...) where iip
7978
dvs = states(sys)
8079
ps = parameters(sys)
8180

src/systems/optimization/optimizationsystem.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,22 @@ function OptimizationSystem(op, states, ps;
3939
OptimizationSystem(op, convert.(Variable,states), convert.(Variable,ps), name, systems)
4040
end
4141

42+
function calculate_gradient(sys::OptimizationSystem)
43+
expand_derivatives.(gradient(equations(sys), [dv() for dv in states(sys)]))
44+
end
45+
46+
function generate_gradient(sys::OptimizationSystem, vs = states(sys), ps = parameters(sys), expression = Val{true}; kwargs...)
47+
grad = calculate_gradient(sys)
48+
return build_function(grad, convert.(Variable,vs), convert.(Variable,ps), (), x->convert(Expr, x); kwargs...)
49+
end
50+
4251
function calculate_hessian(sys::OptimizationSystem)
4352
expand_derivatives.(hessian(equations(sys), [dv() for dv in states(sys)]))
4453
end
4554

4655
function generate_hessian(sys::OptimizationSystem, vs = states(sys), ps = parameters(sys), expression = Val{true}; kwargs...)
4756
hes = calculate_hessian(sys)
48-
return build_function(hes, convert.(Variable,vs), convert.(Variable,ps), (), x->convert(Expr, x))
57+
return build_function(hes, convert.(Variable,vs), convert.(Variable,ps), (), x->convert(Expr, x); kwargs...)
4958
end
5059

5160
function generate_function(sys::OptimizationSystem, vs = states(sys), ps = parameters(sys), expression = Val{true}; kwargs...)
@@ -56,3 +65,19 @@ end
5665

5766
equations(sys::OptimizationSystem) = isempty(sys.systems) ? sys.op : sys.op + reduce(+,namespace_operation.(sys.systems))
5867
namespace_operation(sys::OptimizationSystem) = namespace_operation(sys.op,sys.name,nothing)
68+
69+
function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem,
70+
parammap=DiffEqBase.NullParameters();
71+
u0=nothing, lb=nothing, ub=nothing,
72+
hes = false,
73+
kwargs...) where iip
74+
dvs = states(sys)
75+
ps = parameters(sys)
76+
77+
f = generate_function(sys)
78+
u0 = varmap_to_vars(u0,dvs)
79+
p = varmap_to_vars(parammap,ps)
80+
lb = varmap_to_vars(lb,dvs)
81+
ub = varmap_to_vars(ub,dvs)
82+
OptimizationProblem(f,p;u0=u0,lb=lb,ub=ub,kwargs...)
83+
end

src/variables.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,4 @@ function varmap_to_vars(varmap,varlist)
192192
end
193193

194194
varmap_to_vars(varmap::DiffEqBase.NullParameters,varlist) = varmap
195+
varmap_to_vars(varmap::Nothing,varlist) = varmap

test/optimizationsystem.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@ sys2 = OptimizationSystem(loss,[x,y],[a,b],name=:sys2)
88
@variables z
99
@parameters β
1010
loss2 = sys1.x - sys2.y + z*β
11-
combinedsys = OptimizationSystem(loss,[z],[β],systems=[sys1,sys2],name=:combinedsys)
11+
combinedsys = OptimizationSystem(loss2,[z],[β],systems=[sys1,sys2],name=:combinedsys)
1212

1313
equations(combinedsys)
1414
states(combinedsys)
1515
parameters(combinedsys)
1616

17+
calculate_gradient(combinedsys)
1718
calculate_hessian(combinedsys)
1819
generate_function(combinedsys)
20+
generate_gradient(combinedsys)
1921
generate_hessian(combinedsys)

0 commit comments

Comments
 (0)