@@ -4,6 +4,12 @@ mutable struct DiffEqSystem <: AbstractSystem
44 dvs:: Vector{Variable}
55 ps:: Vector{Variable}
66 jac:: Matrix{Expression}
7+ function DiffEqSystem (eqs, ivs, dvs, ps, jac)
8+ all (! isintermediate, eqs) ||
9+ throw (ArgumentError (" no intermediate equations permitted in DiffEqSystem" ))
10+
11+ new (eqs, ivs, dvs, ps, jac)
12+ end
713end
814
915DiffEqSystem (eqs, ivs, dvs, ps) = DiffEqSystem (eqs, ivs, dvs, ps, Matrix {Expression} (undef,0 ,0 ))
@@ -20,6 +26,9 @@ function DiffEqSystem(eqs, ivs)
2026 DiffEqSystem (eqs, ivs, dvs, ps, Matrix {Expression} (undef,0 ,0 ))
2127end
2228
29+ isintermediate (eq:: Equation ) = ! (isa (eq. lhs, Operation) && isa (eq. lhs. op, Differential))
30+
31+
2332function generate_ode_function (sys:: DiffEqSystem ;version = ArrayFunction)
2433 var_exprs = [:($ (sys. dvs[i]. name) = u[$ i]) for i in eachindex (sys. dvs)]
2534 param_exprs = [:($ (sys. ps[i]. name) = p[$ i]) for i in eachindex (sys. ps)]
@@ -42,8 +51,6 @@ function generate_ode_function(sys::DiffEqSystem;version = ArrayFunction)
4251 end
4352end
4453
45- isintermediate (eq:: Equation ) = ! (isa (eq. lhs, Operation) && isa (eq. lhs. op, Differential))
46-
4754function build_equals_expr (eq:: Equation )
4855 @assert ! isintermediate (eq)
4956
@@ -52,13 +59,7 @@ function build_equals_expr(eq::Equation)
5259end
5360
5461function calculate_jacobian (sys:: DiffEqSystem , simplify= true )
55- calcs, diff_exprs = partition (isintermediate, sys. eqs)
56- rhs = [eq. rhs for eq in diff_exprs]
57-
58- # Handle intermediate calculations by substitution
59- for calc ∈ calcs
60- find_replace! .(rhs, calc. lhs, calc. rhs)
61- end
62+ rhs = [eq. rhs for eq in sys. eqs]
6263
6364 sys_exprs = calculate_jacobian (rhs, sys. dvs)
6465 sys_exprs = Expression[expand_derivatives (expr) for expr in sys_exprs]
6869function generate_ode_jacobian (sys:: DiffEqSystem , simplify= true )
6970 var_exprs = [:($ (sys. dvs[i]. name) = u[$ i]) for i in eachindex (sys. dvs)]
7071 param_exprs = [:($ (sys. ps[i]. name) = p[$ i]) for i in eachindex (sys. ps)]
71- diff_exprs = filter (! isintermediate, sys. eqs)
7272 jac = calculate_jacobian (sys, simplify)
7373 sys. jac = jac
7474 jac_exprs = [:(J[$ i,$ j] = $ (convert (Expr, jac[i,j]))) for i in 1 : size (jac,1 ), j in 1 : size (jac,2 )]
8080function generate_ode_iW (sys:: DiffEqSystem , simplify= true )
8181 var_exprs = [:($ (sys. dvs[i]. name) = u[$ i]) for i in eachindex (sys. dvs)]
8282 param_exprs = [:($ (sys. ps[i]. name) = p[$ i]) for i in eachindex (sys. ps)]
83- diff_exprs = filter (! isintermediate, sys. eqs)
8483 jac = sys. jac
8584
8685 gam = Parameter (:gam )
0 commit comments