|
| 1 | +struct Differential <: Function |
| 2 | + x::Expression |
| 3 | + order::Int |
| 4 | +end |
| 5 | +Differential(x) = Differential(x,1) |
| 6 | + |
| 7 | +Base.show(io::IO, D::Differential) = print(io,"($(D.x),$(D.order))") |
| 8 | +Base.Expr(D::Differential) = D |
| 9 | + |
| 10 | +function Derivative end |
| 11 | +(D::Differential)(x::Operation) = Operation(D, Expression[x]) |
| 12 | +function (D::Differential)(x::Variable) |
| 13 | + D.x === x && return Constant(1) |
| 14 | + has_dependent(x, D.x) || return Constant(0) |
| 15 | + return Variable(x,D) |
| 16 | +end |
| 17 | +Base.:(==)(D1::Differential, D2::Differential) = D1.order == D2.order && D1.x == D2.x |
| 18 | + |
| 19 | +Variable(x::Variable, D::Differential) = Variable(x.name,x.value,x.value_type, |
| 20 | + x.subtype,D,x.dependents,x.description,x.flow,x.domain, |
| 21 | + x.size,x.context) |
| 22 | + |
| 23 | +function expand_derivatives(O::Operation) |
| 24 | + @. O.args = expand_derivatives(O.args) |
| 25 | + |
| 26 | + if O.op isa Differential |
| 27 | + D = O.op |
| 28 | + o = O.args[1] |
| 29 | + return simplify_constants(sum(i->Derivative(o,i)*expand_derivatives(D(o.args[i])),1:length(o.args))) |
| 30 | + end |
| 31 | + |
| 32 | + return O |
| 33 | +end |
| 34 | +expand_derivatives(x::Variable) = x |
| 35 | + |
| 36 | +# Don't specialize on the function here |
| 37 | +function Derivative(O::Operation,idx) |
| 38 | + # This calls the Derivative dispatch from the user or pre-defined code |
| 39 | + Derivative(O.op, O.args, Val(idx)) |
| 40 | +end |
| 41 | +Derivative(op, args, idx) = Derivative(op, (args...,), idx) |
| 42 | + |
| 43 | +# Pre-defined derivatives |
| 44 | +import DiffRules, SpecialFunctions, NaNMath |
| 45 | +for (modu, fun, arity) ∈ DiffRules.diffrules() |
| 46 | + for i ∈ 1:arity |
| 47 | + @eval function Derivative(::typeof($modu.$fun), args::NTuple{$arity,Any}, ::Val{$i}) |
| 48 | + M, f = $(modu, fun) |
| 49 | + partials = DiffRules.diffrule(M, f, args...) |
| 50 | + dx = @static $arity == 1 ? partials : partials[$i] |
| 51 | + parse(Operation,dx) |
| 52 | + end |
| 53 | + end |
| 54 | +end |
| 55 | + |
| 56 | +function count_order(x) |
| 57 | + @assert !(x isa Symbol) "The variable $x must have an order of differentiation that is greater or equal to 1!" |
| 58 | + n = 1 |
| 59 | + while !(x.args[1] isa Symbol) |
| 60 | + n = n+1 |
| 61 | + x = x.args[1] |
| 62 | + end |
| 63 | + n, x.args[1] |
| 64 | +end |
| 65 | + |
| 66 | +function _differential_macro(x) |
| 67 | + ex = Expr(:block) |
| 68 | + lhss = Symbol[] |
| 69 | + x = flatten_expr!(x) |
| 70 | + for di in x |
| 71 | + @assert di isa Expr && di.args[1] == :~ "@Deriv expects a form that looks like `@Deriv D''~t E'~t`" |
| 72 | + lhs = di.args[2] |
| 73 | + rhs = di.args[3] |
| 74 | + order, lhs = count_order(lhs) |
| 75 | + push!(lhss, lhs) |
| 76 | + expr = :($lhs = Differential($rhs, $order)) |
| 77 | + push!(ex.args, expr) |
| 78 | + end |
| 79 | + push!(ex.args, Expr(:tuple, lhss...)) |
| 80 | + ex |
| 81 | +end |
| 82 | + |
| 83 | +macro Deriv(x...) |
| 84 | + esc(_differential_macro(x)) |
| 85 | +end |
| 86 | + |
| 87 | +function calculate_jacobian(eqs,vars) |
| 88 | + Expression[Differential(vars[j])(eqs[i]) for i in 1:length(eqs), j in 1:length(vars)] |
| 89 | +end |
| 90 | + |
| 91 | +export Differential, Derivative, expand_derivatives, @Deriv, calculate_jacobian |
0 commit comments