Skip to content

Commit c7e3866

Browse files
committed
update JumpSystem for auto-alg support
1 parent adb6a84 commit c7e3866

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

src/systems/jumps/jumpsystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,8 @@ jprob = JumpProblem(complete(js), dprob, Direct())
426426
sol = solve(jprob, SSAStepper())
427427
```
428428
"""
429-
function JumpProcesses.JumpProblem(js::JumpSystem, prob, aggregator; callback = nothing,
429+
function JumpProcesses.JumpProblem(js::JumpSystem, prob,
430+
aggregator = JumpProcesses.NullAggregator(); callback = nothing,
430431
eval_expression = false, eval_module = @__MODULE__, kwargs...)
431432
if !iscomplete(js)
432433
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `JumpProblem`")

test/jumpsystem.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,16 +69,21 @@ parammap = [β => 0.1 / 1000, γ => 0.01]
6969
dprob = DiscreteProblem(js2, u₀map, tspan, parammap)
7070
jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false), rng = rng)
7171
Nsims = 30000
72-
function getmean(jprob, Nsims)
72+
function getmean(jprob, Nsims; use_stepper = true)
7373
m = 0.0
7474
for i in 1:Nsims
75-
sol = solve(jprob, SSAStepper())
75+
sol = use_stepper ? solve(jprob, SSAStepper()) : solve(jprob)
7676
m += sol[end, end]
7777
end
7878
m / Nsims
7979
end
8080
m = getmean(jprob, Nsims)
8181

82+
# test auto-alg selection works
83+
jprobb = JumpProblem(js2, dprob; save_positions = (false, false), rng)
84+
mb = getmean(jprobb, Nsims; use_stepper = false)
85+
@test abs(m - mb) / m < 0.01
86+
8287
@variables S2(t)
8388
obs = [S2 ~ 2 * S]
8489
@named js2b = JumpSystem([j₁, j₃], t, [S, I, R], [β, γ], observed = obs)
@@ -89,7 +94,6 @@ sol = solve(jprob, SSAStepper(), saveat = tspan[2] / 10)
8994
@test all(2 .* sol[S] .== sol[S2])
9095

9196
# test save_positions is working
92-
9397
jprob = JumpProblem(js2, dprob, Direct(), save_positions = (false, false), rng = rng)
9498
sol = solve(jprob, SSAStepper(), saveat = 1.0)
9599
@test all((sol.t) .== collect(0.0:tspan[2]))

0 commit comments

Comments
 (0)