Skip to content

Commit 73e6154

Browse files
committed
Remove dependency sklearn
1 parent cb6bdc5 commit 73e6154

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

ot/utils.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
import numpy as np
1515
from scipy.spatial.distance import cdist
16-
from sklearn.metrics.pairwise import euclidean_distances
1716
import sys
1817
import warnings
1918
try:
@@ -77,6 +76,33 @@ def clean_zeros(a, b, M):
7776
b2 = b[b > 0]
7877
return a2, b2, M2
7978

79+
def euclidean_distances(X, Y, squared=False):
80+
"""
81+
Considering the rows of X (and Y=X) as vectors, compute the
82+
distance matrix between each pair of vectors.
83+
Parameters
84+
----------
85+
X : {array-like}, shape (n_samples_1, n_features)
86+
Y : {array-like}, shape (n_samples_2, n_features)
87+
squared : boolean, optional
88+
Return squared Euclidean distances.
89+
Returns
90+
-------
91+
distances : {array}, shape (n_samples_1, n_samples_2)
92+
"""
93+
XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis]
94+
YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :]
95+
distances = np.dot(X, Y.T)
96+
distances *= -2
97+
distances += XX
98+
distances += YY
99+
np.maximum(distances, 0, out=distances)
100+
if X is Y:
101+
# Ensure that distances between vectors and themselves are set to 0.0.
102+
# This may not be the case due to floating point rounding errors.
103+
distances.flat[::distances.shape[0] + 1] = 0.0
104+
return distances if squared else np.sqrt(distances, out=distances)
105+
80106

81107
def dist(x1, x2=None, metric='sqeuclidean'):
82108
"""Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist

0 commit comments

Comments
 (0)