@@ -2138,9 +2138,11 @@ class ConvertAtenConvolutionBackwardOp
21382138 AffineMap::get (numIterators, 0 , lhsExprs, context),
21392139 AffineMap::get (numIterators, 0 , rhsExprs, context),
21402140 AffineMap::get (numIterators, 0 , outExprs, context)};
2141- SmallVector<IT> iteratorTypes = SmallVector<IT>(numIterators, IT::parallel);
2142- std::fill (iteratorTypes.rbegin (),
2143- iteratorTypes.rbegin () + (numSpatialDims + 1 ), IT::reduction);
2141+
2142+ int64_t numReductionDims = numSpatialDims + 1 ;
2143+ SmallVector<IT> iteratorTypes =
2144+ SmallVector<IT>(numIterators - numReductionDims, IT::parallel);
2145+ iteratorTypes.append (numReductionDims, IT::reduction);
21442146
21452147 return createConvAsGenericOp (rewriter, loc, gradOutput, weight,
21462148 gradInputInit, indexingMaps, iteratorTypes);
@@ -2216,9 +2218,11 @@ class ConvertAtenConvolutionBackwardOp
22162218 AffineMap::get (numIterators, 0 , lhsExprs, context),
22172219 AffineMap::get (numIterators, 0 , rhsExprs, context),
22182220 AffineMap::get (numIterators, 0 , outExprs, context)};
2219- SmallVector<IT> iteratorTypes = SmallVector<IT>(numIterators, IT::parallel);
2220- std::fill (iteratorTypes.rbegin (),
2221- iteratorTypes.rbegin () + (numSpatialDims + 1 ), IT::reduction);
2221+
2222+ int64_t numReductionDims = numSpatialDims + 1 ;
2223+ SmallVector<IT> iteratorTypes =
2224+ SmallVector<IT>(numIterators - numReductionDims, IT::parallel);
2225+ iteratorTypes.append (numReductionDims, IT::reduction);
22222226
22232227 return createConvAsGenericOp (rewriter, loc, input, gradOutput,
22242228 gradWeightInit, indexingMaps, iteratorTypes);
0 commit comments