Skip to content

Commit ff8b9a0

Browse files
committed
Make it easier to measure values (implement observation of all parameters and introduce a component that tracks a sensor value)
1 parent 5b8d104 commit ff8b9a0

File tree

11 files changed

+260
-43
lines changed

11 files changed

+260
-43
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
88
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
99
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
1010
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
11+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1112
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
13+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1214
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
1315

1416
[weakdeps]

ext/ModelTestingCalibration.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function ModelTesting.validate(model::AbstractTimeDependentSystem, data; pem_coe
1313
model_transformations = []
1414
end
1515
else
16-
model_transformations = []
16+
model_transformations = experiment_kwargs[:model_transformations]
1717
end
1818
experiment = Experiment(data, model;
1919
model_transformations = model_transformations,

src/ModelTesting.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
module ModelTesting
2-
using DataFrames
2+
using DataFrames, StatsBase, LinearAlgebra
33
using ModelingToolkit, DifferentialEquations, DiffEqDevTools, SymbolicIndexingInterface
4-
4+
abstract type DiscreteEvaluation end
5+
abstract type Metric end
56
include("test/measured.jl")
67
include("test/continuous/delta_sol.jl")
78
include("test/discrete/single_shooting.jl")
89
include("test/discrete/merge.jl")
10+
include("test/metric/metric.jl")
911
end

src/test/continuous/delta_sol.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,33 @@ function make_cols(names, rows)
1616
return cols
1717
end
1818

19-
function discretize_solution(solution::SciMLBase.AbstractTimeseriesSolution, time_ref::SciMLBase.AbstractTimeseriesSolution)
19+
function discretize_solution(solution::SciMLBase.AbstractTimeseriesSolution, time_ref::SciMLBase.AbstractTimeseriesSolution; all_observed=false)
2020
container = symbolic_container(time_ref)
2121
ref_t_vars = independent_variable_symbols(container)
2222
if length(ref_t_vars) > 1
2323
@error "PDE solutions not currently supported; only one iv is allowed"
2424
end
25-
return discretize_solution(solution, time_ref[first(ref_t_vars)])
25+
return discretize_solution(solution, time_ref[first(ref_t_vars)]; all_observed=all_observed)
2626
end
27-
function discretize_solution(solution::SciMLBase.AbstractTimeseriesSolution, time_ref::DataFrame)
27+
function discretize_solution(solution::SciMLBase.AbstractTimeseriesSolution, time_ref::DataFrame; all_observed=false)
2828
@assert "timestamp" names(time_ref) "The dataset B must contain a column named `timestamp`"
29-
return discretize_solution(solution, collect(time_ref[!, "timestamp"]))
29+
return discretize_solution(solution, collect(time_ref[!, "timestamp"]); all_observed=all_observed)
3030
end
31-
function discretize_solution(solution::SciMLBase.AbstractTimeseriesSolution)
31+
function discretize_solution(solution::SciMLBase.AbstractTimeseriesSolution; all_observed=false)
3232
container = symbolic_container(solution)
3333
ref_t_vars = independent_variable_symbols(container)
3434
if length(ref_t_vars) > 1
3535
@error "PDE solutions not currently supported; only one iv is allowed"
3636
end
37-
return discretize_solution(solution, solution[ref_t_var] )
37+
return discretize_solution(solution, solution[ref_t_var]; all_observed=all_observed )
3838
end
39-
function discretize_solution(solution::SciMLBase.AbstractTimeseriesSolution, time_ref::AbstractArray)
39+
function discretize_solution(solution::SciMLBase.AbstractTimeseriesSolution, time_ref::AbstractArray; all_observed=false)
4040
container = symbolic_container(solution)
41-
measured = measured_values(container)
41+
if !all_observed
42+
measured = measured_values(container)
43+
else
44+
measured = all_variable_symbols(container)
45+
end
4246
ref_t_vars = independent_variable_symbols(container)
4347
if length(ref_t_vars) > 1
4448
@error "PDE solutions not currently supported; only one iv is allowed"

src/test/discrete/merge.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ function merge_results(
1616
validate=(true, true),
1717
renamecols = (x -> x) => col_rename)
1818
end
19-
display(merged)
2019
end
2120
return merged
2221
end

src/test/discrete/single_shooting.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,14 @@ function validate(model, data)
55
throw("Discrete comparison currently requires JSMO!")
66
end
77

8-
function compare_discrete(sys, data_a::DataFrame, data_b::DataFrame)
8+
function compare_discrete(sys, data_a::DataFrame, data_b::DataFrame, evaluations::Vector{DiscreteEvaluation}=[FinalState(L∞(name=:L∞)), Mean(L2(:L2)), Accumulate(L∞(name=:L∞))])
99
@assert "timestamp" names(data_a) "The dataset A must contain a column named `timestamp`"
1010
@assert "timestamp" names(data_b) "The dataset B must contain a column named `timestamp`"
11-
12-
measured = measured_values(sys)
13-
measured_cols = measured_names(measured)
14-
@assert all(c->c names(data_a), measured_cols) "All measured values must exist in both datasets (missing value in A)"
15-
@assert all(c->c names(data_b), measured_cols) "All measured values must exist in both datasets (missing value in B)"
1611

1712
test_results = Dict{Symbol, Any}()
18-
delta_sol = data_a[:, measured_cols] .- data_b[:, measured_cols]
19-
test_results[:final] = recursive_mean(abs.(collect(delta_sol[end, :])))
20-
compute_error_metrics(test_results, collect.(eachrow(data_a[:, measured_cols])), collect.(eachrow(data_b[:, measured_cols])))
13+
for evaluation in evaluations
14+
test_results = mergewith(merge, test_results, Dict([evaluate(evaluation, sys, data_a, data_b); ]))
15+
end
2116
return test_results
2217
end
2318
export validate, compare_discrete_to_continous, compare_discrete

src/test/measured.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@
22
struct MeasuredVariable end
33
Symbolics.option_to_metadata_type(::Val{:measured}) = MeasuredVariable
44
ismeasured(x::Num, args...) = ismeasured(Symbolics.unwrap(x), args...)
5-
function ismeasured(x, default = true)
5+
function ismeasured(x, default = false)
66
p = Symbolics.getparent(x, nothing)
77
p === nothing || (x = p)
88
Symbolics.getmetadata(x, MeasuredVariable, default)
99
end
1010

11+
function setmeasured(x)
12+
setmetadata(x, MeasuredVariable, true)
13+
end
14+
1115
function measured_values(sys, v=all_variable_symbols(sys))
1216
filter(x -> ismeasured(x, false), v)
1317
end
@@ -19,4 +23,20 @@ end
1923

2024
function measured_names(measured)
2125
return string.(measured)
22-
end
26+
end
27+
28+
@mtkmodel MeasureComponent begin
29+
@variables begin
30+
value(t), [measured = true]
31+
end
32+
end
33+
34+
function Measurement(sensor; name)
35+
@assert length(sensor.states) == 1 "The Measurement helper requires that the measurement component have only one scalar-valued state"
36+
@variables t value(t) [measured = true]
37+
return ODESystem([
38+
value ~ first(sensor.states)
39+
], t; name = name)
40+
end
41+
42+
export Measurement, MeasureComponent

src/test/metric/metric.jl

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
2+
3+
evaluate(e::DiscreteEvaluation, sys, dt1::DataFrame, dt2::DataFrame) = throw("Unimplemented evaluate for $(typeof(e))")
4+
5+
function _require_solution(sys, dt::DataFrame; name=nothing)
6+
@assert "timestamp" names(dt) "The dataset $name must contain a column named `timestamp`"
7+
measured = measured_values(sys)
8+
measured_cols = measured_names(measured)
9+
@assert all(c->c names(dt), measured_cols) "All measured values must exist in both datasets (missing value in $name)"
10+
end
11+
12+
function _require_joint_solution(sys, dt1, dt2; check_only_final = false)
13+
_require_solution(sys, dt1; name = "dt1")
14+
_require_solution(sys, dt2; name = "dt2")
15+
if !check_only_final
16+
@assert dt1[:, :timestamp] == dt2[:, :timestamp] "Both data frames need to share a temporal discretization"
17+
else
18+
@assert dt1[end, :timestamp] == dt2[end, :timestamp] "Ending times differ"
19+
end
20+
end
21+
22+
struct FinalState{M<:Metric} <: DiscreteEvaluation
23+
name::Symbol
24+
metric::M
25+
FinalState(m::M; name=:final) where {M <: Metric} = new{M}(name, m)
26+
end
27+
function evaluate(f::FinalState, sys, dt1::DataFrame, dt2::DataFrame)
28+
_require_joint_solution(sys, dt1, dt2; check_only_final = true)
29+
measured_cols = setdiff(intersect(names(dt1), names(dt2)), "timestamp")
30+
output = compare(f.metric, sys, dt1[end, measured_cols], dt2[end, measured_cols])
31+
return name(f) => Dict(outputs(f.metric) .=> output)
32+
end
33+
name(f::FinalState) = f.name
34+
35+
struct Accumulate{M<:Metric, T} <: DiscreteEvaluation
36+
name::Symbol
37+
metric::M
38+
op::T
39+
Accumulate(m::M; name=:stage) where {M<:Metric} = let op = acc_op(m); new{M, typeof(op)}(name, m, op) end
40+
Accumulate(m::M, op::T; name=:stage) where {M<:Metric, T} = new{M, T}(name, m, op)
41+
end
42+
name(f::Accumulate) = f.name
43+
44+
function evaluate(f::Accumulate, sys, dt1::DataFrame, dt2::DataFrame; init = nothing)
45+
_require_joint_solution(sys, dt1, dt2)
46+
measured_cols = setdiff(intersect(names(dt1), names(dt2)), "timestamp")
47+
output = isnothing(init) ? initial_output(f.metric) : init
48+
for (row1, row2) in Iterators.zip(eachrow(dt1[!, measured_cols]), eachrow(dt2[!, measured_cols]))
49+
new_output = compare(f.metric, sys, row1, row2)
50+
output = f.op(output, new_output)
51+
end
52+
return name(f) => Dict(outputs(f.metric) .=> output)
53+
end
54+
55+
struct Mean{M<:Metric, T} <: DiscreteEvaluation
56+
name::Symbol
57+
metric::M
58+
op::T
59+
Mean(m::M; name=:mean) where {M<:Metric} = let op = acc_op(m); new{M, typeof(op)}(name, m, op) end
60+
Mean(m::M, op::T; name=:mean) where {M<:Metric, T} = new{M, T}(name, m, op)
61+
end
62+
name(f::Mean) = f.name
63+
64+
function evaluate(f::Mean, sys, dt1::DataFrame, dt2::DataFrame; init = nothing)
65+
_require_joint_solution(sys, dt1, dt2)
66+
measured_cols = setdiff(intersect(names(dt1), names(dt2)), "timestamp")
67+
output = isnothing(init) ? initial_output(f.metric) : init
68+
for (row1, row2) in Iterators.zip(eachrow(dt1[!, measured_cols]), eachrow(dt2[!, measured_cols]))
69+
new_output = compare(f.metric, sys, row1, row2)
70+
output = f.op(output, new_output)
71+
end
72+
return name(f) => Dict(outputs(f.metric) .=> mean.(output))
73+
end
74+
75+
acc_op(m::Metric) = (a,b) -> b # throw away the prior state by default
76+
77+
name(m::Metric)::Symbol = throw("Unimplemented name for metric $(typeof(m))")
78+
outputs(m::Metric)::Vector{Symbol} = throw("Unimplemented outputs for metric $(typeof(m))")
79+
initial_output(m::Metric) = throw("Unimplemented metric output initialization for $(typeof(m)); specify an explicit initalization")
80+
81+
compare(m::Metric, sys, state_a, state_b) = throw("Unimplemented discrete metric comparison for $(typeof(m))")
82+
compare!(o, m::Metric, sys, state_a, state_b) = o .= compare(m, sys, state_a, state_b)
83+
84+
85+
struct L∞ <: Metric
86+
name::Union{Nothing, Symbol}
87+
L∞(;name::Union{Nothing, Symbol} = nothing) = new(name)
88+
end
89+
acc_op(::L∞) = (a,b) -> max.(a,b)
90+
name(m::L∞) = m.name
91+
initial_output(m::L∞) = [0.0]
92+
outputs(m::L∞) = [name(m)]
93+
compare(m::L∞, sys, state_a, state_b) = [maximum(abs.(collect(state_a) .- collect(state_b)))]
94+
95+
struct L2 <: Metric
96+
name::Union{Nothing, Symbol}
97+
end
98+
acc_op(::L2) = (a,b) -> a .+ b
99+
name(m::L2) = m.name
100+
initial_output(m::L2) = [0.0]
101+
outputs(m::L2) = [name(m)]
102+
compare(m::L2, sys, state_a, state_b) = [norm(collect(state_a) .- collect(state_b))]
103+
104+
export Metric, L∞, L2

test/block_modeling.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
using ModelingToolkitStandardLibrary.Electrical
2+
using ModelingToolkitStandardLibrary.Blocks: Constant
3+
4+
@testset "Block Modeling" begin
5+
@testset "RC Explicit" begin
6+
R = 1.0
7+
C = 1.0
8+
V = 1.0
9+
@variables t
10+
@named resistor = Resistor(R = R)
11+
@named capacitor = Capacitor(C = C)
12+
@named source = Voltage()
13+
@named constant = Constant(k = V)
14+
@named ground = Ground()
15+
16+
@named sensor = PotentialSensor()
17+
@named key_parameter = MeasureComponent()
18+
19+
rc_eqs = [connect(constant.output, source.V)
20+
connect(source.p, resistor.p)
21+
connect(resistor.n, capacitor.p)
22+
connect(capacitor.n, source.n, ground.g)
23+
connect(sensor.p, capacitor.p)
24+
sensor.phi ~ key_parameter.value]
25+
26+
@named rc_model = ODESystem(rc_eqs, t,
27+
systems = [resistor, capacitor, constant, source, ground, sensor, key_parameter])
28+
sys = structural_simplify(rc_model)
29+
prob1 = ODEProblem(sys, Pair[], (0, 10.0))
30+
sol1 = solve(prob1, Tsit5())
31+
prob2 = ODEProblem(sys, Pair[capacitor.C => 0.9], (0, 10.0))
32+
sol2 = solve(prob2, Tsit5())
33+
prob3 = ODEProblem(sys, Pair[capacitor.C => 5.0], (0, 10.0))
34+
sol3 = solve(prob3, Tsit5())
35+
36+
d1 = discretize_solution(sol1, sol1)
37+
d2 = discretize_solution(sol2, sol1)
38+
d3 = discretize_solution(sol3, sol1)
39+
results_good = compare_discrete(sys, d1, d2)
40+
results_bad = compare_discrete(sys, d1, d3)
41+
42+
println(results_bad)
43+
end
44+
@testset "RC Functional" begin
45+
R = 1.0
46+
C = 1.0
47+
V = 1.0
48+
@variables t
49+
@named resistor = Resistor(R = R)
50+
@named capacitor = Capacitor(C = C)
51+
@named source = Voltage()
52+
@named constant = Constant(k = V)
53+
@named ground = Ground()
54+
55+
@named sensor = PotentialSensor()
56+
@named key_parameter = Measurement(sensor)
57+
58+
rc_eqs = [connect(constant.output, source.V)
59+
connect(source.p, resistor.p)
60+
connect(resistor.n, capacitor.p)
61+
connect(capacitor.n, source.n, ground.g)
62+
connect(sensor.p, capacitor.p)]
63+
64+
@named rc_model = ODESystem(rc_eqs, t,
65+
systems = [resistor, capacitor, constant, source, ground, sensor, key_parameter])
66+
sys = structural_simplify(rc_model)
67+
prob1 = ODEProblem(sys, Pair[], (0, 10.0))
68+
sol1 = solve(prob1, Tsit5())
69+
prob2 = ODEProblem(sys, Pair[capacitor.C => 0.9], (0, 10.0))
70+
sol2 = solve(prob2, Tsit5())
71+
prob3 = ODEProblem(sys, Pair[capacitor.C => 5.0], (0, 10.0))
72+
sol3 = solve(prob3, Tsit5())
73+
74+
d1 = discretize_solution(sol1, sol1)
75+
d2 = discretize_solution(sol2, sol1)
76+
d3 = discretize_solution(sol3, sol1)
77+
results_good = compare_discrete(sys, d1, d2)
78+
results_bad = compare_discrete(sys, d1, d3)
79+
80+
println(results_bad)
81+
end
82+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ using Test
66

77
@testset "ModelTesting.jl" begin
88
include("timeseries.jl")
9+
include("block_modeling.jl")
910
end

0 commit comments

Comments
 (0)