@@ -11,9 +11,7 @@ b1 = rand(n);
1111
1212function f (A, b1; alg = LUFactorization ())
1313 prob = LinearProblem (A, b1)
14-
1514 sol1 = solve (prob, alg)
16-
1715 s1 = sol1. u
1816 norm (s1)
1917end
@@ -160,7 +158,7 @@ A = rand(n, n);
160158b1 = rand (n);
161159b2 = rand (n);
162160
163- function f (A, b1, b2; alg= LUFactorization ())
161+ function f_ (A, b1, b2; alg= LUFactorization ())
164162 prob = LinearProblem (A, b1)
165163 cache = init (prob, alg)
166164 s1 = copy (solve! (cache). u)
@@ -169,22 +167,23 @@ function f(A, b1, b2; alg=LUFactorization())
169167 norm (s1 + s2)
170168end
171169
172- f_primal = f (copy (A), copy (b1), copy (b2))
173- value, gradient = Mooncake. value_and_gradient!! (
174- prepare_gradient_cache (f, copy (A), copy (b1), copy (b2)),
175- f, copy (A), copy (b1), copy (b2)
170+ f_primal = f_ (copy (A), copy (b1), copy (b2))
171+ rule = Mooncake. build_rrule (f_, copy (A), copy (b1), copy (b2))
172+ value, gradient = Mooncake. value_and_pullback!! (
173+ rule, 1.0 ,
174+ f_, copy (A), copy (b1), copy (b2)
176175)
177176
178- dA2 = ForwardDiff. gradient (x -> f (x, eltype (x).(b1), eltype (x).(b2)), copy (A))
179- db12 = ForwardDiff. gradient (x -> f (eltype (x).(A), x, eltype (x).(b2)), copy (b1))
180- db22 = ForwardDiff. gradient (x -> f (eltype (x).(A), eltype (x).(b1), x), copy (b2))
177+ dA2 = ForwardDiff. gradient (x -> f_ (x, eltype (x).(b1), eltype (x).(b2)), copy (A))
178+ db12 = ForwardDiff. gradient (x -> f_ (eltype (x).(A), x, eltype (x).(b2)), copy (b1))
179+ db22 = ForwardDiff. gradient (x -> f_ (eltype (x).(A), eltype (x).(b1), x), copy (b2))
181180
182181@test value == f_primal
183182@test gradient[2 ] ≈ dA2
184183@test gradient[3 ] ≈ db12
185184@test gradient[4 ] ≈ db22
186185
187- function f2 (A, b1, b2; alg= RFLUFactorization ())
186+ function f_2 (A, b1, b2; alg= RFLUFactorization ())
188187 prob = LinearProblem (A, b1)
189188 cache = init (prob, alg)
190189 s1 = copy (solve! (cache). u)
@@ -193,18 +192,23 @@ function f2(A, b1, b2; alg=RFLUFactorization())
193192 norm (s1 + s2)
194193end
195194
196- f_primal = f2 (copy (A), copy (b1), copy (b2))
197- value, gradient = Mooncake. value_and_gradient!! (
198- prepare_gradient_cache (f2, copy (A), copy (b1), copy (b2)),
199- f2, copy (A), copy (b1), copy (b2)
195+ f_primal = f_2 (copy (A), copy (b1), copy (b2))
196+ rule = Mooncake. build_rrule (f_2, copy (A), copy (b1), copy (b2))
197+ value, gradient = Mooncake. value_and_pullback!! (
198+ rule, 1.0 ,
199+ f_2, copy (A), copy (b1), copy (b2)
200200)
201201
202+ dA2 = ForwardDiff. gradient (x -> f_2 (x, eltype (x).(b1), eltype (x).(b2)), copy (A))
203+ db12 = ForwardDiff. gradient (x -> f_2 (eltype (x).(A), x, eltype (x).(b2)), copy (b1))
204+ db22 = ForwardDiff. gradient (x -> f_2 (eltype (x).(A), eltype (x).(b1), x), copy (b2))
205+
202206@test value == f_primal
203207@test gradient[2 ] ≈ dA2
204208@test gradient[3 ] ≈ db12
205209@test gradient[4 ] ≈ db22
206210
207- function f3 (A, b1, b2; alg= KrylovJL_GMRES ())
211+ function f_3 (A, b1, b2; alg= KrylovJL_GMRES ())
208212 prob = LinearProblem (A, b1)
209213 cache = init (prob, alg)
210214 s1 = copy (solve! (cache). u)
@@ -213,18 +217,23 @@ function f3(A, b1, b2; alg=KrylovJL_GMRES())
213217 norm (s1 + s2)
214218end
215219
216- f_primal = f3 (copy (A), copy (b1), copy (b2))
217- value, gradient = Mooncake. value_and_gradient!! (
218- prepare_gradient_cache (f3, copy (A), copy (b1), copy (b2)),
219- f3, copy (A), copy (b1), copy (b2)
220+ f_primal = f_3 (copy (A), copy (b1), copy (b2))
221+ rule = Mooncake. build_rrule (f_3, copy (A), copy (b1), copy (b2))
222+ value, gradient = Mooncake. value_and_pullback!! (
223+ rule, 1.0 ,
224+ f_3, copy (A), copy (b1), copy (b2)
220225)
221226
227+ dA2 = ForwardDiff. gradient (x -> f_3 (x, eltype (x).(b1), eltype (x).(b2)), copy (A))
228+ db12 = ForwardDiff. gradient (x -> f_3 (eltype (x).(A), x, eltype (x).(b2)), copy (b1))
229+ db22 = ForwardDiff. gradient (x -> f_3 (eltype (x).(A), eltype (x).(b1), x), copy (b2))
230+
222231@test value == f_primal
223- @test gradient[2 ] ≈ dA2 atol = 5e-5
232+ @test gradient[2 ] ≈ dA2
224233@test gradient[3 ] ≈ db12
225234@test gradient[4 ] ≈ db22
226235
227- function f4 (A, b1, b2; alg= LUFactorization ())
236+ function f_4 (A, b1, b2; alg= LUFactorization ())
228237 prob = LinearProblem (A, b1)
229238 cache = init (prob, alg)
230239 solve! (cache)
@@ -238,17 +247,17 @@ end
238247A = rand (n, n);
239248b1 = rand (n);
240249b2 = rand (n);
241- # f_primal = f4 (copy(A), copy(b1), copy(b2))
250+ f_primal = f_4 (copy (A), copy (b1), copy (b2))
242251
243- rule = Mooncake. build_rrule (f4 , copy (A), copy (b1), copy (b2))
252+ rule = Mooncake. build_rrule (f_4 , copy (A), copy (b1), copy (b2))
244253@test_throws " Adjoint case currently not handled" Mooncake. value_and_pullback!! (
245254 rule, 1.0 ,
246- f4 , copy (A), copy (b1), copy (b2)
255+ f_4 , copy (A), copy (b1), copy (b2)
247256)
248257
249- # dA2 = ForwardDiff.gradient(x -> f4 (x, eltype(x).(b1), eltype(x).(b2)), copy(A))
250- # db12 = ForwardDiff.gradient(x -> f4 (eltype(x).(A), x, eltype(x).(b2)), copy(b1))
251- # db22 = ForwardDiff.gradient(x -> f4 (eltype(x).(A), eltype(x).(b1), x), copy(b2))
258+ # dA2 = ForwardDiff.gradient(x -> f_4 (x, eltype(x).(b1), eltype(x).(b2)), copy(A))
259+ # db12 = ForwardDiff.gradient(x -> f_4 (eltype(x).(A), x, eltype(x).(b2)), copy(b1))
260+ # db22 = ForwardDiff.gradient(x -> f_4 (eltype(x).(A), eltype(x).(b1), x), copy(b2))
252261
253262# @test value == f_primal
254263# @test grad[2] ≈ dA2
0 commit comments