Skip to content

Commit b6f73f5

Browse files
committed
Refactor continous discretization outside of specific comparison methods
1 parent 0fc57c5 commit b6f73f5

File tree

5 files changed

+95
-120
lines changed

5 files changed

+95
-120
lines changed

ext/ModelTestingCalibration.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ function ModelTesting.validate(model::AbstractTimeDependentSystem, data; pem_coe
1515
else
1616
model_transformations = []
1717
end
18-
println(experiment_kwargs)
1918
experiment = Experiment(data, model;
2019
model_transformations = model_transformations,
2120
filter(arg->first(arg) != :model_transformations, experiment_kwargs)...)

src/test/continuous/delta_sol.jl

Lines changed: 44 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -16,101 +16,79 @@ function make_cols(names, rows)
1616
return cols
1717
end
1818

19-
function compare_solutions(
20-
(ref_name, reference)::Pair{Symbol, <:SciMLBase.AbstractTimeseriesSolution},
21-
sols::Vector{<:Pair{Symbol, <:SciMLBase.AbstractTimeseriesSolution}};
22-
knots=nothing)
23-
if isnothing(knots)
24-
knots = reference.t
19+
function discretize_solution(solution::SciMLBase.AbstractTimeseriesSolution, time_ref::SciMLBase.AbstractTimeseriesSolution)
20+
container = symbolic_container(time_ref)
21+
ref_t_vars = independent_variable_symbols(container)
22+
if length(ref_t_vars) > 1
23+
@error "PDE solutions not currently supported; only one iv is allowed"
2524
end
26-
results = Dict{Symbol, Any}()
27-
reference_container = symbolic_container(reference)
28-
containers = symbolic_container.(last.(sols))
29-
30-
measured_reference = measured_values(reference_container)
31-
sols_measured = measured_values.(containers)
32-
@assert all(_symbolic_subset.((measured_reference,), sols_measured)) "Test solutions must expose a superset of the reference's variables for comparison"
33-
@assert length(measured_reference) > 0 "Compared solutions must share at least one measured variable"
34-
measured = measured_reference
35-
36-
if knots != reference.t && !reference.dense
37-
@assert reference.dense "Interpolated evaluation points require a dense reference solution"
25+
return discretize_solution(solution, time_ref[first(ref_t_vars)])
26+
end
27+
function discretize_solution(solution::SciMLBase.AbstractTimeseriesSolution, time_ref::DataFrame)
28+
@assert "timestamp" names(time_ref) "The dataset B must contain a column named `timestamp`"
29+
return discretize_solution(solution, collect(time_ref[!, "timestamp"]))
30+
end
31+
function discretize_solution(solution::SciMLBase.AbstractTimeseriesSolution)
32+
container = symbolic_container(solution)
33+
ref_t_vars = independent_variable_symbols(container)
34+
if length(ref_t_vars) > 1
35+
@error "PDE solutions not currently supported; only one iv is allowed"
3836
end
39-
40-
ref_t_vars = independent_variable_symbols(reference_container)
37+
return discretize_solution(solution, solution[ref_t_var] )
38+
end
39+
function discretize_solution(solution::SciMLBase.AbstractTimeseriesSolution, time_ref::AbstractArray)
40+
container = symbolic_container(solution)
41+
measured = measured_values(container)
42+
ref_t_vars = independent_variable_symbols(container)
4143
if length(ref_t_vars) > 1
4244
@error "PDE solutions not currently supported; only one iv is allowed"
4345
end
4446
ref_t_var = first(ref_t_vars)
45-
if knots == reference.t
46-
ref_soln = reference[measured]
47-
cols = make_cols([Symbol(ref_t_var); namespace_symbol.((ref_name,), measured)], reference[[ref_t_var; measured]])
48-
else
49-
ref_soln = reference(knots, idxs=measured)
50-
cols = make_cols([Symbol(ref_t_var); namespace_symbol.((ref_name,), measured)], reference(knots, idxs=[ref_t_var; measured]))
51-
end
52-
testsols = Dict{Symbol, Any}()
53-
results[:metrics] = testsols
54-
for (name, test) in sols
55-
test_results = Dict{Symbol, Any}()
56-
if test.dense
57-
test_soln = test(knots, idxs=measured)
58-
elseif knots == sol.t
59-
test_soln = test[measured]
60-
else
61-
@warn "Comparison of $ref_name to $name is invalid (cannot match timebases); only computing final state error"
62-
continue
63-
end
64-
test_results[:final] = recursive_mean(abs.(ref_soln[end] - test_soln[end]))
65-
compute_error_metrics(test_results, ref_soln, test_soln)
66-
append!(cols, make_cols(namespace_symbol.((name,), measured), test_soln))
67-
testsols[name] = test_results
47+
48+
matching_timebase = solution[ref_t_var] == time_ref
49+
50+
if matching_timebase # if the time_ref match the timebase of the problem then use the value at the nodes regardless of if it's dense or sparse
51+
cols = make_cols(String["timestamp"; measured_names(measured)], solution[[ref_t_var; measured]])
52+
elseif solution.dense # continious-time solution, use the interpolant
53+
cols = make_cols(String["timestamp"; measured_names(measured)], solution(time_ref, idxs=[ref_t_var; measured]))
54+
else
55+
throw("Cannot discretize_solution sparse solution about a different timebase.")
6856
end
69-
results[:data] = DataFrame(cols)
70-
return results
57+
return DataFrame(cols)
7158
end
72-
export compare_solutions
59+
export discretize_solution
7360

7461
function compare_dense_solutions(
75-
(ref_name, reference)::Pair{Symbol, <:SciMLBase.AbstractTimeseriesSolution},
76-
sols::Vector{<:Pair{Symbol, <:SciMLBase.AbstractTimeseriesSolution}};
77-
integrator=Tsit5(),
78-
metric=abs
62+
reference::SciMLBase.AbstractTimeseriesSolution,
63+
sol::SciMLBase.AbstractTimeseriesSolution;
64+
integrator=Tsit5()
7965
)
8066
results = Dict{Symbol, Any}()
8167
reference_container = symbolic_container(reference)
82-
containers = symbolic_container.(last.(sols))
68+
containers = symbolic_container(sol)
8369

8470
measured_reference = measured_values(reference_container)
85-
sols_measured = measured_values.(containers)
86-
@assert all(_symbolic_subset.((measured_reference,), sols_measured)) "Test solutions must expose a superset of the reference's variables for comparison"
71+
sol_measured = measured_values(containers)
72+
@assert _symbolic_subset(measured_reference, sol_measured) "Test solutions must expose a superset of the reference's variables for comparison"
8773
@assert length(measured_reference) > 0 "Compared solutions must share at least one measured variable"
8874
measured = measured_reference
8975

9076
timebounds(sol) = (sol.t[1], sol.t[end])
9177
@assert reference.dense "Dense (integrated) comparision requires a dense reference solution"
92-
for (test_name, test_sol) in sols
93-
@assert test_sol.dense "Test solution $(test_name) must be dense in order to use continous-time comparison"
94-
@assert timebounds(test_sol) == timebounds(reference) "Test solution $(test_name) has time range $(timebounds(test_sol)) which differs from the reference $(timebounds(reference))"
95-
end
96-
78+
@assert sol.dense "Test solution must be dense in order to use continous-time comparison"
79+
@assert timebounds(sol) == timebounds(reference) "Test solution has time range $(timebounds(sol)) which differs from the reference $(timebounds(reference))"
9780

9881
ref_t_vars = independent_variable_symbols(reference_container)
9982
if length(ref_t_vars) > 1
10083
@error "PDE solutions not currently supported; only one iv is allowed"
10184
end
10285
ref_t_var = first(ref_t_vars)
10386

104-
state_size = length(measured)
10587
function compare!(du, u, p, t)
106-
offs = 1
107-
for (test_name, test_sol) in sols
108-
du[offs:offs+state_size-1] .= abs.(reference(t, idxs=measured) .- test_sol(t, idxs=measured))
109-
offs += state_size
110-
end
88+
du .= abs.(reference(t, idxs=measured) .- sol(t, idxs=measured))
11189
end
112-
func = ODEFunction(compare!; sys = SymbolCache(collect(Iterators.flatten([namespace_symbol.((test_name, ), measured) for (test_name, _) in sols])), [], ref_t_var))
113-
prob = ODEProblem(func, zeros(length(sols) * length(measured)), timebounds(reference))
90+
func = ODEFunction(compare!; sys = SymbolCache(collect(Symbol.(measured)), [], ref_t_var))
91+
prob = ODEProblem(func, zeros(length(measured)), timebounds(reference))
11492
soln = solve(prob, integrator)
11593
return soln
11694
end

src/test/discrete/single_shooting.jl

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,19 @@ function validate(model, data)
55
throw("Discrete comparison currently requires JSMO!")
66
end
77

8-
function compare_discrete_to_continous(
9-
sol::SciMLBase.AbstractTimeseriesSolution,
10-
data::DataFrame)
11-
@assert "timestamp" names(data) "The data must contain a column named `timestamp`"
12-
sort!(data, [:timestamp])
13-
@assert data[!, :timestamp] == sol.t "The data's discretization points and the solution time discretization must match"
14-
15-
results = Dict{Symbol, Any}()
16-
reference_container = symbolic_container(sol)
17-
measured_reference = measured_values(reference_container)
18-
@assert length(measured_reference) > 0 "At least one variable must be marked as measured"
19-
measured = measured_reference
20-
21-
measured_names = string.(measured)
22-
if !(all(name->name names(data), measured_names))
23-
error("Measured data points must exist in both model solution & test data; measured parameters in solution: $(measured_names) vs. in data $(names(data))")
24-
end
25-
26-
data_matrix = collect.(eachrow(data[!, string.(measured)])) # lame and slow
27-
solution_data = sol[measured]
28-
metrics = Dict{Symbol, Any}()
29-
compute_error_metrics(metrics, solution_data, data_matrix)
30-
metrics[:final] = recursive_mean(abs.(solution_data[end] - collect(data[end, string.(measured)])))
31-
results[:metrics] = metrics
32-
cols = []
33-
push!(cols, :timestamp => data[:, :timestamp])
34-
append!(cols, make_cols(namespace_symbol.((:simulated,), measured), solution_data))
35-
append!(cols, make_cols(namespace_symbol.((:data,), measured), data_matrix))
36-
results[:data] = DataFrame(cols)
37-
return results
8+
function compare_discrete(sys, data_a::DataFrame, data_b::DataFrame)
9+
@assert "timestamp" names(data_a) "The dataset A must contain a column named `timestamp`"
10+
@assert "timestamp" names(data_b) "The dataset B must contain a column named `timestamp`"
3811

12+
measured = measured_values(sys)
13+
measured_cols = measured_names(measured)
14+
@assert all(c->c names(data_a), measured_cols) "All measured values must exist in both datasets (missing value in A)"
15+
@assert all(c->c names(data_b), measured_cols) "All measured values must exist in both datasets (missing value in B)"
16+
17+
test_results = Dict{Symbol, Any}()
18+
delta_sol = data_a[:, measured_cols] .- data_b[:, measured_cols]
19+
test_results[:final] = recursive_mean(abs.(collect(delta_sol[end, :])))
20+
compute_error_metrics(test_results, collect.(eachrow(data_a[:, measured_cols])), collect.(eachrow(data_b[:, measured_cols])))
21+
return test_results
3922
end
40-
export validate, compare_discrete_to_continous
23+
export validate, compare_discrete_to_continous, compare_discrete

src/test/measured.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,12 @@ end
1111
function measured_values(sys, v=all_variable_symbols(sys))
1212
filter(x -> ismeasured(x, false), v)
1313
end
14+
15+
function measured_system_values(sys)
16+
reference_container = symbolic_container(sys)
17+
return measured_values(reference_container)
18+
end
19+
20+
function measured_names(measured)
21+
return string.(measured)
22+
end

test/timeseries.jl

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,30 @@ tol = 1e-6
1717
prob3 = ODEProblem(fol, [fol.x => 0.0], (0.0, 10.0), [fol.τ => 1.0])
1818
sol3 = solve(prob3, Tsit5(), reltol = 1e-8, abstol = 1e-8)
1919

20-
results = compare_solutions(:reference=>sol1, [:good=>sol2, :bad=>sol3])
21-
22-
@test results[:metrics][:good][:l∞] < tol
23-
@test results[:metrics][:good][:l2] < tol
24-
@test results[:metrics][:good][:final] < tol
25-
@test results[:metrics][:bad][:l∞] > tol
26-
@test results[:metrics][:bad][:l2] > tol
27-
@test results[:metrics][:bad][:final] > tol
20+
d1 = discretize_solution(sol1, sol1)
21+
d2 = discretize_solution(sol2, sol1)
22+
d3 = discretize_solution(sol3, sol1)
23+
results_good = compare_discrete(fol, d1, d2)
24+
results_bad = compare_discrete(fol, d1, d3)
25+
@test results_good[:l∞] < tol
26+
@test results_good[:l2] < tol
27+
@test results_good[:final] < tol
28+
@test results_bad[:l∞] > tol
29+
@test results_bad[:l2] > tol
30+
@test results_bad[:final] > tol
2831

2932
knots = collect(1:0.1:10)
30-
results = compare_solutions(:reference=>sol1, [:good=>sol2, :bad=>sol3]; knots=knots)
31-
32-
@test nrow(results[:data]) == length(knots)
33-
@test results[:metrics][:good][:l∞] < tol
34-
@test results[:metrics][:good][:l2] < tol
35-
@test results[:metrics][:good][:final] < tol
36-
@test results[:metrics][:bad][:l∞] > tol
37-
@test results[:metrics][:bad][:l2] > tol
38-
@test results[:metrics][:bad][:final] > tol
33+
d1 = discretize_solution(sol1, knots)
34+
d2 = discretize_solution(sol2, knots)
35+
d3 = discretize_solution(sol3, knots)
36+
results_good = compare_discrete(fol, d1, d2)
37+
results_bad = compare_discrete(fol, d1, d3)
38+
@test results_good[:l∞] < tol
39+
@test results_good[:l2] < tol
40+
@test results_good[:final] < tol
41+
@test results_bad[:l∞] > tol
42+
@test results_bad[:l2] > tol
43+
@test results_bad[:final] > tol
3944
end
4045
@testset "Model-Model Continous Comparison" begin
4146
@variables t
@@ -52,11 +57,12 @@ tol = 1e-6
5257
prob3 = ODEProblem(fol, [fol.x => 0.0], (0.0, 10.0), [fol.τ => 1.0])
5358
sol3 = solve(prob3, Tsit5(), reltol = 1e-8, abstol = 1e-8)
5459

55-
results = compare_dense_solutions(:reference=>sol1, [:good=>sol2, :bad=>sol3])
56-
@test results[Symbol("good/x(t)"), end] < tol
57-
@test results[Symbol("good/y(t)"), end] < tol
58-
@test results[Symbol("bad/x(t)"), end] > tol
59-
@test results[Symbol("bad/y(t)"), end] > tol
60+
results_good = compare_dense_solutions(sol1, sol2)
61+
results_bad = compare_dense_solutions(sol1, sol3)
62+
@test results_good[Symbol("x(t)")][end] < tol
63+
@test results_good[Symbol("y(t)")][end] < tol
64+
@test results_bad[Symbol("x(t)")][end] > tol
65+
@test results_bad[Symbol("y(t)")][end] > tol
6066
end
6167
end
6268

0 commit comments

Comments
 (0)