Skip to content

Commit c7d6cd5

Browse files
Merge pull request #263 from SciML/components
Remove ODEExpr
2 parents a914076 + da7e231 commit c7d6cd5

File tree

11 files changed

+136
-113
lines changed

11 files changed

+136
-113
lines changed

src/build_function.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,16 +120,16 @@ function numbered_expr(O::Operation,vars,parameters;
120120
varname=varname,paramname=paramname) for x in O.args]...)
121121
end
122122

123-
function numbered_expr(de::ModelingToolkit.ODEExpr,vars::Vector{Variable},parameters;
123+
function numbered_expr(de::ModelingToolkit.Equation,vars::Vector{Variable},parameters;
124124
derivname=:du,varname=:u,paramname=:p)
125-
i = findfirst(x->isequal(x.name,de.x.name),vars)
125+
i = findfirst(x->isequal(x.name,var_from_nested_derivative(de.lhs)[1].name),vars)
126126
:($derivname[$i] = $(numbered_expr(de.rhs,vars,parameters;
127127
derivname=derivname,
128128
varname=varname,paramname=paramname)))
129129
end
130-
function numbered_expr(de::ModelingToolkit.ODEExpr,vars::Vector{Operation},parameters;
130+
function numbered_expr(de::ModelingToolkit.Equation,vars::Vector{Operation},parameters;
131131
derivname=:du,varname=:u,paramname=:p)
132-
i = findfirst(x->isequal(x.op.name,de.x.name),vars)
132+
i = findfirst(x->isequal(x.op.name,var_from_nested_derivative(de.lhs)[1].name),vars)
133133
:($derivname[$i] = $(numbered_expr(de.rhs,vars,parameters;
134134
derivname=derivname,
135135
varname=varname,paramname=paramname)))

src/latexify_recipes.jl

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,6 @@
2222
return lhs, rhs
2323
end
2424

25-
@latexrecipe function f(eqs::Vector{ModelingToolkit.ODEExpr}; iv=:t)
26-
# Set default option values.
27-
env --> :align
28-
29-
# Convert both the left and right hand side to expressions of basic types
30-
# that latexify can deal with.
31-
32-
rhs = getfield.(eqs, :rhs)
33-
rhs = convert.(Expr, rhs)
34-
rhs = [postwalk(x -> x isa ModelingToolkit.Constant ? x.value : x, eq) for eq in rhs]
35-
rhs = [postwalk(x -> x isa Expr && length(x.args) == 1 ? x.args[1] : x, eq) for eq in rhs]
36-
rhs = [postwalk(x -> x isa Expr && x.args[1] == :Differential && length(x.args[2].args) == 2 ? :($(Symbol(:d, x.args[2]))/($(Symbol(:d, x.args[2].args[2])))) : x, eq) for eq in rhs]
37-
rhs = [postwalk(x -> x isa Expr && x.args[1] == :Differential ? "\\frac{d\\left($(Latexify.latexraw(x.args[2]))\\right)}{d$iv}" : x, eq) for eq in rhs]
38-
39-
var = getfield.(getfield.(eqs, :x),:name)
40-
ns = getfield.(eqs, :n)
41-
lhs = [ns[i] == 1 ? Latexify.LaTeXString("\\frac{d$(Latexify.latexraw(var[i]))}{d$iv}") : LaTeXString("\\frac{d^{$(ns[i])}$(Latexify.latexraw(var[i]))}{d$iv^{$(ns[i])}}") for i in 1:length(var)]
42-
return lhs, rhs
43-
end
44-
4525
@latexrecipe function f(sys::ModelingToolkit.ODESystem; iv=:t)
4626
return latexify(sys.eqs; iv=iv)
4727
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,20 @@ end
142142
function DiffEqBase.ODEFunction(sys::AbstractODESystem, args...; kwargs...)
143143
ODEFunction{true}(sys, args...; kwargs...)
144144
end
145+
146+
independent_variables(sys::AbstractODESystem) = Set{Variable}([sys.iv])
147+
dependent_variables(sys::AbstractODESystem) = Set{Variable}(sys.dvs)
148+
parameters(sys::AbstractODESystem) = Set{Variable}(sys.ps)
149+
150+
function _eq_unordered(a, b)
151+
length(a) === length(b) || return false
152+
n = length(a)
153+
idxs = Set(1:n)
154+
for x a
155+
idx = findfirst(isequal(x), b)
156+
idx === nothing && return false
157+
idx idxs || return false
158+
delete!(idxs, idx)
159+
end
160+
return true
161+
end

src/systems/diffeqs/first_order_transform.jl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,34 @@ function lower_varname(var::Variable, idv, order)
44
return Variable(name; known = var.known)
55
end
66

7+
function flatten_differential(O::Operation)
8+
@assert is_derivative(O) "invalid differential: $O"
9+
is_derivative(O.args[1]) || return (O.args[1], O.op.x, 1)
10+
(x, t, order) = flatten_differential(O.args[1])
11+
isequal(t, O.op.x) || throw(ArgumentError("non-matching differentials on lhs: $t, $(O.op.x)"))
12+
return (x, t, order + 1)
13+
end
14+
715
function ode_order_lowering(sys::ODESystem)
816
eqs_lowered, _ = ode_order_lowering(sys.eqs, sys.iv)
9-
ODESystem(eqs_lowered, sys.iv, [eq.x for eq in eqs_lowered], sys.ps)
17+
ODESystem(eqs_lowered, sys.iv, [var_from_nested_derivative(eq.lhs)[1] for eq in eqs_lowered], sys.ps)
1018
end
19+
1120
function ode_order_lowering(eqs, iv)
1221
var_order = Dict{Variable,Int}()
1322
vars = Variable[]
14-
new_eqs = similar(eqs, ODEExpr)
23+
new_eqs = Equation[]
1524
new_vars = Variable[]
1625

1726
for (i, eq) enumerate(eqs)
18-
var, maxorder = eq.x, eq.n
27+
var, maxorder = var_from_nested_derivative(eq.lhs)
1928
if maxorder > get(var_order, var, 0)
2029
var_order[var] = maxorder
2130
any(isequal(var), vars) || push!(vars, var)
2231
end
23-
var′ = lower_varname(eq.x, iv, eq.n - 1)
32+
var′ = lower_varname(var, iv, maxorder - 1)
2433
rhs′ = rename(eq.rhs)
25-
new_eqs[i] = ODEExpr(var′, 1, rhs′)
34+
push!(new_eqs,Differential(iv())(var′(iv())) ~ rhs′)
2635
end
2736

2837
for var vars
@@ -33,7 +42,7 @@ function ode_order_lowering(eqs, iv)
3342
push!(new_vars, rvar)
3443

3544
rhs = rvar(iv())
36-
eq = ODEExpr(lvar, 1, rhs)
45+
eq = Differential(iv())(lvar(iv())) ~ rhs
3746
push!(new_eqs, eq)
3847
end
3948
end

src/systems/diffeqs/odesystem.jl

Lines changed: 25 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,3 @@
1-
isintermediate(eq::Equation) = !(isa(eq.lhs, Operation) && isa(eq.lhs.op, Differential))
2-
3-
function flatten_differential(O::Operation)
4-
@assert is_derivative(O) "invalid differential: $O"
5-
is_derivative(O.args[1]) || return (O.args[1], O.op.x, 1)
6-
(x, t, order) = flatten_differential(O.args[1])
7-
isequal(t, O.op.x) || throw(ArgumentError("non-matching differentials on lhs: $t, $(O.op.x)"))
8-
return (x, t, order + 1)
9-
end
10-
11-
struct ODEExpr # dⁿx/dtⁿ = rhs
12-
x::Variable
13-
n::Int
14-
rhs::Expression
15-
end
16-
function Base.convert(::Type{ODEExpr},eq::Equation)
17-
isintermediate(eq) && throw(ArgumentError("intermediate equation received"))
18-
(x, t, n) = flatten_differential(eq.lhs)
19-
(isa(t, Operation) && isa(t.op, Variable) && isempty(t.args)) ||
20-
throw(ArgumentError("invalid independent variable $t"))
21-
(isa(x, Operation) && isa(x.op, Variable) && length(x.args) == 1 && isequal(first(x.args), t)) ||
22-
throw(ArgumentError("invalid dependent variable $x"))
23-
return t.op, ODEExpr(x.op, n, eq.rhs)
24-
end
25-
Base.:(==)(a::ODEExpr, b::ODEExpr) = isequal((a.x, a.n, a.rhs), (b.x, b.n, b.rhs))
26-
271
"""
282
$(TYPEDEF)
293
@@ -50,7 +24,7 @@ de = ODESystem(eqs)
5024
"""
5125
struct ODESystem <: AbstractODESystem
5226
"""The ODEs defining the system."""
53-
eqs::Vector{ODEExpr}
27+
eqs::Vector{Equation}
5428
"""Independent variable."""
5529
iv::Variable
5630
"""Dependent (state) variables."""
@@ -77,57 +51,41 @@ struct ODESystem <: AbstractODESystem
7751
[`generate_factorized_W`](@ref) is called on the system.
7852
"""
7953
Wfact_t::RefValue{Matrix{Expression}}
54+
"""
55+
Name: the name of the system
56+
"""
57+
name::Symbol
8058
end
8159

82-
function ODESystem(eqs)
83-
reformatted = convert.(ODEExpr,eqs)
84-
85-
ivs = unique(r[1] for r reformatted)
86-
length(ivs) == 1 || throw(ArgumentError("one independent variable currently supported"))
87-
iv = first(ivs)
88-
89-
deqs = [r[2] for r reformatted]
90-
91-
dvs = unique(deq.x for deq deqs)
92-
ps = filter(vars(deq.rhs for deq deqs)) do x
93-
x.known & !isequal(x, iv)
94-
end |> collect
95-
96-
ODESystem(deqs, iv, dvs, ps)
97-
end
98-
99-
function ODESystem(deqs::AbstractVector{ODEExpr}, iv, dvs, ps)
60+
function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
61+
name=gensym(:ODESystem))
62+
iv′ = clean(iv)
63+
dvs′ = [clean(dv) for dv dvs]
64+
ps′ = [clean(p) for p ps]
10065
tgrad = RefValue(Vector{Expression}(undef, 0))
10166
jac = RefValue(Matrix{Expression}(undef, 0, 0))
10267
Wfact = RefValue(Matrix{Expression}(undef, 0, 0))
10368
Wfact_t = RefValue(Matrix{Expression}(undef, 0, 0))
104-
ODESystem(deqs, iv, dvs, ps, tgrad, jac, Wfact, Wfact_t)
69+
ODESystem(deqs, iv, dvs, ps, tgrad, jac, Wfact, Wfact_t, name)
10570
end
10671

107-
function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps)
108-
_dvs = [deq.op for deq dvs]
109-
_iv = iv.op
110-
_ps = [p.op for p ps]
111-
ODESystem(getindex.(convert.(ODEExpr,deqs),2), _iv, _dvs, _ps)
112-
end
72+
var_from_nested_derivative(x) = var_from_nested_derivative(x,0)
73+
var_from_nested_derivative(x,i) = x.op isa Differential ? var_from_nested_derivative(x.args[1],i+1) : (x.op,i)
74+
iv_from_nested_derivative(x) = x.op isa Differential ? iv_from_nested_derivative(x.args[1]) : x.args[1].op
11375

114-
function _eq_unordered(a, b)
115-
length(a) === length(b) || return false
116-
n = length(a)
117-
idxs = Set(1:n)
118-
for x a
119-
idx = findfirst(isequal(x), b)
120-
idx === nothing && return false
121-
idx idxs || return false
122-
delete!(idxs, idx)
123-
end
124-
return true
76+
function ODESystem(eqs; kwargs...)
77+
ivs = unique(iv_from_nested_derivative(eq.lhs) for eq eqs)
78+
length(ivs) == 1 || throw(ArgumentError("one independent variable currently supported"))
79+
iv = first(ivs)
80+
81+
dvs = unique(var_from_nested_derivative(eq.lhs)[1] for eq eqs)
82+
ps = filter(vars(eq.rhs for eq eqs)) do x
83+
x.known & !isequal(x, iv)
84+
end |> collect
85+
ODESystem(eqs, iv, dvs, ps; kwargs...)
12586
end
87+
12688
Base.:(==)(sys1::ODESystem, sys2::ODESystem) =
12789
_eq_unordered(sys1.eqs, sys2.eqs) && isequal(sys1.iv, sys2.iv) &&
12890
_eq_unordered(sys1.dvs, sys2.dvs) && _eq_unordered(sys1.ps, sys2.ps)
12991
# NOTE: equality does not check cached Jacobian
130-
131-
independent_variables(sys::ODESystem) = Set{Variable}([sys.iv])
132-
dependent_variables(sys::ODESystem) = Set{Variable}(sys.dvs)
133-
parameters(sys::ODESystem) = Set{Variable}(sys.ps)

src/systems/diffeqs/sdesystem.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
struct SDESystem <: AbstractODESystem
22
"""The expressions defining the drift term."""
3-
eqs::Vector{ODEExpr}
3+
eqs::Vector{Equation}
44
"""The expressions defining the diffusion term."""
55
noiseeqs
66
"""Independent variable."""
@@ -29,21 +29,21 @@ struct SDESystem <: AbstractODESystem
2929
[`generate_factorized_W`](@ref) is called on the system.
3030
"""
3131
Wfact_t::RefValue{Matrix{Expression}}
32+
"""
33+
Name: the name of the system
34+
"""
35+
name::Symbol
3236
end
3337

34-
function SDESystem(deqs::AbstractVector{ODEExpr}, neqs, iv, dvs, ps)
38+
function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps; name = gensym(:SDESystem))
39+
dvs′ = [clean(dv) for dv dvs]
40+
ps′ = [clean(p) for p ps]
41+
iv′ = clean(iv)
3542
tgrad = RefValue(Vector{Expression}(undef, 0))
3643
jac = RefValue(Matrix{Expression}(undef, 0, 0))
3744
Wfact = RefValue(Matrix{Expression}(undef, 0, 0))
3845
Wfact_t = RefValue(Matrix{Expression}(undef, 0, 0))
39-
SDESystem(deqs, neqs, iv, dvs, ps, tgrad, jac, Wfact, Wfact_t)
40-
end
41-
42-
function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps)
43-
_dvs = [deq.op for deq dvs]
44-
_iv = iv.op
45-
_ps = [p.op for p ps]
46-
SDESystem(getindex.(convert.(ODEExpr,deqs),2), neqs, _iv, _dvs, _ps)
46+
SDESystem(deqs, neqs, iv′, dvs′, ps′, tgrad, jac, Wfact, Wfact_t, name)
4747
end
4848

4949
function generate_diffusion_function(sys::SDESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true}; kwargs...)

src/systems/diffeqs/validation.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
isintermediate(eq::Equation) = !(isa(eq.lhs, Operation) && isa(eq.lhs.op, Differential))
2+
3+
struct ODEExpr # dⁿx/dtⁿ = rhs
4+
x::Variable
5+
n::Int
6+
rhs::Expression
7+
end
8+
function Base.convert(::Type{ODEExpr},eq::Equation)
9+
isintermediate(eq) && throw(ArgumentError("intermediate equation received"))
10+
(x, t, n) = flatten_differential(eq.lhs)
11+
(isa(t, Operation) && isa(t.op, Variable) && isempty(t.args)) ||
12+
throw(ArgumentError("invalid independent variable $t"))
13+
(isa(x, Operation) && isa(x.op, Variable) && length(x.args) == 1 && isequal(first(x.args), t)) ||
14+
throw(ArgumentError("invalid dependent variable $x"))
15+
return t.op, ODEExpr(x.op, n, eq.rhs)
16+
end
17+
Base.:(==)(a::ODEExpr, b::ODEExpr) = isequal((a.x, a.n, a.rhs), (b.x, b.n, b.rhs))

src/systems/nonlinear/nonlinear_system.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,19 @@ struct NonlinearSystem <: AbstractSystem
3434
vs::Vector{Expression}
3535
"""Parameters."""
3636
ps::Vector{Variable}
37-
function NonlinearSystem(eqs, vs)
37+
"""
38+
Name: the name of the system
39+
"""
40+
name::Symbol
41+
function NonlinearSystem(eqs, vs; name = gensym(:NonlinearSystem))
3842
rhss = [eq.rhs for eq eqs]
3943
ps = reduce(, map(_find_params(vs), rhss); init = vnil())
40-
new(eqs, vs, collect(ps))
44+
new(eqs, vs, collect(ps), name)
4145
end
4246

43-
function NonlinearSystem(eqs, vs, ps)
47+
function NonlinearSystem(eqs, vs, ps; name = gensym(:NonlinearSystem))
4448
rhss = [eq.rhs for eq eqs]
45-
new(eqs, vs, [p.op for p in ps])
49+
new(eqs, vs, [p.op for p in ps], name)
4650
end
4751
end
4852

test/build_targets.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ModelingToolkit, Test
44
@derivatives D'~t
55
eqs = [D(x) ~ a*x - x*y,
66
D(y) ~ -3y + x*y]
7-
@test ModelingToolkit.build_function(eqs,[x,y],[a],t,target = ModelingToolkit.StanTarget()) ==
7+
@test ModelingToolkit.build_function(eqs,ModelingToolkit.clean.([x,y]),ModelingToolkit.clean.([a]),t,target = ModelingToolkit.StanTarget()) ==
88
"""
99
real[] diffeqf(real t,real[] internal_var___u,real[] internal_var___p,real[] x_r,int[] x_i) {
1010
real internal_var___du[2];

test/components.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
using ModelingToolkit, StaticArrays, LinearAlgebra
2+
using DiffEqBase
3+
using Test
4+
5+
# Define some variables
6+
@parameters t σ ρ β
7+
@variables x(t) y(t) z(t)
8+
@derivatives D'~t
9+
10+
eqs = [D(x) ~ σ*(y-x),
11+
D(y) ~ x*-z)-y,
12+
0 ~ x + y + z]
13+
14+
de1 = ODESystem(eqs,t,[x,y,z],[σ,ρ,β],name=:lorenz1)
15+
de2 = ODESystem(eqs,t,[x,y,z],[σ,ρ,β],name=:lorenz2)
16+
17+
@parameters α
18+
@variables a(t)
19+
for
20+
21+
convert(DAESystem,sys)
22+
convert(ODESystem,sys)
23+
24+
connected = ODESystems(connnectedeqs,[a],[α],systems=[de1,de2],name=:connectedlorenz)
25+
26+
deleteat!(de2.eqs,3)
27+
de2
28+
push!(de2.eqs,convert(ModelingToolkit.ODEExpr,D(z) ~ z^2)[2])
29+
30+
check_consistency(de2)
31+
32+
b = a*z
33+
[states(de1,:x) ~ α*states(de2,:z),
34+
D(a) ~ b*states(de1,:x)]
35+
36+
states(states) == [lorenz1′x,lorenz1′y,lorenz1′z,lorenz2′x,lorenz2′y,lorenz2′z,a]
37+
parameters(connected) == [lorenz1′σ,lorenz1′ρ,lorenz1′β,lorenz2′σ,lorenz2′ρ,lorenz2′β,α]

0 commit comments

Comments
 (0)