Skip to content

Commit d52a78d

Browse files
author
ievred
committed
pep bregman
1 parent ed34704 commit d52a78d

File tree

1 file changed

+30
-28
lines changed

1 file changed

+30
-28
lines changed

ot/bregman.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,13 +1572,16 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
15721572
nbclasses = len(np.unique(Ys[0]))
15731573
nbdomains = len(Xs)
15741574

1575-
# For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2
1576-
all_domains = []
1577-
15781575
# log dictionary
15791576
if log:
1580-
log = {'niter': 0, 'err': [], 'all_domains': []}
1577+
log = {'niter': 0, 'err': [], 'M': [], 'D1': [], 'D2': []}
1578+
1579+
K = []
1580+
M = []
1581+
D1 = []
1582+
D2 = []
15811583

1584+
# For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2
15821585
for d in range(nbdomains):
15831586
dom = {}
15841587
nsk = Xs[d].shape[0] # get number of elements for this domain
@@ -1591,28 +1594,26 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
15911594
classes = np.unique(Ys[d])
15921595

15931596
# build the corresponding D_1 and D_2 matrices
1594-
D1 = np.zeros((nbclasses, nsk))
1595-
D2 = np.zeros((nbclasses, nsk))
1597+
Dtmp1 = np.zeros((nbclasses, nsk))
1598+
Dtmp2 = np.zeros((nbclasses, nsk))
15961599

15971600
for c in classes:
15981601
nbelemperclass = np.sum(Ys[d] == c)
15991602
if nbelemperclass != 0:
1600-
D1[int(c), Ys[d] == c] = 1.
1601-
D2[int(c), Ys[d] == c] = 1. / (nbelemperclass)
1602-
dom['D1'] = D1
1603-
dom['D2'] = D2
1603+
Dtmp1[int(c), Ys[d] == c] = 1.
1604+
Dtmp2[int(c), Ys[d] == c] = 1. / (nbelemperclass)
1605+
D1.append(Dtmp1)
1606+
D2.append(Dtmp2)
16041607

16051608
# build the cost matrix and the Gibbs kernel
1606-
M = dist(Xs[d], Xt, metric=metric)
1607-
M = M / np.median(M)
1608-
dom['M'] = M
1609-
1610-
K = np.empty(M.shape, dtype=M.dtype)
1611-
np.divide(M, -reg, out=K)
1612-
np.exp(K, out=K)
1613-
dom['K'] = K
1609+
Mtmp = dist(Xs[d], Xt, metric=metric)
1610+
Mtmp = Mtmp / np.median(Mtmp)
1611+
M.append(M)
16141612

1615-
all_domains.append(dom)
1613+
Ktmp = np.empty(Mtmp.shape, dtype=Mtmp.dtype)
1614+
np.divide(Mtmp, -reg, out=Ktmp)
1615+
np.exp(Ktmp, out=Ktmp)
1616+
K.append(Ktmp)
16161617

16171618
# uniform target distribution
16181619
a = unif(np.shape(Xt)[0])
@@ -1627,16 +1628,16 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
16271628

16281629
# update coupling matrices for marginal constraints w.r.t. uniform target distribution
16291630
for d in range(nbdomains):
1630-
all_domains[d]['K'] = projC(all_domains[d]['K'], a)
1631-
other = np.sum(all_domains[d]['K'], axis=1)
1632-
bary = bary + np.log(np.dot(all_domains[d]['D1'], other)) / nbdomains
1631+
K[d] = projC(K[d], a)
1632+
other = np.sum(K[d], axis=1)
1633+
bary = bary + np.log(np.dot(D1[d], other)) / nbdomains
16331634

16341635
bary = np.exp(bary)
16351636

16361637
# update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27]
16371638
for d in range(nbdomains):
1638-
new = np.dot(all_domains[d]['D2'].T, bary)
1639-
all_domains[d]['K'] = projR(all_domains[d]['K'], new)
1639+
new = np.dot(D2[d].T, bary)
1640+
K[d] = projR(K[d], new)
16401641

16411642
err = np.linalg.norm(bary - old_bary)
16421643
cpt = cpt + 1
@@ -1651,14 +1652,15 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
16511652
print('{:5d}|{:8e}|'.format(cpt, err))
16521653

16531654
bary = bary / np.sum(bary)
1654-
couplings = [all_domains[d]['K'] for d in range(nbdomains)]
16551655

16561656
if log:
16571657
log['niter'] = cpt
1658-
log['all_domains'] = all_domains
1659-
return couplings, bary, log
1658+
log['M'] = M
1659+
log['D1'] = D1
1660+
log['D2'] = D2
1661+
return K, bary, log
16601662
else:
1661-
return couplings, bary
1663+
return K, bary
16621664

16631665

16641666
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',

0 commit comments

Comments
 (0)