@@ -36,14 +36,23 @@ function DiffEqSystem(eqs, ivs;
3636 DiffEqSystem (eqs, ivs, dvs, vs, ps, ivs[1 ]. subtype, dv_name, p_name, Matrix {Expression} (undef,0 ,0 ))
3737end
3838
39- function generate_ode_function (sys:: DiffEqSystem )
39+ function generate_ode_function (sys:: DiffEqSystem ;version = ArrayFunction )
4040 var_exprs = [:($ (sys. dvs[i]. name) = u[$ i]) for i in 1 : length (sys. dvs)]
4141 param_exprs = [:($ (sys. ps[i]. name) = p[$ i]) for i in 1 : length (sys. ps)]
4242 sys_exprs = build_equals_expr .(sys. eqs)
43- dvar_exprs = [:(du[$ i] = $ (Symbol (" $(sys. dvs[i]. name) _$(sys. ivs[1 ]. name) " ))) for i in 1 : length (sys. dvs)]
44- exprs = vcat (var_exprs,param_exprs,sys_exprs,dvar_exprs)
45- block = expr_arr_to_block (exprs)
46- :((du,u,p,t)-> $ (block))
43+
44+ if version == ArrayFunction
45+ dvar_exprs = [:(du[$ i] = $ (Symbol (" $(sys. dvs[i]. name) _$(sys. ivs[1 ]. name) " ))) for i in 1 : length (sys. dvs)]
46+ exprs = vcat (var_exprs,param_exprs,sys_exprs,dvar_exprs)
47+ block = expr_arr_to_block (exprs)
48+ :((du,u,p,t)-> $ (block))
49+ elseif version == SArrayFunction
50+ dvar_exprs = [:($ (Symbol (" $(sys. dvs[i]. name) _$(sys. ivs[1 ]. name) " ))) for i in 1 : length (sys. dvs)]
51+ svector_expr = :(typeof (u)($ (dvar_exprs... )))
52+ exprs = vcat (var_exprs,param_exprs,sys_exprs,svector_expr)
53+ block = expr_arr_to_block (exprs)
54+ :((u,p,t)-> $ (block))
55+ end
4756end
4857
4958isintermediate (eq) = eq. args[1 ]. diff == nothing
0 commit comments