Skip to content

Commit 0e25e2c

Browse files
Expose function generation keyword arguments from problem interface
1 parent d328a60 commit 0e25e2c

File tree

4 files changed

+26
-12
lines changed

4 files changed

+26
-12
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ function calculate_massmatrix(sys::AbstractODESystem, simplify=true)
100100
j = findfirst(x->isequal(x.name,var_from_nested_derivative(eq.lhs)[1].name),dvs)
101101
M[i,j] = 1
102102
else
103-
error("Only semi-explicit mass matrices are currently supported")
103+
error("Only semi-explicit constant mass matrices are currently supported")
104104
end
105105
end
106106
M = simplify ? simplify_constants.(M) : M
@@ -121,30 +121,30 @@ respectively.
121121
function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
122122
ps = parameters(sys);
123123
version = nothing, tgrad=false,
124-
jac = false, Wfact = false) where {iip}
125-
f_oop,f_iip = generate_function(sys, dvs, ps, Val{false})
124+
jac = false, Wfact = false; kwargs...) where {iip}
125+
f_oop,f_iip = generate_function(sys, dvs, ps, Val{false}; kwargs...)
126126

127127
f(u,p,t) = f_oop(u,p,t)
128128
f(du,u,p,t) = f_iip(du,u,p,t)
129129

130130
if tgrad
131-
tgrad_oop,tgrad_iip = generate_tgrad(sys, dvs, ps, Val{false})
131+
tgrad_oop,tgrad_iip = generate_tgrad(sys, dvs, ps, Val{false}; kwargs...)
132132
_tgrad(u,p,t) = tgrad_oop(u,p,t)
133133
_tgrad(J,u,p,t) = tgrad_iip(J,u,p,t)
134134
else
135135
_tgrad = nothing
136136
end
137137

138138
if jac
139-
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps, Val{false})
139+
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps, Val{false}; kwargs...)
140140
_jac(u,p,t) = jac_oop(u,p,t)
141141
_jac(J,u,p,t) = jac_iip(J,u,p,t)
142142
else
143143
_jac = nothing
144144
end
145145

146146
if Wfact
147-
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps, true, Val{false})
147+
tmp_Wfact,tmp_Wfact_t = generate_factorized_W(sys, dvs, ps, true, Val{false}; kwargs...)
148148
Wfact_oop, Wfact_iip = tmp_Wfact
149149
Wfact_oop_t, Wfact_iip_t = tmp_Wfact_t
150150
_Wfact(u,p,dtgamma,t) = Wfact_oop(u,p,dtgamma,t)
@@ -173,8 +173,11 @@ function DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem,u0map,tspan,
173173
parammap=DiffEqBase.NullParameters();
174174
version = nothing, tgrad=false,
175175
jac = false, Wfact = false,
176+
checkbounds = false,
177+
linenumbers = true, multithread=false,
176178
kwargs...) where iip
177-
f = ODEFunction(sys)
179+
f = ODEFunction(sys;tgrad=tgrad,jac=jac,Wfact=Wfact,checkbounds=checkbounds,
180+
linenumbers=linenumbers,multithread=multithread)
178181
u0 = varmap_to_vars(u0map,states(sys))
179182
p = varmap_to_vars(parammap,parameters(sys))
180183
ODEProblem(f,u0,tspan,p;kwargs...)

src/systems/diffeqs/sdesystem.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,13 @@ end
125125
function DiffEqBase.SDEProblem{iip}(sys::SDESystem,u0map,tspan,p=parammap;
126126
version = nothing, tgrad=false,
127127
jac = false, Wfact = false,
128+
checkbounds = false,
129+
linenumbers = true, multithread=false,
128130
kwargs...) where iip
129-
f = SDEFunction(sys)
131+
132+
f = SDEFunction(sys;tgrad=tgrad,jac=jac,Wfact=Wfact,checkbounds=checkbounds,
133+
linenumbers=linenumbers,multithread=multithread)
130134
u0 = varmap_to_vars(u0map,states(sys))
131135
p = varmap_to_vars(parammap,parameters(sys))
132-
SDEProblem(f,g,u0,tspan,p;kwargs...)
136+
SDEProblem(f,f.g,u0,tspan,p;kwargs...)
133137
end

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,15 @@ end
7474

7575
function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem,u0map,tspan,
7676
parammap=DiffEqBase.NullParameters();
77-
jac = false,kwargs...) where iip
77+
jac = false,
78+
checkbounds = false,
79+
linenumbers = true, multithread=false,
80+
kwargs...) where iip
7881
dvs = states(sys)
7982
ps = parameters(sys)
8083

81-
f = generate_function(sys)
84+
f = generate_function(sys;checkbounds=checkbounds,linenumbers=linenumbers,
85+
multithread=multithread)
8286
u0 = varmap_to_vars(u0map,dvs)
8387
p = varmap_to_vars(parammap,ps)
8488
NonlinearProblem(f,u0,tspan,p;kwargs...)

src/systems/optimization/optimizationsystem.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,14 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem,
7070
parammap=DiffEqBase.NullParameters();
7171
u0=nothing, lb=nothing, ub=nothing,
7272
hes = false,
73+
checkbounds = false,
74+
linenumbers = true, multithread=false,
7375
kwargs...) where iip
7476
dvs = states(sys)
7577
ps = parameters(sys)
7678

77-
f = generate_function(sys)
79+
f = generate_function(sys,checkbounds=checkbounds,linenumbers=linenumbers,
80+
multithread=multithread)
7881
u0 = varmap_to_vars(u0,dvs)
7982
p = varmap_to_vars(parammap,ps)
8083
lb = varmap_to_vars(lb,dvs)

0 commit comments

Comments
 (0)