@@ -80,22 +80,22 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
8080 auto deg = ptr.narrow (0 , 1 , batch_size) - ptr.narrow (0 , 0 , batch_size);
8181 auto out_ptr = deg.toType (ratio.scalar_type ()) * ratio;
8282 out_ptr = out_ptr.ceil ().toType (torch::kLong ).cumsum (0 );
83- out_ptr = torch::cat ({torch::zeros (1 , ptr.options ()), out_ptr}, 0 );
83+ out_ptr = torch::cat ({torch::zeros ({ 1 } , ptr.options ()), out_ptr}, 0 );
8484
8585 torch::Tensor start;
8686 if (random_start) {
8787 start = torch::rand (batch_size, src.options ());
8888 start = (start * deg.toType (ratio.scalar_type ())).toType (torch::kLong );
8989 } else {
90- start = torch::zeros (batch_size, ptr.options ());
90+ start = torch::zeros ({ batch_size} , ptr.options ());
9191 }
9292
9393 auto dist = torch::full (src.size (0 ), 5e4 , src.options ());
9494
9595 auto out_size = (int64_t *)malloc (sizeof (int64_t ));
9696 cudaMemcpy (out_size, out_ptr[-1 ].data_ptr <int64_t >(), sizeof (int64_t ),
9797 cudaMemcpyDeviceToHost);
98- auto out = torch::empty (out_size[0 ], out_ptr.options ());
98+ auto out = torch::empty ({ out_size[0 ]} , out_ptr.options ());
9999
100100 auto stream = at::cuda::getCurrentCUDAStream ();
101101 auto scalar_type = src.scalar_type ();
0 commit comments