Skip to content

Commit 870b331

Browse files
Merge pull request #296 from SciML/expression
fix composed expression generation
2 parents 868e59e + 77223a7 commit 870b331

File tree

5 files changed

+38
-31
lines changed

5 files changed

+38
-31
lines changed

README.md

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ sys = ODESystem(eqs)
4242
sys = ode_order_lowering(sys)
4343

4444
u0 = [D(x) => 2.0,
45-
x => 1.0,
45+
x => 1.0,
4646
y => 0.0,
4747
z => 0.0]
4848

@@ -53,6 +53,7 @@ p = [σ => 28.0,
5353
tspan = (0.0,100.0)
5454
prob = ODEProblem(sys,u0,tspan,p,jac=true)
5555
sol = solve(prob,Tsit5())
56+
using Plots; plot(sol,vars=(:x,:y))
5657
```
5758

5859
![Lorenz2](https://user-images.githubusercontent.com/1814174/79118645-744eb580-7d5c-11ea-9c37-13c4efd585ca.png)
@@ -77,26 +78,28 @@ lorenz2 = ODESystem(eqs,name=:lorenz2)
7778

7879
@variables α
7980
@parameters γ
80-
connections = [0 ~ lorenz1.x + lorenz2.y + sin(α*γ)]
81+
connections = [0 ~ lorenz1.x + lorenz2.y + α*γ]
8182
connected = ODESystem(connections,t,[α],[γ],systems=[lorenz1,lorenz2])
8283

8384
u0 = [lorenz1.x => 1.0,
8485
lorenz1.y => 0.0,
8586
lorenz1.z => 0.0,
86-
lorenz2.x => 1.0,
87-
lorenz2.y => 0.0,
88-
lorenz2.z => 0.0,
89-
α => 2.0]
87+
lorenz2.x => 0.0,
88+
lorenz2.y => 1.0,
89+
lorenz2.z => 0.0,
90+
α => 2.0]
9091

91-
p = [lorenz1.σ => 28.0,
92-
lorenz1.ρ => 10.0,
92+
p = [lorenz1.σ => 10.0,
93+
lorenz1.ρ => 28.0,
9394
lorenz1.β => 8/3,
94-
lorenz2.σ => 28.0,
95-
lorenz2.ρ => 10.0,
96-
lorenz2.β => 8/3,
97-
γ => 2.0]
95+
lorenz2.σ => 10.0,
96+
lorenz2.ρ => 28.0,
97+
lorenz2.β => 8/3,
98+
γ => 2.0]
9899

99100
tspan = (0.0,100.0)
100101
prob = ODEProblem(connected,u0,tspan,p)
101102
sol = solve(prob,Rodas5())
103+
104+
using Plots; plot(sol,vars=(,Symbol(lorenz1.x),Symbol(lorenz2.y)))
102105
```

src/systems/abstractsystem.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,18 @@ function parameters(sys::AbstractSystem,args...)
8383
newname = renamespace(extra_names,name)
8484
rename(x,renamespace(sys.name,newname))()
8585
end
86+
87+
struct AbstractSysToExpr
88+
sys::AbstractSystem
89+
states::Vector{Variable}
90+
end
91+
AbstractSysToExpr(sys) = AbstractSysToExpr(sys,states(sys))
92+
function (f::AbstractSysToExpr)(O::Operation)
93+
any(isequal(O), f.states) && return O.op.name # variables
94+
if isa(O.op, Variable)
95+
isempty(O.args) && return O.op.name # 0-ary parameters
96+
return build_expr(:call, Any[O.op.name; f.(O.args)])
97+
end
98+
return build_expr(:call, Any[O.op; f.(O.args)])
99+
end
100+
(f::AbstractSysToExpr)(x) = convert(Expr, x)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@ end
2222

2323
struct ODEToExpr
2424
sys::AbstractODESystem
25+
states::Vector{Variable}
2526
end
27+
ODEToExpr(sys) = ODEToExpr(sys,states(sys))
2628
function (f::ODEToExpr)(O::Operation)
2729
if isa(O.op, Variable)
2830
isequal(O.op, f.sys.iv) && return O.op.name # independent variable
29-
O.op f.sys.states && return O.op.name # dependent variables
31+
O.op f.states && return O.op.name # dependent variables
3032
isempty(O.args) && return O.op.name # 0-ary parameters
3133
return build_expr(:call, Any[O.op.name; f.(O.args)])
3234
end

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,27 +53,14 @@ function generate_jacobian(sys::NonlinearSystem, vs = states(sys), ps = paramete
5353
if sparse
5454
jac = SparseArrays.sparse(jac)
5555
end
56-
return build_function(jac, convert.(Variable,vs), convert.(Variable,ps), (), NLSysToExpr(sys))
56+
return build_function(jac, convert.(Variable,vs), convert.(Variable,ps), (), AbstractSysToExpr(sys))
5757
end
5858

59-
struct NLSysToExpr
60-
sys::NonlinearSystem
61-
end
62-
function (f::NLSysToExpr)(O::Operation)
63-
any(isequal(O), f.sys.states) && return O.op.name # variables
64-
if isa(O.op, Variable)
65-
isempty(O.args) && return O.op.name # 0-ary parameters
66-
return build_expr(:call, Any[O.op.name; f.(O.args)])
67-
end
68-
return build_expr(:call, Any[O.op; f.(O.args)])
69-
end
70-
(f::NLSysToExpr)(x) = convert(Expr, x)
71-
7259
function generate_function(sys::NonlinearSystem, vs = states(sys), ps = parameters(sys), expression = Val{true}; kwargs...)
7360
rhss = [eq.rhs for eq sys.eqs]
7461
vs′ = convert.(Variable,vs)
7562
ps′ = convert.(Variable,ps)
76-
return build_function(rhss, vs′, ps′, (), NLSysToExpr(sys), expression; kwargs...)
63+
return build_function(rhss, vs′, ps′, (), AbstractSysToExpr(sys), expression; kwargs...)
7764
end
7865

7966
"""

src/systems/optimization/optimizationsystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ end
4545

4646
function generate_gradient(sys::OptimizationSystem, vs = states(sys), ps = parameters(sys), expression = Val{true}; kwargs...)
4747
grad = calculate_gradient(sys)
48-
return build_function(grad, convert.(Variable,vs), convert.(Variable,ps), (), x->convert(Expr, x); kwargs...)
48+
return build_function(grad, convert.(Variable,vs), convert.(Variable,ps), (), AbstractSysToExpr(sys); kwargs...)
4949
end
5050

5151
function calculate_hessian(sys::OptimizationSystem)
@@ -58,13 +58,13 @@ function generate_hessian(sys::OptimizationSystem, vs = states(sys), ps = parame
5858
if sparse
5959
hes = sparse(hes)
6060
end
61-
return build_function(hes, convert.(Variable,vs), convert.(Variable,ps), (), x->convert(Expr, x); kwargs...)
61+
return build_function(hes, convert.(Variable,vs), convert.(Variable,ps), (), AbstractSysToExpr(sys); kwargs...)
6262
end
6363

6464
function generate_function(sys::OptimizationSystem, vs = states(sys), ps = parameters(sys), expression = Val{true}; kwargs...)
6565
vs′ = convert.(Variable,vs)
6666
ps′ = convert.(Variable,ps)
67-
return build_function(equations(sys), vs′, ps′, (), x->convert(Expr, x), expression; kwargs...)
67+
return build_function(equations(sys), vs′, ps′, (), AbstractSysToExpr(sys), expression; kwargs...)
6868
end
6969

7070
equations(sys::OptimizationSystem) = isempty(sys.systems) ? sys.op : sys.op + reduce(+,namespace_operation.(sys.systems))

0 commit comments

Comments
 (0)