Skip to content

Commit 77bcf83

Browse files
committed
add clean zeros function for sparse distributions
1 parent 2bcc24a commit 77bcf83

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

ot/bregman.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ def sinkhorn(a,b, M, reg,method='sinkhorn', numItermax = 1000, stopThr=1e-9, ver
108108
return sink()
109109

110110

111+
112+
111113
def sinkhorn_knopp(a,b, M, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):
112114
"""
113115
Solve the entropic regularization optimal transport problem and return the OT matrix

ot/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ def unif(n):
5050
"""
5151
return np.ones((n,))/n
5252

53+
def clean_zeros(a,b,M):
54+
""" Remove all components with zeros weights in a and b
55+
"""
56+
M2=M[a>0,:][:,b>0].copy() # copy force c style matrix (froemd)
57+
a2=a[a>0]
58+
b2=b[b>0]
59+
return a2,b2,M2
5360

5461
def dist(x1,x2=None,metric='sqeuclidean'):
5562
"""Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist

0 commit comments

Comments
 (0)