@@ -302,45 +302,105 @@ auto select_registrations TORCHTRT_UNUSED =
302302 .pattern(
303303 {" aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)" ,
304304 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
305- auto in = args[0 ].ITensorOrFreeze (ctx);
306- auto axis = args[1 ].unwrapToInt ();
307- auto maxDim = static_cast <int64_t >(in->getDimensions ().d [axis]);
308- auto startIdx = 0 ;
309- auto startIdxIVal = args[2 ].IValue ();
310- if (!startIdxIVal->isNone ()) {
311- startIdx = startIdxIVal->toInt ();
312- }
313- // Handle case when given tensor index is negative
314- auto start = (startIdx < 0 ) ? (maxDim + startIdx) : startIdx;
315- // Bound the end index to input tensor dimensions at specified axis
316- auto endIdx = maxDim;
317- auto endIdxIVal = args[3 ].IValue ();
318- if (!endIdxIVal->isNone ()) {
319- endIdx = std::min (endIdxIVal->toInt (), maxDim);
320- }
321- auto end = (endIdx < 0 ) ? (maxDim + endIdx) : endIdx;
322- auto step = args[4 ].unwrapToInt ();
323-
324- LOG_DEBUG (" Start idx: " << start);
325- LOG_DEBUG (" End idx: " << end);
326-
327- // indices to be accessed need to be an at::Tensor
328- at::Tensor indices = torch::arange (start, end, step).to (torch::kI32 );
329- auto weights = Weights (ctx, indices);
330-
331- // IConstantLayer to convert indices from Weights to ITensor
332- auto const_layer = ctx->net ->addConstant (weights.shape , weights.data );
333- TORCHTRT_CHECK (const_layer, " Unable to create constant layer from node: " << *n);
334- auto const_out = const_layer->getOutput (0 );
335-
336- // IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices from
337- auto gather_layer = ctx->net ->addGather (*in, *const_out, axis);
338- TORCHTRT_CHECK (gather_layer, " Unable to create gather layer from node: " << *n);
339- auto gather_out = gather_layer->getOutput (0 );
340-
341- auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], gather_out);
342-
343- LOG_DEBUG (" Slice layer output shape: " << out->getDimensions ());
305+ auto in = args[0 ].ITensorOrFreeze (ctx);
306+ auto axis = args[1 ].unwrapToInt ();
307+ auto maxDim = static_cast <int64_t >(in->getDimensions ().d [axis]);
308+ bool dynamic_shape = is_dynamic_shape (in);
309+ auto input_dim = in->getDimensions ();
310+ // add Shape Tensor
311+ auto ishape_layer = ctx->net ->addShape (*in);
312+ auto ishape_tensor = ishape_layer->getOutput (0 ); // input shape
313+
314+ auto startIdx = 0 ;
315+ auto startIdxIVal = args[2 ].IValue ();
316+ if (!startIdxIVal->isNone ()) {
317+ startIdx = startIdxIVal->toInt ();
318+ }
319+ // Handle case when given tensor index is negative
320+ if (maxDim > 0 ) { // only for static shape
321+ startIdx = (startIdx < 0 ) ? (maxDim + startIdx) : startIdx;
322+ }
323+
324+ // Bound the end index to input tensor dimensions at specified axis
325+ auto endIdx = maxDim; // -1 for dynamic shape
326+ auto endIdxIVal = args[3 ].IValue ();
327+ if (!endIdxIVal->isNone ()) {
328+ endIdx = maxDim == -1 ? endIdxIVal->toInt () : std::min (endIdxIVal->toInt (), maxDim);
329+ }
330+ if (maxDim > 0 ) {
331+ endIdx = (endIdx < 0 ) ? (maxDim + endIdx) : endIdx;
332+ }
333+ auto step = args[4 ].unwrapToInt ();
334+
335+ auto nbdims = in->getDimensions ().nbDims ;
336+ nvinfer1::Dims start_, size_, stride_;
337+ start_.nbDims = nbdims;
338+ size_.nbDims = nbdims;
339+ stride_.nbDims = nbdims;
340+ for (int i = 0 ; i < nbdims; i++) {
341+ if (i == axis) {
342+ start_.d [i] = startIdx;
343+ size_.d [i] = (endIdx - startIdx - 1 ) / step + 1 ;
344+ stride_.d [i] = step;
345+ } else {
346+ start_.d [i] = 0 ;
347+ size_.d [i] = input_dim.d [i]; // for static
348+ stride_.d [i] = 1 ;
349+ }
350+ }
351+ auto slice_layer = ctx->net ->addSlice (*in, start_, size_, stride_);
352+
353+ if (dynamic_shape) { // dynamic shape
354+ LOG_DEBUG (" Using dynamic version of slice" );
355+ // start tensor
356+ at::Tensor start_tensor = torch::zeros ({nbdims}).to (torch::kI32 );;
357+ start_tensor[axis] = startIdx;
358+ auto start_itensor = toITensor (ctx, n, &start_tensor);
359+
360+ // step tensor
361+ at::Tensor stride_tensor = torch::ones ({nbdims}).to (torch::kI32 );
362+ stride_tensor[axis] = step;
363+ auto stride_itensor = toITensor (ctx, n, &stride_tensor);
364+
365+ // end tensor
366+ at::Tensor end_tensor = torch::zeros ({nbdims}).to (torch::kI32 );
367+ for (int i = 0 ; i < nbdims; i++) {
368+ if (i == axis) {
369+ end_tensor[i] = endIdxIVal->isNone () ? -1 : endIdx-1 ;
370+ } else {
371+ end_tensor[i] = input_dim.d [i] == -1 ? -1 : input_dim.d [i]-1 ;
372+ }
373+ }
374+ auto end_itensor = toITensor (ctx, n, &end_tensor);
375+
376+ // one itensor
377+ at::Tensor one_tensor = torch::ones ({nbdims}).to (torch::kI32 );
378+ auto one_itensor = toITensor (ctx, n, &one_tensor);
379+
380+ // update start and end
381+ nvinfer1::ITensor* out_start;
382+ nvinfer1::ITensor* out_end;
383+ update_start_and_end (ctx, n, ishape_tensor,
384+ start_itensor, end_itensor,
385+ &out_start, &out_end);
386+
387+ // calculate size
388+ auto sub_layer = ctx->net ->addElementWise (*out_end, *out_start, nvinfer1::ElementWiseOperation::kSUB );
389+ auto sub_itensor = sub_layer->getOutput (0 );
390+ auto div_layer = ctx->net ->addElementWise (*sub_itensor, *stride_itensor, nvinfer1::ElementWiseOperation::kDIV );
391+ auto div_itensor = div_layer->getOutput (0 );
392+ auto add_layer = ctx->net ->addElementWise (*div_itensor, *one_itensor, nvinfer1::ElementWiseOperation::kSUM );
393+ auto size_itensor = add_layer->getOutput (0 );
394+
395+
396+ // update slice layer
397+ slice_layer->setInput (1 , *out_start); // start
398+ slice_layer->setInput (2 , *size_itensor); // size, must be set if input is dynamic
399+
400+ }
401+ auto slice_out = slice_layer->getOutput (0 );
402+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], slice_out);
403+ LOG_DEBUG (" Slice layer output shape: " << out->getDimensions ());
344404
345405 return true ;
346406 }})
0 commit comments