Skip to content

Commit 487f67f

Browse files
Merge pull request #290 from SciML/problem
add a direct problem building interface for AbstractSystems
2 parents cf740b6 + 9883cf8 commit 487f67f

File tree

9 files changed

+137
-11
lines changed

9 files changed

+137
-11
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
44
version = "2.0.0"
55

66
[deps]
7+
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
78
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
89
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
910
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
@@ -21,7 +22,7 @@ TreeViews = "a2a6695c-b41b-5b7d-aed9-dbfdeacea5d7"
2122
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
2223

2324
[compat]
24-
DiffEqBase = "6"
25+
DiffEqBase = "6.28"
2526
DiffRules = "0.1, 1.0"
2627
DocStringExtensions = "0.7, 0.8"
2728
GeneralizedGenerated = "0.1.4, 0.2"

src/ModelingToolkit.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module ModelingToolkit
22

33
using DiffEqBase, Distributed
44
using StaticArrays, LinearAlgebra, SparseArrays
5-
using Latexify, Unitful
5+
using Latexify, Unitful, ArrayInterface
66
using MacroTools
77

88
using Base.Threads
@@ -122,6 +122,7 @@ export simplify_constants
122122
export Operation, Expression, Variable
123123
export calculate_jacobian, generate_jacobian, generate_function
124124
export calculate_tgrad, generate_tgrad
125+
export calculate_gradient, generate_gradient
125126
export calculate_hessian, generate_hessian
126127
export calculate_massmatrix, generate_diffusion_function
127128
export independent_variable, states, parameters, equations

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ function calculate_massmatrix(sys::AbstractODESystem, simplify=true)
100100
j = findfirst(x->isequal(x.name,var_from_nested_derivative(eq.lhs)[1].name),dvs)
101101
M[i,j] = 1
102102
else
103-
error("Only semi-explicit mass matrices are currently supported")
103+
error("Only semi-explicit constant mass matrices are currently supported")
104104
end
105105
end
106106
M = simplify ? simplify_constants.(M) : M
@@ -121,30 +121,30 @@ respectively.
121121
function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
122122
ps = parameters(sys);
123123
version = nothing, tgrad=false,
124-
jac = false, Wfact = false) where {iip}
125-
f_oop,f_iip = generate_function(sys, dvs, ps, Val{false})
124+
jac = false, Wfact = false, kwargs...) where {iip}
125+
f_oop,f_iip = generate_function(sys, dvs, ps, Val{false}; kwargs...)
126126

127127
f(u,p,t) = f_oop(u,p,t)
128128
f(du,u,p,t) = f_iip(du,u,p,t)
129129

130130
if tgrad
131-
tgrad_oop,tgrad_iip = generate_tgrad(sys, dvs, ps, Val{false})
131+
tgrad_oop,tgrad_iip = generate_tgrad(sys, dvs, ps, Val{false}; kwargs...)
132132
_tgrad(u,p,t) = tgrad_oop(u,p,t)
133133
_tgrad(J,u,p,t) = tgrad_iip(J,u,p,t)
134134
else
135135
_tgrad = nothing
136136
end
137137

138138
if jac
139-
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps, Val{false})
139+
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps, Val{false}; kwargs...)
140140
_jac(u,p,t) = jac_oop(u,p,t)
141141
_jac(J,u,p,t) = jac_iip(J,u,p,t)
142142
else
143143
_jac = nothing
144144
end
145145

146146
if Wfact
147-
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps, true, Val{false})
147+
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps, true, Val{false}; kwargs...)
148148
Wfact_oop, Wfact_iip = tmp_Wfact
149149
Wfact_oop_t, Wfact_iip_t = tmp_Wfact_t
150150
_Wfact(u,p,dtgamma,t) = Wfact_oop(u,p,dtgamma,t)
@@ -162,5 +162,23 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
162162
Wfact = _Wfact,
163163
Wfact_t = _Wfact_t,
164164
mass_matrix = M,
165-
syms = Symbol.(sys.states))
165+
syms = Symbol.(states(sys)))
166+
end
167+
168+
function DiffEqBase.ODEProblem(sys::AbstractODESystem, args...; kwargs...)
169+
ODEProblem{true}(sys, args...; kwargs...)
170+
end
171+
172+
function DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem,u0map,tspan,
173+
parammap=DiffEqBase.NullParameters();
174+
version = nothing, tgrad=false,
175+
jac = false, Wfact = false,
176+
checkbounds = false,
177+
linenumbers = true, multithread=false,
178+
kwargs...) where iip
179+
f = ODEFunction(sys;tgrad=tgrad,jac=jac,Wfact=Wfact,checkbounds=checkbounds,
180+
linenumbers=linenumbers,multithread=multithread)
181+
u0 = varmap_to_vars(u0map,states(sys))
182+
p = varmap_to_vars(parammap,parameters(sys))
183+
ODEProblem(f,u0,tspan,p;kwargs...)
166184
end

src/systems/diffeqs/sdesystem.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,17 @@ end
121121
function rename(sys::SDESystem,name)
122122
ODESystem(sys.eqs, sys.noiseeqs, sys.iv, sys.states, sys.ps, sys.tgrad, sys.jac, sys.Wfact, sys.Wfact_t, name, sys.systems)
123123
end
124+
125+
function DiffEqBase.SDEProblem{iip}(sys::SDESystem,u0map,tspan,p=parammap;
126+
version = nothing, tgrad=false,
127+
jac = false, Wfact = false,
128+
checkbounds = false,
129+
linenumbers = true, multithread=false,
130+
kwargs...) where iip
131+
132+
f = SDEFunction(sys;tgrad=tgrad,jac=jac,Wfact=Wfact,checkbounds=checkbounds,
133+
linenumbers=linenumbers,multithread=multithread)
134+
u0 = varmap_to_vars(u0map,states(sys))
135+
p = varmap_to_vars(parammap,parameters(sys))
136+
SDEProblem(f,f.g,u0,tspan,p;kwargs...)
137+
end

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,19 @@ function generate_function(sys::NonlinearSystem, vs = states(sys), ps = paramete
7171
ps′ = convert.(Variable,ps)
7272
return build_function(rhss, vs′, ps′, (), NLSysToExpr(sys), expression; kwargs...)
7373
end
74+
75+
function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem,u0map,tspan,
76+
parammap=DiffEqBase.NullParameters();
77+
jac = false,
78+
checkbounds = false,
79+
linenumbers = true, multithread=false,
80+
kwargs...) where iip
81+
dvs = states(sys)
82+
ps = parameters(sys)
83+
84+
f = generate_function(sys;checkbounds=checkbounds,linenumbers=linenumbers,
85+
multithread=multithread)
86+
u0 = varmap_to_vars(u0map,dvs)
87+
p = varmap_to_vars(parammap,ps)
88+
NonlinearProblem(f,u0,tspan,p;kwargs...)
89+
end

src/systems/optimization/optimizationsystem.jl

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,22 @@ function OptimizationSystem(op, states, ps;
3939
OptimizationSystem(op, convert.(Variable,states), convert.(Variable,ps), name, systems)
4040
end
4141

42+
function calculate_gradient(sys::OptimizationSystem)
43+
expand_derivatives.(gradient(equations(sys), [dv() for dv in states(sys)]))
44+
end
45+
46+
function generate_gradient(sys::OptimizationSystem, vs = states(sys), ps = parameters(sys), expression = Val{true}; kwargs...)
47+
grad = calculate_gradient(sys)
48+
return build_function(grad, convert.(Variable,vs), convert.(Variable,ps), (), x->convert(Expr, x); kwargs...)
49+
end
50+
4251
function calculate_hessian(sys::OptimizationSystem)
4352
expand_derivatives.(hessian(equations(sys), [dv() for dv in states(sys)]))
4453
end
4554

4655
function generate_hessian(sys::OptimizationSystem, vs = states(sys), ps = parameters(sys), expression = Val{true}; kwargs...)
4756
hes = calculate_hessian(sys)
48-
return build_function(hes, convert.(Variable,vs), convert.(Variable,ps), (), x->convert(Expr, x))
57+
return build_function(hes, convert.(Variable,vs), convert.(Variable,ps), (), x->convert(Expr, x); kwargs...)
4958
end
5059

5160
function generate_function(sys::OptimizationSystem, vs = states(sys), ps = parameters(sys), expression = Val{true}; kwargs...)
@@ -56,3 +65,22 @@ end
5665

5766
equations(sys::OptimizationSystem) = isempty(sys.systems) ? sys.op : sys.op + reduce(+,namespace_operation.(sys.systems))
5867
namespace_operation(sys::OptimizationSystem) = namespace_operation(sys.op,sys.name,nothing)
68+
69+
function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem,
70+
parammap=DiffEqBase.NullParameters();
71+
u0=nothing, lb=nothing, ub=nothing,
72+
hes = false,
73+
checkbounds = false,
74+
linenumbers = true, multithread=false,
75+
kwargs...) where iip
76+
dvs = states(sys)
77+
ps = parameters(sys)
78+
79+
f = generate_function(sys,checkbounds=checkbounds,linenumbers=linenumbers,
80+
multithread=multithread)
81+
u0 = varmap_to_vars(u0,dvs)
82+
p = varmap_to_vars(parammap,ps)
83+
lb = varmap_to_vars(lb,dvs)
84+
ub = varmap_to_vars(ub,dvs)
85+
OptimizationProblem(f,p;u0=u0,lb=lb,ub=ub,kwargs...)
86+
end

src/variables.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,25 @@ function TreeViews.treelabel(io::IO,x::Variable,
171171
mime::MIME"text/plain" = MIME"text/plain"())
172172
show(io,mime,Text(x.name))
173173
end
174+
175+
"""
176+
varmap_to_vars(varmap,varlist)
177+
178+
Takes a list of pairs of variables=>values and an ordered list of variables and
179+
creates the array of values in the correct order
180+
"""
181+
function varmap_to_vars(varmap,varlist)
182+
out = similar(varmap,typeof(last(first(varmap))))
183+
for i in 1:length(varmap)
184+
ivar = convert(Variable,varmap[i][1])
185+
j = findfirst(x->ivar.name == convert(Variable,x).name,varlist)
186+
out[j] = varmap[i][2]
187+
end
188+
189+
# Make output match varmap in type and shape
190+
# Does things like MArray->SArray
191+
ArrayInterface.restructure(varmap,out)
192+
end
193+
194+
varmap_to_vars(varmap::DiffEqBase.NullParameters,varlist) = varmap
195+
varmap_to_vars(varmap::Nothing,varlist) = varmap

test/components.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,30 @@ connnectedeqs = [D(a) ~ a*lorenz1.x]
2020

2121
connected1 = ODESystem(connnectedeqs,t,[a],[α],systems=[lorenz1,lorenz2],name=:connected1)
2222

23+
u0map = [lorenz1.x => 1.0,
24+
lorenz2.x => 2.0,
25+
lorenz1.y => 3.0,
26+
lorenz2.y => 4.0,
27+
lorenz1.z => 5.0,
28+
lorenz2.z => 6.0,
29+
a => 7.0]
30+
31+
parammap = SA[lorenz1.σ => 1.0,
32+
lorenz1.ρ => 2.0,
33+
lorenz1.β => 3.0,
34+
lorenz2.σ => 4.0,
35+
lorenz2.ρ => 5.0,
36+
lorenz2.β => 6.0,
37+
α => 7.0]
38+
39+
prob = ODEProblem(connected1,u0map,(0.0,100.0),parammap,jac=true)
40+
41+
@test prob.u0 == [7.0,1.0,3.0,5.0,2.0,4.0,6.0]
42+
@test prob.p == [7.0,1.0,2.0,3.0,4.0,5.0,6.0]
43+
@test prob.u0 isa Vector
44+
@test prob.p isa SVector
45+
@test prob.f.syms == [:a,:lorenz1′x,:lorenz1′y,:lorenz1′z,:lorenz2′x,:lorenz2′y,:lorenz2′z]
46+
2347
eqs_flat = [D(a) ~ a*lorenz1.x,
2448
D(lorenz1.x) ~ lorenz1.σ*(lorenz1.y-lorenz1.x),
2549
D(lorenz1.y) ~ lorenz1.x*(lorenz1.ρ-lorenz1.z)-lorenz1.y,

test/optimizationsystem.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@ sys2 = OptimizationSystem(loss,[x,y],[a,b],name=:sys2)
88
@variables z
99
@parameters β
1010
loss2 = sys1.x - sys2.y + z*β
11-
combinedsys = OptimizationSystem(loss,[z],[β],systems=[sys1,sys2],name=:combinedsys)
11+
combinedsys = OptimizationSystem(loss2,[z],[β],systems=[sys1,sys2],name=:combinedsys)
1212

1313
equations(combinedsys)
1414
states(combinedsys)
1515
parameters(combinedsys)
1616

17+
calculate_gradient(combinedsys)
1718
calculate_hessian(combinedsys)
1819
generate_function(combinedsys)
20+
generate_gradient(combinedsys)
1921
generate_hessian(combinedsys)

0 commit comments

Comments
 (0)