Skip to content

Commit 62e7846

Browse files
committed
[MLIR] Don't lower tensors that can't be represented by an ArrayType
1 parent 8dee997 commit 62e7846

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,11 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
502502
<< type << " illegal: cannot handle zero-element tensors\n");
503503
return nullptr;
504504
}
505+
if (arrayElemCount > std::numeric_limits<unsigned>::max()) {
506+
LLVM_DEBUG(llvm::dbgs()
507+
<< type << " illegal: cannot fit tensor into target type\n");
508+
return nullptr;
509+
}
505510

506511
Type arrayElemType = convertScalarType(targetEnv, options, scalarType);
507512
if (!arrayElemType)

mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,10 @@ func.func @tensor_2d_empty() -> () {
7979
%x = arith.constant dense<> : tensor<2x0xi32>
8080
return
8181
}
82+
83+
// CHECK-LABEL: func @very_large_tensor
84+
// CHECK-NEXT: arith.constant dense<1>
85+
func.func @very_large_tensor() -> () {
86+
%x = arith.constant dense<1> : tensor<4294967296xi32>
87+
return
88+
}

0 commit comments

Comments
 (0)