Skip to content

Commit 36de150

Browse files
Mutation handling, nopre lts tests pass
1 parent 79fbb05 commit 36de150

File tree

3 files changed

+131
-99
lines changed

3 files changed

+131
-99
lines changed

ext/LinearSolveMooncakeExt.jl

Lines changed: 94 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
module LinearSolveMooncakeExt
22

33
using Mooncake
4-
using Mooncake: @from_chainrules, MinimalCtx, ReverseMode, NoRData, increment!!
4+
using Mooncake: @from_chainrules, MinimalCtx, ReverseMode, NoRData, increment!!, @is_primitive, primal, zero_fcodual, CoDual, rdata, fdata
55
using LinearSolve: LinearSolve, SciMLLinearSolveAlgorithm, init, solve!, LinearProblem,
6-
LinearCache, AbstractKrylovSubspaceMethod, DefaultLinearSolver,
7-
defaultalg_adjoint_eval, solve
6+
LinearCache, AbstractKrylovSubspaceMethod, DefaultLinearSolver, LinearSolveAdjoint,
7+
defaultalg_adjoint_eval, solve, LUFactorization
88
using LinearSolve.LinearAlgebra
9+
using LazyArrays: @~, BroadcastArray
910
using SciMLBase
1011

11-
@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve), LinearProblem, Nothing} true ReverseMode
12+
@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve),LinearProblem,Nothing} true ReverseMode
1213
@from_chainrules MinimalCtx Tuple{
13-
typeof(SciMLBase.solve), LinearProblem, SciMLLinearSolveAlgorithm} true ReverseMode
14+
typeof(SciMLBase.solve),LinearProblem,SciMLLinearSolveAlgorithm} true ReverseMode
1415
@from_chainrules MinimalCtx Tuple{
15-
Type{<:LinearProblem}, AbstractMatrix, AbstractVector, SciMLBase.NullParameters} true ReverseMode
16+
Type{<:LinearProblem},AbstractMatrix,AbstractVector,SciMLBase.NullParameters} true ReverseMode
1617

1718
function Mooncake.increment_and_get_rdata!(f, r::NoRData, t::LinearProblem)
1819
f.data.A .+= t.A
@@ -41,8 +42,92 @@ end
4142
@from_chainrules MinimalCtx Tuple{typeof(init),LinearProblem,SciMLLinearSolveAlgorithm} true ReverseMode
4243
@from_chainrules MinimalCtx Tuple{typeof(init),LinearProblem,Nothing} true ReverseMode
4344

44-
# rrule for solve!
45-
@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve!),LinearCache,SciMLLinearSolveAlgorithm} true ReverseMode
46-
@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve!),LinearCache,Nothing} true ReverseMode
45+
# rrules for solve!
46+
# NOTE - Avoid Mooncake.prepare_gradient_cache, only use Mooncake.prepare_pullback_cache (and therefore Mooncake.value_and_pullback!!)
47+
# calling Mooncake.prepare_gradient_cache for functions with solve! will activate unsupported Adjoint case exception for below rrules
48+
# This because in Mooncake.prepare_gradient_cache we reset stacks + state by passing in zero gradient in the reverse pass once.
49+
# However, if one has a valid cache then they can directly use Mooncake.value_and_gradient!!.
50+
51+
@is_primitive MinimalCtx Tuple{typeof(SciMLBase.solve!),LinearCache,SciMLLinearSolveAlgorithm,Vararg}
52+
@is_primitive MinimalCtx Tuple{typeof(SciMLBase.solve!),LinearCache,Nothing,Vararg}
53+
54+
function Mooncake.rrule!!(sig::CoDual{typeof(SciMLBase.solve!)}, _cache::CoDual{<:LinearSolve.LinearCache}, _alg::CoDual{Nothing}, args::Vararg{Any,N}; kwargs...) where {N}
55+
cache = primal(_cache)
56+
assump = OperatorAssumptions()
57+
_alg.x = defaultalg(cache.A, cache.b, assump)
58+
Mooncake.rrule!!(sig, _cache, _alg, args...; kwargs...)
59+
end
60+
61+
function Mooncake.rrule!!(::CoDual{typeof(SciMLBase.solve!)}, _cache::CoDual{<:LinearSolve.LinearCache}, _alg::CoDual{<:SciMLLinearSolveAlgorithm}, args::Vararg{Any,N}; alias_A=zero_fcodual(LinearSolve.default_alias_A(
62+
_alg.x, _cache.x.A, _cache.x.b)), kwargs...) where {N}
63+
64+
cache = primal(_cache)
65+
alg = primal(_alg)
66+
_args = map(primal, args)
67+
68+
(; A, b, sensealg) = cache
69+
A_orig = copy(A)
70+
b_orig = copy(b)
71+
72+
@assert sensealg isa LinearSolveAdjoint "Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis."
73+
74+
# logic behind caching `A` and `b` for the reverse pass based on rrule above for SciMLBase.solve
75+
if sensealg.linsolve === missing
76+
if !(alg isa LinearSolve.AbstractFactorization || alg isa LinearSolve.AbstractKrylovSubspaceMethod ||
77+
alg isa LinearSolve.DefaultLinearSolver)
78+
A_ = alias_A ? deepcopy(A) : A
79+
end
80+
else
81+
A_ = deepcopy(A)
82+
end
83+
84+
sol = zero_fcodual(solve!(cache))
85+
cache.A = A_orig
86+
cache.b = b_orig
87+
88+
function solve!_adjoint(::NoRData)
89+
∂∅ = NoRData()
90+
cachenew = init(LinearProblem(cache.A, cache.b), LUFactorization(), _args...; kwargs...)
91+
new_sol = solve!(cachenew)
92+
∂u = sol.dx.data.u
93+
94+
if sensealg.linsolve === missing
95+
λ = if cache.cacheval isa Factorization
96+
cache.cacheval' \ ∂u
97+
elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization
98+
first(cache.cacheval)' \ ∂u
99+
elseif alg isa AbstractKrylovSubspaceMethod
100+
invprob = LinearProblem(adjoint(cache.A), ∂u)
101+
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
102+
elseif alg isa DefaultLinearSolver
103+
LinearSolve.defaultalg_adjoint_eval(cache, ∂u)
104+
else
105+
invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A`
106+
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
107+
end
108+
else
109+
invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A`
110+
λ = solve(
111+
invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u
112+
end
113+
114+
tu = adjoint(new_sol.u)
115+
∂A = BroadcastArray(@~ .-.* tu))
116+
∂b = λ
117+
118+
if (iszero(∂b) || iszero(∂A)) && !iszero(tu)
119+
error("Adjoint case currently not handled. Instead of using `solve!(cache); s1 = copy(cache.u) ...`, use `sol = solve!(cache); s1 = copy(sol.u)`.")
120+
end
121+
122+
fdata(_cache.dx).fields.A .+= ∂A
123+
fdata(_cache.dx).fields.b .+= ∂b
124+
fdata(_cache.dx).fields.u .+= ∂u
125+
126+
# rdata for cache is a struct with NoRdata field values
127+
return (∂∅, rdata(_cache.dx), ∂∅, ntuple(_ -> ∂∅, length(args))...)
128+
end
129+
130+
return sol, solve!_adjoint
131+
end
47132

48133
end

src/adjoint.jl

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -115,65 +115,3 @@ function CRC.rrule(::typeof(LinearSolve.init), prob::LinearSolve.LinearProblem,
115115

116116
return init_res, init_adjoint
117117
end
118-
119-
function CRC.rrule(T::typeof(SciMLBase.solve!), cache::LinearSolve.LinearCache, alg::Nothing, args...; kwargs...)
120-
assump = OperatorAssumptions()
121-
alg = defaultalg(cache.A, cache.b, assump)
122-
CRC.rrule(T, cache, alg, args...; kwargs)
123-
end
124-
125-
function CRC.rrule(::typeof(SciMLBase.solve!), cache::LinearSolve.LinearCache, alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; alias_A=default_alias_A(
126-
alg, cache.A, cache.b), kwargs...)
127-
(; A, sensealg) = cache
128-
@assert sensealg isa LinearSolveAdjoint "Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis."
129-
130-
# logic behind caching `A` and `b` for the reverse pass based on rrule above for SciMLBase.solve
131-
if sensealg.linsolve === missing
132-
if !(alg isa AbstractFactorization || alg isa AbstractKrylovSubspaceMethod ||
133-
alg isa DefaultLinearSolver)
134-
A_ = alias_A ? deepcopy(A) : A
135-
end
136-
else
137-
A_ = deepcopy(A)
138-
end
139-
140-
sol = solve!(cache)
141-
function solve!_adjoint(∂sol)
142-
∂∅ = NoTangent()
143-
∂u = ∂sol.u
144-
145-
if sensealg.linsolve === missing
146-
λ = if cache.cacheval isa Factorization
147-
cache.cacheval' \ ∂u
148-
elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization
149-
first(cache.cacheval)' \ ∂u
150-
elseif alg isa AbstractKrylovSubspaceMethod
151-
invprob = LinearProblem(adjoint(cache.A), ∂u)
152-
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
153-
elseif alg isa DefaultLinearSolver
154-
LinearSolve.defaultalg_adjoint_eval(cache, ∂u)
155-
else
156-
invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A`
157-
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
158-
end
159-
else
160-
invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A`
161-
λ = solve(
162-
invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u
163-
end
164-
165-
tu = adjoint(sol.u)
166-
∂A = BroadcastArray(@~ .-.* tu))
167-
∂b = λ
168-
169-
if (iszero(∂b) || iszero(∂A)) && !iszero(tu)
170-
error("Adjoint case currently not handled. Instead of using `solve!(cache); s1 = copy(cache.u) ...`, use `sol = solve!(cache); s1 = copy(sol.u)`.")
171-
end
172-
173-
∂prob = LinearProblem(∂A, ∂b, ∂∅)
174-
∂cache = LinearSolve.init(∂prob, u=∂u)
175-
return (∂∅, ∂cache, ∂∅, ntuple(_ -> ∂∅, length(args))...)
176-
end
177-
178-
return sol, solve!_adjoint
179-
end

test/nopre/mooncake.jl

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@ b1 = rand(n);
1111

1212
function f(A, b1; alg = LUFactorization())
1313
prob = LinearProblem(A, b1)
14-
1514
sol1 = solve(prob, alg)
16-
1715
s1 = sol1.u
1816
norm(s1)
1917
end
@@ -160,7 +158,7 @@ A = rand(n, n);
160158
b1 = rand(n);
161159
b2 = rand(n);
162160

163-
function f(A, b1, b2; alg=LUFactorization())
161+
function f_(A, b1, b2; alg=LUFactorization())
164162
prob = LinearProblem(A, b1)
165163
cache = init(prob, alg)
166164
s1 = copy(solve!(cache).u)
@@ -169,22 +167,23 @@ function f(A, b1, b2; alg=LUFactorization())
169167
norm(s1 + s2)
170168
end
171169

172-
f_primal = f(copy(A), copy(b1), copy(b2))
173-
value, gradient = Mooncake.value_and_gradient!!(
174-
prepare_gradient_cache(f, copy(A), copy(b1), copy(b2)),
175-
f, copy(A), copy(b1), copy(b2)
170+
f_primal = f_(copy(A), copy(b1), copy(b2))
171+
rule = Mooncake.build_rrule(f_, copy(A), copy(b1), copy(b2))
172+
value, gradient = Mooncake.value_and_pullback!!(
173+
rule, 1.0,
174+
f_, copy(A), copy(b1), copy(b2)
176175
)
177176

178-
dA2 = ForwardDiff.gradient(x -> f(x, eltype(x).(b1), eltype(x).(b2)), copy(A))
179-
db12 = ForwardDiff.gradient(x -> f(eltype(x).(A), x, eltype(x).(b2)), copy(b1))
180-
db22 = ForwardDiff.gradient(x -> f(eltype(x).(A), eltype(x).(b1), x), copy(b2))
177+
dA2 = ForwardDiff.gradient(x -> f_(x, eltype(x).(b1), eltype(x).(b2)), copy(A))
178+
db12 = ForwardDiff.gradient(x -> f_(eltype(x).(A), x, eltype(x).(b2)), copy(b1))
179+
db22 = ForwardDiff.gradient(x -> f_(eltype(x).(A), eltype(x).(b1), x), copy(b2))
181180

182181
@test value == f_primal
183182
@test gradient[2] dA2
184183
@test gradient[3] db12
185184
@test gradient[4] db22
186185

187-
function f2(A, b1, b2; alg=RFLUFactorization())
186+
function f_2(A, b1, b2; alg=RFLUFactorization())
188187
prob = LinearProblem(A, b1)
189188
cache = init(prob, alg)
190189
s1 = copy(solve!(cache).u)
@@ -193,18 +192,23 @@ function f2(A, b1, b2; alg=RFLUFactorization())
193192
norm(s1 + s2)
194193
end
195194

196-
f_primal = f2(copy(A), copy(b1), copy(b2))
197-
value, gradient = Mooncake.value_and_gradient!!(
198-
prepare_gradient_cache(f2, copy(A), copy(b1), copy(b2)),
199-
f2, copy(A), copy(b1), copy(b2)
195+
f_primal = f_2(copy(A), copy(b1), copy(b2))
196+
rule = Mooncake.build_rrule(f_2, copy(A), copy(b1), copy(b2))
197+
value, gradient = Mooncake.value_and_pullback!!(
198+
rule, 1.0,
199+
f_2, copy(A), copy(b1), copy(b2)
200200
)
201201

202+
dA2 = ForwardDiff.gradient(x -> f_2(x, eltype(x).(b1), eltype(x).(b2)), copy(A))
203+
db12 = ForwardDiff.gradient(x -> f_2(eltype(x).(A), x, eltype(x).(b2)), copy(b1))
204+
db22 = ForwardDiff.gradient(x -> f_2(eltype(x).(A), eltype(x).(b1), x), copy(b2))
205+
202206
@test value == f_primal
203207
@test gradient[2] dA2
204208
@test gradient[3] db12
205209
@test gradient[4] db22
206210

207-
function f3(A, b1, b2; alg=KrylovJL_GMRES())
211+
function f_3(A, b1, b2; alg=KrylovJL_GMRES())
208212
prob = LinearProblem(A, b1)
209213
cache = init(prob, alg)
210214
s1 = copy(solve!(cache).u)
@@ -213,18 +217,23 @@ function f3(A, b1, b2; alg=KrylovJL_GMRES())
213217
norm(s1 + s2)
214218
end
215219

216-
f_primal = f3(copy(A), copy(b1), copy(b2))
217-
value, gradient = Mooncake.value_and_gradient!!(
218-
prepare_gradient_cache(f3, copy(A), copy(b1), copy(b2)),
219-
f3, copy(A), copy(b1), copy(b2)
220+
f_primal = f_3(copy(A), copy(b1), copy(b2))
221+
rule = Mooncake.build_rrule(f_3, copy(A), copy(b1), copy(b2))
222+
value, gradient = Mooncake.value_and_pullback!!(
223+
rule, 1.0,
224+
f_3, copy(A), copy(b1), copy(b2)
220225
)
221226

227+
dA2 = ForwardDiff.gradient(x -> f_3(x, eltype(x).(b1), eltype(x).(b2)), copy(A))
228+
db12 = ForwardDiff.gradient(x -> f_3(eltype(x).(A), x, eltype(x).(b2)), copy(b1))
229+
db22 = ForwardDiff.gradient(x -> f_3(eltype(x).(A), eltype(x).(b1), x), copy(b2))
230+
222231
@test value == f_primal
223-
@test gradient[2] dA2 atol = 5e-5
232+
@test gradient[2] dA2
224233
@test gradient[3] db12
225234
@test gradient[4] db22
226235

227-
function f4(A, b1, b2; alg=LUFactorization())
236+
function f_4(A, b1, b2; alg=LUFactorization())
228237
prob = LinearProblem(A, b1)
229238
cache = init(prob, alg)
230239
solve!(cache)
@@ -238,17 +247,17 @@ end
238247
A = rand(n, n);
239248
b1 = rand(n);
240249
b2 = rand(n);
241-
# f_primal = f4(copy(A), copy(b1), copy(b2))
250+
f_primal = f_4(copy(A), copy(b1), copy(b2))
242251

243-
rule = Mooncake.build_rrule(f4, copy(A), copy(b1), copy(b2))
252+
rule = Mooncake.build_rrule(f_4, copy(A), copy(b1), copy(b2))
244253
@test_throws "Adjoint case currently not handled" Mooncake.value_and_pullback!!(
245254
rule, 1.0,
246-
f4, copy(A), copy(b1), copy(b2)
255+
f_4, copy(A), copy(b1), copy(b2)
247256
)
248257

249-
# dA2 = ForwardDiff.gradient(x -> f4(x, eltype(x).(b1), eltype(x).(b2)), copy(A))
250-
# db12 = ForwardDiff.gradient(x -> f4(eltype(x).(A), x, eltype(x).(b2)), copy(b1))
251-
# db22 = ForwardDiff.gradient(x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b2))
258+
# dA2 = ForwardDiff.gradient(x -> f_4(x, eltype(x).(b1), eltype(x).(b2)), copy(A))
259+
# db12 = ForwardDiff.gradient(x -> f_4(eltype(x).(A), x, eltype(x).(b2)), copy(b1))
260+
# db22 = ForwardDiff.gradient(x -> f_4(eltype(x).(A), eltype(x).(b1), x), copy(b2))
252261

253262
# @test value == f_primal
254263
# @test grad[2] ≈ dA2

0 commit comments

Comments
 (0)