@@ -155,7 +155,6 @@ for alg in (
155155end
156156
157157# Tests for solve! and init rrules.
158-
159158n = 4
160159A = rand (n, n);
161160b1 = rand (n);
@@ -205,8 +204,7 @@ value, gradient = Mooncake.value_and_gradient!!(
205204@test gradient[3 ] ≈ db12
206205@test gradient[4 ] ≈ db22
207206
208- function f3 (A, b1, b2; alg= LUFactorization ())
209- # alg = KrylovJL_GMRES())
207+ function f3 (A, b1, b2; alg= KrylovJL_GMRES ())
210208 prob = LinearProblem (A, b1)
211209 cache = init (prob, alg)
212210 s1 = copy (solve! (cache). u)
@@ -226,37 +224,6 @@ value, gradient = Mooncake.value_and_gradient!!(
226224@test gradient[3 ] ≈ db12
227225@test gradient[4 ] ≈ db22
228226
229- A = rand (n, n);
230- b1 = rand (n);
231-
232- function fnice (A, b, alg)
233- prob = LinearProblem (A, b)
234- sol1 = solve (prob, alg)
235- return sum (sol1. u)
236- end
237-
238- @testset for alg in (
239- LUFactorization (),
240- RFLUFactorization (),
241- KrylovJL_GMRES ()
242- )
243- # for B
244- fb_closure = b -> fnice (A, b, alg)
245- fd_jac_b = FiniteDiff. finite_difference_jacobian (fb_closure, b1) |> vec
246-
247- val, en_jac = Mooncake. value_and_gradient!! (
248- prepare_gradient_cache (fnice, copy (A), copy (b1), alg),
249- fnice, copy (A), copy (b1), alg
250- )
251- @test en_jac[3 ] ≈ fd_jac_b rtol = 1e-5
252-
253- # For A
254- fA_closure = A -> fnice (A, b1, alg)
255- fd_jac_A = FiniteDiff. finite_difference_jacobian (fA_closure, A) |> vec
256- A_grad = en_jac[2 ] |> vec
257- @test A_grad ≈ fd_jac_A rtol = 1e-5
258- end
259-
260227function f4 (A, b1, b2; alg= LUFactorization ())
261228 prob = LinearProblem (A, b1)
262229 cache = init (prob, alg)
@@ -286,4 +253,35 @@ rule = Mooncake.build_rrule(f4, copy(A), copy(b1), copy(b2))
286253# @test value == f_primal
287254# @test grad[2] ≈ dA2
288255# @test grad[3] ≈ db12
289- # @test grad[4] ≈ db22
256+ # @test grad[4] ≈ db22
257+
258+ A = rand (n, n);
259+ b1 = rand (n);
260+
261+ function fnice (A, b, alg)
262+ prob = LinearProblem (A, b)
263+ sol1 = solve (prob, alg)
264+ return sum (sol1. u)
265+ end
266+
267+ @testset for alg in (
268+ LUFactorization (),
269+ RFLUFactorization (),
270+ KrylovJL_GMRES ()
271+ )
272+ # for B
273+ fb_closure = b -> fnice (A, b, alg)
274+ fd_jac_b = FiniteDiff. finite_difference_jacobian (fb_closure, b1) |> vec
275+
276+ val, en_jac = Mooncake. value_and_gradient!! (
277+ prepare_gradient_cache (fnice, copy (A), copy (b1), alg),
278+ fnice, copy (A), copy (b1), alg
279+ )
280+ @test en_jac[3 ] ≈ fd_jac_b rtol = 1e-5
281+
282+ # For A
283+ fA_closure = A -> fnice (A, b1, alg)
284+ fd_jac_A = FiniteDiff. finite_difference_jacobian (fA_closure, A) |> vec
285+ A_grad = en_jac[2 ] |> vec
286+ @test A_grad ≈ fd_jac_A rtol = 1e-5
287+ end
0 commit comments