Skip to content

Commit 9ae43bc

Browse files
mass matrix ODE support
1 parent c7d6cd5 commit 9ae43bc

File tree

6 files changed

+69
-9
lines changed

6 files changed

+69
-9
lines changed

src/ModelingToolkit.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ export Equation, ConstrainedEquation
107107
export simplify_constants
108108

109109
export Operation, Expression
110-
export calculate_jacobian, generate_jacobian, generate_function, generate_diffusion_function
110+
export calculate_jacobian, generate_jacobian, generate_function
111+
export calculate_massmatrix, generate_diffusion_function
111112
export independent_variables, dependent_variables, parameters
112113
export simplified_expr, eval_function
113114
export @register, @I

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,23 @@ function generate_factorized_W(sys::AbstractODESystem, vs = sys.dvs, ps = sys.ps
8989
return (Wfact_func, Wfact_t_func)
9090
end
9191

92+
function calculate_massmatrix(sys::AbstractODESystem, simplify=true)
93+
eqs = sys.eqs
94+
M = zeros(length(eqs),length(eqs))
95+
for (i,eq) in enumerate(eqs)
96+
if eq.lhs isa Constant
97+
@assert eq.lhs.value == 0
98+
elseif eq.lhs.op isa Differential
99+
j = findfirst(x->isequal(x.name,var_from_nested_derivative(eq.lhs)[1].name),sys.dvs)
100+
M[i,j] = 1
101+
else
102+
error("Only semi-explicit mass matrices are currently supported")
103+
end
104+
end
105+
M = simplify ? simplify_constants.(M) : M
106+
M == I ? I : M
107+
end
108+
92109
"""
93110
$(SIGNATURES)
94111
@@ -132,10 +149,13 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = sys.dvs, ps =
132149
_Wfact,_Wfact_t = nothing,nothing
133150
end
134151

152+
M = calculate_massmatrix(sys)
153+
135154
ODEFunction{iip}(f,jac=_jac,
136155
tgrad = _tgrad,
137156
Wfact = _Wfact,
138157
Wfact_t = _Wfact_t,
158+
mass_matrix = M,
139159
syms = Symbol.(sys.dvs))
140160
end
141161

src/systems/diffeqs/sdesystem.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,13 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.dvs, ps = sys.ps;
9898
_Wfact,_Wfact_t = nothing,nothing
9999
end
100100

101+
M = calculate_massmatrix(sys)
102+
101103
SDEFunction{iip}(f,g,jac=_jac,
102104
tgrad = _tgrad,
103105
Wfact = _Wfact,
104106
Wfact_t = _Wfact_t,
107+
mass_matrix = M,
105108
syms = Symbol.(sys.dvs))
106109
end
107110

test/mass_matrix.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using OrdinaryDiffEq, ModelingToolkit, Test
2+
@parameters t
3+
@variables y[1:3](t)
4+
@parameters k[1:3]
5+
@derivatives D'~t
6+
7+
eqs = [D(y[1]) ~ -k[1]*y[1] + k[3]*y[2]*y[3],
8+
D(y[2]) ~ k[1]*y[1] - k[3]*y[2]*y[3] - k[2]*y[2]^2,
9+
0 ~ y[1] + y[2] + y[3] - 1]
10+
11+
sys = ODESystem(eqs,t,y,k)
12+
M = calculate_massmatrix(sys)
13+
M == [1 0 0
14+
0 1 0
15+
0 0 0]
16+
17+
f = ODEFunction(sys)
18+
prob_mm = ODEProblem(f,[1.0,0.0,0.0],(0.0,1e5),(0.04,3e7,1e4))
19+
sol = solve(prob_mm,Rodas5(),reltol=1e-8,abstol=1e-8)
20+
21+
function rober(du,u,p,t)
22+
y₁,y₂,y₃ = u
23+
k₁,k₂,k₃ = p
24+
du[1] = -k₁*y₁ + k₃*y₂*y₃
25+
du[2] = k₁*y₁ - k₃*y₂*y₃ - k₂*y₂^2
26+
du[3] = y₁ + y₂ + y₃ - 1
27+
nothing
28+
end
29+
f = ODEFunction(rober,mass_matrix=M)
30+
prob_mm2 = ODEProblem(f,[1.0,0.0,0.0],(0.0,1e5),(0.04,3e7,1e4))
31+
sol2 = solve(prob_mm2,Rodas5(),reltol=1e-8,abstol=1e-8)
32+
33+
@test Array(sol) == Array(sol2)

test/odesystem.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,16 @@ using Test
77
@variables x(t) y(t) z(t)
88
@derivatives D'~t
99

10+
# Define a differential equation
11+
eqs = [D(x) ~ σ*(y-x),
12+
D(y) ~ x*-z)-y,
13+
D(z) ~ x*y - β*z]
14+
15+
ModelingToolkit.simplified_expr.(eqs)[1]
16+
:(derivative(x(t), t) = σ * (y(t) - x(t))).args
17+
de = ODESystem(eqs)
18+
19+
1020
function _clean(O::Operation)
1121
@assert isa(O.op, Variable)
1222
return O.op
@@ -26,14 +36,6 @@ function test_nlsys_inference(name, sys, vs, ps)
2636
end
2737
end
2838

29-
# Define a differential equation
30-
eqs = [D(x) ~ σ*(y-x),
31-
D(y) ~ x*-z)-y,
32-
D(z) ~ x*y - β*z]
33-
34-
ModelingToolkit.simplified_expr.(eqs)[1]
35-
:(derivative(x(t), t) = σ * (y(t) - x(t))).args
36-
de = ODESystem(eqs)
3739
test_diffeq_inference("standard", de, t, (x, y, z), (σ, ρ, β))
3840
generate_function(de, [x,y,z], [σ,ρ,β])
3941
jac_expr = generate_jacobian(de)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using ModelingToolkit, Test
66
@testset "Operation Overloads Test" begin include("operation_overloads.jl") end
77
@testset "Direct Usage Test" begin include("direct.jl") end
88
@testset "ODESystem Test" begin include("odesystem.jl") end
9+
@testset "Mass Matrix Test" begin include("odesystem.jl") end
910
@testset "SDESystem Test" begin include("sdesystem.jl") end
1011
@testset "Build Targets Test" begin include("build_targets.jl") end
1112
@testset "Domain Test" begin include("domains.jl") end

0 commit comments

Comments
 (0)