@@ -153,3 +153,179 @@ for alg in (
153153 @test results[1 ] ≈ fA (A)
154154 @test mooncake_gradient ≈ fd_jac rtol = 1e-5
155155end
156+
157+ # Tests for solve! and init rrules.
158+
159+ n = 4
160+ A = rand (n, n);
161+ b1 = rand (n);
162+ b2 = rand (n);
163+
164+ function f (A, b1, b2; alg= LUFactorization ())
165+ prob = LinearProblem (A, b1)
166+ cache = init (prob, alg)
167+ s1 = copy (solve! (cache). u)
168+ cache. b = b2
169+ s2 = solve! (cache). u
170+ norm (s1 + s2)
171+ end
172+
173+ f_primal = f (copy (A), copy (b1), copy (b2))
174+ value, gradient = Mooncake. value_and_gradient!! (
175+ prepare_gradient_cache (f, copy (A), copy (b1), copy (b2)),
176+ f, copy (A), copy (b1), copy (b2)
177+ )
178+
179+ dA2 = ForwardDiff. gradient (x -> f (x, eltype (x).(b1), eltype (x).(b2)), copy (A))
180+ db12 = ForwardDiff. gradient (x -> f (eltype (x).(A), x, eltype (x).(b2)), copy (b1))
181+ db22 = ForwardDiff. gradient (x -> f (eltype (x).(A), eltype (x).(b1), x), copy (b2))
182+
183+ @test value == f_primal
184+ @test gradient[2 ] ≈ dA2
185+ @test gradient[3 ] ≈ db12
186+ @test gradient[4 ] ≈ db22
187+
188+ function f2 (A, b1, b2; alg= RFLUFactorization ())
189+ prob = LinearProblem (A, b1)
190+ cache = init (prob, alg)
191+ s1 = copy (solve! (cache). u)
192+ cache. b = b2
193+ s2 = solve! (cache). u
194+ norm (s1 + s2)
195+ end
196+
197+ f_primal = f2 (copy (A), copy (b1), copy (b2))
198+ value, gradient = Mooncake. value_and_gradient!! (
199+ prepare_gradient_cache (f2, copy (A), copy (b1), copy (b2)),
200+ f2, copy (A), copy (b1), copy (b2)
201+ )
202+
203+ @test value == f_primal
204+ @test gradient[2 ] ≈ dA2
205+ @test gradient[3 ] ≈ db12
206+ @test gradient[4 ] ≈ db22
207+
208+ function f3 (A, b1, b2; alg= LUFactorization ())
209+ # alg = KrylovJL_GMRES())
210+ prob = LinearProblem (A, b1)
211+ cache = init (prob, alg)
212+ s1 = copy (solve! (cache). u)
213+ cache. b = b2
214+ s2 = solve! (cache). u
215+ norm (s1 + s2)
216+ end
217+
218+ f_primal = f3 (copy (A), copy (b1), copy (b2))
219+ value, gradient = Mooncake. value_and_gradient!! (
220+ prepare_gradient_cache (f3, copy (A), copy (b1), copy (b2)),
221+ f3, copy (A), copy (b1), copy (b2)
222+ )
223+
224+ @test value == f_primal
225+ @test gradient[2 ] ≈ dA2 atol = 5e-5
226+ @test gradient[3 ] ≈ db12
227+ @test gradient[4 ] ≈ db22
228+
229+ A = rand (n, n);
230+ b1 = rand (n);
231+
232+ function fnice (A, b, alg)
233+ prob = LinearProblem (A, b)
234+ sol1 = solve (prob, alg)
235+ return sum (sol1. u)
236+ end
237+
238+ @testset for alg in (
239+ LUFactorization (),
240+ RFLUFactorization (),
241+ KrylovJL_GMRES ()
242+ )
243+ # for B
244+ fb_closure = b -> fnice (A, b, alg)
245+ fd_jac_b = FiniteDiff. finite_difference_jacobian (fb_closure, b1) |> vec
246+
247+ val, en_jac = Mooncake. value_and_gradient!! (
248+ prepare_gradient_cache (fnice, copy (A), copy (b1), alg),
249+ fnice, copy (A), copy (b1), alg
250+ )
251+ @test en_jac[3 ] ≈ fd_jac_b rtol = 1e-5
252+
253+ # For A
254+ fA_closure = A -> fnice (A, b1, alg)
255+ fd_jac_A = FiniteDiff. finite_difference_jacobian (fA_closure, A) |> vec
256+ A_grad = en_jac[2 ] |> vec
257+ @test A_grad ≈ fd_jac_A rtol = 1e-5
258+ end
259+
260+ # The below test function cases fails !
261+ # AVOID Adjoint case in code as : `solve!(cache); s1 = copy(cache.u)`.
262+ # Instead stick to code like : `sol = solve!(cache); s1 = copy(sol.u)`.
263+
264+ function f4 (A, b1, b2; alg= LUFactorization ())
265+ prob = LinearProblem (A, b1)
266+ cache = init (prob, alg)
267+ solve! (cache)
268+ s1 = copy (cache. u)
269+ cache. b = b2
270+ solve! (cache)
271+ s2 = copy (cache. u)
272+ norm (s1 + s2)
273+ end
274+
275+ # value, grad = Mooncake.value_and_gradient!!(
276+ # prepare_gradient_cache(f4, copy(A), copy(b1), copy(b2)),
277+ # f4, copy(A), copy(b1), copy(b2)
278+ # )
279+ # (0.0, (Mooncake.NoTangent(), [0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]))
280+
281+ # dA2 = ForwardDiff.gradient(x -> f4(x, eltype(x).(b1), eltype(x).(b2)), copy(A))
282+ # db12 = ForwardDiff.gradient(x -> f4(eltype(x).(A), x, eltype(x).(b2)), copy(b1))
283+ # db22 = ForwardDiff.gradient(x -> f4(eltype(x).(A), eltype(x).(b1), x), copy(b2))
284+
285+ # @test value == f_primal
286+ # @test grad[2] ≈ dA2
287+ # @test grad[3] ≈ db12
288+ # @test grad[4] ≈ db22
289+
290+ function testls (A, b, u)
291+ oa = OperatorAssumptions (
292+ true , condition= LinearSolve. OperatorCondition. WellConditioned)
293+ prob = LinearProblem (A, b)
294+ linsolve = init (prob, LUFactorization (), assumptions= oa)
295+ cache = solve! (linsolve)
296+ sum (cache. u)
297+ end
298+
299+ # A = [1.0 2.0; 3.0 4.0]
300+ # b = [1.0, 2.0]
301+ # u = zero(b)
302+ # value, gradient = Mooncake.value_and_gradient!!(
303+ # prepare_gradient_cache(testls, copy(A), copy(b), copy(u)),
304+ # testls, copy(A), copy(b), copy(u)
305+ # )
306+
307+ # dA = gradient[2]
308+ # db = gradient[3]
309+ # du = gradient[4]
310+
311+ function testls (A, b, u)
312+ oa = OperatorAssumptions (
313+ true , condition= LinearSolve. OperatorCondition. WellConditioned)
314+ prob = LinearProblem (A, b)
315+ linsolve = init (prob, LUFactorization (), assumptions= oa)
316+ solve! (linsolve)
317+ sum (linsolve. u)
318+ end
319+
320+ # value, gradient = Mooncake.value_and_gradient!!(
321+ # prepare_gradient_cache(testls, copy(A), copy(b), copy(u)),
322+ # testls, copy(A), copy(b), copy(u)
323+ # )
324+
325+ # dA2 = gradient[2]
326+ # db2 = gradient[3]
327+ # du2 = gradient[4]
328+
329+ # @test dA == dA2
330+ # @test db == db2
331+ # @test du == du2
0 commit comments