Skip to content

Commit fff675e

Browse files
committed
so much faster
1 parent ef96f7a commit fff675e

File tree

4 files changed

+42
-10
lines changed

4 files changed

+42
-10
lines changed

benchmark/benchmark.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import time
2+
3+
import torch
4+
from torch_cluster import sparse_grid_cluster
5+
6+
n = 90000000
7+
s = 1 / 64
8+
9+
print('GPU ===================')
10+
11+
t = time.perf_counter()
12+
pos = torch.cuda.FloatTensor(n, 3).uniform_(0, 1)
13+
size = torch.cuda.FloatTensor([s, s, s])
14+
torch.cuda.synchronize()
15+
print('Init:', time.perf_counter() - t)
16+
17+
t_all = time.perf_counter()
18+
sparse_grid_cluster(pos, size)
19+
torch.cuda.synchronize()
20+
t_all = time.perf_counter() - t_all
21+
print('All:', t_all)
22+
23+
print('CPU ===================')
24+
25+
pos = pos.cpu()
26+
size = size.cpu()
27+
28+
t_all = time.perf_counter()
29+
sparse_grid_cluster(pos, size)
30+
t_all = time.perf_counter() - t_all
31+
print('All:', t_all)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from setuptools import setup, find_packages
44

5-
__version__ = '0.2.3'
5+
__version__ = '0.2.4'
66
url = 'https://github.com/rusty1s/pytorch_cluster'
77

88
install_requires = ['cffi', 'torch-unique']

torch_cluster/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .functions.grid import sparse_grid_cluster, dense_grid_cluster
22

3-
__version__ = '0.2.3'
3+
__version__ = '0.2.4'
44

55
__all__ = ['sparse_grid_cluster', 'dense_grid_cluster', '__version__']

torch_cluster/functions/grid.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@ def _preprocess(position, size, batch=None, start=None):
1818

1919
# Translate to minimal positive positions if no start was passed.
2020
if start is None:
21-
position = position - position.min(dim=-2, keepdim=True)[0]
22-
else:
21+
min = []
22+
for i in range(position.size(-1)):
23+
min.append(position[:, i].min())
24+
position = position - position.new(min)
25+
elif start != 0:
2326
position = position - start
24-
assert position.min() >= 0, (
25-
'Passed origin resulting in unallowed negative positions')
2627

2728
# If given, append batch to position tensor.
2829
if batch is not None:
@@ -37,10 +38,10 @@ def _preprocess(position, size, batch=None, start=None):
3738

3839

3940
def _minimal_cluster_size(position, size):
40-
max = position.max(dim=0)[0]
41-
while max.dim() > 1:
42-
max = max.max(dim=0)[0]
43-
cluster_size = (max / size).long() + 1
41+
max = []
42+
for i in range(position.size(-1)):
43+
max.append(position[:, i].max())
44+
cluster_size = (size.new(max) / size).long() + 1
4445
return cluster_size
4546

4647

0 commit comments

Comments
 (0)