@@ -45,7 +45,7 @@ const DualBLinearProblem = LinearProblem{
4545} where {iip}
4646
4747const DualAbstractLinearProblem = Union{
48- SingleDualLinearProblem, DualALinearProblem, DualBLinearProblem}
48+ SingleDualLinearProblem, DualALinearProblem, DualBLinearProblem, NestedDualLinearProblem }
4949
5050LinearSolve. @concrete mutable struct DualLinearCache
5151 linear_cache
@@ -60,6 +60,36 @@ LinearSolve.@concrete mutable struct DualLinearCache
6060 dual_u
6161end
6262
63+ # function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
64+ # # Solve the primal problem
65+ # dual_u0 = copy(cache.linear_cache.u)
66+ # sol = solve!(cache.linear_cache, alg, args...; kwargs...)
67+ # primal_b = copy(cache.linear_cache.b)
68+ # uu = sol.u
69+
70+ # primal_sol = deepcopy(sol)
71+
72+ # # Solves Dual partials separately
73+ # ∂_A = cache.partials_A
74+ # ∂_b = cache.partials_b
75+
76+ # rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)
77+
78+ # cache.linear_cache.u = dual_u0
79+ # # We can reuse the linear cache, because the same factorization will work for the partials.
80+ # for i in eachindex(rhs_list)
81+ # cache.linear_cache.b = rhs_list[i]
82+ # rhs_list[i] = copy(solve!(cache.linear_cache, alg, args...; kwargs...).u)
83+ # end
84+
85+ # # Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to
86+ # cache.linear_cache.b = primal_b
87+
88+ # partial_sols = rhs_list
89+
90+ # primal_sol, partial_sols
91+ # end
92+
6393function linearsolve_forwarddiff_solve (cache:: DualLinearCache , alg, args... ; kwargs... )
6494 # Solve the primal problem
6595 dual_u0 = copy (cache. linear_cache. u)
@@ -77,16 +107,17 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
77107
78108 cache. linear_cache. u = dual_u0
79109 # We can reuse the linear cache, because the same factorization will work for the partials.
110+ partial_sols = []
80111 for i in eachindex (rhs_list)
81112 cache. linear_cache. b = rhs_list[i]
82- rhs_list[i] = copy (solve! (cache. linear_cache, alg, args... ; kwargs... ). u)
113+ # For nested duals, the result of this solve might also be a dual number
114+ # which will be handled recursively by the same mechanism
115+ push! (partial_sols, copy (solve! (cache. linear_cache, alg, args... ; kwargs... ). u))
83116 end
84117
85- # Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to
118+ # Reset to the original `b` and `u`
86119 cache. linear_cache. b = primal_b
87120
88- partial_sols = rhs_list
89-
90121 primal_sol, partial_sols
91122end
92123
@@ -136,14 +167,55 @@ function linearsolve_dual_solution(
136167 return dual_type (u, partials)
137168end
138169
139- function linearsolve_dual_solution (
140- u:: AbstractArray , partials, dual_type)
170+ function linearsolve_dual_solution (u:: Number , partials,
171+ dual_type:: Type{<:Dual{T, V, P}} ) where {T, V <: AbstractFloat , P}
172+ # Handle single-level duals
173+ return dual_type (u, partials)
174+ end
175+
176+ # function linearsolve_dual_solution(
177+ # u::AbstractArray, partials, dual_type)
178+ # partials_list = RecursiveArrayTools.VectorOfArray(partials)
179+ # return map(((uᵢ, pᵢ),) -> dual_type(uᵢ, Partials(Tuple(pᵢ))),
180+ # zip(u, partials_list[i, :] for i in 1:length(partials_list[1])))
181+ # end
182+
183+ function linearsolve_dual_solution (u:: AbstractArray , partials,
184+ dual_type:: Type{<:Dual{T, V, P}} ) where {T, V <: AbstractFloat , P}
185+ # Handle single-level duals for arrays
141186 partials_list = RecursiveArrayTools. VectorOfArray (partials)
142187 return map (((uᵢ, pᵢ),) -> dual_type (uᵢ, Partials (Tuple (pᵢ))),
143188 zip (u, partials_list[i, :] for i in 1 : length (partials_list[1 ])))
144189end
145190
146- #=
191+
192+ function linearsolve_dual_solution (
193+ u:: Number , partials, dual_type:: Type{<:Dual{T, V, P}} ) where {T, V <: Dual , P}
194+ # Handle nested duals - recursive case
195+ # For nested duals, u itself could be a dual number with its own partials
196+ inner_dual_type = V
197+ outer_tag_type = T
198+
199+ # Reconstruct the nested dual by first building the inner dual, then the outer one
200+ inner_dual = u # u is already a dual for the inner level
201+
202+ # Create outer dual with the inner dual as its value
203+ return Dual {outer_tag_type, typeof(inner_dual), P} (inner_dual, partials)
204+ end
205+
206+ function linearsolve_dual_solution (u:: AbstractArray , partials,
207+ dual_type:: Type{<:Dual{T, V, P}} ) where {T, V <: Dual , P}
208+ # Handle nested duals for arrays - recursive case
209+ inner_dual_type = V
210+ outer_tag_type = T
211+
212+ partials_list = RecursiveArrayTools. VectorOfArray (partials)
213+
214+ # For nested duals, each element of u could be a dual number with its own partials
215+ return map (((uᵢ, pᵢ),) -> Dual {outer_tag_type, typeof(uᵢ), P} (uᵢ, Partials (Tuple (pᵢ))),
216+ zip (u, partials_list[i, :] for i in 1 : length (partials_list[1 ])))
217+ end
218+
147219function SciMLBase. init (
148220 prob:: DualAbstractLinearProblem , alg:: LinearSolve.SciMLLinearSolveAlgorithm ,
149221 args... ;
@@ -235,18 +307,23 @@ end
235307
236308
237309
238- # Helper functions for Dual numbers
239- get_dual_type (x:: Dual ) = typeof (x)
310+ # Enhanced helper functions for Dual numbers to handle recursion
311+ get_dual_type (x:: Dual{T, V, P} ) where {T, V <: AbstractFloat , P} = typeof (x)
312+ get_dual_type (x:: Dual{T, V, P} ) where {T, V <: Dual , P} = typeof (x)
240313get_dual_type (x:: AbstractArray{<:Dual} ) = eltype (x)
241314get_dual_type (x) = nothing
242315
243- partial_vals (x:: Dual ) = ForwardDiff. partials (x)
316+ # Add recursive handling for nested dual partials
317+ partial_vals (x:: Dual{T, V, P} ) where {T, V <: AbstractFloat , P} = ForwardDiff. partials (x)
318+ partial_vals (x:: Dual{T, V, P} ) where {T, V <: Dual , P} = ForwardDiff. partials (x)
244319partial_vals (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. partials, x)
245320partial_vals (x) = nothing
246321
322+ # Add recursive handling for nested dual values
247323nodual_value (x) = x
248- nodual_value (x:: Dual ) = ForwardDiff. value (x)
249- nodual_value (x:: AbstractArray{<:Dual} ) = map (ForwardDiff. value, x)
324+ nodual_value (x:: Dual{T, V, P} ) where {T, V <: AbstractFloat , P} = ForwardDiff. value (x)
325+ nodual_value (x:: Dual{T, V, P} ) where {T, V <: Dual , P} = x. value # Keep the inner dual intact
326+ nodual_value (x:: AbstractArray{<:Dual} ) = map (nodual_value, x)
250327
251328
252329function partials_to_list (partial_matrix:: Vector )
0 commit comments