@@ -355,3 +355,63 @@ func.func @convolution_backward_weights_2x2s_2x2p_2x2d_4g(%arg0: !torch.vtensor<
355355}
356356
357357// -----
358+
359+ // CHECK-LABEL: func.func @convolution_backward_input_1x1x1s_1x0x1p_1x1x1d_1g(
360+ // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,4,64,64,64],f32>, %[[VAL_1:.*]]: !torch.vtensor<[1,320,64,64,64],f32>,
361+ // CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[4,320,3,1,3],f32>,
362+ // CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[1,320,64,64,64],f32>, !torch.vtensor<[4],f32>) {
363+ func.func @convolution_backward_input_1x1x1s_1x0x1p_1x1x1d_1g (%arg0: !torch.vtensor <[1 ,4 ,64 ,64 ,64 ],f32 >, %arg1: !torch.vtensor <[1 ,320 ,64 ,64 ,64 ],f32 >, %arg2: !torch.vtensor <[4 ,320 ,3 ,1 ,3 ],f32 >, %arg3: !torch.vtensor <[],f32 >) -> (!torch.vtensor <[1 ,320 ,64 ,64 ,64 ],f32 >, !torch.vtensor <[4 ],f32 >) {
364+ // CHECK: %[[CST0:.*]] = arith.constant 0 : index
365+ // CHECK: %[[CST2:.*]] = arith.constant 2 : index
366+ // CHECK: %[[CST0F:.*]] = arith.constant 0.000000e+00 : f32
367+ // CHECK: %[[WT:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[4,320,3,1,3],f32> -> tensor<4x320x3x1x3xf32>
368+ // CHECK: %[[GO:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,4,64,64,64],f32> -> tensor<1x4x64x64x64xf32>
369+ // CHECK: %[[W_EMPTY:.*]] = tensor.empty() : tensor<4x320x3x1x3xf32>
370+ // CHECK: %[[W_FILLED:.*]] = linalg.fill ins(%[[CST0F]] : f32) outs(%[[W_EMPTY]] : tensor<4x320x3x1x3xf32>) -> tensor<4x320x3x1x3xf32>
371+ // CHECK: %[[W_REV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[WT]] : tensor<4x320x3x1x3xf32>) outs(%[[W_FILLED]] : tensor<4x320x3x1x3xf32>) {
372+ // CHECK-NEXT: ^bb0(%[[IN_W:.*]]: f32, %[[OUT_W:.*]]: f32):
373+ // CHECK-NEXT: %[[I0:.*]] = linalg.index 0 : index
374+ // CHECK-NEXT: %[[I1:.*]] = linalg.index 1 : index
375+ // CHECK-NEXT: %[[I2:.*]] = linalg.index 2 : index
376+ // CHECK-NEXT: %[[I4:.*]] = linalg.index 4 : index
377+ // CHECK-NEXT: %[[R2:.*]] = arith.subi %[[CST2]], %[[I2]] : index
378+ // CHECK-NEXT: %[[R4:.*]] = arith.subi %[[CST2]], %[[I4]] : index
379+ // CHECK-NEXT: %[[EX:.*]] = tensor.extract %[[WT]][%[[I0]], %[[I1]], %[[R2]], %[[CST0]], %[[R4]]] : tensor<4x320x3x1x3xf32>
380+ // CHECK-NEXT: linalg.yield %[[EX]] : f32
381+ // CHECK-NEXT: } -> tensor<4x320x3x1x3xf32>
382+ // CHECK: %[[PAD:.*]] = tensor.pad %[[GO]] low[0, 0, 1, 0, 1] high[0, 0, 1, 0, 1]
383+ // CHECK: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index):
384+ // CHECK: tensor.yield %[[CST0F]] : f32
385+ // CHECK: } : tensor<1x4x64x64x64xf32> to tensor<1x4x66x64x66xf32>
386+ // CHECK: %[[OUT_EMPTY:.*]] = tensor.empty() : tensor<1x320x64x64x64xf32>
387+ // CHECK: %[[OUT_FILLED:.*]] = linalg.fill ins(%[[CST0F]] : f32) outs(%[[OUT_EMPTY]] : tensor<1x320x64x64x64xf32>) -> tensor<1x320x64x64x64xf32>
388+ // CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d1, d6, d7, d8)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction", "reduction"]} ins(%[[PAD]], %[[W_REV]] : tensor<1x4x66x64x66xf32>, tensor<4x320x3x1x3xf32>) outs(%[[OUT_FILLED]] : tensor<1x320x64x64x64xf32>) {
389+ // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
390+ // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32
391+ // CHECK-NEXT: %[[ACC:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32
392+ // CHECK-NEXT: linalg.yield %[[ACC]] : f32
393+ // CHECK-NEXT: } -> tensor<1x320x64x64x64xf32>
394+ // CHECK: %[[IGRAD:.*]] = torch_c.from_builtin_tensor %[[CONV]] : tensor<1x320x64x64x64xf32> -> !torch.vtensor<[1,320,64,64,64],f32>
395+ // CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<4xf32>
396+ // CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST0F]] : f32) outs(%[[SUM_EMPTY]] : tensor<4xf32>) -> tensor<4xf32>
397+ // CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction", "reduction"]} ins(%[[GO]] : tensor<1x4x64x64x64xf32>) outs(%[[SUM_FILLED]] : tensor<4xf32>) {
398+ // CHECK-NEXT: ^bb0(%[[IN_B:.*]]: f32, %[[ACC_B:.*]]: f32):
399+ // CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[IN_B]], %[[ACC_B]] : f32
400+ // CHECK-NEXT: linalg.yield %[[B_RES]] : f32
401+ // CHECK-NEXT: } -> tensor<4xf32>
402+ // CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[SUM_GEN]] : tensor<4xf32> -> !torch.vtensor<[4],f32>
403+ // CHECK: return %[[IGRAD]], %[[BIAS]] : !torch.vtensor<[1,320,64,64,64],f32>, !torch.vtensor<[4],f32>
404+ %true = torch.constant.bool true
405+ %int0 = torch.constant.int 0
406+ %false = torch.constant.bool false
407+ %int1 = torch.constant.int 1
408+ %0 = torch.prim.ListConstruct %int1 : (!torch.int ) -> !torch.list <int >
409+ %1 = torch.prim.ListConstruct %int1 , %int1 , %int1 : (!torch.int , !torch.int , !torch.int ) -> !torch.list <int >
410+ %2 = torch.prim.ListConstruct %int1 , %int0 , %int1 : (!torch.int , !torch.int , !torch.int ) -> !torch.list <int >
411+ %3 = torch.prim.ListConstruct %int0 , %int0 , %int0 : (!torch.int , !torch.int , !torch.int ) -> !torch.list <int >
412+ %4 = torch.prim.ListConstruct %true , %false , %true : (!torch.bool , !torch.bool , !torch.bool ) -> !torch.list <bool >
413+ %result0 , %result1 , %result2 = torch.aten.convolution_backward %arg0 , %arg1 , %arg2 , %0 , %1 , %2 , %1 , %false , %3 , %int1 , %4 : !torch.vtensor <[1 ,4 ,64 ,64 ,64 ],f32 >, !torch.vtensor <[1 ,320 ,64 ,64 ,64 ],f32 >, !torch.vtensor <[4 ,320 ,3 ,1 ,3 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int , !torch.list <bool > -> !torch.vtensor <[1 ,320 ,64 ,64 ,64 ],f32 >, !torch.none , !torch.vtensor <[4 ],f32 >
414+ return %result0 , %result2 : !torch.vtensor <[1 ,320 ,64 ,64 ,64 ],f32 >, !torch.vtensor <[4 ],f32 >
415+ }
416+
417+ // -----
0 commit comments