Skip to content

Commit 4cc15bc

Browse files
committed
Add helpers for problem manipulation
1 parent ff8b9a0 commit 4cc15bc

File tree

5 files changed

+24
-4
lines changed

5 files changed

+24
-4
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
1010
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
13+
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1314
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1415
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
1516

src/ModelTesting.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ include("test/continuous/delta_sol.jl")
88
include("test/discrete/single_shooting.jl")
99
include("test/discrete/merge.jl")
1010
include("test/metric/metric.jl")
11+
include("problem_config/problem.jl")
1112
end

src/problem_config/problem.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
module ProblemHelpers
2+
import SciMLBase
3+
using DataFrames, DifferentialEquations
4+
# some thin wrappers over remake
5+
initial_condition(prob::SciMLBase.AbstractDEProblem, u0) = remake(prob, u0 = u0)
6+
tspan(prob::SciMLBase.AbstractDEProblem, tspan) = remake(prob, tspan = tspan)
7+
parameters(prob::SciMLBase.AbstractDEProblem, ps) = remake(prob, p = ps)
8+
9+
remake_solver_kwarg(prob, merge; kwargs...) = remake(prob; kwargs = mergewith(merge, kwargs, prob.kwargs))
10+
tstops(prob::SciMLBase.AbstractDEProblem, tstops) = remake_solver_kwarg(prob, (new, old) -> sort!([new; old]); tstops = tstops)
11+
tstops(prob::SciMLBase.AbstractDEProblem, tstops::DataFrame) = tstops(prob, collect(tstops[:, :timestamp]))
12+
13+
export initial_condition, tspan, parameters, tstops
14+
end

src/test/discrete/single_shooting.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@ function compare_discrete(sys, data_a::DataFrame, data_b::DataFrame, evaluations
99
@assert "timestamp" names(data_a) "The dataset A must contain a column named `timestamp`"
1010
@assert "timestamp" names(data_b) "The dataset B must contain a column named `timestamp`"
1111

12+
data_a = semijoin(data_a, data_b, on=:timestamp)
13+
data_b = semijoin(data_b, data_a, on=:timestamp)
14+
if nrow(data_a) == 0 || nrow(data_b) == 0
15+
throw("Comparison datasets are empty; check that the intersection of the discretization times is nonempty")
16+
end
17+
if nrow(data_a) <= 2 || nrow(data_b) <= 2
18+
@warn "Two or fewer (typically start and end) timestamps exist in the compared datasets; check that the time discretization of the comparisons align"
19+
end
1220
test_results = Dict{Symbol, Any}()
1321
for evaluation in evaluations
1422
test_results = mergewith(merge, test_results, Dict([evaluate(evaluation, sys, data_a, data_b); ]))

test/block_modeling.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@ using ModelingToolkitStandardLibrary.Blocks: Constant
3838
d3 = discretize_solution(sol3, sol1)
3939
results_good = compare_discrete(sys, d1, d2)
4040
results_bad = compare_discrete(sys, d1, d3)
41-
42-
println(results_bad)
4341
end
4442
@testset "RC Functional" begin
4543
R = 1.0
@@ -76,7 +74,5 @@ using ModelingToolkitStandardLibrary.Blocks: Constant
7674
d3 = discretize_solution(sol3, sol1)
7775
results_good = compare_discrete(sys, d1, d2)
7876
results_bad = compare_discrete(sys, d1, d3)
79-
80-
println(results_bad)
8177
end
8278
end

0 commit comments

Comments
 (0)