@@ -307,37 +307,38 @@ auto select_registrations TORCHTRT_UNUSED =
307307 {" aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)" ,
308308 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
309309 auto in = args[0 ].ITensorOrFreeze (ctx);
310- auto axis = args[1 ].unwrapToInt ();
311- auto maxDim = static_cast <int64_t >(in->getDimensions ().d [axis]);
310+ int axis = args[1 ].unwrapToInt ();
311+ int maxDim = static_cast <int32_t >(in->getDimensions ().d [axis]);
312312 bool dynamic_shape = is_dynamic_shape (in);
313313 auto input_dim = in->getDimensions ();
314314 // add Shape Tensor
315315 auto ishape_layer = ctx->net ->addShape (*in);
316316 auto ishape_tensor = ishape_layer->getOutput (0 ); // input shape
317317
318- auto startIdx = 0 ;
318+ int startIdx = 0 ;
319319 auto startIdxIVal = args[2 ].IValue ();
320320 if (!startIdxIVal->isNone ()) {
321- startIdx = startIdxIVal->toInt ();
321+ startIdx = std::min (( int64_t )std::numeric_limits< int32_t >:: max (), startIdxIVal->toInt () );
322322 }
323323 // Handle case when given tensor index is negative
324324 if (maxDim > 0 ) { // only for static shape
325325 startIdx = (startIdx < 0 ) ? (maxDim + startIdx) : startIdx;
326326 }
327327
328328 // Bound the end index to input tensor dimensions at specified axis
329- auto endIdx = maxDim; // -1 for dynamic shape
329+ int endIdx = maxDim; // -1 for dynamic shape
330330 auto endIdxIVal = args[3 ].IValue ();
331331 if (!endIdxIVal->isNone ()) {
332- endIdx = maxDim == -1 ? endIdxIVal->toInt () : std::min (endIdxIVal->toInt (), maxDim);
332+ int truncate_value = std::min ((int64_t )std::numeric_limits<int32_t >::max (), endIdxIVal->toInt ());
333+ endIdx = maxDim == -1 ? truncate_value : std::min (truncate_value, maxDim);
333334 }
334335 if (maxDim > 0 ) {
335336 endIdx = (endIdx < 0 ) ? (maxDim + endIdx) : endIdx;
336337 }
337- auto step = args[4 ].unwrapToInt ();
338+ int step = args[4 ].unwrapToInt ();
338339
339340 // update start, end, stride for static shape
340- auto nbdims = in->getDimensions ().nbDims ;
341+ int nbdims = in->getDimensions ().nbDims ;
341342 nvinfer1::Dims start_, size_, stride_;
342343 start_.nbDims = nbdims;
343344 size_.nbDims = nbdims;
0 commit comments