@@ -25,21 +25,16 @@ __global__ void grid_cuda_kernel(
2525
2626at::Tensor grid_cuda (at::Tensor pos, at::Tensor size, at::Tensor start,
2727 at::Tensor end) {
28- size = size.toType (pos.type ());
29- start = start.toType (pos.type ());
30- end = end.toType (pos.type ());
31-
3228 const auto num_nodes = pos.size (0 );
3329 auto cluster = at::empty (pos.type ().toScalarType (at::kLong ), {num_nodes});
3430
3531 AT_DISPATCH_ALL_TYPES (pos.type (), " grid_cuda_kernel" , [&] {
36- auto cluster_data = cluster.data <int64_t >();
37- auto pos_info = at::cuda::detail::getTensorInfo<scalar_t , int >(pos);
38- auto size_data = size.data <scalar_t >();
39- auto start_data = start.data <scalar_t >();
40- auto end_data = end.data <scalar_t >();
4132 grid_cuda_kernel<scalar_t ><<<BLOCKS(num_nodes), THREADS>>> (
42- cluster_data, pos_info, size_data, start_data, end_data, num_nodes);
33+ cluster.data <int64_t >(),
34+ at::cuda::detail::getTensorInfo<scalar_t , int >(pos),
35+ size.toType (pos.type ()).data <scalar_t >(),
36+ start..toType (pos.type ()).data <scalar_t >(),
37+ end.toType (pos.type ()).data <scalar_t >(), num_nodes);
4338 });
4439
4540 return cluster;
0 commit comments