@@ -6,23 +6,28 @@ struct DefaultComparison
66 :L∞ => (delta, t) -> norm .(delta, Inf ),
77 :L1 => (delta, t) -> norm .(delta, 1 ),
88 :L2 => (delta, t) -> norm .(delta, 2 ),
9- :rms => (delta, t) -> sqrt .(1 / length (t) .* sum .(map (d-> d .^ 2 , delta)))]))
9+ :rms => (delta, t) -> sqrt .(1 / length (t) .* sum .(map (d-> d .^ 2 , delta))),
10+ :final => (delta, t) -> last .(delta)]))
1011 end
1112 return new (field_cmp)
1213 end
1314end
14- function (d:: DefaultComparison )(c, b, t, n, r)
15+ function (d:: DefaultComparison )(c, names, b, t, n, r)
1516 delta = map ((o, re) -> o .- re, n, r)
1617 cmps = [name => cmper (delta, t) for (name, cmper) in d. field_cmp]
17- return DataFrame ([:var => b , cmps... ,
18+ return DataFrame ([:var => names , cmps... ,
1819 :observed => SymbolicIndexingInterface. is_observed .(c, b)])
1920end
2021
2122
2223function compare (
23- new_sol:: SciMLBase.AbstractTimeseriesSolution , basis:: Vector ,
24- reference:: DataFrame , reference_basis:: Vector ,
25- cmp; warn_observed= true )
24+ new_sol:: SciMLBase.AbstractTimeseriesSolution ,
25+ reference:: DataFrame ,
26+ over:: Vector {<: Union{Pair{<:Any, String}, Pair{<:Any, Pair{String, String}}} },
27+ cmp= DefaultComparison (); warn_observed= true )
28+ basis = first .(over)
29+ reference_basis = map (e -> e isa Pair ? first (e) : e, last .(over))
30+ output_names = map ((n1, n2) -> n2 isa Pair ? last (n2) : string (n1), basis, last .(over))
2631 @assert " timestamp" ∈ names (reference) " The dataset must contain a column named `timestamp`"
2732 new_container = SymbolicIndexingInterface. symbolic_container (new_sol)
2833 @assert all (SymbolicIndexingInterface. is_observed .((new_container, ), basis) .| SymbolicIndexingInterface. is_variable .((new_container, ), basis)) " All basis symbols must be observed in the new system"
@@ -31,7 +36,7 @@ function compare(
3136 dat = new_sol (reference[:, :timestamp ], idxs= basis)
3237 obs = [[dat[j][i] for j= 1 : nrow (reference)] for i in eachindex (basis)]
3338 ref = collect .(eachcol (reference[:, Not (:timestamp )]))
34- cmp (new_container, basis, reference[:, :timestamp ], obs, ref)
39+ cmp (new_container, output_names, basis, reference[:, :timestamp ], obs, ref)
3540 else
3641 foldl (red, cmp (row[:timestamp ], new_sol[row[:timestamp ], basis], row[reference_basis]) for row in eachrow (reference); init= isnothing (init) ? nothing : init (basis))
3742 end
@@ -54,6 +59,6 @@ function compare(
5459 end
5560 matching_sym
5661 end for ref_name in setdiff (names (reference), [" timestamp" ])]
57- return compare (new_sol, eval_syms, reference, setdiff (names (reference), [" timestamp" ]), cmp)
62+ return compare (new_sol, reference, eval_syms .=> setdiff (names (reference), [" timestamp" ]), cmp)
5863end
5964export compare
0 commit comments