Skip to content

Commit c149277

Browse files
rrules for solve!, init
1 parent f247c7f commit c149277

File tree

3 files changed

+267
-0
lines changed

3 files changed

+267
-0
lines changed

ext/LinearSolveMooncakeExt.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,20 @@ function Mooncake.to_cr_tangent(x::Mooncake.PossiblyUninitTangent{T}) where {T}
2929
end
3030
end
3131

32+
function Mooncake.increment_and_get_rdata!(f, r::NoRData, t::LinearCache)
33+
println("inside increment and get rdata 2")
34+
f.fields.A .+= t.A
35+
f.fields.b .+= t.b
36+
37+
return NoRData()
38+
end
39+
40+
# rrules for LinearCache
41+
@from_chainrules MinimalCtx Tuple{typeof(init),LinearProblem,SciMLLinearSolveAlgorithm} true ReverseMode
42+
@from_chainrules MinimalCtx Tuple{typeof(init),LinearProblem,Nothing} true ReverseMode
43+
44+
# rrule for solve!
45+
@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve!),LinearCache,SciMLLinearSolveAlgorithm} true ReverseMode
46+
@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve!),LinearCache,Nothing} true ReverseMode
47+
3248
end

src/adjoint.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,78 @@ function CRC.rrule(::Type{<:LinearProblem}, A, b, p; kwargs...)
9999
∇prob(∂prob) = (NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p)
100100
return prob, ∇prob
101101
end
102+
103+
function CRC.rrule(T::typeof(LinearSolve.init), prob::LinearSolve.LinearProblem, alg::Nothing, args...; kwargs...)
104+
assump = OperatorAssumptions(issquare(prob.A))
105+
alg = defaultalg(prob.A, prob.b, assump)
106+
CRC.rrule(T, prob, alg, args...; kwargs...)
107+
end
108+
109+
function CRC.rrule(::typeof(LinearSolve.init), prob::LinearSolve.LinearProblem, alg::Union{LinearSolve.SciMLLinearSolveAlgorithm,Nothing}, args...; kwargs...)
110+
init_res = LinearSolve.init(prob, alg)
111+
function init_adjoint(∂init)
112+
∂prob = LinearProblem(∂init.A, ∂init.b, NoTangent())
113+
return NoTangent(), ∂prob, NoTangent(), ntuple((_ -> NoTangent(), length(args))...)
114+
end
115+
116+
return init_res, init_adjoint
117+
end
118+
119+
function CRC.rrule(T::typeof(SciMLBase.solve!), cache::LinearSolve.LinearCache, alg::Nothing, args...; kwargs...)
120+
assump = OperatorAssumptions()
121+
alg = defaultalg(cache.A, cache.b, assump)
122+
CRC.rrule(T, cache, alg, args...; kwargs)
123+
end
124+
125+
function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearSolve.LinearCache, alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; alias_A=default_alias_A(
126+
alg, cache.A, cache.b), kwargs...)
127+
_cache = deepcopy(cache)
128+
(; A, sensealg) = _cache
129+
@assert sensealg isa LinearSolveAdjoint "Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis."
130+
131+
# logic behind caching `A` and `b` for the reverse pass based on rrule above for SciMLBase.solve
132+
if sensealg.linsolve === missing
133+
if !(alg isa AbstractFactorization || alg isa AbstractKrylovSubspaceMethod ||
134+
alg isa DefaultLinearSolver)
135+
A_ = alias_A ? deepcopy(A) : A
136+
end
137+
else
138+
A_ = deepcopy(A)
139+
end
140+
141+
sol = solve!(_cache)
142+
143+
function solve!_adjoint(∂sol)
144+
∂∅ = NoTangent()
145+
∂u = ∂sol.u
146+
147+
if sensealg.linsolve === missing
148+
λ = if _cache.cacheval isa Factorization
149+
_cache.cacheval' \ ∂u
150+
elseif _cache.cacheval isa Tuple && _cache.cacheval[1] isa Factorization
151+
first(_cache.cacheval)' \ ∂u
152+
elseif alg isa AbstractKrylovSubspaceMethod
153+
invprob = LinearProblem(adjoint(_cache.A), ∂u)
154+
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
155+
elseif alg isa DefaultLinearSolver
156+
LinearSolve.defaultalg_adjoint_eval(_cache, ∂u)
157+
else
158+
invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A`
159+
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
160+
end
161+
else
162+
invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A`
163+
λ = solve(
164+
invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u
165+
end
166+
167+
tu = adjoint(sol.u)
168+
∂A = BroadcastArray(@~ .-.* tu))
169+
∂b = λ
170+
∂prob = LinearProblem(∂A, ∂b, ∂∅)
171+
∂cache = LinearSolve.init(∂prob)
172+
return (∂∅, ∂cache, ∂∅, ntuple(_ -> ∂∅, length(args))...)
173+
end
174+
175+
return sol, solve!_adjoint
176+
end

test/nopre/mooncake.jl

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,179 @@ for alg in (
153153
@test results[1] fA(A)
154154
@test mooncake_gradient fd_jac rtol = 1e-5
155155
end
156+
157+
# Tests for solve! and init rrules.
158+
159+
n = 4
160+
A = rand(n, n);
161+
b1 = rand(n);
162+
b2 = rand(n);
163+
164+
function f(A, b1, b2; alg=LUFactorization())
165+
prob = LinearProblem(A, b1)
166+
cache = init(prob, alg)
167+
s1 = copy(solve!(cache).u)
168+
cache.b = b2
169+
s2 = solve!(cache).u
170+
norm(s1 + s2)
171+
end
172+
173+
f_primal = f(copy(A), copy(b1), copy(b2))
174+
value, gradient = Mooncake.value_and_gradient!!(
175+
prepare_gradient_cache(f, copy(A), copy(b1), copy(b2)),
176+
f, copy(A), copy(b1), copy(b2)
177+
)
178+
179+
dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1), eltype(x).(b2)), copy(A))
180+
db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x, eltype(x).(b2)), copy(b1))
181+
db22 = ForwardDiff.gradient(x -> f(eltype(x).(A), eltype(x).(b1), x), copy(b2))
182+
183+
@test value == f_primal
184+
@test gradient[2] dA2
185+
@test gradient[3] db12
186+
@test gradient[4] db22
187+
188+
function f2(A, b1, b2; alg=RFLUFactorization())
189+
prob = LinearProblem(A, b1)
190+
cache = init(prob, alg)
191+
s1 = copy(solve!(cache).u)
192+
cache.b = b2
193+
s2 = solve!(cache).u
194+
norm(s1 + s2)
195+
end
196+
197+
f_primal = f2(copy(A), copy(b1), copy(b2))
198+
value, gradient = Mooncake.value_and_gradient!!(
199+
prepare_gradient_cache(f2, copy(A), copy(b1), copy(b2)),
200+
f2, copy(A), copy(b1), copy(b2)
201+
)
202+
203+
@test value == f_primal
204+
@test gradient[2] dA2
205+
@test gradient[3] db12
206+
@test gradient[4] db22
207+
208+
function f3(A, b1, b2; alg=LUFactorization())
209+
# alg = KrylovJL_GMRES())
210+
prob = LinearProblem(A, b1)
211+
cache = init(prob, alg)
212+
s1 = copy(solve!(cache).u)
213+
cache.b = b2
214+
s2 = solve!(cache).u
215+
norm(s1 + s2)
216+
end
217+
218+
f_primal = f3(copy(A), copy(b1), copy(b2))
219+
value, gradient = Mooncake.value_and_gradient!!(
220+
prepare_gradient_cache(f3, copy(A), copy(b1), copy(b2)),
221+
f3, copy(A), copy(b1), copy(b2)
222+
)
223+
224+
@test value == f_primal
225+
@test gradient[2] dA2 atol = 5e-5
226+
@test gradient[3] db12
227+
@test gradient[4] db22
228+
229+
A = rand(n, n);
230+
b1 = rand(n);
231+
232+
function fnice(A, b, alg)
233+
prob = LinearProblem(A, b)
234+
sol1 = solve(prob, alg)
235+
return sum(sol1.u)
236+
end
237+
238+
@testset for alg in (
239+
LUFactorization(),
240+
RFLUFactorization(),
241+
KrylovJL_GMRES()
242+
)
243+
# for B
244+
fb_closure = b -> fnice(A, b, alg)
245+
fd_jac_b = FiniteDiff.finite_difference_jacobian(fb_closure, b1) |> vec
246+
247+
val, en_jac = Mooncake.value_and_gradient!!(
248+
prepare_gradient_cache(fnice, copy(A), copy(b1), alg),
249+
fnice, copy(A), copy(b1), alg
250+
)
251+
@test en_jac[3] fd_jac_b rtol = 1e-5
252+
253+
# For A
254+
fA_closure = A -> fnice(A, b1, alg)
255+
fd_jac_A = FiniteDiff.finite_difference_jacobian(fA_closure, A) |> vec
256+
A_grad = en_jac[2] |> vec
257+
@test A_grad fd_jac_A rtol = 1e-5
258+
end
259+
260+
# The below test function cases fails !
261+
# AVOID Adjoint case in code as : `solve!(cache); s1 = copy(cache.u)`.
262+
# Instead stick to code like : `sol = solve!(cache); s1 = copy(sol.u)`.
263+
264+
function f4(A, b1, b2; alg=LUFactorization())
265+
prob = LinearProblem(A, b1)
266+
cache = init(prob, alg)
267+
solve!(cache)
268+
s1 = copy(cache.u)
269+
cache.b = b2
270+
solve!(cache)
271+
s2 = copy(cache.u)
272+
norm(s1 + s2)
273+
end
274+
275+
# value, grad = Mooncake.value_and_gradient!!(
276+
# prepare_gradient_cache(f4, copy(A), copy(b1), copy(b2)),
277+
# f4, copy(A), copy(b1), copy(b2)
278+
# )
279+
# (0.0, (Mooncake.NoTangent(), [0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]))
280+
281+
# dA2 = ForwardDiff.gradient(x -> f4(x, eltype(x).(b1), eltype(x).(b2)), copy(A))
282+
# db12 = ForwardDiff.gradient(x -> f4(eltype(x).(A), x, eltype(x).(b2)), copy(b1))
283+
# db22 = ForwardDiff.gradient(x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b2))
284+
285+
# @test value == f_primal
286+
# @test grad[2] ≈ dA2
287+
# @test grad[3] ≈ db12
288+
# @test grad[4] ≈ db22
289+
290+
function testls(A, b, u)
291+
oa = OperatorAssumptions(
292+
true, condition=LinearSolve.OperatorCondition.WellConditioned)
293+
prob = LinearProblem(A, b)
294+
linsolve = init(prob, LUFactorization(), assumptions=oa)
295+
cache = solve!(linsolve)
296+
sum(cache.u)
297+
end
298+
299+
# A = [1.0 2.0; 3.0 4.0]
300+
# b = [1.0, 2.0]
301+
# u = zero(b)
302+
# value, gradient = Mooncake.value_and_gradient!!(
303+
# prepare_gradient_cache(testls, copy(A), copy(b), copy(u)),
304+
# testls, copy(A), copy(b), copy(u)
305+
# )
306+
307+
# dA = gradient[2]
308+
# db = gradient[3]
309+
# du = gradient[4]
310+
311+
function testls(A, b, u)
312+
oa = OperatorAssumptions(
313+
true, condition=LinearSolve.OperatorCondition.WellConditioned)
314+
prob = LinearProblem(A, b)
315+
linsolve = init(prob, LUFactorization(), assumptions=oa)
316+
solve!(linsolve)
317+
sum(linsolve.u)
318+
end
319+
320+
# value, gradient = Mooncake.value_and_gradient!!(
321+
# prepare_gradient_cache(testls, copy(A), copy(b), copy(u)),
322+
# testls, copy(A), copy(b), copy(u)
323+
# )
324+
325+
# dA2 = gradient[2]
326+
# db2 = gradient[3]
327+
# du2 = gradient[4]
328+
329+
# @test dA == dA2
330+
# @test db == db2
331+
# @test du == du2

0 commit comments

Comments
 (0)