11module LinearSolveForwardDiffExt
22
33using LinearSolve
4+ using LinearSolve: SciMLLinearSolveAlgorithm
45using LinearAlgebra
56using ForwardDiff
67using ForwardDiff: Dual, Partials
9293
9394function linearsolve_forwarddiff_solve (cache:: DualLinearCache , alg, args... ; kwargs... )
9495 # Solve the primal problem
96+ # Main.@infiltrate
9597 dual_u0 = copy (cache. linear_cache. u)
98+ # Main.@infiltrate
9699 sol = solve! (cache. linear_cache, alg, args... ; kwargs... )
97100 primal_b = copy (cache. linear_cache. b)
98101 uu = sol. u
@@ -104,6 +107,7 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
104107 ∂_b = cache. partials_b
105108
106109 rhs_list = xp_linsolve_rhs (uu, ∂_A, ∂_b)
110+ # Main.@infiltrate
107111
108112 cache. linear_cache. u = dual_u0
109113 # We can reuse the linear cache, because the same factorization will work for the partials.
@@ -152,8 +156,16 @@ function SciMLBase.solve(prob::DualAbstractLinearProblem, args...; kwargs...)
152156end
153157
154158function SciMLBase.solve(prob::DualAbstractLinearProblem, ::Nothing, args...;
155- assump = OperatorAssumptions(issquare(prob.A)), kwargs...)
156- return solve(prob, LinearSolve.defaultalg(prob.A, prob.b, assump), args...; kwargs...)
159+ assump = OperatorAssumptions(issquare(nodual_value(prob.A))), kwargs...)
160+ # Extract primal values
161+ primal_A = nodual_value(prob.A)
162+ primal_b = nodual_value(prob.b)
163+
164+ # Use the default algorithm selection based on primal values
165+ default_alg = LinearSolve.defaultalg(primal_A, primal_b, assump)
166+
167+ # Solve with the selected algorithm
168+ return solve(prob, default_alg, args...; kwargs...)
157169end
158170
159171function SciMLBase.solve(prob::DualAbstractLinearProblem,
@@ -226,10 +238,10 @@ function SciMLBase.init(
226238 verbose:: Bool = false ,
227239 Pl = nothing ,
228240 Pr = nothing ,
229- assumptions = OperatorAssumptions ( issquare (prob . A)) ,
241+ assumptions = nothing ,
230242 sensealg = LinearSolveAdjoint (),
231243 kwargs... )
232-
244+ @info " here! "
233245 (; A, b, u0, p) = prob
234246 new_A = nodual_value (A)
235247 new_b = nodual_value (b)
@@ -240,12 +252,14 @@ function SciMLBase.init(
240252
241253 primal_prob = remake (prob; A = new_A, b = new_b, u0 = new_u0)
242254
255+ assumptions = OperatorAssumptions (issquare (primal_prob. A))
256+
243257 if get_dual_type (prob. A) != = nothing
244258 dual_type = get_dual_type (prob. A)
245259 elseif get_dual_type (prob. b) != = nothing
246260 dual_type = get_dual_type (prob. b)
247261 end
248-
262+ # Main.@infiltrate
249263 non_partial_cache = init (
250264 primal_prob, alg, args... ; alias = alias, abstol = abstol, reltol = reltol,
251265 maxiters = maxiters, verbose = verbose, Pl = Pl, Pr = Pr, assumptions = assumptions,
@@ -254,10 +268,15 @@ function SciMLBase.init(
254268end
255269
256270function SciMLBase. solve! (cache:: DualLinearCache , args... ; kwargs... )
271+ solve! (cache, cache. alg, args... ; kwargs... )
272+ end
273+
274+ function SciMLBase. solve! (cache:: DualLinearCache , alg:: SciMLLinearSolveAlgorithm , args... ; kwargs... )
275+ # Main.@infiltrate
257276 sol,
258277 partials = linearsolve_forwarddiff_solve (
259278 cache:: DualLinearCache , cache. alg, args... ; kwargs... )
260-
279+ # Main.@infiltrate
261280 dual_sol = linearsolve_dual_solution (sol. u, partials, cache. dual_type)
262281
263282 cache. dual_u = dual_sol
334353function partials_to_list (partial_matrix)
335354 p = length (first (partial_matrix))
336355 m, n = size (partial_matrix)
337- res_list = fill (zeros (m, n), p)
356+ res_list = fill (zeros (typeof (partial_matrix[ 1 , 1 ][ 1 ]), m, n), p)
338357 for k in 1 : p
339- res = zeros (m, n)
358+ res = zeros (typeof (partial_matrix[ 1 , 1 ][ 1 ]), m, n)
340359 for i in 1 : m
341360 for j in 1 : n
342361 res[i, j] = partial_matrix[i, j][k]
0 commit comments