@@ -78,27 +78,28 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
7878 auto batch_size = ptr.numel () - 1 ;
7979
8080 auto deg = ptr.narrow (0 , 1 , batch_size) - ptr.narrow (0 , 0 , batch_size);
81- auto out_ptr = deg.toType (torch:: kFloat ) * ratio;
81+ auto out_ptr = deg.toType (ratio. scalar_type () ) * ratio;
8282 out_ptr = out_ptr.ceil ().toType (torch::kLong ).cumsum (0 );
8383 out_ptr = torch::cat ({torch::zeros (1 , ptr.options ()), out_ptr}, 0 );
8484
8585 torch::Tensor start;
8686 if (random_start) {
8787 start = torch::rand (batch_size, src.options ());
88- start = (start * deg.toType (torch:: kFloat )).toType (torch::kLong );
88+ start = (start * deg.toType (ratio. scalar_type () )).toType (torch::kLong );
8989 } else {
9090 start = torch::zeros (batch_size, ptr.options ());
9191 }
9292
93- auto dist = torch::full (src.size (0 ), 1e38 , src.options ());
93+ auto dist = torch::full (src.size (0 ), 5e4 , src.options ());
9494
9595 auto out_size = (int64_t *)malloc (sizeof (int64_t ));
9696 cudaMemcpy (out_size, out_ptr[-1 ].data_ptr <int64_t >(), sizeof (int64_t ),
9797 cudaMemcpyDeviceToHost);
9898 auto out = torch::empty (out_size[0 ], out_ptr.options ());
9999
100100 auto stream = at::cuda::getCurrentCUDAStream ();
101- AT_DISPATCH_FLOATING_TYPES (src.scalar_type (), " fps_kernel" , [&] {
101+ auto scalar_type = src.scalar_type ();
102+ AT_DISPATCH_FLOATING_TYPES_AND (at::ScalarType::Half, scalar_type, " _" , [&] {
102103 fps_kernel<scalar_t ><<<batch_size, THREADS, 0 , stream>>> (
103104 src.data_ptr <scalar_t >(), ptr.data_ptr <int64_t >(),
104105 out_ptr.data_ptr <int64_t >(), start.data_ptr <int64_t >(),
0 commit comments