|
2 | 2 |
|
3 | 3 | #include <torch/extension.h> |
4 | 4 |
|
5 | | -int64_t cuda_version(); |
| 5 | +#include "macros.h" |
6 | 6 |
|
7 | | -torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, double ratio, |
| 7 | +namespace cluster { |
| 8 | +CLUSTER_API int64_t cuda_version() noexcept; |
| 9 | + |
| 10 | +namespace detail { |
| 11 | +CLUSTER_INLINE_VARIABLE int64_t _cuda_version = cuda_version(); |
| 12 | +} // namespace detail |
| 13 | +} // namespace cluster |
| 14 | + |
| 15 | +CLUSTER_API torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, double ratio, |
8 | 16 | bool random_start); |
9 | 17 |
|
10 | | -torch::Tensor graclus(torch::Tensor rowptr, torch::Tensor col, |
| 18 | +CLUSTER_API torch::Tensor graclus(torch::Tensor rowptr, torch::Tensor col, |
11 | 19 | torch::optional<torch::Tensor> optional_weight); |
12 | 20 |
|
13 | | -torch::Tensor grid(torch::Tensor pos, torch::Tensor size, |
| 21 | +CLUSTER_API torch::Tensor grid(torch::Tensor pos, torch::Tensor size, |
14 | 22 | torch::optional<torch::Tensor> optional_start, |
15 | 23 | torch::optional<torch::Tensor> optional_end); |
16 | 24 |
|
17 | | -torch::Tensor knn(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x, |
| 25 | +CLUSTER_API torch::Tensor knn(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x, |
18 | 26 | torch::Tensor ptr_y, int64_t k, bool cosine); |
19 | 27 |
|
20 | | -torch::Tensor nearest(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x, |
| 28 | +CLUSTER_API torch::Tensor nearest(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x, |
21 | 29 | torch::Tensor ptr_y); |
22 | 30 |
|
23 | | -torch::Tensor radius(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x, |
| 31 | +CLUSTER_API torch::Tensor radius(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x, |
24 | 32 | torch::Tensor ptr_y, double r, int64_t max_num_neighbors); |
25 | 33 |
|
26 | | -std::tuple<torch::Tensor, torch::Tensor> |
| 34 | +CLUSTER_API std::tuple<torch::Tensor, torch::Tensor> |
27 | 35 | random_walk(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start, |
28 | 36 | int64_t walk_length, double p, double q); |
29 | 37 |
|
30 | | -torch::Tensor neighbor_sampler(torch::Tensor start, torch::Tensor rowptr, |
| 38 | +CLUSTER_API torch::Tensor neighbor_sampler(torch::Tensor start, torch::Tensor rowptr, |
31 | 39 | int64_t count, double factor); |
0 commit comments