Skip to content

Commit 02ed45d

Browse files
add and test automatic multithreading
1 parent 6b8fbc3 commit 02ed45d

File tree

4 files changed

+68
-2
lines changed

4 files changed

+68
-2
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using StaticArrays, LinearAlgebra, SparseArrays
55
using Latexify, Unitful
66
using MacroTools
77

8-
using MacroTools
8+
using Base.Threads
99
import MacroTools: splitdef, combinedef, postwalk, striplines
1010
import GeneralizedGenerated
1111
using DocStringExtensions

src/build_function.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ end
4747
function _build_function(target::JuliaTarget, rhss, vs, ps = (), args = (),
4848
conv = simplified_expr, expression = Val{true};
4949
checkbounds = false, constructor=nothing,
50-
linenumbers = true)
50+
linenumbers = true, multithread=true)
5151
_vs = map(x-> x isa Operation ? x.op : x, vs)
5252
_ps = map(x-> x isa Operation ? x.op : x, ps)
5353
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(_vs)]
@@ -67,6 +67,21 @@ function _build_function(target::JuliaTarget, rhss, vs, ps = (), args = (),
6767

6868
ip_let_expr = Expr(:let, var_eqs, build_expr(:block, ip_sys_exprs))
6969

70+
if multithread
71+
lens = Int(ceil(length(ip_let_expr.args[2].args)/Threads.nthreads()))
72+
threaded_exprs = vcat([quote
73+
Threads.@spawn begin
74+
$(ip_let_expr.args[2].args[((i-1)*lens+1):i*lens]...)
75+
end
76+
end for i in 1:Threads.nthreads()-1],
77+
quote
78+
Threads.@spawn begin
79+
$(ip_let_expr.args[2].args[((Threads.nthreads()-1)*lens+1):end]...)
80+
end
81+
end)
82+
ip_let_expr.args[2] = ModelingToolkit.build_expr(:block, threaded_exprs)
83+
end
84+
7085
tuple_sys_expr = build_expr(:tuple, [conv(rhs) for rhs rhss])
7186

7287
if rhss isa Matrix

test/bigsystem.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
using ModelingToolkit, LinearAlgebra
2+
3+
# Define the constants for the PDE
4+
const α₂ = 1.0
5+
const α₃ = 1.0
6+
const β₁ = 1.0
7+
const β₂ = 1.0
8+
const β₃ = 1.0
9+
const r₁ = 1.0
10+
const r₂ = 1.0
11+
const D = 100.0
12+
const γ₁ = 0.1
13+
const γ₂ = 0.1
14+
const γ₃ = 0.1
15+
const N = 8
16+
const X = reshape([i for i in 1:N for j in 1:N],N,N)
17+
const Y = reshape([j for i in 1:N for j in 1:N],N,N)
18+
const α₁ = 1.0.*(X.>=4*N/5)
19+
20+
const Mx = Tridiagonal([1.0 for i in 1:N-1],[-2.0 for i in 1:N],[1.0 for i in 1:N-1])
21+
const My = copy(Mx)
22+
Mx[2,1] = 2.0
23+
Mx[end-1,end] = 2.0
24+
My[1,2] = 2.0
25+
My[end,end-1] = 2.0
26+
27+
# Define the initial condition as normal arrays
28+
@variables du[1:N,1:N,1:3] u[1:N,1:N,1:3] MyA[1:N,1:N] AMx[1:N,1:N] DA[1:N,1:N]
29+
30+
# Define the discretized PDE as an ODE function
31+
function f(du,u,p,t)
32+
A = @view u[:,:,1]
33+
B = @view u[:,:,2]
34+
C = @view u[:,:,3]
35+
dA = @view du[:,:,1]
36+
dB = @view du[:,:,2]
37+
dC = @view du[:,:,3]
38+
mul!(MyA,My,A)
39+
mul!(AMx,A,Mx)
40+
@. DA = D*(MyA + AMx)
41+
@. dA = DA + α₁ - β₁*A - r₁*A*B + r₂*C
42+
@. dB = α₂ - β₂*B - r₁*A*B + r₂*C
43+
@. dC = α₃ - β₃*C + r₁*A*B - r₂*C
44+
end
45+
46+
f(du,u,nothing,0.0)
47+
48+
ModelingToolkit.build_function(du,u,multithread=true);
49+
jac = ModelingToolkit.jacobian(vec(du),vec(u),simplify=false)
50+
ModelingToolkit.build_function(vec(jac),u,multithread=true);

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ using ModelingToolkit, Test
1515
@testset "Constraints Test" begin include("constraints.jl") end
1616
@testset "PDE Construction Test" begin include("pde.jl") end
1717
@testset "Distributed Test" begin include("distributed.jl") end
18+
@testset "Test Big System Usage" begin include("bigsystem.jl") end
1819
#@testset "Latexify recipes Test" begin include("latexify.jl") end

0 commit comments

Comments
 (0)