Skip to content

Commit b5a85e2

Browse files
committed
Implement compare
1 parent 0b2deda commit b5a85e2

File tree

5 files changed

+64
-8
lines changed

5 files changed

+64
-8
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ ModelTestingCalibration = "JuliaSimModelOptimizer"
2222

2323
[compat]
2424
julia = "1.10"
25+
SymbolicIndexingInterface = "^0.3.15"
2526

2627
[extras]
2728
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

src/ModelTesting.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ include("test/continuous/delta_sol.jl")
88
include("test/discrete/single_shooting.jl")
99
include("test/discrete/merge.jl")
1010
include("test/metric/metric.jl")
11+
include("test/discrete/compare.jl")
12+
include("test/metric/spec.jl")
1113
include("test/instantaneous/instant.jl")
1214
include("problem_config/problem.jl")
1315
end

src/test/continuous/delta_sol.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ function make_cols(names, rows)
1717
end
1818

1919
function discretize_solution(solution::SciMLBase.AbstractTimeseriesSolution, time_ref::SciMLBase.AbstractTimeseriesSolution; all_observed=false)
20-
container = symbolic_container(time_ref)
20+
container = SymbolicIndexingInterface.symbolic_container(time_ref)
2121
ref_t_vars = independent_variable_symbols(container)
2222
if length(ref_t_vars) > 1
2323
@error "PDE solutions not currently supported; only one iv is allowed"
@@ -29,15 +29,15 @@ function discretize_solution(solution::SciMLBase.AbstractTimeseriesSolution, tim
2929
return discretize_solution(solution, collect(time_ref[!, "timestamp"]); all_observed=all_observed)
3030
end
3131
function discretize_solution(solution::SciMLBase.AbstractTimeseriesSolution; all_observed=false)
32-
container = symbolic_container(solution)
32+
container = SymbolicIndexingInterface.symbolic_container(solution)
3333
ref_t_vars = independent_variable_symbols(container)
3434
if length(ref_t_vars) > 1
3535
@error "PDE solutions not currently supported; only one iv is allowed"
3636
end
37-
return discretize_solution(solution, solution[ref_t_var]; all_observed=all_observed )
37+
return discretize_solution(solution, solution[first(ref_t_vars)]; all_observed=all_observed )
3838
end
3939
function discretize_solution(solution::SciMLBase.AbstractTimeseriesSolution, time_ref::AbstractArray; measured=nothing, all_observed=false)
40-
container = symbolic_container(solution)
40+
container = SymbolicIndexingInterface.symbolic_container(solution)
4141
if isnothing(measured)
4242
if all_observed
4343
measured = all_variable_symbols(container)
@@ -72,8 +72,8 @@ function compare_dense_solutions(
7272
integrator=Tsit5()
7373
)
7474
results = Dict{Symbol, Any}()
75-
reference_container = symbolic_container(reference)
76-
containers = symbolic_container(sol)
75+
reference_container = SymbolicIndexingInterface.symbolic_container(reference)
76+
containers = SymbolicIndexingInterface.symbolic_container(sol)
7777

7878
measured_reference = isnothing(reference_measured) ? measured_values(reference_container) : reference_measured
7979
sol_measured = isnothing(solution_measured) ? measured_values(containers) : sol_measured

src/test/discrete/compare.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
function compare(
2+
new_sol::SciMLBase.AbstractTimeseriesSolution, basis::Vector,
3+
reference::DataFrame, reference_basis::Vector,
4+
cmp::Function, red::Function; init=nothing)
5+
@assert "timestamp" names(reference) "The dataset must contain a column named `timestamp`"
6+
new_container = SymbolicIndexingInterface.symbolic_container(new_sol)
7+
@assert all(SymbolicIndexingInterface.is_observed.((new_container, ), basis) .| SymbolicIndexingInterface.is_variable.((new_container, ), basis)) "All basis symbols must be observed in the new system"
8+
@assert all(b names(reference) for b in reference_basis) "The reference basis must be a subset of the columns in the reference data"
9+
if new_sol.dense
10+
foldl(red, cmp(row[:timestamp], new_sol(row[:timestamp], idxs=basis), row[reference_basis]) for row in eachrow(reference); init=isnothing(init) ? nothing : init(basis))
11+
else
12+
foldl(red, cmp(row[:timestamp], new_sol[row[:timestamp], basis], row[reference_basis]) for row in eachrow(reference); init=isnothing(init) ? nothing : init(basis))
13+
end
14+
end
15+
16+
function compare(
17+
new_sol::SciMLBase.AbstractTimeseriesSolution,
18+
reference::DataFrame,
19+
cmp::Function, red::Function; init=nothing, to_name=string, warn_observed=true)
20+
@assert "timestamp" names(reference) "The dataset must contain a column named `timestamp`"
21+
22+
new_container = SymbolicIndexingInterface.symbolic_container(new_sol)
23+
all_available = SymbolicIndexingInterface.all_variable_symbols(new_container)
24+
all_available_syms = Dict(to_name.(all_available) .=> all_available)
25+
eval_syms = [ begin
26+
@assert ref_name keys(all_available_syms) "Reference value $ref_name not found in the new solution (either as state or observed)"
27+
matching_sym = all_available_syms[ref_name]
28+
if warn_observed && SymbolicIndexingInterface.is_observed(new_container, matching_sym)
29+
@warn "The variable $matching_sym is observed in the new solution while the past $ref_name is provided explicitly; problem structure may have changed"
30+
end
31+
matching_sym
32+
end for ref_name in setdiff(names(reference), ["timestamp"])]
33+
return compare(new_sol, eval_syms, reference, setdiff(names(reference), ["timestamp"]), cmp, red; init=init)
34+
end
35+
36+
function compare(
37+
new_sol::SciMLBase.AbstractTimeseriesSolution,
38+
reference::DataFrame; warn_observed=true)
39+
@assert "timestamp" names(reference) "The dataset must contain a column named `timestamp`"
40+
return compare(new_sol, reference,
41+
(t, n, r) -> abs.(collect(n) .- collect(r)),
42+
(acc, nv) -> begin acc[:, :l∞] = max.(collect(acc[:, :l∞]), collect(nv)); acc end;
43+
init=(v) -> DataFrame(:name => v, :l∞ => zeros.(length(v))),
44+
warn_observed=warn_observed)
45+
end
46+
export compare

test/block_modeling.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,24 @@ using ModelingToolkitStandardLibrary.Blocks: Constant
6565
prob1 = ODEProblem(sys, Pair[], (0, 10.0))
6666
sol1 = solve(prob1, Tsit5())
6767

68+
#=
6869
ref = DataFrame(:timestamp => sol1.t, Symbol(resistor.v) => sol1[resistor.v])
69-
measure, measured_sys = compare_data(resistor.v, ref)(sys)
70+
measure, measured_sys = compare_data([resistor.v], ref)(sys)
7071
prob_measured = ODEProblem(measured_sys, Pair[], (0, 10.0))
7172
sol_measured = solve(prob_measured, Tsit5())
72-
73+
=#
7374
prob2 = ODEProblem(sys, Pair[capacitor.C => 0.9], (0, 10.0))
7475
sol2 = solve(prob2, Tsit5())
7576
prob3 = ODEProblem(sys, Pair[capacitor.C => 5.0], (0, 10.0))
7677
sol3 = solve(prob3, Tsit5())
7778

7879
d1 = discretize_solution(sol1, sol1)
80+
81+
ds1 = discretize_solution(sol1, sol1; measured=states(sys))
82+
println(compare(sol3, ds1))
83+
println(compare(sol2, ds1))
84+
println(compare(sol1, ds1))
85+
7986
d2 = discretize_solution(sol2, sol1)
8087
d3 = discretize_solution(sol3, sol1)
8188
results_good = compare_discrete(sys, d1, d2)

0 commit comments

Comments
 (0)