Skip to content

Commit 79fbb05

Browse files
some code formatting
1 parent 80b0cb6 commit 79fbb05

File tree

1 file changed

+33
-35
lines changed

1 file changed

+33
-35
lines changed

test/nopre/mooncake.jl

Lines changed: 33 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,6 @@ for alg in (
155155
end
156156

157157
# Tests for solve! and init rrules.
158-
159158
n = 4
160159
A = rand(n, n);
161160
b1 = rand(n);
@@ -205,8 +204,7 @@ value, gradient = Mooncake.value_and_gradient!!(
205204
@test gradient[3] db12
206205
@test gradient[4] db22
207206

208-
function f3(A, b1, b2; alg=LUFactorization())
209-
# alg = KrylovJL_GMRES())
207+
function f3(A, b1, b2; alg=KrylovJL_GMRES())
210208
prob = LinearProblem(A, b1)
211209
cache = init(prob, alg)
212210
s1 = copy(solve!(cache).u)
@@ -226,37 +224,6 @@ value, gradient = Mooncake.value_and_gradient!!(
226224
@test gradient[3] db12
227225
@test gradient[4] db22
228226

229-
A = rand(n, n);
230-
b1 = rand(n);
231-
232-
function fnice(A, b, alg)
233-
prob = LinearProblem(A, b)
234-
sol1 = solve(prob, alg)
235-
return sum(sol1.u)
236-
end
237-
238-
@testset for alg in (
239-
LUFactorization(),
240-
RFLUFactorization(),
241-
KrylovJL_GMRES()
242-
)
243-
# for B
244-
fb_closure = b -> fnice(A, b, alg)
245-
fd_jac_b = FiniteDiff.finite_difference_jacobian(fb_closure, b1) |> vec
246-
247-
val, en_jac = Mooncake.value_and_gradient!!(
248-
prepare_gradient_cache(fnice, copy(A), copy(b1), alg),
249-
fnice, copy(A), copy(b1), alg
250-
)
251-
@test en_jac[3] fd_jac_b rtol = 1e-5
252-
253-
# For A
254-
fA_closure = A -> fnice(A, b1, alg)
255-
fd_jac_A = FiniteDiff.finite_difference_jacobian(fA_closure, A) |> vec
256-
A_grad = en_jac[2] |> vec
257-
@test A_grad fd_jac_A rtol = 1e-5
258-
end
259-
260227
function f4(A, b1, b2; alg=LUFactorization())
261228
prob = LinearProblem(A, b1)
262229
cache = init(prob, alg)
@@ -286,4 +253,35 @@ rule = Mooncake.build_rrule(f4, copy(A), copy(b1), copy(b2))
286253
# @test value == f_primal
287254
# @test grad[2] ≈ dA2
288255
# @test grad[3] ≈ db12
289-
# @test grad[4] ≈ db22
256+
# @test grad[4] ≈ db22
257+
258+
A = rand(n, n);
259+
b1 = rand(n);
260+
261+
function fnice(A, b, alg)
262+
prob = LinearProblem(A, b)
263+
sol1 = solve(prob, alg)
264+
return sum(sol1.u)
265+
end
266+
267+
@testset for alg in (
268+
LUFactorization(),
269+
RFLUFactorization(),
270+
KrylovJL_GMRES()
271+
)
272+
# for B
273+
fb_closure = b -> fnice(A, b, alg)
274+
fd_jac_b = FiniteDiff.finite_difference_jacobian(fb_closure, b1) |> vec
275+
276+
val, en_jac = Mooncake.value_and_gradient!!(
277+
prepare_gradient_cache(fnice, copy(A), copy(b1), alg),
278+
fnice, copy(A), copy(b1), alg
279+
)
280+
@test en_jac[3] fd_jac_b rtol = 1e-5
281+
282+
# For A
283+
fA_closure = A -> fnice(A, b1, alg)
284+
fd_jac_A = FiniteDiff.finite_difference_jacobian(fA_closure, A) |> vec
285+
A_grad = en_jac[2] |> vec
286+
@test A_grad fd_jac_A rtol = 1e-5
287+
end

0 commit comments

Comments
 (0)