@@ -1572,13 +1572,16 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
15721572 nbclasses = len (np .unique (Ys [0 ]))
15731573 nbdomains = len (Xs )
15741574
1575- # For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2
1576- all_domains = []
1577-
15781575 # log dictionary
15791576 if log :
1580- log = {'niter' : 0 , 'err' : [], 'all_domains' : []}
1577+ log = {'niter' : 0 , 'err' : [], 'M' : [], 'D1' : [], 'D2' : []}
1578+
1579+ K = []
1580+ M = []
1581+ D1 = []
1582+ D2 = []
15811583
1584+ # For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2
15821585 for d in range (nbdomains ):
15831586 dom = {}
15841587 nsk = Xs [d ].shape [0 ] # get number of elements for this domain
@@ -1591,28 +1594,26 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
15911594 classes = np .unique (Ys [d ])
15921595
15931596 # build the corresponding D_1 and D_2 matrices
1594- D1 = np .zeros ((nbclasses , nsk ))
1595- D2 = np .zeros ((nbclasses , nsk ))
1597+ Dtmp1 = np .zeros ((nbclasses , nsk ))
1598+ Dtmp2 = np .zeros ((nbclasses , nsk ))
15961599
15971600 for c in classes :
15981601 nbelemperclass = np .sum (Ys [d ] == c )
15991602 if nbelemperclass != 0 :
1600- D1 [int (c ), Ys [d ] == c ] = 1.
1601- D2 [int (c ), Ys [d ] == c ] = 1. / (nbelemperclass )
1602- dom [ 'D1' ] = D1
1603- dom [ 'D2' ] = D2
1603+ Dtmp1 [int (c ), Ys [d ] == c ] = 1.
1604+ Dtmp2 [int (c ), Ys [d ] == c ] = 1. / (nbelemperclass )
1605+ D1 . append ( Dtmp1 )
1606+ D2 . append ( Dtmp2 )
16041607
16051608 # build the cost matrix and the Gibbs kernel
1606- M = dist (Xs [d ], Xt , metric = metric )
1607- M = M / np .median (M )
1608- dom ['M' ] = M
1609-
1610- K = np .empty (M .shape , dtype = M .dtype )
1611- np .divide (M , - reg , out = K )
1612- np .exp (K , out = K )
1613- dom ['K' ] = K
1609+ Mtmp = dist (Xs [d ], Xt , metric = metric )
1610+ Mtmp = Mtmp / np .median (Mtmp )
1611+ M .append (M )
16141612
1615- all_domains .append (dom )
1613+ Ktmp = np .empty (Mtmp .shape , dtype = Mtmp .dtype )
1614+ np .divide (Mtmp , - reg , out = Ktmp )
1615+ np .exp (Ktmp , out = Ktmp )
1616+ K .append (Ktmp )
16161617
16171618 # uniform target distribution
16181619 a = unif (np .shape (Xt )[0 ])
@@ -1627,16 +1628,16 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
16271628
16281629 # update coupling matrices for marginal constraints w.r.t. uniform target distribution
16291630 for d in range (nbdomains ):
1630- all_domains [d ][ 'K' ] = projC (all_domains [ d ][ 'K' ], a )
1631- other = np .sum (all_domains [ d ][ 'K' ], axis = 1 )
1632- bary = bary + np .log (np .dot (all_domains [ d ][ 'D1' ], other )) / nbdomains
1631+ K [d ] = projC (K [ d ], a )
1632+ other = np .sum (K [ d ], axis = 1 )
1633+ bary = bary + np .log (np .dot (D1 [ d ], other )) / nbdomains
16331634
16341635 bary = np .exp (bary )
16351636
16361637 # update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27]
16371638 for d in range (nbdomains ):
1638- new = np .dot (all_domains [ d ][ 'D2' ].T , bary )
1639- all_domains [d ][ 'K' ] = projR (all_domains [ d ][ 'K' ], new )
1639+ new = np .dot (D2 [ d ].T , bary )
1640+ K [d ] = projR (K [ d ], new )
16401641
16411642 err = np .linalg .norm (bary - old_bary )
16421643 cpt = cpt + 1
@@ -1651,14 +1652,15 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
16511652 print ('{:5d}|{:8e}|' .format (cpt , err ))
16521653
16531654 bary = bary / np .sum (bary )
1654- couplings = [all_domains [d ]['K' ] for d in range (nbdomains )]
16551655
16561656 if log :
16571657 log ['niter' ] = cpt
1658- log ['all_domains' ] = all_domains
1659- return couplings , bary , log
1658+ log ['M' ] = M
1659+ log ['D1' ] = D1
1660+ log ['D2' ] = D2
1661+ return K , bary , log
16601662 else :
1661- return couplings , bary
1663+ return K , bary
16621664
16631665
16641666def empirical_sinkhorn (X_s , X_t , reg , a = None , b = None , metric = 'sqeuclidean' ,
0 commit comments