Skip to content

Commit 13dabd4

Browse files
committed
graclus weight gpu bugfix
1 parent 7985cdd commit 13dabd4

File tree

5 files changed

+6
-6
lines changed

5 files changed

+6
-6
lines changed

aten/THC/THCNumerics.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616
template<typename T>
1717
struct THCNumerics {
1818
static inline __host__ __device__ T div(T a, T b) { return a / b; }
19-
static inline __host__ __device__ bool gt(T a, T b) { return a > b; }
19+
static inline __host__ __device__ bool gte(T a, T b) { return a >= b; }
2020
};
2121

2222
#ifdef CUDA_HALF_TENSOR
2323
template<>
2424
struct THCNumerics<half> {
2525
static inline __host__ __device__ half div(half a, half b) { return f2h(h2f(a) / h2f(b)); }
26-
static inline __host__ __device__ bool gt(half a, half b) { return h2f(a) > h2f(b); }
26+
static inline __host__ __device__ bool gte(half a, half b) { return h2f(a) >= h2f(b); }
2727
};
2828
#endif // CUDA_HALF_TENSOR
2929

aten/THC/THCPropose.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ __global__ void weightedProposeKernel(int64_t *color, int64_t *prop, int64_t *ro
3232
tmp = weight[e];
3333
if (isDead && color[c] < 0) { isDead = false; } // Unmatched neighbor found.
3434
// Find maximum weighted red neighbor.
35-
if (color[c] == -2 && THCNumerics<T>::gt(tmp, maxWeight)) {
35+
if (color[c] == -2 && THCNumerics<T>::gte(tmp, maxWeight)) {
3636
matchedValue = c;
3737
maxWeight = tmp;
3838
}

aten/THC/THCResponse.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ __global__ void weightedResponseKernel(int64_t *color, int64_t *prop, int64_t *r
3535
tmp = weight[e];
3636
if (isDead && color[c] < 0) { isDead = false; } // Unmatched neighbor found.
3737
// Find maximum weighted blue neighbor, who proposed to i.
38-
if (color[c] == -1 && prop[c] == i && THCNumerics<T>::gt(tmp, maxWeight)) {
38+
if (color[c] == -1 && prop[c] == i && THCNumerics<T>::gte(tmp, maxWeight)) {
3939
matchedValue = c;
4040
maxWeight = tmp;
4141
}

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from setuptools import setup, find_packages
44

5-
__version__ = '1.0.2'
5+
__version__ = '1.0.3'
66
url = 'https://github.com/rusty1s/pytorch_cluster'
77

88
install_requires = ['cffi']

torch_cluster/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .graclus import graclus_cluster
22
from .grid import grid_cluster
33

4-
__version__ = '1.0.2'
4+
__version__ = '1.0.3'
55

66
__all__ = ['graclus_cluster', 'grid_cluster', '__version__']

0 commit comments

Comments
 (0)