Skip to content

Commit 2738738

Browse files
authored
Add bf16 support for knn_cpu, radius_cpu and graclus_cpu (#144)
1 parent eea2fc5 commit 2738738

File tree

4 files changed

+5
-4
lines changed

4 files changed

+5
-4
lines changed

csrc/cpu/graclus_cpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ torch::Tensor graclus_cpu(torch::Tensor rowptr, torch::Tensor col,
4747
} else {
4848
auto weight = optional_weight.value();
4949
auto scalar_type = weight.scalar_type();
50-
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
50+
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "graclus_cpu", [&] {
5151
auto weight_data = weight.data_ptr<scalar_t>();
5252

5353
for (auto n = 0; n < num_nodes; n++) {

csrc/cpu/knn_cpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ torch::Tensor knn_cpu(torch::Tensor x, torch::Tensor y,
2525

2626
std::vector<size_t> out_vec = std::vector<size_t>();
2727

28-
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, x.scalar_type(), "_", [&] {
28+
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "knn_cpu", [&] {
2929
// See: nanoflann/examples/vector_of_vectors_example.cpp
3030

3131
auto x_data = x.data_ptr<scalar_t>();

csrc/cpu/radius_cpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ torch::Tensor radius_cpu(torch::Tensor x, torch::Tensor y,
2525

2626
std::vector<size_t> out_vec = std::vector<size_t>();
2727

28-
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, x.scalar_type(), "_", [&] {
28+
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "radius_cpu", [&] {
2929
// See: nanoflann/examples/vector_of_vectors_example.cpp
3030

3131
auto x_data = x.data_ptr<scalar_t>();

test/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22

3-
dtypes = [torch.half, torch.float, torch.double, torch.int, torch.long]
3+
dtypes = [torch.half, torch.bfloat16, torch.float, torch.double,
4+
torch.int, torch.long]
45
grad_dtypes = [torch.half, torch.float, torch.double]
56

67
devices = [torch.device('cpu')]

0 commit comments

Comments
 (0)