Skip to content

Commit b229c37

Browse files
committed
[feat]: support slice with dynamic shape
Signed-off-by: inocsin <vcheungyi@163.com>
1 parent c2fb43b commit b229c37

File tree

6 files changed

+344
-40
lines changed

6 files changed

+344
-40
lines changed

core/conversion/converters/converter_util.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,60 @@ nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t, const std::
199199
return out;
200200
}
201201

202+
nvinfer1::ITensor* toITensor(ConversionCtx* ctx, const torch::jit::Node* n, at::Tensor* input) {
203+
204+
auto weights = Weights(ctx, *input);
205+
// IConstantLayer to convert indices from Weights to ITensor
206+
auto const_layer = ctx->net->addConstant(weights.shape, weights.data); // shouln't use constant
207+
TORCHTRT_CHECK(const_layer, "Unable to create constant layer from node: " << *n);
208+
auto const_out = const_layer->getOutput(0);
209+
return const_out;
210+
}
211+
212+
// clamp x to [lower_bound, upper_bound]
213+
nvinfer1::ITensor* clamp(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* x,
214+
nvinfer1::ITensor* lower_bound, nvinfer1::ITensor* upper_bound) {
215+
auto max_layer = ctx->net->addElementWise(*x, *lower_bound, nvinfer1::ElementWiseOperation::kMAX);
216+
auto max_itensor = max_layer->getOutput(0);
217+
auto min_layer = ctx->net->addElementWise(*max_itensor, *upper_bound, nvinfer1::ElementWiseOperation::kMIN);
218+
auto min_itensor = min_layer->getOutput(0);
219+
return min_itensor;
220+
}
221+
222+
// return indices < 0 ? inputDims + indices : indices
223+
nvinfer1::ITensor* bump_if_negtive(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* input_dim,
224+
nvinfer1::ITensor* indices) {
225+
auto nbdims = input_dim->getDimensions().d[0];
226+
auto zero = torch::zeros({nbdims}).to(torch::kI32);
227+
auto neg = - torch::ones({nbdims}).to(torch::kI32);
228+
auto zero_itensor = toITensor(ctx, n, &zero);
229+
auto neg_itensor = toITensor(ctx, n, &neg);
230+
auto signs = clamp(ctx, n, indices, neg_itensor, zero_itensor);
231+
auto mul = ctx->net->addElementWise(*signs, *input_dim, nvinfer1::ElementWiseOperation::kPROD);
232+
auto mul_itensor = mul->getOutput(0);
233+
auto sub = ctx->net->addElementWise(*indices, *mul_itensor, nvinfer1::ElementWiseOperation::kSUB);
234+
auto sub_itensor = sub->getOutput(0);
235+
return sub_itensor;
236+
}
237+
238+
void update_start_and_end(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in_shape,
239+
nvinfer1::ITensor* in_start, nvinfer1::ITensor* in_end,
240+
nvinfer1::ITensor** out_start, nvinfer1::ITensor** out_end) {
241+
*out_start = bump_if_negtive(ctx, n, in_shape, in_start);
242+
*out_end = bump_if_negtive(ctx, n, in_shape, in_end);
243+
}
244+
245+
bool is_dynamic_shape(nvinfer1::ITensor* tensor) {
246+
auto dim = tensor->getDimensions();
247+
auto ndims = dim.nbDims;
248+
for (int i = 0; i < ndims; i++) {
249+
if (dim.d[i] == -1) {
250+
return true;
251+
}
252+
}
253+
return false;
254+
}
255+
202256
} // namespace converters
203257
} // namespace conversion
204258
} // namespace core

core/conversion/converters/converter_util.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,20 @@ nvinfer1::ITensor* castITensor(ConversionCtx* ctx, nvinfer1::ITensor* tensor, nv
5050
// Freeze an at::Tensor in a IConstant layer
5151
nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t, const std::string& name = std::string());
5252

53+
nvinfer1::ITensor* toITensor(ConversionCtx* ctx, const torch::jit::Node* n, at::Tensor* input);
54+
55+
nvinfer1::ITensor* clamp(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* x,
56+
nvinfer1::ITensor* lower_bound, nvinfer1::ITensor* upper_bound);
57+
58+
nvinfer1::ITensor* bump_if_negtive(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* input_dim,
59+
nvinfer1::ITensor* indices);
60+
61+
void update_start_and_end(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in_shape,
62+
nvinfer1::ITensor* in_start, nvinfer1::ITensor* in_end,
63+
nvinfer1::ITensor** out_start, nvinfer1::ITensor** out_end);
64+
65+
bool is_dynamic_shape(nvinfer1::ITensor* tensor);
66+
5367
} // namespace converters
5468
} // namespace conversion
5569
} // namespace core

core/conversion/converters/impl/select.cpp

Lines changed: 99 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -302,45 +302,105 @@ auto select_registrations TORCHTRT_UNUSED =
302302
.pattern(
303303
{"aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)",
304304
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
305-
auto in = args[0].ITensorOrFreeze(ctx);
306-
auto axis = args[1].unwrapToInt();
307-
auto maxDim = static_cast<int64_t>(in->getDimensions().d[axis]);
308-
auto startIdx = 0;
309-
auto startIdxIVal = args[2].IValue();
310-
if (!startIdxIVal->isNone()) {
311-
startIdx = startIdxIVal->toInt();
312-
}
313-
// Handle case when given tensor index is negative
314-
auto start = (startIdx < 0) ? (maxDim + startIdx) : startIdx;
315-
// Bound the end index to input tensor dimensions at specified axis
316-
auto endIdx = maxDim;
317-
auto endIdxIVal = args[3].IValue();
318-
if (!endIdxIVal->isNone()) {
319-
endIdx = std::min(endIdxIVal->toInt(), maxDim);
320-
}
321-
auto end = (endIdx < 0) ? (maxDim + endIdx) : endIdx;
322-
auto step = args[4].unwrapToInt();
323-
324-
LOG_DEBUG("Start idx: " << start);
325-
LOG_DEBUG("End idx: " << end);
326-
327-
// indices to be accessed need to be an at::Tensor
328-
at::Tensor indices = torch::arange(start, end, step).to(torch::kI32);
329-
auto weights = Weights(ctx, indices);
330-
331-
// IConstantLayer to convert indices from Weights to ITensor
332-
auto const_layer = ctx->net->addConstant(weights.shape, weights.data);
333-
TORCHTRT_CHECK(const_layer, "Unable to create constant layer from node: " << *n);
334-
auto const_out = const_layer->getOutput(0);
335-
336-
// IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices from
337-
auto gather_layer = ctx->net->addGather(*in, *const_out, axis);
338-
TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
339-
auto gather_out = gather_layer->getOutput(0);
340-
341-
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gather_out);
342-
343-
LOG_DEBUG("Slice layer output shape: " << out->getDimensions());
305+
auto in = args[0].ITensorOrFreeze(ctx);
306+
auto axis = args[1].unwrapToInt();
307+
auto maxDim = static_cast<int64_t>(in->getDimensions().d[axis]);
308+
bool dynamic_shape = is_dynamic_shape(in);
309+
auto input_dim = in->getDimensions();
310+
// add Shape Tensor
311+
auto ishape_layer = ctx->net->addShape(*in);
312+
auto ishape_tensor = ishape_layer->getOutput(0); // input shape
313+
314+
auto startIdx = 0;
315+
auto startIdxIVal = args[2].IValue();
316+
if (!startIdxIVal->isNone()) {
317+
startIdx = startIdxIVal->toInt();
318+
}
319+
// Handle case when given tensor index is negative
320+
if (maxDim > 0) { // only for static shape
321+
startIdx = (startIdx < 0) ? (maxDim + startIdx) : startIdx;
322+
}
323+
324+
// Bound the end index to input tensor dimensions at specified axis
325+
auto endIdx = maxDim; // -1 for dynamic shape
326+
auto endIdxIVal = args[3].IValue();
327+
if (!endIdxIVal->isNone()) {
328+
endIdx = maxDim == -1 ? endIdxIVal->toInt() : std::min(endIdxIVal->toInt(), maxDim);
329+
}
330+
if (maxDim > 0) {
331+
endIdx = (endIdx < 0) ? (maxDim + endIdx) : endIdx;
332+
}
333+
auto step = args[4].unwrapToInt();
334+
335+
auto nbdims = in->getDimensions().nbDims;
336+
nvinfer1::Dims start_, size_, stride_;
337+
start_.nbDims = nbdims;
338+
size_.nbDims = nbdims;
339+
stride_.nbDims = nbdims;
340+
for (int i = 0; i < nbdims; i++) {
341+
if (i == axis) {
342+
start_.d[i] = startIdx;
343+
size_.d[i] = (endIdx - startIdx - 1) / step + 1;
344+
stride_.d[i] = step;
345+
} else {
346+
start_.d[i] = 0;
347+
size_.d[i] = input_dim.d[i]; // for static
348+
stride_.d[i] = 1;
349+
}
350+
}
351+
auto slice_layer = ctx->net->addSlice(*in, start_, size_, stride_);
352+
353+
if (dynamic_shape) { // dynamic shape
354+
LOG_DEBUG("Using dynamic version of slice");
355+
// start tensor
356+
at::Tensor start_tensor = torch::zeros({nbdims}).to(torch::kI32);;
357+
start_tensor[axis] = startIdx;
358+
auto start_itensor = toITensor(ctx, n, &start_tensor);
359+
360+
// step tensor
361+
at::Tensor stride_tensor = torch::ones({nbdims}).to(torch::kI32);
362+
stride_tensor[axis] = step;
363+
auto stride_itensor = toITensor(ctx, n, &stride_tensor);
364+
365+
// end tensor
366+
at::Tensor end_tensor = torch::zeros({nbdims}).to(torch::kI32);
367+
for (int i = 0; i < nbdims; i++) {
368+
if (i == axis) {
369+
end_tensor[i] = endIdxIVal->isNone() ? -1: endIdx-1;
370+
} else {
371+
end_tensor[i] = input_dim.d[i] == -1 ? -1 : input_dim.d[i]-1;
372+
}
373+
}
374+
auto end_itensor = toITensor(ctx, n, &end_tensor);
375+
376+
// one itensor
377+
at::Tensor one_tensor = torch::ones({nbdims}).to(torch::kI32);
378+
auto one_itensor = toITensor(ctx, n, &one_tensor);
379+
380+
// update start and end
381+
nvinfer1::ITensor* out_start;
382+
nvinfer1::ITensor* out_end;
383+
update_start_and_end(ctx, n, ishape_tensor,
384+
start_itensor, end_itensor,
385+
&out_start, &out_end);
386+
387+
// calculate size
388+
auto sub_layer = ctx->net->addElementWise(*out_end, *out_start, nvinfer1::ElementWiseOperation::kSUB);
389+
auto sub_itensor = sub_layer->getOutput(0);
390+
auto div_layer = ctx->net->addElementWise(*sub_itensor, *stride_itensor, nvinfer1::ElementWiseOperation::kDIV);
391+
auto div_itensor = div_layer->getOutput(0);
392+
auto add_layer = ctx->net->addElementWise(*div_itensor, *one_itensor, nvinfer1::ElementWiseOperation::kSUM);
393+
auto size_itensor = add_layer->getOutput(0);
394+
395+
396+
// update slice layer
397+
slice_layer->setInput(1, *out_start); // start
398+
slice_layer->setInput(2, *size_itensor); // size, must be set if input is dynamic
399+
400+
}
401+
auto slice_out = slice_layer->getOutput(0);
402+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], slice_out);
403+
LOG_DEBUG("Slice layer output shape: " << out->getDimensions());
344404

345405
return true;
346406
}})

core/conversion/var/Var.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
110110
out = ptr_.tensor;
111111
}
112112

113+
LOG_DEBUG("ITensor name: " << out->getName());
113114
LOG_DEBUG("ITensor shape: " << out->getDimensions());
114115
LOG_DEBUG("ITensor type: " << out->getType());
115116
return out;

0 commit comments

Comments
 (0)