@@ -218,8 +218,8 @@ nvinfer1::ITensor* clamp(
218218}
219219
220220// clamp x to [0, input_dim]
221- nvinfer1::ITensor* clamp_to_input_dim (ConversionCtx* ctx, nvinfer1::ITensor* x, nvinfer1::ITensor* input_dim) {
222- auto nbdims = input_dim->getDimensions ().d [0 ];
221+ nvinfer1::ITensor* clamp_to_input_dim (ConversionCtx* ctx, nvinfer1::ITensor* x, nvinfer1::ITensor* input_dim, int nbdims ) {
222+ // auto nbdims = input_dim->getDimensions().d[0];
223223 auto zero = torch::zeros ({nbdims}).to (torch::kI32 );
224224 auto zero_itensor = tensor_to_const (ctx, zero);
225225 auto one = torch::ones ({nbdims}).to (torch::kI32 );
@@ -243,8 +243,7 @@ nvinfer1::ITensor* clamp_to_input_dim(ConversionCtx* ctx, nvinfer1::ITensor* x,
243243}
244244
245245// return indices < 0 ? inputDims + indices : indices
246- nvinfer1::ITensor* bump_if_negtive (ConversionCtx* ctx, nvinfer1::ITensor* input_dim, nvinfer1::ITensor* indices) {
247- auto nbdims = input_dim->getDimensions ().d [0 ];
246+ nvinfer1::ITensor* bump_if_negtive (ConversionCtx* ctx, nvinfer1::ITensor* input_dim, nvinfer1::ITensor* indices, int nbdims) {
248247 auto zero = torch::zeros ({nbdims}).to (torch::kI32 );
249248 auto neg = -torch::ones ({nbdims}).to (torch::kI32 );
250249 auto zero_itensor = tensor_to_const (ctx, zero);
@@ -270,11 +269,12 @@ std::vector<nvinfer1::ITensor*> update_start_and_end(
270269 ConversionCtx* ctx,
271270 nvinfer1::ITensor* in_shape,
272271 nvinfer1::ITensor* in_start,
273- nvinfer1::ITensor* in_end) {
274- auto start = bump_if_negtive (ctx, in_shape, in_start);
275- auto out_start = clamp_to_input_dim (ctx, start, in_shape);
276- auto end = bump_if_negtive (ctx, in_shape, in_end);
277- auto out_end = clamp_to_input_dim (ctx, end, in_shape);
272+ nvinfer1::ITensor* in_end,
273+ int nbdims) {
274+ auto start = bump_if_negtive (ctx, in_shape, in_start, nbdims);
275+ auto out_start = clamp_to_input_dim (ctx, start, in_shape, nbdims);
276+ auto end = bump_if_negtive (ctx, in_shape, in_end, nbdims);
277+ auto out_end = clamp_to_input_dim (ctx, end, in_shape, nbdims);
278278 std::vector<nvinfer1::ITensor*> outputs;
279279 outputs.push_back (out_start);
280280 outputs.push_back (out_end);
0 commit comments