@@ -18,7 +18,14 @@ auto cast_registrations TORCHTRT_UNUSED =
1818 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
1919 auto self = args[0 ].ITensorOrFreeze (ctx);
2020 auto output_dtype = args[1 ].unwrapToScalar ().to <int64_t >();
21- auto trt_dtype = util::ScalarTypeToTRTDataType (static_cast <at::ScalarType>(output_dtype));
21+ auto scalar_dtype = static_cast <at::ScalarType>(output_dtype);
22+ nvinfer1::DataType trt_dtype;
23+ if (scalar_dtype == at::kLong ) {
24+ LOG_WARNING (" Truncating aten::to output type from at::kLong to at::kInt" );
25+ trt_dtype = nvinfer1::DataType::kINT32 ;
26+ } else {
27+ trt_dtype = util::ScalarTypeToTRTDataType (static_cast <at::ScalarType>(output_dtype));
28+ }
2229 auto casted_itensor = castITensor (ctx, self, trt_dtype);
2330 auto output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], casted_itensor);
2431 LOG_DEBUG (" [aten::to.dtype] Output tensor shape: " << output->getDimensions ());
@@ -33,9 +40,14 @@ auto cast_registrations TORCHTRT_UNUSED =
3340 // later shape analysis phase of fallback
3441 auto self = args[0 ].ITensorOrFreeze (ctx);
3542 auto output_dtype = args[2 ].unwrapToScalar ().to <int64_t >();
36-
37- auto trt_dtype = util::ScalarTypeToTRTDataType (static_cast <at::ScalarType>(output_dtype));
38-
43+ auto scalar_dtype = static_cast <at::ScalarType>(output_dtype);
44+ nvinfer1::DataType trt_dtype;
45+ if (scalar_dtype == at::kLong ) {
46+ LOG_WARNING (" Truncating aten::to output type from at::kLong to at::kInt" );
47+ trt_dtype = nvinfer1::DataType::kINT32 ;
48+ } else {
49+ trt_dtype = util::ScalarTypeToTRTDataType (static_cast <at::ScalarType>(output_dtype));
50+ }
3951 auto casted_itensor = castITensor (ctx, self, trt_dtype);
4052 auto output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], casted_itensor);
4153 LOG_DEBUG (" [aten::to.device] Output tensor shape: " << output->getDimensions ());
0 commit comments