Skip to content

Commit 547759a

Browse files
committed
GPU build
1 parent 4e2e69b commit 547759a

File tree

7 files changed

+10
-9
lines changed

7 files changed

+10
-9
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ install:
5959
- conda install pytorch=${TORCH_VERSION} ${TOOLKIT} -c pytorch --yes
6060
- source script/torch.sh
6161
- pip install flake8 codecov
62-
- python setup.py install
62+
- pip install .[test]
6363
script:
6464
- flake8 .
6565
- python setup.py test

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
cmake_minimum_required(VERSION 3.0)
22
project(torchcluster)
33
set(CMAKE_CXX_STANDARD 14)
4-
set(TORCHCLUSTER_VERSION 1.5.4)
4+
set(TORCHCLUSTER_VERSION 1.5.5)
55

66
option(WITH_CUDA "Enable CUDA support" OFF)
77

csrc/cuda/knn_cuda.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,12 @@ torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y,
101101
CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());
102102

103103
auto dist = torch::full(y.size(0) * k, 1e38, y.options());
104-
auto row = torch::empty(y.size(0) * k, ptr_y.options());
105-
auto col = torch::full(y.size(0) * k, -1, ptr_y.options());
104+
auto row = torch::empty(y.size(0) * k, ptr_y.value().options());
105+
auto col = torch::full(y.size(0) * k, -1, ptr_y.value().options());
106106

107107
auto stream = at::cuda::getCurrentCUDAStream();
108108
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "knn_kernel", [&] {
109-
knn_kernel<scalar_t><<<ptr_x.size(0) - 1, THREADS, 0, stream>>>(
109+
knn_kernel<scalar_t><<<ptr_x.value().size(0) - 1, THREADS, 0, stream>>>(
110110
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
111111
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
112112
dist.data_ptr<scalar_t>(), row.data_ptr<int64_t>(),

csrc/cuda/radius_cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y,
7575

7676
auto stream = at::cuda::getCurrentCUDAStream();
7777
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "radius_kernel", [&] {
78-
radius_kernel<scalar_t><<<ptr_x.size(0) - 1, THREADS, 0, stream>>>(
78+
radius_kernel<scalar_t><<<ptr_x.value().size(0) - 1, THREADS, 0, stream>>>(
7979
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
8080
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
8181
row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), r, max_num_neighbors,

csrc/cuda/radius_cuda.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
#include <torch/extension.h>
44

55
torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y,
6-
torch::optiona<torch::Tensor> ptr_x,
6+
torch::optional<torch::Tensor> ptr_x,
77
torch::optional<torch::Tensor> ptr_y, double r,
88
int64_t max_num_neighbors);

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def get_extensions():
6363

6464
setup(
6565
name='torch_cluster',
66-
version='1.5.4',
66+
version='1.5.5',
6767
author='Matthias Fey',
6868
author_email='matthias.fey@tu-dortmund.de',
6969
url='https://github.com/rusty1s/pytorch_cluster',
@@ -80,6 +80,7 @@ def get_extensions():
8080
install_requires=install_requires,
8181
setup_requires=setup_requires,
8282
tests_require=tests_require,
83+
extras_require={'test': tests_require},
8384
ext_modules=get_extensions() if not BUILD_DOCS else [],
8485
cmdclass={
8586
'build_ext':

torch_cluster/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55

6-
__version__ = '1.5.4'
6+
__version__ = '1.5.5'
77

88
for library in [
99
'_version', '_grid', '_graclus', '_fps', '_rw', '_sampler', '_nearest',

0 commit comments

Comments
 (0)