Skip to content

Commit da7e231

Browse files
remove ODEExpr
1 parent 8584ab4 commit da7e231

File tree

10 files changed

+114
-103
lines changed

10 files changed

+114
-103
lines changed

src/build_function.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,16 +120,16 @@ function numbered_expr(O::Operation,vars,parameters;
120120
varname=varname,paramname=paramname) for x in O.args]...)
121121
end
122122

123-
function numbered_expr(de::ModelingToolkit.ODEExpr,vars::Vector{Variable},parameters;
123+
function numbered_expr(de::ModelingToolkit.Equation,vars::Vector{Variable},parameters;
124124
derivname=:du,varname=:u,paramname=:p)
125-
i = findfirst(x->isequal(x.name,de.x.name),vars)
125+
i = findfirst(x->isequal(x.name,var_from_nested_derivative(de.lhs)[1].name),vars)
126126
:($derivname[$i] = $(numbered_expr(de.rhs,vars,parameters;
127127
derivname=derivname,
128128
varname=varname,paramname=paramname)))
129129
end
130-
function numbered_expr(de::ModelingToolkit.ODEExpr,vars::Vector{Operation},parameters;
130+
function numbered_expr(de::ModelingToolkit.Equation,vars::Vector{Operation},parameters;
131131
derivname=:du,varname=:u,paramname=:p)
132-
i = findfirst(x->isequal(x.op.name,de.x.name),vars)
132+
i = findfirst(x->isequal(x.op.name,var_from_nested_derivative(de.lhs)[1].name),vars)
133133
:($derivname[$i] = $(numbered_expr(de.rhs,vars,parameters;
134134
derivname=derivname,
135135
varname=varname,paramname=paramname)))

src/latexify_recipes.jl

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,6 @@
2222
return lhs, rhs
2323
end
2424

25-
@latexrecipe function f(eqs::Vector{ModelingToolkit.ODEExpr}; iv=:t)
26-
# Set default option values.
27-
env --> :align
28-
29-
# Convert both the left and right hand side to expressions of basic types
30-
# that latexify can deal with.
31-
32-
rhs = getfield.(eqs, :rhs)
33-
rhs = convert.(Expr, rhs)
34-
rhs = [postwalk(x -> x isa ModelingToolkit.Constant ? x.value : x, eq) for eq in rhs]
35-
rhs = [postwalk(x -> x isa Expr && length(x.args) == 1 ? x.args[1] : x, eq) for eq in rhs]
36-
rhs = [postwalk(x -> x isa Expr && x.args[1] == :Differential && length(x.args[2].args) == 2 ? :($(Symbol(:d, x.args[2]))/($(Symbol(:d, x.args[2].args[2])))) : x, eq) for eq in rhs]
37-
rhs = [postwalk(x -> x isa Expr && x.args[1] == :Differential ? "\\frac{d\\left($(Latexify.latexraw(x.args[2]))\\right)}{d$iv}" : x, eq) for eq in rhs]
38-
39-
var = getfield.(getfield.(eqs, :x),:name)
40-
ns = getfield.(eqs, :n)
41-
lhs = [ns[i] == 1 ? Latexify.LaTeXString("\\frac{d$(Latexify.latexraw(var[i]))}{d$iv}") : LaTeXString("\\frac{d^{$(ns[i])}$(Latexify.latexraw(var[i]))}{d$iv^{$(ns[i])}}") for i in 1:length(var)]
42-
return lhs, rhs
43-
end
44-
4525
@latexrecipe function f(sys::ModelingToolkit.ODESystem; iv=:t)
4626
return latexify(sys.eqs; iv=iv)
4727
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,20 @@ end
142142
function DiffEqBase.ODEFunction(sys::AbstractODESystem, args...; kwargs...)
143143
ODEFunction{true}(sys, args...; kwargs...)
144144
end
145+
146+
independent_variables(sys::AbstractODESystem) = Set{Variable}([sys.iv])
147+
dependent_variables(sys::AbstractODESystem) = Set{Variable}(sys.dvs)
148+
parameters(sys::AbstractODESystem) = Set{Variable}(sys.ps)
149+
150+
function _eq_unordered(a, b)
151+
length(a) === length(b) || return false
152+
n = length(a)
153+
idxs = Set(1:n)
154+
for x a
155+
idx = findfirst(isequal(x), b)
156+
idx === nothing && return false
157+
idx idxs || return false
158+
delete!(idxs, idx)
159+
end
160+
return true
161+
end

src/systems/diffeqs/first_order_transform.jl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,34 @@ function lower_varname(var::Variable, idv, order)
44
return Variable(name; known = var.known)
55
end
66

7+
function flatten_differential(O::Operation)
8+
@assert is_derivative(O) "invalid differential: $O"
9+
is_derivative(O.args[1]) || return (O.args[1], O.op.x, 1)
10+
(x, t, order) = flatten_differential(O.args[1])
11+
isequal(t, O.op.x) || throw(ArgumentError("non-matching differentials on lhs: $t, $(O.op.x)"))
12+
return (x, t, order + 1)
13+
end
14+
715
function ode_order_lowering(sys::ODESystem)
816
eqs_lowered, _ = ode_order_lowering(sys.eqs, sys.iv)
9-
ODESystem(eqs_lowered, sys.iv, [eq.x for eq in eqs_lowered], sys.ps)
17+
ODESystem(eqs_lowered, sys.iv, [var_from_nested_derivative(eq.lhs)[1] for eq in eqs_lowered], sys.ps)
1018
end
19+
1120
function ode_order_lowering(eqs, iv)
1221
var_order = Dict{Variable,Int}()
1322
vars = Variable[]
14-
new_eqs = similar(eqs, ODEExpr)
23+
new_eqs = Equation[]
1524
new_vars = Variable[]
1625

1726
for (i, eq) enumerate(eqs)
18-
var, maxorder = eq.x, eq.n
27+
var, maxorder = var_from_nested_derivative(eq.lhs)
1928
if maxorder > get(var_order, var, 0)
2029
var_order[var] = maxorder
2130
any(isequal(var), vars) || push!(vars, var)
2231
end
23-
var′ = lower_varname(eq.x, iv, eq.n - 1)
32+
var′ = lower_varname(var, iv, maxorder - 1)
2433
rhs′ = rename(eq.rhs)
25-
new_eqs[i] = ODEExpr(var′, 1, rhs′)
34+
push!(new_eqs,Differential(iv())(var′(iv())) ~ rhs′)
2635
end
2736

2837
for var vars
@@ -33,7 +42,7 @@ function ode_order_lowering(eqs, iv)
3342
push!(new_vars, rvar)
3443

3544
rhs = rvar(iv())
36-
eq = ODEExpr(lvar, 1, rhs)
45+
eq = Differential(iv())(lvar(iv())) ~ rhs
3746
push!(new_eqs, eq)
3847
end
3948
end

src/systems/diffeqs/odesystem.jl

Lines changed: 15 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,3 @@
1-
isintermediate(eq::Equation) = !(isa(eq.lhs, Operation) && isa(eq.lhs.op, Differential))
2-
3-
function flatten_differential(O::Operation)
4-
@assert is_derivative(O) "invalid differential: $O"
5-
is_derivative(O.args[1]) || return (O.args[1], O.op.x, 1)
6-
(x, t, order) = flatten_differential(O.args[1])
7-
isequal(t, O.op.x) || throw(ArgumentError("non-matching differentials on lhs: $t, $(O.op.x)"))
8-
return (x, t, order + 1)
9-
end
10-
11-
struct ODEExpr # dⁿx/dtⁿ = rhs
12-
x::Variable
13-
n::Int
14-
rhs::Expression
15-
end
16-
function Base.convert(::Type{ODEExpr},eq::Equation)
17-
isintermediate(eq) && throw(ArgumentError("intermediate equation received"))
18-
(x, t, n) = flatten_differential(eq.lhs)
19-
(isa(t, Operation) && isa(t.op, Variable) && isempty(t.args)) ||
20-
throw(ArgumentError("invalid independent variable $t"))
21-
(isa(x, Operation) && isa(x.op, Variable) && length(x.args) == 1 && isequal(first(x.args), t)) ||
22-
throw(ArgumentError("invalid dependent variable $x"))
23-
return t.op, ODEExpr(x.op, n, eq.rhs)
24-
end
25-
Base.:(==)(a::ODEExpr, b::ODEExpr) = isequal((a.x, a.n, a.rhs), (b.x, b.n, b.rhs))
26-
271
"""
282
$(TYPEDEF)
293
@@ -50,7 +24,7 @@ de = ODESystem(eqs)
5024
"""
5125
struct ODESystem <: AbstractODESystem
5226
"""The ODEs defining the system."""
53-
eqs::Vector{ODEExpr}
27+
eqs::Vector{Equation}
5428
"""Independent variable."""
5529
iv::Variable
5630
"""Dependent (state) variables."""
@@ -83,55 +57,35 @@ struct ODESystem <: AbstractODESystem
8357
name::Symbol
8458
end
8559

86-
function ODESystem(deqs::AbstractVector{ODEExpr}, iv, dvs, ps; name=gensym(:ODESystem))
60+
function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
61+
name=gensym(:ODESystem))
62+
iv′ = clean(iv)
63+
dvs′ = [clean(dv) for dv dvs]
64+
ps′ = [clean(p) for p ps]
8765
tgrad = RefValue(Vector{Expression}(undef, 0))
8866
jac = RefValue(Matrix{Expression}(undef, 0, 0))
8967
Wfact = RefValue(Matrix{Expression}(undef, 0, 0))
9068
Wfact_t = RefValue(Matrix{Expression}(undef, 0, 0))
91-
ODESystem(deqs, iv, dvs, ps, tgrad, jac, Wfact, Wfact_t, name)
69+
ODESystem(deqs, iv, dvs, ps, tgrad, jac, Wfact, Wfact_t, name)
9270
end
9371

94-
function ODESystem(eqs; kwargs...)
95-
reformatted = convert.(ODEExpr,eqs)
72+
var_from_nested_derivative(x) = var_from_nested_derivative(x,0)
73+
var_from_nested_derivative(x,i) = x.op isa Differential ? var_from_nested_derivative(x.args[1],i+1) : (x.op,i)
74+
iv_from_nested_derivative(x) = x.op isa Differential ? iv_from_nested_derivative(x.args[1]) : x.args[1].op
9675

97-
ivs = unique(r[1] for r reformatted)
76+
function ODESystem(eqs; kwargs...)
77+
ivs = unique(iv_from_nested_derivative(eq.lhs) for eq eqs)
9878
length(ivs) == 1 || throw(ArgumentError("one independent variable currently supported"))
9979
iv = first(ivs)
10080

101-
deqs = [r[2] for r reformatted]
102-
103-
dvs = unique(deq.x for deq deqs)
104-
ps = filter(vars(deq.rhs for deq deqs)) do x
81+
dvs = unique(var_from_nested_derivative(eq.lhs)[1] for eq eqs)
82+
ps = filter(vars(eq.rhs for eq eqs)) do x
10583
x.known & !isequal(x, iv)
10684
end |> collect
107-
108-
ODESystem(deqs, iv, dvs, ps; kwargs...)
85+
ODESystem(eqs, iv, dvs, ps; kwargs...)
10986
end
11087

111-
function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; kwargs...)
112-
_dvs = [deq.op for deq dvs]
113-
_iv = iv.op
114-
_ps = [p.op for p ps]
115-
ODESystem(getindex.(convert.(ODEExpr,deqs),2), _iv, _dvs, _ps; kwargs...)
116-
end
117-
118-
function _eq_unordered(a, b)
119-
length(a) === length(b) || return false
120-
n = length(a)
121-
idxs = Set(1:n)
122-
for x a
123-
idx = findfirst(isequal(x), b)
124-
idx === nothing && return false
125-
idx idxs || return false
126-
delete!(idxs, idx)
127-
end
128-
return true
129-
end
13088
Base.:(==)(sys1::ODESystem, sys2::ODESystem) =
13189
_eq_unordered(sys1.eqs, sys2.eqs) && isequal(sys1.iv, sys2.iv) &&
13290
_eq_unordered(sys1.dvs, sys2.dvs) && _eq_unordered(sys1.ps, sys2.ps)
13391
# NOTE: equality does not check cached Jacobian
134-
135-
independent_variables(sys::ODESystem) = Set{Variable}([sys.iv])
136-
dependent_variables(sys::ODESystem) = Set{Variable}(sys.dvs)
137-
parameters(sys::ODESystem) = Set{Variable}(sys.ps)

src/systems/diffeqs/sdesystem.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
struct SDESystem <: AbstractODESystem
22
"""The expressions defining the drift term."""
3-
eqs::Vector{ODEExpr}
3+
eqs::Vector{Equation}
44
"""The expressions defining the diffusion term."""
55
noiseeqs
66
"""Independent variable."""
@@ -35,19 +35,15 @@ struct SDESystem <: AbstractODESystem
3535
name::Symbol
3636
end
3737

38-
function SDESystem(deqs::AbstractVector{ODEExpr}, neqs, iv, dvs, ps; name = gensym(:SDESystem))
38+
function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps; name = gensym(:SDESystem))
39+
dvs′ = [clean(dv) for dv dvs]
40+
ps′ = [clean(p) for p ps]
41+
iv′ = clean(iv)
3942
tgrad = RefValue(Vector{Expression}(undef, 0))
4043
jac = RefValue(Matrix{Expression}(undef, 0, 0))
4144
Wfact = RefValue(Matrix{Expression}(undef, 0, 0))
4245
Wfact_t = RefValue(Matrix{Expression}(undef, 0, 0))
43-
SDESystem(deqs, neqs, iv, dvs, ps, tgrad, jac, Wfact, Wfact_t, name)
44-
end
45-
46-
function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps; kwargs...)
47-
_dvs = [deq.op for deq dvs]
48-
_iv = iv.op
49-
_ps = [p.op for p ps]
50-
SDESystem(getindex.(convert.(ODEExpr,deqs),2), neqs, _iv, _dvs, _ps; kwargs...)
46+
SDESystem(deqs, neqs, iv′, dvs′, ps′, tgrad, jac, Wfact, Wfact_t, name)
5147
end
5248

5349
function generate_diffusion_function(sys::SDESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true}; kwargs...)

src/systems/diffeqs/validation.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
isintermediate(eq::Equation) = !(isa(eq.lhs, Operation) && isa(eq.lhs.op, Differential))
2+
3+
struct ODEExpr # dⁿx/dtⁿ = rhs
4+
x::Variable
5+
n::Int
6+
rhs::Expression
7+
end
8+
function Base.convert(::Type{ODEExpr},eq::Equation)
9+
isintermediate(eq) && throw(ArgumentError("intermediate equation received"))
10+
(x, t, n) = flatten_differential(eq.lhs)
11+
(isa(t, Operation) && isa(t.op, Variable) && isempty(t.args)) ||
12+
throw(ArgumentError("invalid independent variable $t"))
13+
(isa(x, Operation) && isa(x.op, Variable) && length(x.args) == 1 && isequal(first(x.args), t)) ||
14+
throw(ArgumentError("invalid dependent variable $x"))
15+
return t.op, ODEExpr(x.op, n, eq.rhs)
16+
end
17+
Base.:(==)(a::ODEExpr, b::ODEExpr) = isequal((a.x, a.n, a.rhs), (b.x, b.n, b.rhs))

test/build_targets.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ModelingToolkit, Test
44
@derivatives D'~t
55
eqs = [D(x) ~ a*x - x*y,
66
D(y) ~ -3y + x*y]
7-
@test ModelingToolkit.build_function(eqs,[x,y],[a],t,target = ModelingToolkit.StanTarget()) ==
7+
@test ModelingToolkit.build_function(eqs,ModelingToolkit.clean.([x,y]),ModelingToolkit.clean.([a]),t,target = ModelingToolkit.StanTarget()) ==
88
"""
99
real[] diffeqf(real t,real[] internal_var___u,real[] internal_var___p,real[] x_r,int[] x_i) {
1010
real internal_var___du[2];

test/components.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
using ModelingToolkit, StaticArrays, LinearAlgebra
2+
using DiffEqBase
3+
using Test
4+
5+
# Define some variables
6+
@parameters t σ ρ β
7+
@variables x(t) y(t) z(t)
8+
@derivatives D'~t
9+
10+
eqs = [D(x) ~ σ*(y-x),
11+
D(y) ~ x*-z)-y,
12+
0 ~ x + y + z]
13+
14+
de1 = ODESystem(eqs,t,[x,y,z],[σ,ρ,β],name=:lorenz1)
15+
de2 = ODESystem(eqs,t,[x,y,z],[σ,ρ,β],name=:lorenz2)
16+
17+
@parameters α
18+
@variables a(t)
19+
for
20+
21+
convert(DAESystem,sys)
22+
convert(ODESystem,sys)
23+
24+
connected = ODESystems(connnectedeqs,[a],[α],systems=[de1,de2],name=:connectedlorenz)
25+
26+
deleteat!(de2.eqs,3)
27+
de2
28+
push!(de2.eqs,convert(ModelingToolkit.ODEExpr,D(z) ~ z^2)[2])
29+
30+
check_consistency(de2)
31+
32+
b = a*z
33+
[states(de1,:x) ~ α*states(de2,:z),
34+
D(a) ~ b*states(de1,:x)]
35+
36+
states(states) == [lorenz1′x,lorenz1′y,lorenz1′z,lorenz2′x,lorenz2′y,lorenz2′z,a]
37+
parameters(connected) == [lorenz1′σ,lorenz1′ρ,lorenz1′β,lorenz2′σ,lorenz2′ρ,lorenz2′β,α]

test/odesystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,11 @@ lowered_eqs = [D(u_tt) ~ 2u_tt + u_t + x_t + 1
137137
D(u_t) ~ u_tt
138138
D(u) ~ u_t
139139
D(x) ~ x_t]
140+
140141
@test de1 == ODESystem(lowered_eqs)
141142

142143
# issue #219
143-
@test de1.dvs == [eq.x for eq in de1.eqs] == ODESystem(lowered_eqs).dvs
144+
@test de1.dvs == [ModelingToolkit.var_from_nested_derivative(eq.lhs)[1] for eq in de1.eqs] == ODESystem(lowered_eqs).dvs
144145

145146
test_diffeq_inference("first-order transform", de1, t, [u_tt, x_t, u_t, u, x], [])
146147
du = zeros(5)

0 commit comments

Comments
 (0)