Skip to content

Commit 13712c8

Browse files
committed
Change cmp interface to work over vectors rather than implicit foldl
1 parent 54ef15d commit 13712c8

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

src/test/discrete/compare.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
function compare(
22
new_sol::SciMLBase.AbstractTimeseriesSolution, basis::Vector,
33
reference::DataFrame, reference_basis::Vector,
4-
cmp::Function, red::Function; init=nothing)
4+
cmp::Function; warn_observed=true)
55
@assert "timestamp" names(reference) "The dataset must contain a column named `timestamp`"
66
new_container = SymbolicIndexingInterface.symbolic_container(new_sol)
77
@assert all(SymbolicIndexingInterface.is_observed.((new_container, ), basis) .| SymbolicIndexingInterface.is_variable.((new_container, ), basis)) "All basis symbols must be observed in the new system"
88
@assert all(b names(reference) for b in reference_basis) "The reference basis must be a subset of the columns in the reference data"
99
if new_sol.dense
10-
foldl(red, cmp(row[:timestamp], new_sol(row[:timestamp], idxs=basis), row[reference_basis]) for row in eachrow(reference); init=isnothing(init) ? nothing : init(basis))
10+
dat = new_sol(reference[:, :timestamp], idxs=basis)
11+
obs = [[dat[j][i] for j=1:nrow(reference)] for i in eachindex(basis)]
12+
ref = collect.(eachcol(reference[:, Not(:timestamp)]))
13+
cmp(basis, reference[:, :timestamp], obs, ref)
1114
else
1215
foldl(red, cmp(row[:timestamp], new_sol[row[:timestamp], basis], row[reference_basis]) for row in eachrow(reference); init=isnothing(init) ? nothing : init(basis))
1316
end
@@ -16,7 +19,7 @@ end
1619
function compare(
1720
new_sol::SciMLBase.AbstractTimeseriesSolution,
1821
reference::DataFrame,
19-
cmp::Function, red::Function; init=nothing, to_name=string, warn_observed=true)
22+
cmp::Function; to_name=string, warn_observed=true)
2023
@assert "timestamp" names(reference) "The dataset must contain a column named `timestamp`"
2124

2225
new_container = SymbolicIndexingInterface.symbolic_container(new_sol)
@@ -30,17 +33,24 @@ function compare(
3033
end
3134
matching_sym
3235
end for ref_name in setdiff(names(reference), ["timestamp"])]
33-
return compare(new_sol, eval_syms, reference, setdiff(names(reference), ["timestamp"]), cmp, red; init=init)
36+
return compare(new_sol, eval_syms, reference, setdiff(names(reference), ["timestamp"]), cmp)
3437
end
3538

3639
function compare(
3740
new_sol::SciMLBase.AbstractTimeseriesSolution,
3841
reference::DataFrame; warn_observed=true)
3942
@assert "timestamp" names(reference) "The dataset must contain a column named `timestamp`"
43+
new_container = SymbolicIndexingInterface.symbolic_container(new_sol)
4044
return compare(new_sol, reference,
41-
(t, n, r) -> abs.(collect(n) .- collect(r)),
42-
(acc, nv) -> begin acc[:, :l∞] = max.(collect(acc[:, :l∞]), collect(nv)); acc end;
43-
init=(v) -> DataFrame(:name => v, :l∞ => zeros.(length(v))),
45+
(b, t, n, r) -> begin
46+
delta = map((o, re) -> o .- re, n, r)
47+
return DataFrame(:var => b,
48+
:L∞ => norm.(delta, Inf),
49+
:L1 => norm.(delta, 1),
50+
:L2 => norm.(delta, 2),
51+
:rms => sqrt.(1/length(t) .* sum.(map(d-> d .^ 2, delta))),
52+
:observed => SymbolicIndexingInterface.is_observed.(new_container, b))
53+
end,
4454
warn_observed=warn_observed)
4555
end
4656
export compare

0 commit comments

Comments
 (0)