Skip to content

Commit 40670b1

Browse files
committed
Replaced std::fill with .append addressing Zach's comment
1 parent ff7aaf0 commit 40670b1

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2138,9 +2138,11 @@ class ConvertAtenConvolutionBackwardOp
21382138
AffineMap::get(numIterators, 0, lhsExprs, context),
21392139
AffineMap::get(numIterators, 0, rhsExprs, context),
21402140
AffineMap::get(numIterators, 0, outExprs, context)};
2141-
SmallVector<IT> iteratorTypes = SmallVector<IT>(numIterators, IT::parallel);
2142-
std::fill(iteratorTypes.rbegin(),
2143-
iteratorTypes.rbegin() + (numSpatialDims + 1), IT::reduction);
2141+
2142+
int64_t numReductionDims = numSpatialDims + 1;
2143+
SmallVector<IT> iteratorTypes =
2144+
SmallVector<IT>(numIterators - numReductionDims, IT::parallel);
2145+
iteratorTypes.append(numReductionDims, IT::reduction);
21442146

21452147
return createConvAsGenericOp(rewriter, loc, gradOutput, weight,
21462148
gradInputInit, indexingMaps, iteratorTypes);
@@ -2216,9 +2218,11 @@ class ConvertAtenConvolutionBackwardOp
22162218
AffineMap::get(numIterators, 0, lhsExprs, context),
22172219
AffineMap::get(numIterators, 0, rhsExprs, context),
22182220
AffineMap::get(numIterators, 0, outExprs, context)};
2219-
SmallVector<IT> iteratorTypes = SmallVector<IT>(numIterators, IT::parallel);
2220-
std::fill(iteratorTypes.rbegin(),
2221-
iteratorTypes.rbegin() + (numSpatialDims + 1), IT::reduction);
2221+
2222+
int64_t numReductionDims = numSpatialDims + 1;
2223+
SmallVector<IT> iteratorTypes =
2224+
SmallVector<IT>(numIterators - numReductionDims, IT::parallel);
2225+
iteratorTypes.append(numReductionDims, IT::reduction);
22222226

22232227
return createConvAsGenericOp(rewriter, loc, input, gradOutput,
22242228
gradWeightInit, indexingMaps, iteratorTypes);

0 commit comments

Comments
 (0)