@@ -219,6 +219,24 @@ nvinfer1::ITensor* clamp(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1
219219 return min_itensor;
220220}
221221
222+ // clamp x to [0, input_dim]
223+ nvinfer1::ITensor* clamp_to_input_dim (ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* x,
224+ nvinfer1::ITensor* input_dim) {
225+ auto nbdims = input_dim->getDimensions ().d [0 ];
226+ auto zero = torch::zeros ({nbdims}).to (torch::kI32 );
227+ auto zero_itensor = toITensor (ctx, n, &zero);
228+ auto one = torch::ones ({nbdims}).to (torch::kI32 );
229+ auto one_itensor = toITensor (ctx, n, &one);
230+ auto upper_bound_layer = ctx->net ->addElementWise (*input_dim, *one_itensor, nvinfer1::ElementWiseOperation::kSUB );
231+ auto upper_bound = upper_bound_layer->getOutput (0 );
232+ auto max_layer = ctx->net ->addElementWise (*x, *zero_itensor, nvinfer1::ElementWiseOperation::kMAX );
233+ auto max_itensor = max_layer->getOutput (0 );
234+ auto min_layer = ctx->net ->addElementWise (*max_itensor, *upper_bound, nvinfer1::ElementWiseOperation::kMIN );
235+ auto min_itensor = min_layer->getOutput (0 );
236+ return min_itensor;
237+ }
238+
239+
222240// return indices < 0 ? inputDims + indices : indices
223241nvinfer1::ITensor* bump_if_negtive (ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* input_dim,
224242 nvinfer1::ITensor* indices) {
@@ -238,8 +256,10 @@ nvinfer1::ITensor* bump_if_negtive(ConversionCtx* ctx, const torch::jit::Node* n
238256void update_start_and_end (ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in_shape,
239257 nvinfer1::ITensor* in_start, nvinfer1::ITensor* in_end,
240258 nvinfer1::ITensor** out_start, nvinfer1::ITensor** out_end) {
241- *out_start = bump_if_negtive (ctx, n, in_shape, in_start);
242- *out_end = bump_if_negtive (ctx, n, in_shape, in_end);
259+ auto start = bump_if_negtive (ctx, n, in_shape, in_start);
260+ *out_start = clamp_to_input_dim (ctx, n, start, in_shape);
261+ auto end = bump_if_negtive (ctx, n, in_shape, in_end);
262+ *out_end = clamp_to_input_dim (ctx, n, end, in_shape);
243263}
244264
245265bool is_dynamic_shape (nvinfer1::ITensor* tensor) {
0 commit comments