Skip to content

Commit 80b0cb6

Browse files
fix erroring out when using solve! as is.
1 parent c149277 commit 80b0cb6

File tree

3 files changed

+26
-65
lines changed

3 files changed

+26
-65
lines changed

ext/LinearSolveMooncakeExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ function Mooncake.to_cr_tangent(x::Mooncake.PossiblyUninitTangent{T}) where {T}
3030
end
3131

3232
function Mooncake.increment_and_get_rdata!(f, r::NoRData, t::LinearCache)
33-
println("inside increment and get rdata 2")
3433
f.fields.A .+= t.A
3534
f.fields.b .+= t.b
35+
f.fields.u .+= t.u
3636

3737
return NoRData()
3838
end

src/adjoint.jl

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,7 @@ end
124124

125125
function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearSolve.LinearCache, alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; alias_A=default_alias_A(
126126
alg, cache.A, cache.b), kwargs...)
127-
_cache = deepcopy(cache)
128-
(; A, sensealg) = _cache
127+
(; A, sensealg) = cache
129128
@assert sensealg isa LinearSolveAdjoint "Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis."
130129

131130
# logic behind caching `A` and `b` for the reverse pass based on rrule above for SciMLBase.solve
@@ -138,22 +137,21 @@ function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearSolve.LinearCache, a
138137
A_ = deepcopy(A)
139138
end
140139

141-
sol = solve!(_cache)
142-
140+
sol = solve!(cache)
143141
function solve!_adjoint(∂sol)
144142
∂∅ = NoTangent()
145143
∂u = ∂sol.u
146144

147145
if sensealg.linsolve === missing
148-
λ = if _cache.cacheval isa Factorization
149-
_cache.cacheval' \ ∂u
150-
elseif _cache.cacheval isa Tuple && _cache.cacheval[1] isa Factorization
151-
first(_cache.cacheval)' \ ∂u
146+
λ = if cache.cacheval isa Factorization
147+
cache.cacheval' \ ∂u
148+
elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization
149+
first(cache.cacheval)' \ ∂u
152150
elseif alg isa AbstractKrylovSubspaceMethod
153-
invprob = LinearProblem(adjoint(_cache.A), ∂u)
151+
invprob = LinearProblem(adjoint(cache.A), ∂u)
154152
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
155153
elseif alg isa DefaultLinearSolver
156-
LinearSolve.defaultalg_adjoint_eval(_cache, ∂u)
154+
LinearSolve.defaultalg_adjoint_eval(cache, ∂u)
157155
else
158156
invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A`
159157
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
@@ -167,8 +165,13 @@ function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearSolve.LinearCache, a
167165
tu = adjoint(sol.u)
168166
∂A = BroadcastArray(@~ .-.* tu))
169167
∂b = λ
168+
169+
if (iszero(∂b) || iszero(∂A)) && !iszero(tu)
170+
error("Adjoint case currently not handled. Instead of using `solve!(cache); s1 = copy(cache.u) ...`, use `sol = solve!(cache); s1 = copy(sol.u)`.")
171+
end
172+
170173
∂prob = LinearProblem(∂A, ∂b, ∂∅)
171-
∂cache = LinearSolve.init(∂prob)
174+
∂cache = LinearSolve.init(∂prob, u=∂u)
172175
return (∂∅, ∂cache, ∂∅, ntuple(_ -> ∂∅, length(args))...)
173176
end
174177

test/nopre/mooncake.jl

Lines changed: 11 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -257,10 +257,6 @@ end
257257
@test A_grad fd_jac_A rtol = 1e-5
258258
end
259259

260-
# The below test function cases fails !
261-
# AVOID Adjoint case in code as : `solve!(cache); s1 = copy(cache.u)`.
262-
# Instead stick to code like : `sol = solve!(cache); s1 = copy(sol.u)`.
263-
264260
function f4(A, b1, b2; alg=LUFactorization())
265261
prob = LinearProblem(A, b1)
266262
cache = init(prob, alg)
@@ -272,11 +268,16 @@ function f4(A, b1, b2; alg=LUFactorization())
272268
norm(s1 + s2)
273269
end
274270

275-
# value, grad = Mooncake.value_and_gradient!!(
276-
# prepare_gradient_cache(f4, copy(A), copy(b1), copy(b2)),
277-
# f4, copy(A), copy(b1), copy(b2)
278-
# )
279-
# (0.0, (Mooncake.NoTangent(), [0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]))
271+
A = rand(n, n);
272+
b1 = rand(n);
273+
b2 = rand(n);
274+
# f_primal = f4(copy(A), copy(b1), copy(b2))
275+
276+
rule = Mooncake.build_rrule(f4, copy(A), copy(b1), copy(b2))
277+
@test_throws "Adjoint case currently not handled" Mooncake.value_and_pullback!!(
278+
rule, 1.0,
279+
f4, copy(A), copy(b1), copy(b2)
280+
)
280281

281282
# dA2 = ForwardDiff.gradient(x -> f4(x, eltype(x).(b1), eltype(x).(b2)), copy(A))
282283
# db12 = ForwardDiff.gradient(x -> f4(eltype(x).(A), x, eltype(x).(b2)), copy(b1))
@@ -285,47 +286,4 @@ end
285286
# @test value == f_primal
286287
# @test grad[2] ≈ dA2
287288
# @test grad[3] ≈ db12
288-
# @test grad[4] ≈ db22
289-
290-
function testls(A, b, u)
291-
oa = OperatorAssumptions(
292-
true, condition=LinearSolve.OperatorCondition.WellConditioned)
293-
prob = LinearProblem(A, b)
294-
linsolve = init(prob, LUFactorization(), assumptions=oa)
295-
cache = solve!(linsolve)
296-
sum(cache.u)
297-
end
298-
299-
# A = [1.0 2.0; 3.0 4.0]
300-
# b = [1.0, 2.0]
301-
# u = zero(b)
302-
# value, gradient = Mooncake.value_and_gradient!!(
303-
# prepare_gradient_cache(testls, copy(A), copy(b), copy(u)),
304-
# testls, copy(A), copy(b), copy(u)
305-
# )
306-
307-
# dA = gradient[2]
308-
# db = gradient[3]
309-
# du = gradient[4]
310-
311-
function testls(A, b, u)
312-
oa = OperatorAssumptions(
313-
true, condition=LinearSolve.OperatorCondition.WellConditioned)
314-
prob = LinearProblem(A, b)
315-
linsolve = init(prob, LUFactorization(), assumptions=oa)
316-
solve!(linsolve)
317-
sum(linsolve.u)
318-
end
319-
320-
# value, gradient = Mooncake.value_and_gradient!!(
321-
# prepare_gradient_cache(testls, copy(A), copy(b), copy(u)),
322-
# testls, copy(A), copy(b), copy(u)
323-
# )
324-
325-
# dA2 = gradient[2]
326-
# db2 = gradient[3]
327-
# du2 = gradient[4]
328-
329-
# @test dA == dA2
330-
# @test db == db2
331-
# @test du == du2
289+
# @test grad[4] ≈ db22

0 commit comments

Comments
 (0)