Skip to content

Commit c2bc515

Browse files
some refactoring
1 parent 5cdcb6e commit c2bc515

File tree

6 files changed

+230
-178
lines changed

6 files changed

+230
-178
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ModelingToolkit"
22
uuid = "961ee093-0014-501f-94e3-6117800e7a78"
33
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
4-
version = "1.4.3"
4+
version = "1.4.4"
55

66
[deps]
77
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"

src/ModelingToolkit.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ $(TYPEDEF)
3131
TODO
3232
"""
3333
abstract type AbstractSystem end
34+
abstract type AbstractODESystem <: AbstractSystem end
3435

3536
Base.promote_rule(::Type{<:Number},::Type{<:Expression}) = Expression
3637
Base.zero(::Type{<:Expression}) = Constant(0)
@@ -81,8 +82,6 @@ Get the set of parameters variables for the given system.
8182
"""
8283
function parameters end
8384

84-
@enum FunctionVersion ArrayFunction=1 SArrayFunction=2
85-
8685
include("variables.jl")
8786
include("operations.jl")
8887
include("differentials.jl")
@@ -93,7 +92,9 @@ include("utils.jl")
9392
include("direct.jl")
9493
include("domains.jl")
9594
include("systems/diffeqs/odesystem.jl")
95+
include("systems/diffeqs/abstractodesystem.jl")
9696
include("systems/diffeqs/first_order_transform.jl")
97+
include("systems/diffeqs/modelingtoolkitize.jl")
9798
include("systems/nonlinear/nonlinear_system.jl")
9899
include("systems/pde/pdesystem.jl")
99100
include("latexify_recipes.jl")
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
function calculate_tgrad(sys::AbstractODESystem)
2+
isempty(sys.tgrad[]) || return sys.tgrad[] # use cached tgrad, if possible
3+
rhs = [detime_dvs(eq.rhs) for eq sys.eqs]
4+
iv = sys.iv()
5+
notime_tgrad = [expand_derivatives(ModelingToolkit.Differential(iv)(r)) for r in rhs]
6+
tgrad = retime_dvs.(notime_tgrad,(sys.dvs,),iv)
7+
sys.tgrad[] = tgrad
8+
return tgrad
9+
end
10+
11+
function calculate_jacobian(sys::AbstractODESystem)
12+
isempty(sys.jac[]) || return sys.jac[] # use cached Jacobian, if possible
13+
rhs = [eq.rhs for eq sys.eqs]
14+
15+
iv = sys.iv()
16+
dvs = [dv(iv) for dv sys.dvs]
17+
18+
jac = expand_derivatives.(calculate_jacobian(rhs, dvs))
19+
sys.jac[] = jac # cache Jacobian
20+
return jac
21+
end
22+
23+
struct ODEToExpr
24+
sys::ODESystem
25+
end
26+
function (f::ODEToExpr)(O::Operation)
27+
if isa(O.op, Variable)
28+
isequal(O.op, f.sys.iv) && return O.op.name # independent variable
29+
O.op f.sys.dvs && return O.op.name # dependent variables
30+
isempty(O.args) && return O.op.name # 0-ary parameters
31+
return build_expr(:call, Any[O.op.name; f.(O.args)])
32+
end
33+
return build_expr(:call, Any[Symbol(O.op); f.(O.args)])
34+
end
35+
(f::ODEToExpr)(x) = convert(Expr, x)
36+
37+
function generate_tgrad(sys::AbstractODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true}; kwargs...)
38+
tgrad = calculate_tgrad(sys)
39+
return build_function(tgrad, dvs, ps, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
40+
end
41+
42+
function generate_jacobian(sys::AbstractODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true}; kwargs...)
43+
jac = calculate_jacobian(sys)
44+
return build_function(jac, dvs, ps, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
45+
end
46+
47+
function generate_function(sys::AbstractODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true}; kwargs...)
48+
rhss = [deq.rhs for deq sys.eqs]
49+
dvs′ = [clean(dv) for dv dvs]
50+
ps′ = [clean(p) for p ps]
51+
return build_function(rhss, dvs′, ps′, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
52+
end
53+
54+
function calculate_factorized_W(sys::AbstractODESystem, simplify=true)
55+
isempty(sys.Wfact[]) || return (sys.Wfact[],sys.Wfact_t[])
56+
57+
jac = calculate_jacobian(sys)
58+
gam = Variable(:gam; known = true)()
59+
60+
W = - LinearAlgebra.I + gam*jac
61+
Wfact = lu(W, Val(false), check=false).factors
62+
63+
if simplify
64+
Wfact = simplify_constants.(Wfact)
65+
end
66+
67+
W_t = - LinearAlgebra.I/gam + jac
68+
Wfact_t = lu(W_t, Val(false), check=false).factors
69+
if simplify
70+
Wfact_t = simplify_constants.(Wfact_t)
71+
end
72+
sys.Wfact[] = Wfact
73+
sys.Wfact_t[] = Wfact_t
74+
75+
(Wfact,Wfact_t)
76+
end
77+
78+
function generate_factorized_W(sys::AbstractODESystem, vs = sys.dvs, ps = sys.ps, simplify=true, expression = Val{true}; kwargs...)
79+
(Wfact,Wfact_t) = calculate_factorized_W(sys,simplify)
80+
siz = size(Wfact)
81+
constructor = :(x -> begin
82+
A = SMatrix{$siz...}(x)
83+
StaticArrays.LU(LowerTriangular( SMatrix{$siz...}(UnitLowerTriangular(A)) ), UpperTriangular(A), SVector(ntuple(n->n, max($siz...))))
84+
end)
85+
86+
Wfact_func = build_function(Wfact , vs, ps, (:gam,:t), ODEToExpr(sys), expression;constructor=constructor,kwargs...)
87+
Wfact_t_func = build_function(Wfact_t, vs, ps, (:gam,:t), ODEToExpr(sys), expression;constructor=constructor,kwargs...)
88+
89+
return (Wfact_func, Wfact_t_func)
90+
end
91+
92+
"""
93+
$(SIGNATURES)
94+
95+
Create an `ODEFunction` from the [`ODESystem`](@ref). The arguments `dvs` and `ps`
96+
are used to set the order of the dependent variable and parameter vectors,
97+
respectively.
98+
"""
99+
function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = sys.dvs, ps = sys.ps;
100+
version = nothing, tgrad=false,
101+
jac = false, Wfact = false) where {iip}
102+
f_oop,f_iip = generate_function(sys, dvs, ps, Val{false})
103+
104+
f(u,p,t) = f_oop(u,p,t)
105+
f(du,u,p,t) = f_iip(du,u,p,t)
106+
107+
if tgrad
108+
tgrad_oop,tgrad_iip = generate_tgrad(sys, dvs, ps, Val{false})
109+
_tgrad(u,p,t) = tgrad_oop(u,p,t)
110+
_tgrad(J,u,p,t) = tgrad_iip(J,u,p,t)
111+
else
112+
_tgrad = nothing
113+
end
114+
115+
if jac
116+
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps, Val{false})
117+
_jac(u,p,t) = jac_oop(u,p,t)
118+
_jac(J,u,p,t) = jac_iip(J,u,p,t)
119+
else
120+
_jac = nothing
121+
end
122+
123+
if Wfact
124+
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps, true, Val{false})
125+
Wfact_oop, Wfact_iip = tmp_Wfact
126+
Wfact_oop_t, Wfact_iip_t = tmp_Wfact_t
127+
_Wfact(u,p,dtgamma,t) = Wfact_oop(u,p,dtgamma,t)
128+
_Wfact(W,u,p,dtgamma,t) = Wfact_iip(W,u,p,dtgamma,t)
129+
_Wfact_t(u,p,dtgamma,t) = Wfact_oop_t(u,p,dtgamma,t)
130+
_Wfact_t(W,u,p,dtgamma,t) = Wfact_iip_t(W,u,p,dtgamma,t)
131+
else
132+
_Wfact,_Wfact_t = nothing,nothing
133+
end
134+
135+
ODEFunction{iip}(f,jac=_jac,
136+
tgrad = _tgrad,
137+
Wfact = _Wfact,
138+
Wfact_t = _Wfact_t,
139+
syms = Symbol.(sys.dvs))
140+
end
141+
142+
function DiffEqBase.ODEFunction(sys::AbstractODESystem, args...; kwargs...)
143+
ODEFunction{true}(sys, args...; kwargs...)
144+
end
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""
2+
$(SIGNATURES)
3+
4+
Generate `ODESystem`, dependent variables, and parameters from an `ODEProblem`.
5+
"""
6+
function modelingtoolkitize(prob::DiffEqBase.ODEProblem)
7+
prob.f isa DiffEqBase.AbstractParameterizedFunction &&
8+
return (prob.f.sys, prob.f.sys.dvs, prob.f.sys.ps)
9+
@parameters t
10+
vars = [Variable(:x, i)(t) for i in eachindex(prob.u0)]
11+
params = prob.p isa DiffEqBase.NullParameters ? [] :
12+
[Variable(,i; known = true)() for i in eachindex(prob.p)]
13+
@derivatives D'~t
14+
15+
rhs = [D(var) for var in vars]
16+
17+
if DiffEqBase.isinplace(prob)
18+
lhs = similar(vars, Any)
19+
prob.f(lhs, vars, params, t)
20+
else
21+
lhs = prob.f(vars, params, t)
22+
end
23+
24+
eqs = vcat([rhs[i] ~ lhs[i] for i in eachindex(prob.u0)]...)
25+
de = ODESystem(eqs)
26+
27+
de, vars, params
28+
end

src/systems/diffeqs/odesystem.jl

Lines changed: 1 addition & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ eqs = [D(x) ~ σ*(y-x),
5050
de = ODESystem(eqs)
5151
```
5252
"""
53-
struct ODESystem <: AbstractSystem
53+
struct ODESystem <: AbstractODESystem
5454
"""The ODEs defining the system."""
5555
eqs::Vector{ODEExpr}
5656
"""Independent variable."""
@@ -133,177 +133,3 @@ Base.:(==)(sys1::ODESystem, sys2::ODESystem) =
133133
independent_variables(sys::ODESystem) = Set{Variable}([sys.iv])
134134
dependent_variables(sys::ODESystem) = Set{Variable}(sys.dvs)
135135
parameters(sys::ODESystem) = Set{Variable}(sys.ps)
136-
137-
function calculate_tgrad(sys::ODESystem)
138-
isempty(sys.tgrad[]) || return sys.tgrad[] # use cached tgrad, if possible
139-
rhs = [detime_dvs(eq.rhs) for eq sys.eqs]
140-
iv = sys.iv()
141-
notime_tgrad = [expand_derivatives(ModelingToolkit.Differential(iv)(r)) for r in rhs]
142-
tgrad = retime_dvs.(notime_tgrad,(sys.dvs,),iv)
143-
sys.tgrad[] = tgrad
144-
return tgrad
145-
end
146-
147-
function calculate_jacobian(sys::ODESystem)
148-
isempty(sys.jac[]) || return sys.jac[] # use cached Jacobian, if possible
149-
rhs = [eq.rhs for eq sys.eqs]
150-
151-
iv = sys.iv()
152-
dvs = [dv(iv) for dv sys.dvs]
153-
154-
jac = expand_derivatives.(calculate_jacobian(rhs, dvs))
155-
sys.jac[] = jac # cache Jacobian
156-
return jac
157-
end
158-
159-
struct ODEToExpr
160-
sys::ODESystem
161-
end
162-
function (f::ODEToExpr)(O::Operation)
163-
if isa(O.op, Variable)
164-
isequal(O.op, f.sys.iv) && return O.op.name # independent variable
165-
O.op f.sys.dvs && return O.op.name # dependent variables
166-
isempty(O.args) && return O.op.name # 0-ary parameters
167-
return build_expr(:call, Any[O.op.name; f.(O.args)])
168-
end
169-
return build_expr(:call, Any[Symbol(O.op); f.(O.args)])
170-
end
171-
(f::ODEToExpr)(x) = convert(Expr, x)
172-
173-
function generate_tgrad(sys::ODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true}; kwargs...)
174-
tgrad = calculate_tgrad(sys)
175-
return build_function(tgrad, dvs, ps, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
176-
end
177-
178-
function generate_jacobian(sys::ODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true}; kwargs...)
179-
jac = calculate_jacobian(sys)
180-
return build_function(jac, dvs, ps, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
181-
end
182-
183-
function generate_function(sys::ODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true}; kwargs...)
184-
rhss = [deq.rhs for deq sys.eqs]
185-
dvs′ = [clean(dv) for dv dvs]
186-
ps′ = [clean(p) for p ps]
187-
return build_function(rhss, dvs′, ps′, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
188-
end
189-
190-
function calculate_factorized_W(sys::ODESystem, simplify=true)
191-
isempty(sys.Wfact[]) || return (sys.Wfact[],sys.Wfact_t[])
192-
193-
jac = calculate_jacobian(sys)
194-
gam = Variable(:gam; known = true)()
195-
196-
W = - LinearAlgebra.I + gam*jac
197-
Wfact = lu(W, Val(false), check=false).factors
198-
199-
if simplify
200-
Wfact = simplify_constants.(Wfact)
201-
end
202-
203-
W_t = - LinearAlgebra.I/gam + jac
204-
Wfact_t = lu(W_t, Val(false), check=false).factors
205-
if simplify
206-
Wfact_t = simplify_constants.(Wfact_t)
207-
end
208-
sys.Wfact[] = Wfact
209-
sys.Wfact_t[] = Wfact_t
210-
211-
(Wfact,Wfact_t)
212-
end
213-
214-
function generate_factorized_W(sys::ODESystem, vs = sys.dvs, ps = sys.ps, simplify=true, expression = Val{true}; kwargs...)
215-
(Wfact,Wfact_t) = calculate_factorized_W(sys,simplify)
216-
siz = size(Wfact)
217-
constructor = :(x -> begin
218-
A = SMatrix{$siz...}(x)
219-
StaticArrays.LU(LowerTriangular( SMatrix{$siz...}(UnitLowerTriangular(A)) ), UpperTriangular(A), SVector(ntuple(n->n, max($siz...))))
220-
end)
221-
222-
Wfact_func = build_function(Wfact , vs, ps, (:gam,:t), ODEToExpr(sys), expression;constructor=constructor,kwargs...)
223-
Wfact_t_func = build_function(Wfact_t, vs, ps, (:gam,:t), ODEToExpr(sys), expression;constructor=constructor,kwargs...)
224-
225-
return (Wfact_func, Wfact_t_func)
226-
end
227-
228-
"""
229-
$(SIGNATURES)
230-
231-
Create an `ODEFunction` from the [`ODESystem`](@ref). The arguments `dvs` and `ps`
232-
are used to set the order of the dependent variable and parameter vectors,
233-
respectively.
234-
"""
235-
function DiffEqBase.ODEFunction{iip}(sys::ODESystem, dvs = sys.dvs, ps = sys.ps;
236-
version = nothing, tgrad=false,
237-
jac = false, Wfact = false) where {iip}
238-
f_oop,f_iip = generate_function(sys, dvs, ps, Val{false})
239-
240-
f(u,p,t) = f_oop(u,p,t)
241-
f(du,u,p,t) = f_iip(du,u,p,t)
242-
243-
if tgrad
244-
tgrad_oop,tgrad_iip = generate_tgrad(sys, dvs, ps, Val{false})
245-
_tgrad(u,p,t) = tgrad_oop(u,p,t)
246-
_tgrad(J,u,p,t) = tgrad_iip(J,u,p,t)
247-
else
248-
_tgrad = nothing
249-
end
250-
251-
if jac
252-
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps, Val{false})
253-
_jac(u,p,t) = jac_oop(u,p,t)
254-
_jac(J,u,p,t) = jac_iip(J,u,p,t)
255-
else
256-
_jac = nothing
257-
end
258-
259-
if Wfact
260-
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps, true, Val{false})
261-
Wfact_oop, Wfact_iip = tmp_Wfact
262-
Wfact_oop_t, Wfact_iip_t = tmp_Wfact_t
263-
_Wfact(u,p,dtgamma,t) = Wfact_oop(u,p,dtgamma,t)
264-
_Wfact(W,u,p,dtgamma,t) = Wfact_iip(W,u,p,dtgamma,t)
265-
_Wfact_t(u,p,dtgamma,t) = Wfact_oop_t(u,p,dtgamma,t)
266-
_Wfact_t(W,u,p,dtgamma,t) = Wfact_iip_t(W,u,p,dtgamma,t)
267-
else
268-
_Wfact,_Wfact_t = nothing,nothing
269-
end
270-
271-
ODEFunction{iip}(f,jac=_jac,
272-
tgrad = _tgrad,
273-
Wfact = _Wfact,
274-
Wfact_t = _Wfact_t,
275-
syms = Symbol.(sys.dvs))
276-
end
277-
278-
function DiffEqBase.ODEFunction(sys::ODESystem, args...; kwargs...)
279-
ODEFunction{true}(sys, args...; kwargs...)
280-
end
281-
282-
"""
283-
$(SIGNATURES)
284-
285-
Generate `ODESystem`, dependent variables, and parameters from an `ODEProblem`.
286-
"""
287-
function modelingtoolkitize(prob::DiffEqBase.ODEProblem)
288-
prob.f isa DiffEqBase.AbstractParameterizedFunction &&
289-
return (prob.f.sys, prob.f.sys.dvs, prob.f.sys.ps)
290-
@parameters t
291-
vars = [Variable(:x, i)(t) for i in eachindex(prob.u0)]
292-
params = prob.p isa DiffEqBase.NullParameters ? [] :
293-
[Variable(,i; known = true)() for i in eachindex(prob.p)]
294-
@derivatives D'~t
295-
296-
rhs = [D(var) for var in vars]
297-
298-
if DiffEqBase.isinplace(prob)
299-
lhs = similar(vars, Any)
300-
prob.f(lhs, vars, params, t)
301-
else
302-
lhs = prob.f(vars, params, t)
303-
end
304-
305-
eqs = vcat([rhs[i] ~ lhs[i] for i in eachindex(prob.u0)]...)
306-
de = ODESystem(eqs)
307-
308-
de, vars, params
309-
end

0 commit comments

Comments
 (0)