Skip to content

Commit 82b29fd

Browse files
committed
Support configurable comparison operations in the default comparison type
1 parent 13512bb commit 82b29fd

File tree

1 file changed

+24
-21
lines changed

1 file changed

+24
-21
lines changed

src/test/discrete/compare.jl

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,28 @@
1+
struct DefaultComparison
2+
field_cmp::Dict{Symbol, Function}
3+
function DefaultComparison(field_cmp::Dict{Symbol, Function}=Dict{Symbol, Function}(); use_defaults=true)
4+
if use_defaults
5+
merge!(field_cmp, Dict{Symbol, Function}([
6+
:L∞ => (delta, t) -> norm.(delta, Inf),
7+
:L1 => (delta, t) -> norm.(delta, 1),
8+
:L2 => (delta, t) -> norm.(delta, 2),
9+
:rms => (delta, t) -> sqrt.(1/length(t) .* sum.(map(d-> d .^ 2, delta)))]))
10+
end
11+
return new(field_cmp)
12+
end
13+
end
14+
function (d::DefaultComparison)(c, b, t, n, r)
15+
delta = map((o, re) -> o .- re, n, r)
16+
cmps = [name => cmper(delta, t) for (name, cmper) in d.field_cmp]
17+
return DataFrame([:var => b, cmps...,
18+
:observed => SymbolicIndexingInterface.is_observed.(c, b)])
19+
end
20+
21+
122
function compare(
223
new_sol::SciMLBase.AbstractTimeseriesSolution, basis::Vector,
324
reference::DataFrame, reference_basis::Vector,
4-
cmp::Function; warn_observed=true)
25+
cmp; warn_observed=true)
526
@assert "timestamp" names(reference) "The dataset must contain a column named `timestamp`"
627
new_container = SymbolicIndexingInterface.symbolic_container(new_sol)
728
@assert all(SymbolicIndexingInterface.is_observed.((new_container, ), basis) .| SymbolicIndexingInterface.is_variable.((new_container, ), basis)) "All basis symbols must be observed in the new system"
@@ -10,7 +31,7 @@ function compare(
1031
dat = new_sol(reference[:, :timestamp], idxs=basis)
1132
obs = [[dat[j][i] for j=1:nrow(reference)] for i in eachindex(basis)]
1233
ref = collect.(eachcol(reference[:, Not(:timestamp)]))
13-
cmp(basis, reference[:, :timestamp], obs, ref)
34+
cmp(new_container, basis, reference[:, :timestamp], obs, ref)
1435
else
1536
foldl(red, cmp(row[:timestamp], new_sol[row[:timestamp], basis], row[reference_basis]) for row in eachrow(reference); init=isnothing(init) ? nothing : init(basis))
1637
end
@@ -19,7 +40,7 @@ end
1940
function compare(
2041
new_sol::SciMLBase.AbstractTimeseriesSolution,
2142
reference::DataFrame,
22-
cmp::Function; to_name=string, warn_observed=true)
43+
cmp=DefaultComparison(); to_name=string, warn_observed=true)
2344
@assert "timestamp" names(reference) "The dataset must contain a column named `timestamp`"
2445

2546
new_container = SymbolicIndexingInterface.symbolic_container(new_sol)
@@ -35,22 +56,4 @@ function compare(
3556
end for ref_name in setdiff(names(reference), ["timestamp"])]
3657
return compare(new_sol, eval_syms, reference, setdiff(names(reference), ["timestamp"]), cmp)
3758
end
38-
39-
function compare(
40-
new_sol::SciMLBase.AbstractTimeseriesSolution,
41-
reference::DataFrame; warn_observed=true)
42-
@assert "timestamp" names(reference) "The dataset must contain a column named `timestamp`"
43-
new_container = SymbolicIndexingInterface.symbolic_container(new_sol)
44-
return compare(new_sol, reference,
45-
(b, t, n, r) -> begin
46-
delta = map((o, re) -> o .- re, n, r)
47-
return DataFrame(:var => b,
48-
:L∞ => norm.(delta, Inf),
49-
:L1 => norm.(delta, 1),
50-
:L2 => norm.(delta, 2),
51-
:rms => sqrt.(1/length(t) .* sum.(map(d-> d .^ 2, delta))),
52-
:observed => SymbolicIndexingInterface.is_observed.(new_container, b))
53-
end,
54-
warn_observed=warn_observed)
55-
end
5659
export compare

0 commit comments

Comments
 (0)