|
1 | 1 | mutable struct DiffEqSystem <: AbstractSystem |
2 | | - eqs::Vector{Operation} |
| 2 | + eqs::Vector{Equation} |
3 | 3 | ivs::Vector{Variable} |
4 | 4 | dvs::Vector{Variable} |
5 | 5 | vs::Vector{Variable} |
@@ -41,71 +41,64 @@ function generate_ode_function(sys::DiffEqSystem;version = ArrayFunction) |
41 | 41 | param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)] |
42 | 42 | sys_exprs = build_equals_expr.(sys.eqs) |
43 | 43 | if version == ArrayFunction |
44 | | - dvar_exprs = [:(du[$i] = $(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in 1:length(sys.dvs)] |
45 | | - exprs = vcat(var_exprs,param_exprs,sys_exprs,dvar_exprs) |
46 | | - block = expr_arr_to_block(exprs) |
47 | | - :((du,u,p,t)->$(toexpr(block))) |
| 44 | + dvar_exprs = [:(du[$i] = $(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in 1:length(sys.dvs)] |
| 45 | + exprs = vcat(var_exprs,param_exprs,sys_exprs,dvar_exprs) |
| 46 | + block = expr_arr_to_block(exprs) |
| 47 | + :((du,u,p,t)->$(toexpr(block))) |
48 | 48 | elseif version == SArrayFunction |
49 | | - dvar_exprs = [:($(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in 1:length(sys.dvs)] |
50 | | - svector_expr = quote |
51 | | - E = eltype(tuple($(dvar_exprs...))) |
52 | | - T = StaticArrays.similar_type(typeof(u), E) |
53 | | - T($(dvar_exprs...)) |
54 | | - end |
55 | | - exprs = vcat(var_exprs,param_exprs,sys_exprs,svector_expr) |
56 | | - block = expr_arr_to_block(exprs) |
57 | | - :((u,p,t)->$(toexpr(block))) |
| 49 | + dvar_exprs = [:($(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in 1:length(sys.dvs)] |
| 50 | + svector_expr = quote |
| 51 | + E = eltype(tuple($(dvar_exprs...))) |
| 52 | + T = StaticArrays.similar_type(typeof(u), E) |
| 53 | + T($(dvar_exprs...)) |
| 54 | + end |
| 55 | + exprs = vcat(var_exprs,param_exprs,sys_exprs,svector_expr) |
| 56 | + block = expr_arr_to_block(exprs) |
| 57 | + :((u,p,t)->$(toexpr(block))) |
58 | 58 | end |
59 | 59 | end |
60 | 60 |
|
61 | | -isintermediate(eq) = eq.args[1].diff == nothing |
| 61 | +isintermediate(eq::Equation) = eq.lhs.diff === nothing |
62 | 62 |
|
63 | | -function build_equals_expr(eq) |
64 | | - @assert typeof(eq.args[1]) <: Variable |
65 | | - if !(isintermediate(eq)) |
66 | | - # Differential statement |
67 | | - :($(Symbol("$(eq.args[1].name)_$(eq.args[1].diff.x.name)")) = $(eq.args[2])) |
68 | | - else |
69 | | - # Intermediate calculation |
70 | | - :($(Symbol("$(eq.args[1].name)")) = $(eq.args[2])) |
71 | | - end |
| 63 | +function build_equals_expr(eq::Equation) |
| 64 | + @assert typeof(eq.lhs) <: Variable |
| 65 | + |
| 66 | + lhs = eq.lhs.name |
| 67 | + isintermediate(eq) || (lhs = Symbol(lhs, :_, "$(eq.lhs.diff.x.name)")) |
| 68 | + |
| 69 | + return :($lhs = $(convert(Expr, eq.rhs))) |
72 | 70 | end |
73 | 71 |
|
74 | | -function calculate_jacobian(sys::DiffEqSystem,simplify=true) |
75 | | - diff_idxs = map(eq->eq.args[1].diff !=nothing,sys.eqs) |
76 | | - diff_exprs = sys.eqs[diff_idxs] |
77 | | - rhs = [eq.args[2] for eq in diff_exprs] |
| 72 | +function calculate_jacobian(sys::DiffEqSystem, simplify=true) |
| 73 | + calcs, diff_exprs = partition(isintermediate, sys.eqs) |
| 74 | + rhs = [eq.rhs for eq in diff_exprs] |
| 75 | + |
78 | 76 | # Handle intermediate calculations by substitution |
79 | | - calcs = sys.eqs[.!(diff_idxs)] |
80 | | - for i in 1:length(calcs) |
81 | | - find_replace!.(rhs,calcs[i].args[1],calcs[i].args[2]) |
| 77 | + for calc ∈ calcs |
| 78 | + find_replace!.(rhs, calc.lhs, calc.rhs) |
82 | 79 | end |
83 | | - sys_exprs = calculate_jacobian(rhs,sys.dvs) |
| 80 | + |
| 81 | + sys_exprs = calculate_jacobian(rhs, sys.dvs) |
84 | 82 | sys_exprs = Expression[expand_derivatives(expr) for expr in sys_exprs] |
85 | | - if simplify |
86 | | - sys_exprs = Expression[simplify_constants(expr) for expr in sys_exprs] |
87 | | - end |
88 | 83 | sys_exprs |
89 | 84 | end |
90 | 85 |
|
91 | | -function generate_ode_jacobian(sys::DiffEqSystem,simplify=true) |
| 86 | +function generate_ode_jacobian(sys::DiffEqSystem, simplify=true) |
92 | 87 | var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in 1:length(sys.dvs)] |
93 | 88 | param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)] |
94 | | - diff_idxs = map(eq->eq.args[1].diff !=nothing,sys.eqs) |
95 | | - diff_exprs = sys.eqs[diff_idxs] |
96 | | - jac = calculate_jacobian(sys,simplify) |
| 89 | + diff_exprs = filter(!isintermediate, sys.eqs) |
| 90 | + jac = calculate_jacobian(sys, simplify) |
97 | 91 | sys.jac = jac |
98 | 92 | jac_exprs = [:(J[$i,$j] = $(convert(Expr, jac[i,j]))) for i in 1:size(jac,1), j in 1:size(jac,2)] |
99 | 93 | exprs = vcat(var_exprs,param_exprs,vec(jac_exprs)) |
100 | 94 | block = expr_arr_to_block(exprs) |
101 | 95 | :((J,u,p,t)->$(block)) |
102 | 96 | end |
103 | 97 |
|
104 | | -function generate_ode_iW(sys::DiffEqSystem,simplify=true) |
| 98 | +function generate_ode_iW(sys::DiffEqSystem, simplify=true) |
105 | 99 | var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in 1:length(sys.dvs)] |
106 | 100 | param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)] |
107 | | - diff_idxs = map(eq->eq.args[1].diff !=nothing,sys.eqs) |
108 | | - diff_exprs = sys.eqs[diff_idxs] |
| 101 | + diff_exprs = filter(!isintermediate, sys.eqs) |
109 | 102 | jac = sys.jac |
110 | 103 |
|
111 | 104 | gam = Variable(:gam) |
@@ -144,5 +137,6 @@ function DiffEqBase.ODEFunction(sys::DiffEqSystem;version = ArrayFunction,kwargs |
144 | 137 | end |
145 | 138 | end |
146 | 139 |
|
| 140 | + |
147 | 141 | export DiffEqSystem, ODEFunction |
148 | 142 | export generate_ode_function |
0 commit comments