@@ -181,19 +181,24 @@ def test_gromov2_gradients():
181181
182182 if torch :
183183
184- p1 = torch .tensor (p , requires_grad = True )
185- q1 = torch .tensor (q , requires_grad = True )
186- C11 = torch .tensor (C1 , requires_grad = True )
187- C12 = torch .tensor (C2 , requires_grad = True )
184+ devices = [torch .device ("cpu" )]
185+ if torch .cuda .is_available ():
186+ devices .append (torch .device ("cuda" ))
187+ for device in devices :
188+ p1 = torch .tensor (p , requires_grad = True , device = device )
189+ q1 = torch .tensor (q , requires_grad = True , device = device )
190+ C11 = torch .tensor (C1 , requires_grad = True , device = device )
191+ C12 = torch .tensor (C2 , requires_grad = True , device = device )
188192
189- val = ot .gromov_wasserstein2 (C11 , C12 , p1 , q1 )
193+ val = ot .gromov_wasserstein2 (C11 , C12 , p1 , q1 )
190194
191- val .backward ()
195+ val .backward ()
192196
193- assert q1 .shape == q1 .grad .shape
194- assert p1 .shape == p1 .grad .shape
195- assert C11 .shape == C11 .grad .shape
196- assert C12 .shape == C12 .grad .shape
197+ assert val .device == p1 .device
198+ assert q1 .shape == q1 .grad .shape
199+ assert p1 .shape == p1 .grad .shape
200+ assert C11 .shape == C11 .grad .shape
201+ assert C12 .shape == C12 .grad .shape
197202
198203
199204@pytest .skip_backend ("jax" , reason = "test very slow with jax backend" )
@@ -636,21 +641,26 @@ def test_fgw2_gradients():
636641
637642 if torch :
638643
639- p1 = torch .tensor (p , requires_grad = True )
640- q1 = torch .tensor (q , requires_grad = True )
641- C11 = torch .tensor (C1 , requires_grad = True )
642- C12 = torch .tensor (C2 , requires_grad = True )
643- M1 = torch .tensor (M , requires_grad = True )
644-
645- val = ot .fused_gromov_wasserstein2 (M1 , C11 , C12 , p1 , q1 )
646-
647- val .backward ()
648-
649- assert q1 .shape == q1 .grad .shape
650- assert p1 .shape == p1 .grad .shape
651- assert C11 .shape == C11 .grad .shape
652- assert C12 .shape == C12 .grad .shape
653- assert M1 .shape == M1 .grad .shape
644+ devices = [torch .device ("cpu" )]
645+ if torch .cuda .is_available ():
646+ devices .append (torch .device ("cuda" ))
647+ for device in devices :
648+ p1 = torch .tensor (p , requires_grad = True , device = device )
649+ q1 = torch .tensor (q , requires_grad = True , device = device )
650+ C11 = torch .tensor (C1 , requires_grad = True , device = device )
651+ C12 = torch .tensor (C2 , requires_grad = True , device = device )
652+ M1 = torch .tensor (M , requires_grad = True , device = device )
653+
654+ val = ot .fused_gromov_wasserstein2 (M1 , C11 , C12 , p1 , q1 )
655+
656+ val .backward ()
657+
658+ assert val .device == p1 .device
659+ assert q1 .shape == q1 .grad .shape
660+ assert p1 .shape == p1 .grad .shape
661+ assert C11 .shape == C11 .grad .shape
662+ assert C12 .shape == C12 .grad .shape
663+ assert M1 .shape == M1 .grad .shape
654664
655665
656666def test_fgw_barycenter (nx ):
0 commit comments