1+ export DiffEqSystem, ODEFunction
2+
3+
14using Base: RefValue
25
36
@@ -42,34 +45,7 @@ function DiffEqSystem(eqs, iv)
4245end
4346
4447
45- function generate_ode_function (sys:: DiffEqSystem ; version:: FunctionVersion = ArrayFunction)
46- var_exprs = [:($ (sys. dvs[i]. name) = u[$ i]) for i in eachindex (sys. dvs)]
47- param_exprs = [:($ (sys. ps[i]. name) = p[$ i]) for i in eachindex (sys. ps)]
48- sys_exprs = build_equals_expr .(sys. eqs)
49- if version === ArrayFunction
50- dvar_exprs = [:(du[$ i] = $ (Symbol (" $(sys. dvs[i]. name) _$(sys. iv. name) " ))) for i in eachindex (sys. dvs)]
51- exprs = vcat (var_exprs,param_exprs,sys_exprs,dvar_exprs)
52- block = expr_arr_to_block (exprs)
53- :((du,u,p,t)-> $ (toexpr (block)))
54- elseif version === SArrayFunction
55- dvar_exprs = [:($ (Symbol (" $(sys. dvs[i]. name) _$(sys. iv. name) " ))) for i in eachindex (sys. dvs)]
56- svector_expr = quote
57- E = eltype (tuple ($ (dvar_exprs... )))
58- T = StaticArrays. similar_type (typeof (u), E)
59- T ($ (dvar_exprs... ))
60- end
61- exprs = vcat (var_exprs,param_exprs,sys_exprs,svector_expr)
62- block = expr_arr_to_block (exprs)
63- :((u,p,t)-> $ (toexpr (block)))
64- end
65- end
66-
67- function build_equals_expr (eq:: DiffEq )
68- lhs = Symbol (eq. var. name, :_ , eq. D. x. name)
69- return :($ lhs = $ (convert (Expr, eq. rhs)))
70- end
71-
72- function calculate_jacobian (sys:: DiffEqSystem , simplify= true )
48+ function calculate_jacobian (sys:: DiffEqSystem )
7349 isempty (sys. jac[]) || return sys. jac[] # use cached Jacobian, if possible
7450 rhs = [eq. rhs for eq in sys. eqs]
7551
@@ -78,20 +54,19 @@ function calculate_jacobian(sys::DiffEqSystem, simplify=true)
7854 return jac
7955end
8056
81- function generate_ode_jacobian (sys:: DiffEqSystem , simplify = true )
82- var_exprs = [:( $ ( sys. dvs[i] . name) = u[ $ i]) for i in eachindex (sys . dvs)]
83- param_exprs = [:( $ ( sys. ps[i] . name) = p[ $ i]) for i in eachindex (sys . ps)]
84- jac = calculate_jacobian (sys, simplify)
85- jac_exprs = [:(J[ $ i, $ j] = $ ( convert (Expr, jac[i,j]))) for i in 1 : size (jac, 1 ), j in 1 : size (jac, 2 )]
86- exprs = vcat (var_exprs,param_exprs, vec (jac_exprs) )
87- block = expr_arr_to_block (exprs)
88- :((J,u,p,t) -> $ (block) )
57+ function generate_jacobian (sys:: DiffEqSystem ; version :: FunctionVersion = ArrayFunction )
58+ jac = calculate_jacobian ( sys)
59+ return build_function (jac, sys. dvs, sys . ps, (sys . iv . name,); version = version)
60+ end
61+
62+ function generate_function (sys :: DiffEqSystem ; version :: FunctionVersion = ArrayFunction )
63+ rhss = [eq . rhs for eq ∈ sys . eqs]
64+ return build_function (rhss, sys . dvs, sys . ps, (sys . iv . name,); version = version )
8965end
9066
91- function generate_ode_iW (sys:: DiffEqSystem , simplify= true )
92- var_exprs = [:($ (sys. dvs[i]. name) = u[$ i]) for i in eachindex (sys. dvs)]
93- param_exprs = [:($ (sys. ps[i]. name) = p[$ i]) for i in eachindex (sys. ps)]
94- jac = calculate_jacobian (sys, simplify)
67+
68+ function generate_ode_iW (sys:: DiffEqSystem , simplify= true ; version:: FunctionVersion = ArrayFunction)
69+ jac = calculate_jacobian (sys)
9570
9671 gam = Parameter (:gam )
9772
@@ -110,25 +85,18 @@ function generate_ode_iW(sys::DiffEqSystem, simplify=true)
11085 iW_t = simplify_constants .(iW_t)
11186 end
11287
113- iW_exprs = [:(iW[ $ i, $ j] = $ ( convert (Expr, iW[i,j]))) for i in 1 : size (iW, 1 ), j in 1 : size (iW, 2 )]
114- exprs = vcat (var_exprs,param_exprs, vec (iW_exprs) )
115- block = expr_arr_to_block (exprs )
88+ vs, ps = sys . dvs, sys . ps
89+ iW_func = build_function (iW , vs, ps, ( :gam , :t ); version = version )
90+ iW_t_func = build_function (iW_t, vs, ps, ( :gam , :t ); version = version )
11691
117- iW_t_exprs = [:(iW[$ i,$ j] = $ (convert (Expr, iW_t[i,j]))) for i in 1 : size (iW_t,1 ), j in 1 : size (iW_t,2 )]
118- exprs = vcat (var_exprs,param_exprs,vec (iW_t_exprs))
119- block2 = expr_arr_to_block (exprs)
120- :((iW,u,p,gam,t)-> $ (block)),:((iW,u,p,gam,t)-> $ (block2))
92+ return (iW_func, iW_t_func)
12193end
12294
12395function DiffEqBase. ODEFunction (sys:: DiffEqSystem ; version:: FunctionVersion = ArrayFunction)
124- expr = generate_ode_function (sys; version = version)
96+ expr = generate_function (sys; version = version)
12597 if version === ArrayFunction
12698 ODEFunction {true} (eval (expr))
12799 elseif version === SArrayFunction
128100 ODEFunction {false} (eval (expr))
129101 end
130102end
131-
132-
133- export DiffEqSystem, ODEFunction
134- export generate_ode_function
0 commit comments