Skip to content

Commit 0388574

Browse files
committed
Update (dual) linear cache values/partials in-place
1 parent d9b3a96 commit 0388574

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,11 @@ function SciMLBase.solve!(
309309
end
310310

311311
# If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache
312-
function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
312+
function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val::AbstractArray)
313313
# If the property is A or b, also update it in the LinearCache
314314
if sym === :A || sym === :b || sym === :u
315-
setproperty!(dc.linear_cache, sym, nodual_value(val))
315+
prop = nodual_value!(getproperty(dc.linear_cache, sym), val) # Update in-place
316+
setproperty!(dc.linear_cache, sym, prop) # Does additional invalidation logic etc.
316317
elseif hasfield(DualLinearCache, sym)
317318
setfield!(dc, sym, val)
318319
elseif hasfield(LinearSolve.LinearCache, sym)
@@ -322,15 +323,15 @@ function Base.setproperty!(dc::DualLinearCache, sym::Symbol, val)
322323
# Update the partials and invalidate cache if setting A or b
323324
if sym === :A
324325
setfield!(dc, :dual_A, val)
325-
setfield!(dc, :partials_A, partial_vals(val))
326+
partial_vals!(getfield(dc, :partials_A), val) # Update in-place
326327
setfield!(dc, :rhs_cache_valid, false) # Invalidate cache
327328
elseif sym === :b
328329
setfield!(dc, :dual_b, val)
329-
setfield!(dc, :partials_b, partial_vals(val))
330+
partial_vals!(getfield(dc, :partials_b), val) # Update in-place
330331
setfield!(dc, :rhs_cache_valid, false) # Invalidate cache
331332
elseif sym === :u
332333
setfield!(dc, :dual_u, val)
333-
setfield!(dc, :partials_u, partial_vals(val))
334+
partial_vals!(getfield(dc, :partials_u), val) # Update in-place
334335
end
335336
end
336337

@@ -360,11 +361,13 @@ partial_vals(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.pa
360361
partial_vals(x::Dual{T, V, P}) where {T, V <: Dual, P} = ForwardDiff.partials(x)
361362
partial_vals(x::AbstractArray{<:Dual}) = map(ForwardDiff.partials, x)
362363
partial_vals(x) = nothing
364+
partial_vals!(out, x) = map!(partial_vals, out, x) # Update in-place
363365

364366
# Add recursive handling for nested dual values
365367
nodual_value(x) = x
366368
nodual_value(x::Dual{T, V, P}) where {T, V <: AbstractFloat, P} = ForwardDiff.value(x)
367369
nodual_value(x::Dual{T, V, P}) where {T, V <: Dual, P} = x.value # Keep the inner dual intact
370+
nodual_value!(out, x) = map!(nodual_value, out, x) # Update in-place
368371

369372
function nodual_value(x::AbstractArray{<:Dual})
370373
# Create a similar array with the appropriate element type

0 commit comments

Comments
 (0)