Skip to content

Commit 8d2e038

Browse files
Expose sparse matrix generation in higher level prob constructors
1 parent 487f67f commit 8d2e038

File tree

4 files changed

+32
-18
lines changed

4 files changed

+32
-18
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,11 @@ function generate_tgrad(sys::AbstractODESystem, dvs = states(sys), ps = paramete
3939
return build_function(tgrad, dvs, ps, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
4040
end
4141

42-
function generate_jacobian(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys), expression = Val{true}; kwargs...)
42+
function generate_jacobian(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys), expression = Val{true}; sparse = false, kwargs...)
4343
jac = calculate_jacobian(sys)
44+
if sparse
45+
jac = sparse(jac)
46+
end
4447
return build_function(jac, dvs, ps, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
4548
end
4649

@@ -121,7 +124,8 @@ respectively.
121124
function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
122125
ps = parameters(sys);
123126
version = nothing, tgrad=false,
124-
jac = false, Wfact = false, kwargs...) where {iip}
127+
jac = false, Wfact = false,
128+
kwargs...) where {iip}
125129
f_oop,f_iip = generate_function(sys, dvs, ps, Val{false}; kwargs...)
126130

127131
f(u,p,t) = f_oop(u,p,t)
@@ -173,11 +177,12 @@ function DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem,u0map,tspan,
173177
parammap=DiffEqBase.NullParameters();
174178
version = nothing, tgrad=false,
175179
jac = false, Wfact = false,
176-
checkbounds = false,
180+
checkbounds = false, sparse = false,
177181
linenumbers = true, multithread=false,
178182
kwargs...) where iip
179183
f = ODEFunction(sys;tgrad=tgrad,jac=jac,Wfact=Wfact,checkbounds=checkbounds,
180-
linenumbers=linenumbers,multithread=multithread)
184+
linenumbers=linenumbers,multithread=multithread,
185+
sparse=sparse)
181186
u0 = varmap_to_vars(u0map,states(sys))
182187
p = varmap_to_vars(parammap,parameters(sys))
183188
ODEProblem(f,u0,tspan,p;kwargs...)

src/systems/diffeqs/sdesystem.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,33 +67,33 @@ respectively.
6767
"""
6868
function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.ps;
6969
version = nothing, tgrad=false,
70-
jac = false, Wfact = false) where {iip}
71-
f_oop,f_iip = generate_function(sys, dvs, ps, Val{false})
72-
g_oop,g_iip = generate_diffusion_function(sys, dvs, ps, Val{false})
70+
jac = false, Wfact = false, kwargs...) where {iip}
71+
f_oop,f_iip = generate_function(sys, dvs, ps, Val{false}; kwargs...)
72+
g_oop,g_iip = generate_diffusion_function(sys, dvs, ps, Val{false}; kwargs...)
7373

7474
f(u,p,t) = f_oop(u,p,t)
7575
f(du,u,p,t) = f_iip(du,u,p,t)
7676
g(u,p,t) = g_oop(u,p,t)
7777
g(du,u,p,t) = g_iip(du,u,p,t)
7878

7979
if tgrad
80-
tgrad_oop,tgrad_iip = generate_tgrad(sys, dvs, ps, Val{false})
80+
tgrad_oop,tgrad_iip = generate_tgrad(sys, dvs, ps, Val{false}; kwargs...)
8181
_tgrad(u,p,t) = tgrad_oop(u,p,t)
8282
_tgrad(J,u,p,t) = tgrad_iip(J,u,p,t)
8383
else
8484
_tgrad = nothing
8585
end
8686

8787
if jac
88-
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps, Val{false})
88+
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps, Val{false}; kwargs...)
8989
_jac(u,p,t) = jac_oop(u,p,t)
9090
_jac(J,u,p,t) = jac_iip(J,u,p,t)
9191
else
9292
_jac = nothing
9393
end
9494

9595
if Wfact
96-
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps, true, Val{false})
96+
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps, true, Val{false}; kwargs...)
9797
Wfact_oop, Wfact_iip = tmp_Wfact
9898
Wfact_oop_t, Wfact_iip_t = tmp_Wfact_t
9999
_Wfact(u,p,dtgamma,t) = Wfact_oop(u,p,dtgamma,t)
@@ -125,12 +125,13 @@ end
125125
function DiffEqBase.SDEProblem{iip}(sys::SDESystem,u0map,tspan,p=parammap;
126126
version = nothing, tgrad=false,
127127
jac = false, Wfact = false,
128-
checkbounds = false,
128+
checkbounds = false, sparse = false,
129129
linenumbers = true, multithread=false,
130130
kwargs...) where iip
131131

132132
f = SDEFunction(sys;tgrad=tgrad,jac=jac,Wfact=Wfact,checkbounds=checkbounds,
133-
linenumbers=linenumbers,multithread=multithread)
133+
linenumbers=linenumbers,multithread=multithread,
134+
sparse=false)
134135
u0 = varmap_to_vars(u0map,states(sys))
135136
p = varmap_to_vars(parammap,parameters(sys))
136137
SDEProblem(f,f.g,u0,tspan,p;kwargs...)

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,12 @@ function calculate_jacobian(sys::NonlinearSystem)
4747
return jac
4848
end
4949

50-
function generate_jacobian(sys::NonlinearSystem, vs = states(sys), ps = parameters(sys), expression = Val{true}; kwargs...)
50+
function generate_jacobian(sys::NonlinearSystem, vs = states(sys), ps = parameters(sys), expression = Val{true};
51+
sparse = false, kwargs...)
5152
jac = calculate_jacobian(sys)
53+
if sparse
54+
jac = sparse(jac)
55+
end
5256
return build_function(jac, convert.(Variable,vs), convert.(Variable,ps), (), NLSysToExpr(sys))
5357
end
5458

@@ -74,15 +78,15 @@ end
7478

7579
function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem,u0map,tspan,
7680
parammap=DiffEqBase.NullParameters();
77-
jac = false,
81+
jac = false, sparse=false,
7882
checkbounds = false,
7983
linenumbers = true, multithread=false,
8084
kwargs...) where iip
8185
dvs = states(sys)
8286
ps = parameters(sys)
8387

8488
f = generate_function(sys;checkbounds=checkbounds,linenumbers=linenumbers,
85-
multithread=multithread)
89+
multithread=multithread,sparse=sparse)
8690
u0 = varmap_to_vars(u0map,dvs)
8791
p = varmap_to_vars(parammap,ps)
8892
NonlinearProblem(f,u0,tspan,p;kwargs...)

src/systems/optimization/optimizationsystem.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,12 @@ function calculate_hessian(sys::OptimizationSystem)
5252
expand_derivatives.(hessian(equations(sys), [dv() for dv in states(sys)]))
5353
end
5454

55-
function generate_hessian(sys::OptimizationSystem, vs = states(sys), ps = parameters(sys), expression = Val{true}; kwargs...)
55+
function generate_hessian(sys::OptimizationSystem, vs = states(sys), ps = parameters(sys), expression = Val{true};
56+
sparse = false, kwargs...)
5657
hes = calculate_hessian(sys)
58+
if sparse
59+
hes = sparse(hes)
60+
end
5761
return build_function(hes, convert.(Variable,vs), convert.(Variable,ps), (), x->convert(Expr, x); kwargs...)
5862
end
5963

@@ -69,15 +73,15 @@ namespace_operation(sys::OptimizationSystem) = namespace_operation(sys.op,sys.na
6973
function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem,
7074
parammap=DiffEqBase.NullParameters();
7175
u0=nothing, lb=nothing, ub=nothing,
72-
hes = false,
76+
hes = false, sparse = false,
7377
checkbounds = false,
7478
linenumbers = true, multithread=false,
7579
kwargs...) where iip
7680
dvs = states(sys)
7781
ps = parameters(sys)
7882

7983
f = generate_function(sys,checkbounds=checkbounds,linenumbers=linenumbers,
80-
multithread=multithread)
84+
multithread=multithread,sparse=sparse)
8185
u0 = varmap_to_vars(u0,dvs)
8286
p = varmap_to_vars(parammap,ps)
8387
lb = varmap_to_vars(lb,dvs)

0 commit comments

Comments
 (0)