Skip to content

Commit d82e6eb

Browse files
committed
Fix convolutional_barycenter kernel for non-symmetric images
Add authorship
1 parent 0baf83b commit d82e6eb

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

ot/bregman.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# Titouan Vayer <titouan.vayer@irisa.fr>
1010
# Hicham Janati <hicham.janati@inria.fr>
1111
# Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>
12+
# Alexander Tong <alexander.tong@yale.edu>
1213
#
1314
# License: MIT License
1415

@@ -1346,12 +1347,17 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
13461347
err = 1
13471348

13481349
# build the convolution operator
1350+
# this is equivalent to blurring on horizontal then vertical directions
13491351
t = np.linspace(0, 1, A.shape[1])
13501352
[Y, X] = np.meshgrid(t, t)
13511353
xi1 = np.exp(-(X - Y)**2 / reg)
13521354

1355+
t = np.linspace(0, 1, A.shape[2])
1356+
[Y, X] = np.meshgrid(t, t)
1357+
xi2 = np.exp(-(X - Y)**2 / reg)
1358+
13531359
def K(x):
1354-
return np.dot(np.dot(xi1, x), xi1)
1360+
return np.dot(np.dot(xi1, x), xi2)
13551361

13561362
while (err > stopThr and cpt < numItermax):
13571363

test/test_bregman.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,3 +351,10 @@ def test_screenkhorn():
351351
# check marginals
352352
np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
353353
np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02)
354+
355+
356+
def test_convolutional_barycenter_non_square():
357+
# test for image with height not equal width
358+
A = np.ones((2, 2, 3)) / (2 * 3)
359+
b = ot.bregman.convolutional_barycenter2d(A, 1e-03)
360+
np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)

0 commit comments

Comments
 (0)