@@ -61,36 +61,6 @@ LinearSolve.@concrete mutable struct DualLinearCache
6161 dual_u
6262end
6363
64- # function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
65- # # Solve the primal problem
66- # dual_u0 = copy(cache.linear_cache.u)
67- # sol = solve!(cache.linear_cache, alg, args...; kwargs...)
68- # primal_b = copy(cache.linear_cache.b)
69- # uu = sol.u
70-
71- # primal_sol = deepcopy(sol)
72-
73- # # Solves Dual partials separately
74- # ∂_A = cache.partials_A
75- # ∂_b = cache.partials_b
76-
77- # rhs_list = xp_linsolve_rhs(uu, ∂_A, ∂_b)
78-
79- # cache.linear_cache.u = dual_u0
80- # # We can reuse the linear cache, because the same factorization will work for the partials.
81- # for i in eachindex(rhs_list)
82- # cache.linear_cache.b = rhs_list[i]
83- # rhs_list[i] = copy(solve!(cache.linear_cache, alg, args...; kwargs...).u)
84- # end
85-
86- # # Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to
87- # cache.linear_cache.b = primal_b
88-
89- # partial_sols = rhs_list
90-
91- # primal_sol, partial_sols
92- # end
93-
9464function linearsolve_forwarddiff_solve (cache:: DualLinearCache , alg, args... ; kwargs... )
9565 # Solve the primal problem
9666 dual_u0 = copy (cache. linear_cache. u)
@@ -108,17 +78,16 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
10878
10979 cache. linear_cache. u = dual_u0
11080 # We can reuse the linear cache, because the same factorization will work for the partials.
111- partial_sols = []
11281 for i in eachindex (rhs_list)
11382 cache. linear_cache. b = rhs_list[i]
114- # For nested duals, the result of this solve might also be a dual number
115- # which will be handled recursively by the same mechanism
116- push! (partial_sols, copy (solve! (cache. linear_cache, alg, args... ; kwargs... ). u))
83+ rhs_list[i] = copy (solve! (cache. linear_cache, alg, args... ; kwargs... ). u)
11784 end
11885
119- # Reset to the original `b` and `u`
86+ # Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to
12087 cache. linear_cache. b = primal_b
12188
89+ partial_sols = rhs_list
90+
12291 primal_sol, partial_sols
12392end
12493
@@ -147,30 +116,6 @@ function xp_linsolve_rhs(
147116 b_list
148117end
149118
150- #=
151- function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...)
152- return solve(prob, nothing, args...; kwargs...)
153- end
154-
155- function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...;
156- assump = OperatorAssumptions(issquare(nodual_value(prob.A))), kwargs...)
157- # Extract primal values
158- primal_A = nodual_value(prob.A)
159- primal_b = nodual_value(prob.b)
160-
161- # Use the default algorithm selection based on primal values
162- default_alg = LinearSolve.defaultalg(primal_A, primal_b, assump)
163-
164- # Solve with the selected algorithm
165- return solve(prob, default_alg, args...; kwargs...)
166- end
167-
168- function SciMLBase.solve(prob::DualAbstractLinearProblem,
169- alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; kwargs...)
170- solve!(init(prob, alg, args...; kwargs...))
171- end
172- =#
173-
174119function linearsolve_dual_solution (
175120 u:: Number , partials, dual_type)
176121 return dual_type (u, partials)
@@ -252,7 +197,7 @@ function SciMLBase.init(
252197 primal_prob, alg, args... ; alias = alias, abstol = abstol, reltol = reltol,
253198 maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
254199 sensealg = sensealg, u0 = new_u0, kwargs... )
255- return DualLinearCache (non_partial_cache, dual_type, ∂_A, ∂_b, ! isnothing (∂_b) ? zero .(∂_b) : ∂_b, A, b, zero .(b ))
200+ return DualLinearCache (non_partial_cache, dual_type, ∂_A, ∂_b, ! isnothing (∂_b) ? zero .(∂_b) : ∂_b, A, b, zeros (dual_type, length (b) ))
256201end
257202
258203function SciMLBase. solve! (cache:: DualLinearCache , args... ; kwargs... )
@@ -264,14 +209,13 @@ function SciMLBase.solve!(cache::DualLinearCache, alg::SciMLLinearSolveAlgorithm
264209 partials = linearsolve_forwarddiff_solve (
265210 cache:: DualLinearCache , cache. alg, args... ; kwargs... )
266211 dual_sol = linearsolve_dual_solution (sol. u, partials, cache. dual_type)
267-
212+ Main . @infiltrate
268213 cache. dual_u = dual_sol
269214
270215 return SciMLBase. build_linear_solution (
271216 cache. alg, dual_sol, sol. resid, cache; sol. retcode, sol. iters, sol. stats
272217 )
273218end
274- = #
275219
276220# If setting A or b for DualLinearCache, put the Dual-stripped versions in the LinearCache
277221function Base. setproperty! (dc:: DualLinearCache , sym:: Symbol , val)
0 commit comments