Skip to content

Commit 8d2f993

Browse files
Merge pull request #295 from SciML/lowering
Add lowering integration test
2 parents d82c0c8 + 02df4d1 commit 8d2f993

File tree

7 files changed

+39
-11
lines changed

7 files changed

+39
-11
lines changed

src/build_function.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ function _build_function(target::JuliaTarget, op::Operation, vs, ps = (), args =
1313
conv = simplified_expr, expression = Val{true};
1414
checkbounds = false, constructor=nothing,
1515
linenumbers = true)
16-
_vs = map(x-> x isa Operation ? x.op : x, vs)
17-
_ps = map(x-> x isa Operation ? x.op : x, ps)
16+
_vs = convert.(Variable,vs)
17+
_ps = convert.(Variable,ps)
1818
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(_vs)]
1919
param_pairs = [(p.name, :(p[$i])) for (i, p) enumerate(_ps)]
2020
(ls, rs) = zip(var_pairs..., param_pairs...)
@@ -48,8 +48,8 @@ function _build_function(target::JuliaTarget, rhss, vs, ps = (), args = (),
4848
conv = simplified_expr, expression = Val{true};
4949
checkbounds = false, constructor=nothing,
5050
linenumbers = true, multithread=false)
51-
_vs = map(x-> x isa Operation ? x.op : x, vs)
52-
_ps = map(x-> x isa Operation ? x.op : x, ps)
51+
_vs = convert.(Variable,vs)
52+
_ps = convert.(Variable,ps)
5353
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(_vs)]
5454
param_pairs = [(p.name, :(p[$i])) for (i, p) enumerate(_ps)]
5555
(ls, rs) = zip(var_pairs..., param_pairs...)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ end
4242
function generate_jacobian(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys), expression = Val{true}; sparse = false, kwargs...)
4343
jac = calculate_jacobian(sys)
4444
if sparse
45-
jac = sparse(jac)
45+
jac = SparseArrays.sparse(jac)
4646
end
4747
return build_function(jac, dvs, ps, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
4848
end
@@ -125,6 +125,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
125125
ps = parameters(sys);
126126
version = nothing, tgrad=false,
127127
jac = false, Wfact = false,
128+
sparse = false,
128129
kwargs...) where {iip}
129130
f_oop,f_iip = generate_function(sys, dvs, ps, Val{false}; kwargs...)
130131

@@ -140,7 +141,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
140141
end
141142

142143
if jac
143-
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps, Val{false}; kwargs...)
144+
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps, Val{false}; sparse = sparse, kwargs...)
144145
_jac(u,p,t) = jac_oop(u,p,t)
145146
_jac(J,u,p,t) = jac_iip(J,u,p,t)
146147
else

src/systems/diffeqs/sdesystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ are used to set the order of the dependent variable and parameter vectors,
6666
respectively.
6767
"""
6868
function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.ps;
69-
version = nothing, tgrad=false,
69+
version = nothing, tgrad=false, sparse = false,
7070
jac = false, Wfact = false, kwargs...) where {iip}
7171
f_oop,f_iip = generate_function(sys, dvs, ps, Val{false}; kwargs...)
7272
g_oop,g_iip = generate_diffusion_function(sys, dvs, ps, Val{false}; kwargs...)
@@ -85,7 +85,7 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.
8585
end
8686

8787
if jac
88-
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps, Val{false}; kwargs...)
88+
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps, Val{false}; sparse=sparse, kwargs...)
8989
_jac(u,p,t) = jac_oop(u,p,t)
9090
_jac(J,u,p,t) = jac_iip(J,u,p,t)
9191
else
@@ -131,7 +131,7 @@ function DiffEqBase.SDEProblem{iip}(sys::SDESystem,u0map,tspan,p=parammap;
131131

132132
f = SDEFunction(sys;tgrad=tgrad,jac=jac,Wfact=Wfact,checkbounds=checkbounds,
133133
linenumbers=linenumbers,multithread=multithread,
134-
sparse=false)
134+
sparse=sparse)
135135
u0 = varmap_to_vars(u0map,states(sys))
136136
p = varmap_to_vars(parammap,parameters(sys))
137137
SDEProblem(f,f.g,u0,tspan,p;kwargs...)

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ function generate_jacobian(sys::NonlinearSystem, vs = states(sys), ps = paramete
5151
sparse = false, kwargs...)
5252
jac = calculate_jacobian(sys)
5353
if sparse
54-
jac = sparse(jac)
54+
jac = SparseArrays.sparse(jac)
5555
end
5656
return build_function(jac, convert.(Variable,vs), convert.(Variable,ps), (), NLSysToExpr(sys))
5757
end

src/systems/optimization/optimizationsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem,
8181
ps = parameters(sys)
8282

8383
f = generate_function(sys,checkbounds=checkbounds,linenumbers=linenumbers,
84-
multithread=multithread,sparse=sparse)
84+
multithread=multithread)
8585
u0 = varmap_to_vars(u0,dvs)
8686
p = varmap_to_vars(parammap,ps)
8787
lb = varmap_to_vars(lb,dvs)

test/lowering_solving.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using ModelingToolkit, OrdinaryDiffEq
2+
3+
@parameters t σ ρ β
4+
@variables x(t) y(t) z(t)
5+
@derivatives D'~t
6+
7+
eqs = [D(D(x)) ~ σ*(y-x),
8+
D(y) ~ x*-z)-y,
9+
D(z) ~ x*y - β*z]
10+
11+
sys = ODESystem(eqs)
12+
sys = ode_order_lowering(sys)
13+
14+
u0 = [D(x) => 2.0,
15+
x => 1.0,
16+
y => 0.0,
17+
z => 0.0]
18+
19+
p ==> 28.0,
20+
ρ => 10.0,
21+
β => 8/3]
22+
23+
tspan = (0.0,100.0)
24+
prob = ODEProblem(sys,u0,tspan,p,jac=true)
25+
sol = solve(prob,Tsit5())
26+
#using Plots; plot(sol,vars=(:x,:y))

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ using SafeTestsets, Test
1515
@safetestset "Domain Test" begin include("domains.jl") end
1616
@safetestset "Constraints Test" begin include("constraints.jl") end
1717
@safetestset "PDE Construction Test" begin include("pde.jl") end
18+
@safetestset "Lowering Integration Test" begin include("lowering_solving.jl") end
1819
@safetestset "Test Big System Usage" begin include("bigsystem.jl") end
1920
@testset "Latexify recipes Test" begin include("latexify.jl") end
2021
@testset "Distributed Test" begin include("distributed.jl") end

0 commit comments

Comments
 (0)