Skip to content

Commit e8620a8

Browse files
authored
Half-precision support (#119)
* half support * deprecation * typo * test half * fix test
1 parent 0d735d7 commit e8620a8

File tree

12 files changed

+25
-18
lines changed

12 files changed

+25
-18
lines changed

csrc/cpu/graclus_cpu.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ torch::Tensor graclus_cpu(torch::Tensor rowptr, torch::Tensor col,
4646
}
4747
} else {
4848
auto weight = optional_weight.value();
49-
AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "weighted_graclus", [&] {
49+
auto scalar_type = weight.scalar_type();
50+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
5051
auto weight_data = weight.data_ptr<scalar_t>();
5152

5253
for (auto n = 0; n < num_nodes; n++) {

csrc/cpu/knn_cpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
2525

2626
std::vector<size_t> out_vec = std::vector<size_t>();
2727

28-
AT_DISPATCH_ALL_TYPES(x.scalar_type(), "knn_cpu", [&] {
28+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, x.scalar_type(), "_", [&] {
2929
// See: nanoflann/examples/vector_of_vectors_example.cpp
3030

3131
auto x_data = x.data_ptr<scalar_t>();

csrc/cpu/radius_cpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
2525

2626
std::vector<size_t> out_vec = std::vector<size_t>();
2727

28-
AT_DISPATCH_ALL_TYPES(x.scalar_type(), "radius_cpu", [&] {
28+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, x.scalar_type(), "_", [&] {
2929
// See: nanoflann/examples/vector_of_vectors_example.cpp
3030

3131
auto x_data = x.data_ptr<scalar_t>();

csrc/cuda/fps_cuda.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,27 +78,28 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
7878
auto batch_size = ptr.numel() - 1;
7979

8080
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
81-
auto out_ptr = deg.toType(torch::kFloat) * ratio;
81+
auto out_ptr = deg.toType(ratio.scalar_type()) * ratio;
8282
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
8383
out_ptr = torch::cat({torch::zeros(1, ptr.options()), out_ptr}, 0);
8484

8585
torch::Tensor start;
8686
if (random_start) {
8787
start = torch::rand(batch_size, src.options());
88-
start = (start * deg.toType(torch::kFloat)).toType(torch::kLong);
88+
start = (start * deg.toType(ratio.scalar_type())).toType(torch::kLong);
8989
} else {
9090
start = torch::zeros(batch_size, ptr.options());
9191
}
9292

93-
auto dist = torch::full(src.size(0), 1e38, src.options());
93+
auto dist = torch::full(src.size(0), 5e4, src.options());
9494

9595
auto out_size = (int64_t *)malloc(sizeof(int64_t));
9696
cudaMemcpy(out_size, out_ptr[-1].data_ptr<int64_t>(), sizeof(int64_t),
9797
cudaMemcpyDeviceToHost);
9898
auto out = torch::empty(out_size[0], out_ptr.options());
9999

100100
auto stream = at::cuda::getCurrentCUDAStream();
101-
AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "fps_kernel", [&] {
101+
auto scalar_type = src.scalar_type();
102+
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
102103
fps_kernel<scalar_t><<<batch_size, THREADS, 0, stream>>>(
103104
src.data_ptr<scalar_t>(), ptr.data_ptr<int64_t>(),
104105
out_ptr.data_ptr<int64_t>(), start.data_ptr<int64_t>(),

csrc/cuda/graclus_cuda.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ void propose(torch::Tensor out, torch::Tensor proposal, torch::Tensor rowptr,
113113
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), out.numel());
114114
} else {
115115
auto weight = optional_weight.value();
116-
AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "propose_kernel", [&] {
116+
auto scalar_type = weight.scalar_type();
117+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
117118
weighted_propose_kernel<scalar_t>
118119
<<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
119120
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
@@ -201,7 +202,8 @@ void respond(torch::Tensor out, torch::Tensor proposal, torch::Tensor rowptr,
201202
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), out.numel());
202203
} else {
203204
auto weight = optional_weight.value();
204-
AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "respond_kernel", [&] {
205+
auto scalar_type = weight.scalar_type();
206+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
205207
weighted_respond_kernel<scalar_t>
206208
<<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
207209
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),

csrc/cuda/grid_cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ torch::Tensor grid_cuda(torch::Tensor pos, torch::Tensor size,
6161
auto out = torch::empty(pos.size(0), pos.options().dtype(torch::kLong));
6262

6363
auto stream = at::cuda::getCurrentCUDAStream();
64-
AT_DISPATCH_ALL_TYPES(pos.scalar_type(), "grid_kernel", [&] {
64+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, pos.scalar_type(), "_", [&] {
6565
grid_kernel<scalar_t><<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
6666
pos.data_ptr<scalar_t>(), size.data_ptr<scalar_t>(),
6767
start.data_ptr<scalar_t>(), end.data_ptr<scalar_t>(),

csrc/cuda/knn_cuda.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
4545
int64_t best_idx[100];
4646

4747
for (int e = 0; e < k; e++) {
48-
best_dist[e] = 1e10;
48+
best_dist[e] = 5e4;
4949
best_idx[e] = -1;
5050
}
5151

@@ -121,7 +121,8 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
121121
dim3 BLOCKS((y.size(0) + THREADS - 1) / THREADS);
122122

123123
auto stream = at::cuda::getCurrentCUDAStream();
124-
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "knn_kernel", [&] {
124+
auto scalar_type = x.scalar_type();
125+
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
125126
knn_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>(
126127
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
127128
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),

csrc/cuda/nearest_cuda.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ torch::Tensor nearest_cuda(torch::Tensor x, torch::Tensor y,
7979
auto out = torch::empty({x.size(0)}, ptr_x.options());
8080

8181
auto stream = at::cuda::getCurrentCUDAStream();
82-
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "nearest_kernel", [&] {
82+
auto scalar_type = x.scalar_type();
83+
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
8384
nearest_kernel<scalar_t><<<x.size(0), THREADS, 0, stream>>>(
8485
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
8586
ptr_x.data_ptr<int64_t>(), ptr_y.data_ptr<int64_t>(),

csrc/cuda/radius_cuda.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y,
8080
dim3 BLOCKS((y.size(0) + THREADS - 1) / THREADS);
8181

8282
auto stream = at::cuda::getCurrentCUDAStream();
83-
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "radius_kernel", [&] {
83+
auto scalar_type = x.scalar_type();
84+
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
8485
radius_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>(
8586
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
8687
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),

test/test_knn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_knn_graph(dtype, device):
6767
(3, 2), (0, 3), (2, 3)])
6868

6969

70-
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
70+
@pytest.mark.parametrize('dtype,device', product([torch.float], devices))
7171
def test_knn_graph_large(dtype, device):
7272
x = torch.randn(1000, 3, dtype=dtype, device=device)
7373

0 commit comments

Comments
 (0)