Skip to content

Commit cb0e5f6

Browse files
committed
cleaner
1 parent b992389 commit cb0e5f6

File tree

1 file changed

+5
-10
lines changed

1 file changed

+5
-10
lines changed

aten/cuda/cluster_kernel.cu

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,16 @@ __global__ void grid_cuda_kernel(
2525

2626
at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
2727
at::Tensor end) {
28-
size = size.toType(pos.type());
29-
start = start.toType(pos.type());
30-
end = end.toType(pos.type());
31-
3228
const auto num_nodes = pos.size(0);
3329
auto cluster = at::empty(pos.type().toScalarType(at::kLong), {num_nodes});
3430

3531
AT_DISPATCH_ALL_TYPES(pos.type(), "grid_cuda_kernel", [&] {
36-
auto cluster_data = cluster.data<int64_t>();
37-
auto pos_info = at::cuda::detail::getTensorInfo<scalar_t, int>(pos);
38-
auto size_data = size.data<scalar_t>();
39-
auto start_data = start.data<scalar_t>();
40-
auto end_data = end.data<scalar_t>();
4132
grid_cuda_kernel<scalar_t><<<BLOCKS(num_nodes), THREADS>>>(
42-
cluster_data, pos_info, size_data, start_data, end_data, num_nodes);
33+
cluster.data<int64_t>(),
34+
at::cuda::detail::getTensorInfo<scalar_t, int>(pos),
35+
size.toType(pos.type()).data<scalar_t>(),
36+
start..toType(pos.type()).data<scalar_t>(),
37+
end.toType(pos.type()).data<scalar_t>(), num_nodes);
4338
});
4439

4540
return cluster;

0 commit comments

Comments
 (0)