Skip to content

Commit ff7aaf0

Browse files
committed
Updated flip-part addressed Zach's comment
1 parent 7ed7b5b commit ff7aaf0

File tree

4 files changed

+112
-11
lines changed

4 files changed

+112
-11
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1843,19 +1843,18 @@ class ConvertAtenConvolutionBackwardOp
18431843
// If convolution bwd is grouped, `weight` should be expanded
18441844
auto weightExpanded = isGroupedConvBwd ? expandGroups(weight, 0) : weight;
18451845

1846-
// Flip weight along spatial dims only if
1847-
// - kernel size is greater than 1,
1848-
// - the kernel is not a 1x1 or 1x1x1 kernel.
1846+
// Flip weight along non-unit spatial dims.
18491847
SmallVector<int64_t> weightDimsInt = makeShapeTorchCompatible(
18501848
cast<RankedTensorType>(weightExpanded.getType()).getShape());
1851-
bool is1x1Kernel = std::all_of(weightDimsInt.rbegin(),
1852-
weightDimsInt.rbegin() + numSpatialDims,
1853-
[](int64_t dim) { return dim == 1; });
1854-
if (numSpatialDims > 1 && !is1x1Kernel) {
1855-
SmallVector<int64_t> weightFlipDims;
1856-
weightFlipDims.reserve(numSpatialDims);
1857-
for (int64_t i = 0; i < static_cast<int64_t>(numSpatialDims); ++i)
1858-
weightFlipDims.push_back(spatialStartDimIdx + i);
1849+
// Collect any non-unit spatial dim indices.
1850+
SmallVector<int64_t> weightFlipDims;
1851+
for (auto [idx, dim] : llvm::enumerate(weightDimsInt)) {
1852+
if (idx >= spatialStartDimIdx && dim != 1) {
1853+
weightFlipDims.push_back(static_cast<int64_t>(idx));
1854+
}
1855+
}
1856+
// Perform a flip if we have more than one non-trivial spatial dim.
1857+
if (weightFlipDims.size() > 1) {
18591858
weightExpanded = torch_to_linalg::flipTensor(
18601859
rewriter, loc, weightExpanded, weightFlipDims);
18611860
}

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,6 +1136,7 @@
11361136
"Convolution2DStaticModule_basic",
11371137
"Convolution2DSingleIntTupleModule_basic",
11381138
"ConvolutionBackwardModule2DStatic_basic",
1139+
"ConvolutionBackwardModule3DStatic_basic",
11391140
"ConvolutionModule2DTransposeStridedStatic_basic",
11401141
"ConvolutionModule2DTransposeScalarTupleParams_basic",
11411142
"Conv_Transpose1dStaticModule_basic",
@@ -2934,6 +2935,7 @@
29342935
"ConvolutionBackwardModule2DDilated_basic",
29352936
"ConvolutionBackwardModule2DPadded_basic",
29362937
"ConvolutionBackwardModule2DStatic_basic",
2938+
"ConvolutionBackwardModule3DStatic_basic",
29372939
"ConvolutionBackwardModule2DStridedPaddedDilatedGrouped_basic",
29382940
"ConvolutionBackwardModule2DStrided_basic",
29392941
"ConvolutionBackwardModule2D_basic",
@@ -4343,6 +4345,7 @@
43434345
"ConvolutionBackwardModule2DDilated_basic",
43444346
"ConvolutionBackwardModule2DPadded_basic",
43454347
"ConvolutionBackwardModule2DStatic_basic",
4348+
"ConvolutionBackwardModule3DStatic_basic",
43464349
"ConvolutionBackwardModule2DStridedPaddedDilatedGrouped_basic",
43474350
"ConvolutionBackwardModule2DStrided_basic",
43484351
"ConvolutionBackwardModule2D_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,45 @@ def ConvolutionBackwardModule2DStatic_basic(module, tu: TestUtils):
158158
)
159159

160160

161+
class ConvolutionBackwardModule3DStatic(torch.nn.Module):
162+
def __init__(self):
163+
super().__init__()
164+
165+
@export
166+
@annotate_args(
167+
[
168+
None,
169+
([1, 4, 64, 64, 64], torch.float32, True),
170+
([1, 320, 64, 64, 64], torch.float32, True),
171+
([4, 320, 3, 1, 3], torch.float32, True),
172+
]
173+
)
174+
def forward(self, grad_out, input_vec, weight):
175+
return torch.ops.aten.convolution_backward(
176+
grad_out,
177+
input_vec,
178+
weight,
179+
bias_sizes=[4],
180+
stride=[1, 1, 1],
181+
padding=[1, 0, 1],
182+
dilation=[1, 1, 1],
183+
transposed=False,
184+
output_padding=[0, 0, 0],
185+
groups=1,
186+
output_mask=[True, True, True],
187+
)
188+
189+
190+
@register_test_case(module_factory=lambda: ConvolutionBackwardModule3DStatic())
191+
def ConvolutionBackwardModule3DStatic_basic(module, tu: TestUtils):
192+
with torch.backends.mkldnn.flags(enabled=False):
193+
module.forward(
194+
tu.rand(1, 4, 64, 64, 64),
195+
tu.rand(1, 320, 64, 64, 64),
196+
tu.rand(4, 320, 3, 1, 3),
197+
)
198+
199+
161200
class ConvolutionBackwardModule2DPadded(torch.nn.Module):
162201
def __init__(self):
163202
super().__init__()

test/Conversion/TorchToLinalg/convolution_bwd.mlir

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,3 +355,63 @@ func.func @convolution_backward_weights_2x2s_2x2p_2x2d_4g(%arg0: !torch.vtensor<
355355
}
356356

357357
// -----
358+
359+
// CHECK-LABEL: func.func @convolution_backward_input_1x1x1s_1x0x1p_1x1x1d_1g(
360+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,4,64,64,64],f32>, %[[VAL_1:.*]]: !torch.vtensor<[1,320,64,64,64],f32>,
361+
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[4,320,3,1,3],f32>,
362+
// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[1,320,64,64,64],f32>, !torch.vtensor<[4],f32>) {
363+
func.func @convolution_backward_input_1x1x1s_1x0x1p_1x1x1d_1g(%arg0: !torch.vtensor<[1,4,64,64,64],f32>, %arg1: !torch.vtensor<[1,320,64,64,64],f32>, %arg2: !torch.vtensor<[4,320,3,1,3],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[1,320,64,64,64],f32>, !torch.vtensor<[4],f32>) {
364+
// CHECK: %[[CST0:.*]] = arith.constant 0 : index
365+
// CHECK: %[[CST2:.*]] = arith.constant 2 : index
366+
// CHECK: %[[CST0F:.*]] = arith.constant 0.000000e+00 : f32
367+
// CHECK: %[[WT:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[4,320,3,1,3],f32> -> tensor<4x320x3x1x3xf32>
368+
// CHECK: %[[GO:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,4,64,64,64],f32> -> tensor<1x4x64x64x64xf32>
369+
// CHECK: %[[W_EMPTY:.*]] = tensor.empty() : tensor<4x320x3x1x3xf32>
370+
// CHECK: %[[W_FILLED:.*]] = linalg.fill ins(%[[CST0F]] : f32) outs(%[[W_EMPTY]] : tensor<4x320x3x1x3xf32>) -> tensor<4x320x3x1x3xf32>
371+
// CHECK: %[[W_REV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[WT]] : tensor<4x320x3x1x3xf32>) outs(%[[W_FILLED]] : tensor<4x320x3x1x3xf32>) {
372+
// CHECK-NEXT: ^bb0(%[[IN_W:.*]]: f32, %[[OUT_W:.*]]: f32):
373+
// CHECK-NEXT: %[[I0:.*]] = linalg.index 0 : index
374+
// CHECK-NEXT: %[[I1:.*]] = linalg.index 1 : index
375+
// CHECK-NEXT: %[[I2:.*]] = linalg.index 2 : index
376+
// CHECK-NEXT: %[[I4:.*]] = linalg.index 4 : index
377+
// CHECK-NEXT: %[[R2:.*]] = arith.subi %[[CST2]], %[[I2]] : index
378+
// CHECK-NEXT: %[[R4:.*]] = arith.subi %[[CST2]], %[[I4]] : index
379+
// CHECK-NEXT: %[[EX:.*]] = tensor.extract %[[WT]][%[[I0]], %[[I1]], %[[R2]], %[[CST0]], %[[R4]]] : tensor<4x320x3x1x3xf32>
380+
// CHECK-NEXT: linalg.yield %[[EX]] : f32
381+
// CHECK-NEXT: } -> tensor<4x320x3x1x3xf32>
382+
// CHECK: %[[PAD:.*]] = tensor.pad %[[GO]] low[0, 0, 1, 0, 1] high[0, 0, 1, 0, 1]
383+
// CHECK: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index):
384+
// CHECK: tensor.yield %[[CST0F]] : f32
385+
// CHECK: } : tensor<1x4x64x64x64xf32> to tensor<1x4x66x64x66xf32>
386+
// CHECK: %[[OUT_EMPTY:.*]] = tensor.empty() : tensor<1x320x64x64x64xf32>
387+
// CHECK: %[[OUT_FILLED:.*]] = linalg.fill ins(%[[CST0F]] : f32) outs(%[[OUT_EMPTY]] : tensor<1x320x64x64x64xf32>) -> tensor<1x320x64x64x64xf32>
388+
// CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d1, d6, d7, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction", "reduction"]} ins(%[[PAD]], %[[W_REV]] : tensor<1x4x66x64x66xf32>, tensor<4x320x3x1x3xf32>) outs(%[[OUT_FILLED]] : tensor<1x320x64x64x64xf32>) {
389+
// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
390+
// CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32
391+
// CHECK-NEXT: %[[ACC:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32
392+
// CHECK-NEXT: linalg.yield %[[ACC]] : f32
393+
// CHECK-NEXT: } -> tensor<1x320x64x64x64xf32>
394+
// CHECK: %[[IGRAD:.*]] = torch_c.from_builtin_tensor %[[CONV]] : tensor<1x320x64x64x64xf32> -> !torch.vtensor<[1,320,64,64,64],f32>
395+
// CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<4xf32>
396+
// CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST0F]] : f32) outs(%[[SUM_EMPTY]] : tensor<4xf32>) -> tensor<4xf32>
397+
// CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction", "reduction"]} ins(%[[GO]] : tensor<1x4x64x64x64xf32>) outs(%[[SUM_FILLED]] : tensor<4xf32>) {
398+
// CHECK-NEXT: ^bb0(%[[IN_B:.*]]: f32, %[[ACC_B:.*]]: f32):
399+
// CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[IN_B]], %[[ACC_B]] : f32
400+
// CHECK-NEXT: linalg.yield %[[B_RES]] : f32
401+
// CHECK-NEXT: } -> tensor<4xf32>
402+
// CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[SUM_GEN]] : tensor<4xf32> -> !torch.vtensor<[4],f32>
403+
// CHECK: return %[[IGRAD]], %[[BIAS]] : !torch.vtensor<[1,320,64,64,64],f32>, !torch.vtensor<[4],f32>
404+
%true = torch.constant.bool true
405+
%int0 = torch.constant.int 0
406+
%false = torch.constant.bool false
407+
%int1 = torch.constant.int 1
408+
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
409+
%1 = torch.prim.ListConstruct %int1, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
410+
%2 = torch.prim.ListConstruct %int1, %int0, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
411+
%3 = torch.prim.ListConstruct %int0, %int0, %int0 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
412+
%4 = torch.prim.ListConstruct %true, %false, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list<bool>
413+
%result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %2, %1, %false, %3, %int1, %4 : !torch.vtensor<[1,4,64,64,64],f32>, !torch.vtensor<[1,320,64,64,64],f32>, !torch.vtensor<[4,320,3,1,3],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int, !torch.list<bool> -> !torch.vtensor<[1,320,64,64,64],f32>, !torch.none, !torch.vtensor<[4],f32>
414+
return %result0, %result2 : !torch.vtensor<[1,320,64,64,64],f32>, !torch.vtensor<[4],f32>
415+
}
416+
417+
// -----

0 commit comments

Comments
 (0)