1- mutable struct DiffEqSystem <: AbstractSystem
2- eqs:: Vector{Equation}
3- ivs:: Vector{Variable}
1+ using Base: RefValue
2+
3+
4+ isintermediate (eq:: Equation ) = ! (isa (eq. lhs, Operation) && isa (eq. lhs. op, Differential))
5+
6+ struct DiffEq # D(x) = t
7+ D:: Differential # D
8+ var:: Variable # x
9+ rhs:: Expression # t
10+ end
11+ function Base. convert (:: Type{DiffEq} , eq:: Equation )
12+ isintermediate (eq) && throw (ArgumentError (" intermediate equation received" ))
13+ return DiffEq (eq. lhs. op, eq. lhs. args[1 ], eq. rhs)
14+ end
15+ Base.:(== )(a:: DiffEq , b:: DiffEq ) = (a. D, a. var, a. rhs) == (b. D, b. var, b. rhs)
16+ get_args (eq:: DiffEq ) = Expression[eq. var, eq. rhs]
17+
18+ struct DiffEqSystem <: AbstractSystem
19+ eqs:: Vector{DiffEq}
20+ iv:: Variable
421 dvs:: Vector{Variable}
522 ps:: Vector{Variable}
6- jac:: Matrix{Expression}
7- function DiffEqSystem (eqs, ivs, dvs, ps, jac)
8- all (! isintermediate, eqs) ||
9- throw (ArgumentError (" no intermediate equations permitted in DiffEqSystem" ))
10-
11- new (eqs, ivs, dvs, ps, jac)
23+ jac:: RefValue{Matrix{Expression}}
24+ function DiffEqSystem (eqs, iv, dvs, ps)
25+ jac = RefValue (Matrix {Expression} (undef, 0 , 0 ))
26+ new (eqs, iv, dvs, ps, jac)
1227 end
1328end
1429
15- DiffEqSystem (eqs, ivs, dvs, ps) = DiffEqSystem (eqs, ivs, dvs, ps, Matrix {Expression} (undef,0 ,0 ))
16-
1730function DiffEqSystem (eqs)
1831 dvs, = extract_elements (eqs, [_is_dependent])
1932 ivs = unique (vcat ((dv. dependents for dv ∈ dvs). .. ))
20- ps, = extract_elements (eqs, [_is_parameter (ivs)])
21- DiffEqSystem (eqs, ivs, dvs, ps, Matrix {Expression} (undef,0 ,0 ))
33+ length (ivs) == 1 || throw (ArgumentError (" one independent variable currently supported" ))
34+ iv = first (ivs)
35+ ps, = extract_elements (eqs, [_is_parameter (iv)])
36+ DiffEqSystem (eqs, iv, dvs, ps)
2237end
2338
24- function DiffEqSystem (eqs, ivs )
25- dvs, ps = extract_elements (eqs, [_is_dependent, _is_parameter (ivs )])
26- DiffEqSystem (eqs, ivs , dvs, ps, Matrix {Expression} (undef, 0 , 0 ) )
39+ function DiffEqSystem (eqs, iv )
40+ dvs, ps = extract_elements (eqs, [_is_dependent, _is_parameter (iv )])
41+ DiffEqSystem (eqs, iv , dvs, ps)
2742end
2843
29- isintermediate (eq:: Equation ) = ! (isa (eq. lhs, Operation) && isa (eq. lhs. op, Differential))
30-
3144
32- function generate_ode_function (sys:: DiffEqSystem ;version = ArrayFunction)
45+ function generate_ode_function (sys:: DiffEqSystem ; version:: FunctionVersion = ArrayFunction)
3346 var_exprs = [:($ (sys. dvs[i]. name) = u[$ i]) for i in eachindex (sys. dvs)]
3447 param_exprs = [:($ (sys. ps[i]. name) = p[$ i]) for i in eachindex (sys. ps)]
3548 sys_exprs = build_equals_expr .(sys. eqs)
36- if version == ArrayFunction
37- dvar_exprs = [:(du[$ i] = $ (Symbol (" $(sys. dvs[i]. name) _$(sys. ivs[ 1 ] . name) " ))) for i in eachindex (sys. dvs)]
49+ if version === ArrayFunction
50+ dvar_exprs = [:(du[$ i] = $ (Symbol (" $(sys. dvs[i]. name) _$(sys. iv . name) " ))) for i in eachindex (sys. dvs)]
3851 exprs = vcat (var_exprs,param_exprs,sys_exprs,dvar_exprs)
3952 block = expr_arr_to_block (exprs)
4053 :((du,u,p,t)-> $ (toexpr (block)))
41- elseif version == SArrayFunction
42- dvar_exprs = [:($ (Symbol (" $(sys. dvs[i]. name) _$(sys. ivs[ 1 ] . name) " ))) for i in eachindex (sys. dvs)]
54+ elseif version === SArrayFunction
55+ dvar_exprs = [:($ (Symbol (" $(sys. dvs[i]. name) _$(sys. iv . name) " ))) for i in eachindex (sys. dvs)]
4356 svector_expr = quote
4457 E = eltype (tuple ($ (dvar_exprs... )))
4558 T = StaticArrays. similar_type (typeof (u), E)
@@ -51,26 +64,24 @@ function generate_ode_function(sys::DiffEqSystem;version = ArrayFunction)
5164 end
5265end
5366
54- function build_equals_expr (eq:: Equation )
55- @assert ! isintermediate (eq)
56-
57- lhs = Symbol (eq. lhs. args[1 ]. name, :_ , eq. lhs. op. x. name)
67+ function build_equals_expr (eq:: DiffEq )
68+ lhs = Symbol (eq. var. name, :_ , eq. D. x. name)
5869 return :($ lhs = $ (convert (Expr, eq. rhs)))
5970end
6071
6172function calculate_jacobian (sys:: DiffEqSystem , simplify= true )
73+ isempty (sys. jac[]) || return sys. jac[] # use cached Jacobian, if possible
6274 rhs = [eq. rhs for eq in sys. eqs]
6375
64- sys_exprs = calculate_jacobian (rhs, sys. dvs)
65- sys_exprs = Expression[ expand_derivatives (expr) for expr in sys_exprs]
66- sys_exprs
76+ jac = expand_derivatives .( calculate_jacobian (rhs, sys. dvs) )
77+ sys . jac[] = jac # cache Jacobian
78+ return jac
6779end
6880
6981function generate_ode_jacobian (sys:: DiffEqSystem , simplify= true )
7082 var_exprs = [:($ (sys. dvs[i]. name) = u[$ i]) for i in eachindex (sys. dvs)]
7183 param_exprs = [:($ (sys. ps[i]. name) = p[$ i]) for i in eachindex (sys. ps)]
7284 jac = calculate_jacobian (sys, simplify)
73- sys. jac = jac
7485 jac_exprs = [:(J[$ i,$ j] = $ (convert (Expr, jac[i,j]))) for i in 1 : size (jac,1 ), j in 1 : size (jac,2 )]
7586 exprs = vcat (var_exprs,param_exprs,vec (jac_exprs))
7687 block = expr_arr_to_block (exprs)
8091function generate_ode_iW (sys:: DiffEqSystem , simplify= true )
8192 var_exprs = [:($ (sys. dvs[i]. name) = u[$ i]) for i in eachindex (sys. dvs)]
8293 param_exprs = [:($ (sys. ps[i]. name) = p[$ i]) for i in eachindex (sys. ps)]
83- jac = sys. jac
94+ jac = calculate_jacobian ( sys, simplify)
8495
8596 gam = Parameter (:gam )
8697
@@ -109,12 +120,12 @@ function generate_ode_iW(sys::DiffEqSystem, simplify=true)
109120 :((iW,u,p,gam,t)-> $ (block)),:((iW,u,p,gam,t)-> $ (block2))
110121end
111122
112- function DiffEqBase. ODEFunction (sys:: DiffEqSystem ;version = ArrayFunction,kwargs ... )
113- expr = generate_ode_function (sys;version= version,kwargs ... )
114- if version == ArrayFunction
115- ODEFunction {true} (eval (expr))
116- elseif version == SArrayFunction
117- ODEFunction {false} (eval (expr))
123+ function DiffEqBase. ODEFunction (sys:: DiffEqSystem ; version:: FunctionVersion = ArrayFunction)
124+ expr = generate_ode_function (sys; version = version)
125+ if version === ArrayFunction
126+ ODEFunction {true} (eval (expr))
127+ elseif version === SArrayFunction
128+ ODEFunction {false} (eval (expr))
118129 end
119130end
120131
0 commit comments