@@ -82,7 +82,7 @@ function calculate_hessian end
8282
8383"""
8484```julia
85- generate_tgrad(sys::AbstractTimeDependentSystem, dvs = unknowns(sys), ps = parameters (sys),
85+ generate_tgrad(sys::AbstractTimeDependentSystem, dvs = unknowns(sys), ps = full_parameters (sys),
8686 expression = Val{true}; kwargs...)
8787```
8888
@@ -93,7 +93,7 @@ function generate_tgrad end
9393
9494"""
9595```julia
96- generate_gradient(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters (sys),
96+ generate_gradient(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters (sys),
9797 expression = Val{true}; kwargs...)
9898```
9999
@@ -104,7 +104,7 @@ function generate_gradient end
104104
105105"""
106106```julia
107- generate_jacobian(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters (sys),
107+ generate_jacobian(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters (sys),
108108 expression = Val{true}; sparse = false, kwargs...)
109109```
110110
@@ -115,7 +115,7 @@ function generate_jacobian end
115115
116116"""
117117```julia
118- generate_factorized_W(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters (sys),
118+ generate_factorized_W(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters (sys),
119119 expression = Val{true}; sparse = false, kwargs...)
120120```
121121
@@ -126,7 +126,7 @@ function generate_factorized_W end
126126
127127"""
128128```julia
129- generate_hessian(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters (sys),
129+ generate_hessian(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters (sys),
130130 expression = Val{true}; sparse = false, kwargs...)
131131```
132132
@@ -137,14 +137,158 @@ function generate_hessian end
137137
138138"""
139139```julia
140- generate_function(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters (sys),
140+ generate_function(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters (sys),
141141 expression = Val{true}; kwargs...)
142142```
143143
144144Generate a function to evaluate the system's equations.
145145"""
146146function generate_function end
147147
148+ function generate_custom_function (sys:: AbstractSystem , exprs, dvs = unknowns (sys),
149+ ps = parameters (sys); wrap_code = nothing , kwargs... )
150+ p = reorder_parameters (sys, ps)
151+ isscalar = ! (exprs isa AbstractArray)
152+ if wrap_code === nothing
153+ wrap_code = isscalar ? identity : (identity, identity)
154+ end
155+ pre, sol_states = get_substitutions_and_solved_unknowns (sys)
156+
157+ if is_time_dependent (sys)
158+ return build_function (exprs,
159+ dvs,
160+ p... ,
161+ get_iv (sys);
162+ kwargs... ,
163+ postprocess_fbody = pre,
164+ states = sol_states,
165+ wrap_code = wrap_code .∘ wrap_mtkparameters (sys, isscalar) .∘
166+ wrap_array_vars (sys, exprs; dvs)
167+ )
168+ else
169+ return build_function (exprs,
170+ dvs,
171+ p... ;
172+ kwargs... ,
173+ postprocess_fbody = pre,
174+ states = sol_states,
175+ wrap_code = wrap_code .∘ wrap_mtkparameters (sys, isscalar) .∘
176+ wrap_array_vars (sys, exprs; dvs)
177+ )
178+ end
179+ end
180+
181+ function wrap_array_vars (sys:: AbstractSystem , exprs; dvs = unknowns (sys))
182+ isscalar = ! (exprs isa AbstractArray)
183+ allvars = if isscalar
184+ Set (get_variables (exprs))
185+ else
186+ union (get_variables .(exprs)... )
187+ end
188+ array_vars = Dict {Any, AbstractArray{Int}} ()
189+ for (j, x) in enumerate (dvs)
190+ if istree (x) && operation (x) == getindex
191+ arg = arguments (x)[1 ]
192+ arg in allvars || continue
193+ inds = get! (() -> Int[], array_vars, arg)
194+ push! (inds, j)
195+ end
196+ end
197+ for (k, inds) in array_vars
198+ if inds == (inds′ = inds[1 ]: inds[end ])
199+ array_vars[k] = inds′
200+ end
201+ end
202+ if isscalar
203+ function (expr)
204+ Func (
205+ expr. args,
206+ [],
207+ Let (
208+ [k ← :(view ($ (expr. args[1 ]. name), $ v)) for (k, v) in array_vars],
209+ expr. body,
210+ false
211+ )
212+ )
213+ end
214+ else
215+ function (expr)
216+ Func (
217+ expr. args,
218+ [],
219+ Let (
220+ [k ← :(view ($ (expr. args[1 ]. name), $ v)) for (k, v) in array_vars],
221+ expr. body,
222+ false
223+ )
224+ )
225+ end ,
226+ function (expr)
227+ Func (
228+ expr. args,
229+ [],
230+ Let (
231+ [k ← :(view ($ (expr. args[2 ]. name), $ v)) for (k, v) in array_vars],
232+ expr. body,
233+ false
234+ )
235+ )
236+ end
237+ end
238+ end
239+
240+ function wrap_mtkparameters (sys:: AbstractSystem , isscalar:: Bool )
241+ if has_index_cache (sys) && get_index_cache (sys) != = nothing
242+ offset = Int (is_time_dependent (sys))
243+
244+ if isscalar
245+ function (expr)
246+ p = gensym (:p )
247+ Func (
248+ [
249+ expr. args[1 ],
250+ DestructuredArgs (
251+ [arg. name for arg in expr. args[2 : (end - offset)]], p),
252+ (isone (offset) ? (expr. args[end ],) : ()). ..
253+ ],
254+ [],
255+ Let (expr. args[2 : (end - offset)], expr. body, false )
256+ )
257+ end
258+ else
259+ function (expr)
260+ p = gensym (:p )
261+ Func (
262+ [
263+ expr. args[1 ],
264+ DestructuredArgs (
265+ [arg. name for arg in expr. args[2 : (end - offset)]], p),
266+ (isone (offset) ? (expr. args[end ],) : ()). ..
267+ ],
268+ [],
269+ Let (expr. args[2 : (end - offset)], expr. body, false )
270+ )
271+ end ,
272+ function (expr)
273+ p = gensym (:p )
274+ Func (
275+ [
276+ expr. args[1 ],
277+ expr. args[2 ],
278+ DestructuredArgs (
279+ [arg. name for arg in expr. args[3 : (end - offset)]], p),
280+ (isone (offset) ? (expr. args[end ],) : ()). ..
281+ ],
282+ [],
283+ Let (expr. args[3 : (end - offset)], expr. body, false )
284+ )
285+ end
286+ end
287+ else
288+ identity
289+ end
290+ end
291+
148292mutable struct Substitutions
149293 subs:: Vector{Equation}
150294 deps:: Vector{Vector{Int}}
0 commit comments