Skip to content

Commit cd6280a

Browse files
Refactor iW code generation
1 parent 4a8a86e commit cd6280a

File tree

2 files changed

+9
-14
lines changed

2 files changed

+9
-14
lines changed

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,7 @@ system_vars(sys::DiffEqSystem) = sys.dvs
6060
system_params(sys::DiffEqSystem) = sys.ps
6161

6262

63-
function generate_ode_iW(sys::DiffEqSystem, simplify=true)
64-
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in eachindex(sys.dvs)]
65-
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in eachindex(sys.ps)]
63+
function generate_ode_iW(sys::DiffEqSystem, simplify=true; version::FunctionVersion = ArrayFunction)
6664
jac = calculate_jacobian(sys)
6765

6866
gam = Parameter(:gam)
@@ -82,14 +80,11 @@ function generate_ode_iW(sys::DiffEqSystem, simplify=true)
8280
iW_t = simplify_constants.(iW_t)
8381
end
8482

85-
iW_exprs = [:(iW[$i,$j] = $(convert(Expr, iW[i,j]))) for i in 1:size(iW,1), j in 1:size(iW,2)]
86-
exprs = vcat(var_exprs,param_exprs,vec(iW_exprs))
87-
block = expr_arr_to_block(exprs)
83+
vs, ps = system_vars(sys), system_params(sys)
84+
iW_func = build_function(iW , vs, ps, (:gam,:t); version = version)
85+
iW_t_func = build_function(iW_t, vs, ps, (:gam,:t); version = version)
8886

89-
iW_t_exprs = [:(iW[$i,$j] = $(convert(Expr, iW_t[i,j]))) for i in 1:size(iW_t,1), j in 1:size(iW_t,2)]
90-
exprs = vcat(var_exprs,param_exprs,vec(iW_t_exprs))
91-
block2 = expr_arr_to_block(exprs)
92-
:((iW,u,p,gam,t)->$(block)),:((iW,u,p,gam,t)->$(block2))
87+
return (iW_func, iW_t_func)
9388
end
9489

9590
function DiffEqBase.ODEFunction(sys::DiffEqSystem; version::FunctionVersion = ArrayFunction)

src/systems/systems.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ end
2121
function generate_function(sys::AbstractSystem; version::FunctionVersion = ArrayFunction)
2222
sys_eqs = system_eqs(sys)
2323
vs, ps = system_vars(sys), system_params(sys)
24-
return build_function([eq.rhs for eq sys_eqs], vs, ps; version = version)
24+
return build_function([eq.rhs for eq sys_eqs], vs, ps, (:t,); version = version)
2525
end
2626

27-
function build_function(rhss, vs, ps; version::FunctionVersion)
27+
function build_function(rhss, vs, ps, args; version::FunctionVersion)
2828
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(vs)]
2929
param_pairs = [(p.name, :(p[$i])) for (i, p) enumerate(ps)]
3030
(ls, rs) = collect(zip(var_pairs..., param_pairs...))
@@ -35,11 +35,11 @@ function build_function(rhss, vs, ps; version::FunctionVersion)
3535
X = gensym()
3636
sys_exprs = [:($X[$i] = $(convert(Expr, rhs))) for (i, rhs) enumerate(rhss)]
3737
let_expr = Expr(:let, var_eqs, build_expr(:block, sys_exprs))
38-
:(($X,u,p,t) -> $let_expr)
38+
:(($X,u,p,$(args...)) -> $let_expr)
3939
elseif version === SArrayFunction
4040
sys_expr = build_expr(:tuple, [convert(Expr, rhs) for rhs rhss])
4141
let_expr = Expr(:let, var_eqs, sys_expr)
42-
:((u,p,t) -> begin
42+
:((u,p,$(args...)) -> begin
4343
X = $let_expr
4444
T = StaticArrays.similar_type(typeof(u), eltype(X))
4545
T(X)

0 commit comments

Comments
 (0)