1- module MTKChainRulesCoreExt
2-
3- import ModelingToolkit as MTK
4- import ChainRulesCore
5- import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk
6-
7- function ChainRulesCore. rrule (:: Type{MTK.MTKParameters} , tunables, args... )
1+ function ChainRulesCore. rrule (:: Type{MTKParameters} , tunables, args... )
82 function mtp_pullback (dt)
93 dt = unthunk (dt)
104 dtunables = dt isa AbstractArray ? dt : dt. tunable
115 (NoTangent (), dtunables[1 : length (tunables)],
126 ntuple (_ -> NoTangent (), length (args))... )
137 end
14- MTK . MTKParameters (tunables, args... ), mtp_pullback
8+ MTKParameters (tunables, args... ), mtp_pullback
159end
1610
1711function subset_idxs (idxs, portion, template)
@@ -70,23 +64,23 @@ function selected_tangents(
7064end
7165
7266function ChainRulesCore. rrule (
73- :: typeof (MTK . remake_buffer), indp, oldbuf:: MTK. MTKParameters , idxs, vals)
67+ :: typeof (remake_buffer), indp, oldbuf:: MTKParameters , idxs, vals)
7468 if idxs isa AbstractSet
7569 idxs = collect (idxs)
7670 end
7771 idxs = map (idxs) do i
78- i isa MTK . ParameterIndex ? i : MTK . parameter_index (indp, i)
72+ i isa ParameterIndex ? i : parameter_index (indp, i)
7973 end
80- newbuf = MTK . remake_buffer (indp, oldbuf, idxs, vals)
74+ newbuf = remake_buffer (indp, oldbuf, idxs, vals)
8175 tunable_idxs = reduce (
82- vcat, (idx. idx for idx in idxs if idx. portion isa MTK . SciMLStructures. Tunable);
76+ vcat, (idx. idx for idx in idxs if idx. portion isa SciMLStructures. Tunable);
8377 init = Union{Int, AbstractVector{Int}}[])
8478 initials_idxs = reduce (
85- vcat, (idx. idx for idx in idxs if idx. portion isa MTK . SciMLStructures. Initials);
79+ vcat, (idx. idx for idx in idxs if idx. portion isa SciMLStructures. Initials);
8680 init = Union{Int, AbstractVector{Int}}[])
87- disc_idxs = subset_idxs (idxs, MTK . SciMLStructures. Discrete (), oldbuf. discrete)
88- const_idxs = subset_idxs (idxs, MTK . SciMLStructures. Constants (), oldbuf. constant)
89- nn_idxs = subset_idxs (idxs, MTK . NONNUMERIC_PORTION, oldbuf. nonnumeric)
81+ disc_idxs = subset_idxs (idxs, SciMLStructures. Discrete (), oldbuf. discrete)
82+ const_idxs = subset_idxs (idxs, SciMLStructures. Constants (), oldbuf. constant)
83+ nn_idxs = subset_idxs (idxs, NONNUMERIC_PORTION, oldbuf. nonnumeric)
9084
9185 pullback = let idxs = idxs
9286 function remake_buffer_pullback (buf′)
@@ -102,13 +96,11 @@ function ChainRulesCore.rrule(
10296 oldbuf′ = Tangent {typeof(oldbuf)} (;
10397 tunable, initials, discrete, constant, nonnumeric)
10498 idxs′ = NoTangent ()
105- vals′ = map (i -> MTK . _ducktyped_parameter_values (buf′, i), idxs)
99+ vals′ = map (i -> _ducktyped_parameter_values (buf′, i), idxs)
106100 return f′, indp′, oldbuf′, idxs′, vals′
107101 end
108102 end
109103 newbuf, pullback
110104end
111105
112- ChainRulesCore. @non_differentiable Base. getproperty (sys:: MTK.AbstractSystem , x:: Symbol )
113-
114- end
106+ ChainRulesCore. @non_differentiable Base. getproperty (sys:: AbstractSystem , x:: Symbol )
0 commit comments