Skip to content

Commit c8b35f1

Browse files
committed
Clean up and test compare interface
1 parent fe874a5 commit c8b35f1

File tree

2 files changed

+29
-21
lines changed

2 files changed

+29
-21
lines changed

src/test/compare.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,28 @@ struct DefaultComparison
66
:L∞ => (delta, t) -> norm.(delta, Inf),
77
:L1 => (delta, t) -> norm.(delta, 1),
88
:L2 => (delta, t) -> norm.(delta, 2),
9-
:rms => (delta, t) -> sqrt.(1/length(t) .* sum.(map(d-> d .^ 2, delta)))]))
9+
:rms => (delta, t) -> sqrt.(1/length(t) .* sum.(map(d-> d .^ 2, delta))),
10+
:final => (delta, t) -> last.(delta)]))
1011
end
1112
return new(field_cmp)
1213
end
1314
end
14-
function (d::DefaultComparison)(c, b, t, n, r)
15+
function (d::DefaultComparison)(c, names, b, t, n, r)
1516
delta = map((o, re) -> o .- re, n, r)
1617
cmps = [name => cmper(delta, t) for (name, cmper) in d.field_cmp]
17-
return DataFrame([:var => b, cmps...,
18+
return DataFrame([:var => names, cmps...,
1819
:observed => SymbolicIndexingInterface.is_observed.(c, b)])
1920
end
2021

2122

2223
function compare(
23-
new_sol::SciMLBase.AbstractTimeseriesSolution, basis::Vector,
24-
reference::DataFrame, reference_basis::Vector,
25-
cmp; warn_observed=true)
24+
new_sol::SciMLBase.AbstractTimeseriesSolution,
25+
reference::DataFrame,
26+
over::Vector{<:Union{Pair{<:Any, String}, Pair{<:Any, Pair{String, String}}}},
27+
cmp=DefaultComparison(); warn_observed=true)
28+
basis = first.(over)
29+
reference_basis = map(e -> e isa Pair ? first(e) : e, last.(over))
30+
output_names = map((n1, n2) -> n2 isa Pair ? last(n2) : string(n1), basis, last.(over))
2631
@assert "timestamp" names(reference) "The dataset must contain a column named `timestamp`"
2732
new_container = SymbolicIndexingInterface.symbolic_container(new_sol)
2833
@assert all(SymbolicIndexingInterface.is_observed.((new_container, ), basis) .| SymbolicIndexingInterface.is_variable.((new_container, ), basis)) "All basis symbols must be observed in the new system"
@@ -31,7 +36,7 @@ function compare(
3136
dat = new_sol(reference[:, :timestamp], idxs=basis)
3237
obs = [[dat[j][i] for j=1:nrow(reference)] for i in eachindex(basis)]
3338
ref = collect.(eachcol(reference[:, Not(:timestamp)]))
34-
cmp(new_container, basis, reference[:, :timestamp], obs, ref)
39+
cmp(new_container, output_names, basis, reference[:, :timestamp], obs, ref)
3540
else
3641
foldl(red, cmp(row[:timestamp], new_sol[row[:timestamp], basis], row[reference_basis]) for row in eachrow(reference); init=isnothing(init) ? nothing : init(basis))
3742
end
@@ -54,6 +59,6 @@ function compare(
5459
end
5560
matching_sym
5661
end for ref_name in setdiff(names(reference), ["timestamp"])]
57-
return compare(new_sol, eval_syms, reference, setdiff(names(reference), ["timestamp"]), cmp)
62+
return compare(new_sol, reference, eval_syms .=> setdiff(names(reference), ["timestamp"]), cmp)
5863
end
5964
export compare

test/block_modeling.jl

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,26 @@ using SymbolicIndexingInterface
2828
sys = structural_simplify(rc_model)
2929
prob1 = ODEProblem(sys, Pair[], (0, 10.0))
3030
sol1 = solve(prob1, Tsit5())
31-
32-
#=
33-
ref = DataFrame(:timestamp => sol1.t, Symbol(resistor.v) => sol1[resistor.v])
34-
measure, measured_sys = compare_data([resistor.v], ref)(sys)
35-
prob_measured = ODEProblem(measured_sys, Pair[], (0, 10.0))
36-
sol_measured = solve(prob_measured, Tsit5())
37-
=#
38-
prob2 = ODEProblem(sys, Pair[capacitor.C => 0.9], (0, 10.0))
31+
prob2 = ODEProblem(sys, Pair[capacitor.C => 0.9], (0, 10.0)) # this converges to nearly the same solution but is a little too fast
3932
sol2 = solve(prob2, Tsit5())
40-
prob3 = ODEProblem(sys, Pair[capacitor.C => 5.0], (0, 10.0))
33+
prob3 = ODEProblem(sys, Pair[capacitor.C => 5.0], (0, 10.0)) # this doesn't stabilize in the allotted time
4134
sol3 = solve(prob3, Tsit5())
4235

4336
d1 = discretize_solution(sol1, sol1)
44-
4537
ds1 = discretize_solution(sol1, sol1; measured=SymbolicIndexingInterface.all_variable_symbols(sys))
46-
println(compare(sol3, ds1))
47-
println(compare(sol2, ds1))
48-
println(compare(sol1, ds1))
38+
39+
@test sum(compare(sol1, ds1; warn_observed=false)[:, :L∞]) < 0.01
40+
@test sum(compare(sol2, ds1; warn_observed=false)[:, :L∞]) > sum(compare(sol1, ds1; warn_observed=false)[:, :L∞])
41+
@test sum(compare(sol3, ds1; warn_observed=false)[:, :L∞]) > sum(compare(sol2, ds1; warn_observed=false)[:, :L∞])
42+
@test sum(compare(sol1, ds1; warn_observed=false)[:, :L∞]) > sum(compare(sol1, d1; warn_observed=false)[:, :L∞])
43+
@test sum(compare(sol2, ds1; warn_observed=false)[:, :L∞]) > sum(compare(sol2, d1; warn_observed=false)[:, :L∞])
44+
@test sum(compare(sol3, ds1; warn_observed=false)[:, :L∞]) > sum(compare(sol3, d1; warn_observed=false)[:, :L∞])
45+
46+
# construct a fictional power measurement
47+
power_synth = select(ds1, :timestamp, ["capacitor₊v(t)", "capacitor₊i(t)"] => ((v, i) -> v .* i) => "power")
48+
@test compare(sol1, power_synth, [capacitor.i * capacitor.v => "power" => "power"])[1, "L∞"] < 0.01
49+
@test compare(sol2, power_synth, [capacitor.i * capacitor.v => "power" => "power"])[1, "L∞"] < 0.05
50+
@test compare(sol3, power_synth, [capacitor.i * capacitor.v => "power" => "power"])[1, "L∞"] < 0.3
51+
4952
end
5053
end

0 commit comments

Comments
 (0)