Skip to content

Commit dbcafbe

Browse files
Export symbols (#132)
* Export symbols in the DLL (similar to rusty1s/pytorch_sparse#198) * Export symbols in the DLL (similar to rusty1s/pytorch_sparse#198) * Export symbols in the DLL (similar to rusty1s/pytorch_sparse#198) * Update setup.py Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
1 parent 9472aef commit dbcafbe

30 files changed

+80
-38
lines changed

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ endif()
1515
find_package(Python3 COMPONENTS Development)
1616
find_package(Torch REQUIRED)
1717

18-
file(GLOB HEADERS csrc/cluster.h)
19-
file(GLOB OPERATOR_SOURCES csrc/cpu/*.h csrc/cpu/*.cpp csrc/*.cpp)
18+
file(GLOB HEADERS csrc/*.h)
19+
file(GLOB OPERATOR_SOURCES csrc/*.* csrc/cpu/*.*)
2020
if(WITH_CUDA)
2121
file(GLOB OPERATOR_SOURCES ${OPERATOR_SOURCES} csrc/cuda/*.h csrc/cuda/*.cu)
2222
endif()

csrc/cluster.h

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,38 @@
22

33
#include <torch/extension.h>
44

5-
int64_t cuda_version();
5+
#include "macros.h"
66

7-
torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, double ratio,
7+
namespace cluster {
8+
CLUSTER_API int64_t cuda_version() noexcept;
9+
10+
namespace detail {
11+
CLUSTER_INLINE_VARIABLE int64_t _cuda_version = cuda_version();
12+
} // namespace detail
13+
} // namespace cluster
14+
15+
CLUSTER_API torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, double ratio,
816
bool random_start);
917

10-
torch::Tensor graclus(torch::Tensor rowptr, torch::Tensor col,
18+
CLUSTER_API torch::Tensor graclus(torch::Tensor rowptr, torch::Tensor col,
1119
torch::optional<torch::Tensor> optional_weight);
1220

13-
torch::Tensor grid(torch::Tensor pos, torch::Tensor size,
21+
CLUSTER_API torch::Tensor grid(torch::Tensor pos, torch::Tensor size,
1422
torch::optional<torch::Tensor> optional_start,
1523
torch::optional<torch::Tensor> optional_end);
1624

17-
torch::Tensor knn(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
25+
CLUSTER_API torch::Tensor knn(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
1826
torch::Tensor ptr_y, int64_t k, bool cosine);
1927

20-
torch::Tensor nearest(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
28+
CLUSTER_API torch::Tensor nearest(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
2129
torch::Tensor ptr_y);
2230

23-
torch::Tensor radius(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
31+
CLUSTER_API torch::Tensor radius(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
2432
torch::Tensor ptr_y, double r, int64_t max_num_neighbors);
2533

26-
std::tuple<torch::Tensor, torch::Tensor>
34+
CLUSTER_API std::tuple<torch::Tensor, torch::Tensor>
2735
random_walk(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
2836
int64_t walk_length, double p, double q);
2937

30-
torch::Tensor neighbor_sampler(torch::Tensor start, torch::Tensor rowptr,
38+
CLUSTER_API torch::Tensor neighbor_sampler(torch::Tensor start, torch::Tensor rowptr,
3139
int64_t count, double factor);

csrc/cpu/fps_cpu.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include <torch/extension.h>
3+
#include "../extensions.h"
44

55
torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
66
bool random_start);

csrc/cpu/graclus_cpu.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include <torch/extension.h>
3+
#include "../extensions.h"
44

55
torch::Tensor graclus_cpu(torch::Tensor rowptr, torch::Tensor col,
66
torch::optional<torch::Tensor> optional_weight);

csrc/cpu/grid_cpu.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#pragma once
22

3-
#include <torch/extension.h>
4-
3+
#include "../extensions.h"
54
torch::Tensor grid_cpu(torch::Tensor pos, torch::Tensor size,
65
torch::optional<torch::Tensor> optional_start,
76
torch::optional<torch::Tensor> optional_end);

csrc/cpu/knn_cpu.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include <torch/extension.h>
3+
#include "../extensions.h"
44

55
torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
66
torch::optional<torch::Tensor> ptr_x,

csrc/cpu/radius_cpu.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include <torch/extension.h>
3+
#include "../extensions.h"
44

55
torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
66
torch::optional<torch::Tensor> ptr_x,

csrc/cpu/rw_cpu.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include <torch/extension.h>
3+
#include "../extensions.h"
44

55
std::tuple<torch::Tensor, torch::Tensor>
66
random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,

csrc/cpu/sampler_cpu.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include <torch/extension.h>
3+
#include "../extensions.h"
44

55
torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr,
66
int64_t count, double factor);

csrc/cpu/utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include <torch/extension.h>
3+
#include "../extensions.h"
44

55
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
66
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")

0 commit comments

Comments
 (0)