1616
1717
1818struct DiffEq # dⁿx/dtⁿ = rhs
19- x:: Expression
20- t:: Variable
19+ x:: Variable
2120 n:: Int
2221 rhs:: Expression
2322end
24- function Base . convert ( :: Type{DiffEq} , eq:: Equation )
23+ function to_diffeq ( eq:: Equation )
2524 isintermediate (eq) && throw (ArgumentError (" intermediate equation received" ))
2625 (x, t, n) = flatten_differential (eq. lhs)
27- return DiffEq (x, t, n, eq. rhs)
26+ (isa (t, Operation) && isa (t. op, Variable) && isempty (t. args)) ||
27+ throw (ArgumentError (" invalid independent variable $t " ))
28+ (isa (x, Operation) && isa (x. op, Variable) && length (x. args) == 1 && isequal (first (x. args), t)) ||
29+ throw (ArgumentError (" invalid dependent variable $x " ))
30+ return t. op, DiffEq (x. op, n, eq. rhs)
2831end
2932Base.:(== )(a:: DiffEq , b:: DiffEq ) = isequal ((a. x, a. t, a. n, a. rhs), (b. x, b. t, b. n, b. rhs))
30- get_args (eq:: DiffEq ) = Expression[eq. x, eq. t, eq. rhs]
3133
3234struct DiffEqSystem <: AbstractSystem
3335 eqs:: Vector{DiffEq}
3436 iv:: Variable
3537 dvs:: Vector{Variable}
3638 ps:: Vector{Variable}
3739 jac:: RefValue{Matrix{Expression}}
38- function DiffEqSystem (eqs, iv, dvs, ps)
39- jac = RefValue (Matrix {Expression} (undef, 0 , 0 ))
40- new (eqs, iv, dvs, ps, jac)
41- end
42- end
40+ function DiffEqSystem (eqs)
41+ reformatted = to_diffeq .(eqs)
4342
44- function DiffEqSystem (eqs)
45- dvs, = extract_elements (eqs, [_is_dependent])
46- ivs = unique (vcat ((dv. dependents for dv ∈ dvs). .. ))
47- length (ivs) == 1 || throw (ArgumentError (" one independent variable currently supported" ))
48- iv = first (ivs)
49- ps, = extract_elements (eqs, [_is_parameter (iv)])
50- DiffEqSystem (eqs, iv, dvs, ps)
51- end
43+ ivs = unique (r[1 ] for r ∈ reformatted)
44+ length (ivs) == 1 || throw (ArgumentError (" one independent variable currently supported" ))
45+ iv = first (ivs)
5246
53- function DiffEqSystem (eqs, iv)
54- dvs, ps = extract_elements (eqs, [_is_dependent, _is_parameter (iv)])
55- DiffEqSystem (eqs, iv, dvs, ps)
47+ deqs = [r[2 ] for r ∈ reformatted]
48+
49+ dvs = [deq. x for deq ∈ deqs]
50+ ps = filter (vars (deq. rhs for deq ∈ deqs)) do x
51+ x. known & ! isequal (x, iv)
52+ end |> collect
53+
54+ jac = RefValue (Matrix {Expression} (undef, 0 , 0 ))
55+
56+ new (deqs, iv, dvs, ps, jac)
57+ end
5658end
5759
5860
5961function calculate_jacobian (sys:: DiffEqSystem )
6062 isempty (sys. jac[]) || return sys. jac[] # use cached Jacobian, if possible
61- rhs = [eq. rhs for eq in sys. eqs]
63+ rhs = [eq. rhs for eq ∈ sys. eqs]
6264
63- jac = expand_derivatives .(calculate_jacobian (rhs, sys. dvs))
65+ jac = expand_derivatives .(calculate_jacobian (rhs, sys. dvs, sys . iv ))
6466 sys. jac[] = jac # cache Jacobian
6567 return jac
6668end
@@ -70,16 +72,30 @@ function generate_jacobian(sys::DiffEqSystem; version::FunctionVersion = ArrayFu
7072 return build_function (jac, sys. dvs, sys. ps, (sys. iv. name,); version = version)
7173end
7274
75+ struct DiffEqToExpr
76+ sys:: DiffEqSystem
77+ end
78+ function (f:: DiffEqToExpr )(O:: Operation )
79+ if isa (O. op, Variable)
80+ isequal (O. op, f. sys. iv) && return O. op. name # independent variable
81+ O. op ∈ f. sys. dvs && return O. op. name # dependent variables
82+ isempty (O. args) && return O. op. name # 0-ary parameters
83+ return build_expr (:call , Any[O. op. name; f .(O. args)])
84+ end
85+ return build_expr (:call , Any[O. op; f .(O. args)])
86+ end
87+ (f:: DiffEqToExpr )(x) = convert (Expr, x)
88+
7389function generate_function (sys:: DiffEqSystem ; version:: FunctionVersion = ArrayFunction)
74- rhss = [eq . rhs for eq ∈ sys. eqs]
75- return build_function (rhss, sys. dvs, sys. ps, (sys. iv. name,); version = version)
90+ rhss = [deq . rhs for deq ∈ sys. eqs]
91+ return build_function (rhss, sys. dvs, sys. ps, (sys. iv. name,), DiffEqToExpr (sys) ; version = version)
7692end
7793
7894
7995function generate_ode_iW (sys:: DiffEqSystem , simplify= true ; version:: FunctionVersion = ArrayFunction)
8096 jac = calculate_jacobian (sys)
8197
82- gam = Variable (:gam ; known = true )
98+ gam = Variable (:gam ; known = true )()
8399
84100 W = LinearAlgebra. I - gam* jac
85101 W = SMatrix {size(W,1),size(W,2)} (W)
0 commit comments