Skip to content

Commit 416f5e8

Browse files
Merge pull request #267 from SciML/calc
Fix calculated functions on components
2 parents 1497481 + 4a9a4e2 commit 416f5e8

File tree

4 files changed

+98
-32
lines changed

4 files changed

+98
-32
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
function calculate_tgrad(sys::AbstractODESystem)
22
isempty(sys.tgrad[]) || return sys.tgrad[] # use cached tgrad, if possible
3-
rhs = [detime_dvs(eq.rhs) for eq sys.eqs]
3+
rhs = [detime_dvs(eq.rhs) for eq equations(sys)]
44
iv = sys.iv()
55
notime_tgrad = [expand_derivatives(ModelingToolkit.Differential(iv)(r)) for r in rhs]
6-
tgrad = retime_dvs.(notime_tgrad,(sys.dvs,),iv)
6+
tgrad = retime_dvs.(notime_tgrad,(states(sys),),iv)
77
sys.tgrad[] = tgrad
88
return tgrad
99
end
1010

1111
function calculate_jacobian(sys::AbstractODESystem)
1212
isempty(sys.jac[]) || return sys.jac[] # use cached Jacobian, if possible
13-
rhs = [eq.rhs for eq sys.eqs]
13+
rhs = [eq.rhs for eq equations(sys)]
1414

1515
iv = sys.iv()
16-
dvs = [dv(iv) for dv sys.dvs]
16+
dvs = [dv(iv) for dv states(sys)]
1717

1818
jac = expand_derivatives.(calculate_jacobian(rhs, dvs))
1919
sys.jac[] = jac # cache Jacobian
@@ -34,18 +34,18 @@ function (f::ODEToExpr)(O::Operation)
3434
end
3535
(f::ODEToExpr)(x) = convert(Expr, x)
3636

37-
function generate_tgrad(sys::AbstractODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true}; kwargs...)
37+
function generate_tgrad(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys), expression = Val{true}; kwargs...)
3838
tgrad = calculate_tgrad(sys)
3939
return build_function(tgrad, dvs, ps, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
4040
end
4141

42-
function generate_jacobian(sys::AbstractODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true}; kwargs...)
42+
function generate_jacobian(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys), expression = Val{true}; kwargs...)
4343
jac = calculate_jacobian(sys)
4444
return build_function(jac, dvs, ps, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
4545
end
4646

47-
function generate_function(sys::AbstractODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true}; kwargs...)
48-
rhss = [deq.rhs for deq sys.eqs]
47+
function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys), expression = Val{true}; kwargs...)
48+
rhss = [deq.rhs for deq equations(sys)]
4949
dvs′ = [clean(dv) for dv dvs]
5050
ps′ = [clean(p) for p ps]
5151
return build_function(rhss, dvs′, ps′, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
@@ -75,7 +75,7 @@ function calculate_factorized_W(sys::AbstractODESystem, simplify=true)
7575
(Wfact,Wfact_t)
7676
end
7777

78-
function generate_factorized_W(sys::AbstractODESystem, vs = sys.dvs, ps = sys.ps, simplify=true, expression = Val{true}; kwargs...)
78+
function generate_factorized_W(sys::AbstractODESystem, vs = states(sys), ps = parameters(sys), simplify=true, expression = Val{true}; kwargs...)
7979
(Wfact,Wfact_t) = calculate_factorized_W(sys,simplify)
8080
siz = size(Wfact)
8181
constructor = :(x -> begin
@@ -90,13 +90,14 @@ function generate_factorized_W(sys::AbstractODESystem, vs = sys.dvs, ps = sys.ps
9090
end
9191

9292
function calculate_massmatrix(sys::AbstractODESystem, simplify=true)
93-
eqs = sys.eqs
93+
eqs = equations(sys)
94+
dvs = states(sys)
9495
M = zeros(length(eqs),length(eqs))
9596
for (i,eq) in enumerate(eqs)
9697
if eq.lhs isa Constant
9798
@assert eq.lhs.value == 0
9899
elseif eq.lhs.op isa Differential
99-
j = findfirst(x->isequal(x.name,var_from_nested_derivative(eq.lhs)[1].name),sys.dvs)
100+
j = findfirst(x->isequal(x.name,var_from_nested_derivative(eq.lhs)[1].name),dvs)
100101
M[i,j] = 1
101102
else
102103
error("Only semi-explicit mass matrices are currently supported")
@@ -113,7 +114,8 @@ Create an `ODEFunction` from the [`ODESystem`](@ref). The arguments `dvs` and `p
113114
are used to set the order of the dependent variable and parameter vectors,
114115
respectively.
115116
"""
116-
function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = sys.dvs, ps = sys.ps;
117+
function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
118+
ps = parameters(sys);
117119
version = nothing, tgrad=false,
118120
jac = false, Wfact = false) where {iip}
119121
f_oop,f_iip = generate_function(sys, dvs, ps, Val{false})
@@ -166,14 +168,14 @@ function DiffEqBase.ODEFunction(sys::AbstractODESystem, args...; kwargs...)
166168
end
167169

168170
function namespace_variables(sys::AbstractODESystem)
169-
[rename(x,renamespace(sys.name,x.name)) for x in sys.dvs]
171+
[rename(x,renamespace(sys.name,x.name)) for x in states(sys)]
170172
end
171173

172174
function namespace_parameters(sys::AbstractODESystem)
173-
[rename(x,renamespace(sys.name,x.name)) for x in sys.ps]
175+
[rename(x,renamespace(sys.name,x.name)) for x in parameters(sys)]
174176
end
175177

176-
namespace_equations(sys::AbstractODESystem) = namespace_equation.(sys.eqs,sys.name,sys.iv.name)
178+
namespace_equations(sys::AbstractODESystem) = namespace_equation.(equations(sys),sys.name,sys.iv.name)
177179

178180
function namespace_equation(eq::Equation,name,ivname)
179181
_lhs = namespace_operation(eq.lhs,name,ivname)
@@ -191,11 +193,11 @@ end
191193
namespace_operation(O::Constant,name,ivname) = O
192194

193195
independent_variable(sys::AbstractODESystem) = sys.iv
194-
states(sys::AbstractODESystem) = [sys.dvs;reduce(vcat,namespace_variables.(sys.systems))]
195-
parameters(sys::AbstractODESystem) = [sys.ps;reduce(vcat,namespace_parameters.(sys.systems))]
196+
states(sys::AbstractODESystem) = isempty(sys.systems) ? sys.dvs : [sys.dvs;reduce(vcat,namespace_variables.(sys.systems))]
197+
parameters(sys::AbstractODESystem) = isempty(sys.systems) ? sys.ps : [sys.ps;reduce(vcat,namespace_parameters.(sys.systems))]
196198

197199
function equations(sys::AbstractODESystem)
198-
[sys.eqs;reduce(vcat,namespace_equations.(sys.systems))]
200+
isempty(sys.systems) ? sys.eqs : [sys.eqs;reduce(vcat,namespace_equations.(sys.systems))]
199201
end
200202

201203
function states(sys::AbstractODESystem,name::Symbol)
@@ -208,6 +210,18 @@ function parameters(sys::AbstractODESystem,name::Symbol)
208210
Variable(Symbol(string(sys.name)*""*string(x.name)),known=x.known)(sys.iv())
209211
end
210212

213+
function states(sys::AbstractODESystem,args...)
214+
name = last(args)
215+
extra_names = reduce(*,["$(x.name)" for x in args[1:end-1]])
216+
Variable(Symbol(string(sys.name)*extra_names*""*string(name)))(sys.iv())
217+
end
218+
219+
function parameters(sys::AbstractODESystem,args...)
220+
name = last(args)
221+
extra_names = reduce(*,["$(x.name)" for x in args[1:end-1]])
222+
Variable(Symbol(string(sys.name)*extra_names*""*string(name)))(sys.iv())
223+
end
224+
211225
function _eq_unordered(a, b)
212226
length(a) === length(b) || return false
213227
n = length(a)

src/systems/nonlinear/nonlinear_system.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,8 @@ function _find_params(O, vs)
6262
return ps
6363
end
6464

65-
66-
independent_variables(sys::NonlinearSystem) = Set{Variable}()
67-
dependent_variables(sys::NonlinearSystem) = Set{Expression}(sys.vs)
68-
parameters(sys::NonlinearSystem) = Set{Variable}(sys.ps)
65+
states(sys::NonlinearSystem) = sys.vs
66+
parameters(sys::NonlinearSystem) = sys.ps
6967

7068

7169
function calculate_jacobian(sys::NonlinearSystem)

test/components.jl

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ lorenz2 = ODESystem(eqs,t,[x,y,z],[σ,ρ,β],name=:lorenz2)
1818
@variables a(t)
1919
connnectedeqs = [D(a) ~ a*states(lorenz1,:x)]
2020

21-
connected = ODESystem(connnectedeqs,t,[a],[α],systems=[lorenz1,lorenz2],name=:connectedlorenz)
21+
connected1 = ODESystem(connnectedeqs,t,[a],[α],systems=[lorenz1,lorenz2],name=:connected1)
2222

2323
@variables lorenz1′x(t) lorenz1′y(t) lorenz1′z(t) lorenz2′x(t) lorenz2′y(t) lorenz2′z(t)
2424
@parameters lorenz1′σ lorenz1′ρ lorenz1′β lorenz2′σ lorenz2′ρ lorenz2′β
@@ -31,6 +31,60 @@ eqs_flat = [D(a) ~ a*lorenz1′x,
3131
D(lorenz2′y) ~ lorenz2′x*(lorenz2′ρ-lorenz2′z)-lorenz2′y,
3232
0 ~ lorenz2′x + lorenz2′y + lorenz2′β*lorenz2′z]
3333

34-
@test [x.name for x in states(connected)] == [:a,:lorenz1′x,:lorenz1′y,:lorenz1′z,:lorenz2′x,:lorenz2′y,:lorenz2′z]
35-
@test [x.name for x in parameters(connected)] == [,:lorenz1′σ,:lorenz1′ρ,:lorenz1′β,:lorenz2′σ,:lorenz2′ρ,:lorenz2′β]
36-
@test eqs_flat == equations(connected)
34+
@test [x.name for x in states(connected1)] == [:a,:lorenz1′x,:lorenz1′y,:lorenz1′z,:lorenz2′x,:lorenz2′y,:lorenz2′z]
35+
@test [x.name for x in parameters(connected1)] == [,:lorenz1′σ,:lorenz1′ρ,:lorenz1′β,:lorenz2′σ,:lorenz2′ρ,:lorenz2′β]
36+
@test eqs_flat == equations(connected1)
37+
38+
connected2 = ODESystem(connnectedeqs,t,[a],[α],systems=[lorenz1,lorenz2],name=:connected2)
39+
40+
@parameters γ
41+
@variables g(t)
42+
connnectedeqs2 = [D(g) ~ g*states(connected1,lorenz1,:x)]
43+
doublelevel = ODESystem(connnectedeqs2,t,[g],[γ],systems=[connected1,connected2],name=:doublelevel)
44+
45+
@test [x.name for x in states(doublelevel)] == [:g,
46+
:connected1′a,:connected1′lorenz1′x,:connected1′lorenz1′y,:connected1′lorenz1′z,:connected1′lorenz2′x,:connected1′lorenz2′y,:connected1′lorenz2′z,
47+
:connected2′a,:connected2′lorenz1′x,:connected2′lorenz1′y,:connected2′lorenz1′z,:connected2′lorenz2′x,:connected2′lorenz2′y,:connected2′lorenz2′z]
48+
@test [x.name for x in parameters(doublelevel)] == [,
49+
:connected1′α,:connected1′lorenz1′σ,:connected1′lorenz1′ρ,:connected1′lorenz1′β,:connected1′lorenz2′σ,:connected1′lorenz2′ρ,:connected1′lorenz2′β,
50+
:connected2′α,:connected2′lorenz1′σ,:connected2′lorenz1′ρ,:connected2′lorenz1′β,:connected2′lorenz2′σ,:connected2′lorenz2′ρ,:connected2′lorenz2′β]
51+
52+
@variables connected1′a(t) connected1′lorenz1′x(t) connected1′lorenz1′y(t) connected1′lorenz1′z(t) connected1′lorenz2′x(t) connected1′lorenz2′y(t) connected1′lorenz2′z(t)
53+
@variables connected2′a(t) connected2′lorenz1′x(t) connected2′lorenz1′y(t) connected2′lorenz1′z(t) connected2′lorenz2′x(t) connected2′lorenz2′y(t) connected2′lorenz2′z(t)
54+
@parameters connected1′α connected1′lorenz1′σ connected1′lorenz1′ρ connected1′lorenz1′β connected1′lorenz2′σ connected1′lorenz2′ρ connected1′lorenz2′β
55+
@parameters connected2′α connected2′lorenz1′σ connected2′lorenz1′ρ connected2′lorenz1′β connected2′lorenz2′σ connected2′lorenz2′ρ connected2′lorenz2′β
56+
57+
eqs_flat2 = [D(g) ~ g*connected1′lorenz1′x,
58+
D(connected1′a) ~ connected1′a*connected1′lorenz1′x,
59+
D(connected1′lorenz1′x) ~ connected1′lorenz1′σ*(connected1′lorenz1′y-connected1′lorenz1′x),
60+
D(connected1′lorenz1′y) ~ connected1′lorenz1′x*(connected1′lorenz1′ρ-connected1′lorenz1′z)-connected1′lorenz1′y,
61+
0 ~ connected1′lorenz1′x + connected1′lorenz1′y + connected1′lorenz1′β*connected1′lorenz1′z,
62+
D(connected1′lorenz2′x) ~ connected1′lorenz2′σ*(connected1′lorenz2′y-connected1′lorenz2′x),
63+
D(connected1′lorenz2′y) ~ connected1′lorenz2′x*(connected1′lorenz2′ρ-connected1′lorenz2′z)-connected1′lorenz2′y,
64+
0 ~ connected1′lorenz2′x + connected1′lorenz2′y + connected1′lorenz2′β*connected1′lorenz2′z,
65+
D(connected2′a) ~ connected2′a*connected2′lorenz1′x,
66+
D(connected2′lorenz1′x) ~ connected2′lorenz1′σ*(connected2′lorenz1′y-connected2′lorenz1′x),
67+
D(connected2′lorenz1′y) ~ connected2′lorenz1′x*(connected2′lorenz1′ρ-connected2′lorenz1′z)-connected2′lorenz1′y,
68+
0 ~ connected2′lorenz1′x + connected2′lorenz1′y + connected2′lorenz1′β*connected2′lorenz1′z,
69+
D(connected2′lorenz2′x) ~ connected2′lorenz2′σ*(connected2′lorenz2′y-connected2′lorenz2′x),
70+
D(connected2′lorenz2′y) ~ connected2′lorenz2′x*(connected2′lorenz2′ρ-connected2′lorenz2′z)-connected2′lorenz2′y,
71+
0 ~ connected2′lorenz2′x + connected2′lorenz2′y + connected2′lorenz2′β*connected2′lorenz2′z]
72+
73+
@test eqs_flat2 == equations(doublelevel)
74+
75+
M = Array(I,15,15)
76+
M[5,5] = false
77+
M[8,8] = false
78+
M[12,12] = false
79+
M[15,15] = false
80+
@test calculate_massmatrix(doublelevel) == M
81+
82+
jac = [connected1′lorenz1′x 0 g zeros(1,12)
83+
zeros(7,1) calculate_jacobian(connected1) zeros(Expression,7,7)
84+
zeros(Expression,7,8) calculate_jacobian(connected2)]
85+
86+
jac2 = [connected1′lorenz1′x 0 g zeros(1,12)
87+
zeros(7,1) ModelingToolkit.namespace_operation.(calculate_jacobian(connected1),connected1.name,:t) zeros(Expression,7,7)
88+
zeros(Expression,7,8) ModelingToolkit.namespace_operation.(calculate_jacobian(connected2),connected2.name,:t)]
89+
90+
@test all(isequal.(calculate_jacobian(doublelevel),jac2))

test/odesystem.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,19 @@ end
2424
_clean(x::Variable) = x
2525
function test_diffeq_inference(name, sys, iv, dvs, ps)
2626
@testset "ODESystem construction: $name" begin
27-
@test independent_variables(sys) == Set([_clean(iv)])
28-
@test dependent_variables(sys) == Set(_clean.(dvs))
29-
@test parameters(sys) == Set(_clean.(ps))
27+
@test independent_variable(sys) == _clean(iv)
28+
@test Set(states(sys)) == Set(_clean.(dvs))
29+
@test Set(parameters(sys)) == Set(_clean.(ps))
3030
end
3131
end
3232
function test_nlsys_inference(name, sys, vs, ps)
3333
@testset "NonlinearSystem construction: $name" begin
34-
@test dependent_variables(sys) == Set(vs)
35-
@test parameters(sys) == Set(_clean.(ps))
34+
@test Set(states(sys)) == Set(vs)
35+
@test Set(parameters(sys)) == Set(_clean.(ps))
3636
end
3737
end
3838

39-
test_diffeq_inference("standard", de, t, (x, y, z), (σ, ρ, β))
39+
test_diffeq_inference("standard", de, t, [x, y, z], [ρ, σ, β])
4040
generate_function(de, [x,y,z], [σ,ρ,β])
4141
jac_expr = generate_jacobian(de)
4242
jac = calculate_jacobian(de)

0 commit comments

Comments
 (0)