Skip to content

Commit 58ac26e

Browse files
mahmood82Mahmood Yassin
andauthored
[CIR] Emit bitcast for equal-width types (#1991)
This patch adds support for emitting a `cir.bitcast` when reinterpreting values that have the **same bit-width**. This enables correct handling of vector reinterpretation in CIR and aligns with the behavior of the LLVM IR lowering. Previously, equal-width reinterpretations were not handled, which caused assertions or incorrect lowering when working with vector types or other equal-sized aggregates. Key points: - Introduces `cir.bitcast` emission when source and destination types have equal width and only reinterpretation is required. - Avoids unnecessary intermediate casts. - Enables upcoming work on vector reinterpretation and fixes gaps in the current type conversion pipeline. Testing: - Added/updated CIR tests validating equal-width reinterpretation. - Verified end-to-end lowering through LLVM for vector types. --------- Co-authored-by: Mahmood Yassin <mahmood.yassin@mobileye.com>
1 parent 5f17769 commit 58ac26e

File tree

5 files changed

+160
-13
lines changed

5 files changed

+160
-13
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,9 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
600600

601601
mlir::Value createPtrBitcast(mlir::Value src, mlir::Type newPointeeTy) {
602602
assert(mlir::isa<cir::PointerType>(src.getType()) && "expected ptr src");
603-
return createBitcast(src, getPointerTo(newPointeeTy));
603+
auto srcPtrTy = mlir::cast<cir::PointerType>(src.getType());
604+
mlir::Type newPtrTy = getPointerTo(newPointeeTy, srcPtrTy.getAddrSpace());
605+
return createBitcast(src, newPtrTy);
604606
}
605607

606608
mlir::Value createAddrSpaceCast(mlir::Location loc, mlir::Value src,
@@ -612,6 +614,29 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
612614
return createAddrSpaceCast(src.getLoc(), src, newTy);
613615
}
614616

617+
mlir::Value createPointerBitCastOrAddrSpaceCast(mlir::Location loc,
618+
mlir::Value src,
619+
mlir::Type newPointerTy) {
620+
assert(mlir::isa<cir::PointerType>(src.getType()) &&
621+
"expected source pointer");
622+
assert(mlir::isa<cir::PointerType>(newPointerTy) &&
623+
"expected destination pointer type");
624+
625+
auto srcPtrTy = mlir::cast<cir::PointerType>(src.getType());
626+
auto dstPtrTy = mlir::cast<cir::PointerType>(newPointerTy);
627+
628+
mlir::Value addrSpaceCasted = src;
629+
if (srcPtrTy.getAddrSpace() != dstPtrTy.getAddrSpace())
630+
addrSpaceCasted = createAddrSpaceCast(loc, src, dstPtrTy);
631+
632+
return createPtrBitcast(addrSpaceCasted, dstPtrTy.getPointee());
633+
}
634+
635+
mlir::Value createPointerBitCastOrAddrSpaceCast(mlir::Value src,
636+
mlir::Type newPointerTy) {
637+
return createPointerBitCastOrAddrSpaceCast(src.getLoc(), src, newPointerTy);
638+
}
639+
615640
mlir::Value createPtrIsNull(mlir::Value ptr) {
616641
return createNot(createPtrToBoolCast(ptr));
617642
}

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -831,7 +831,81 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
831831
mlir::Value VisitObjCDictionaryLiteral(ObjCDictionaryLiteral *E) {
832832
llvm_unreachable("NYI");
833833
}
834-
mlir::Value VisitAsTypeExpr(AsTypeExpr *E) { llvm_unreachable("NYI"); }
834+
835+
// Create cast instructions for converting LLVM value Src to MLIR type DstTy.
836+
// Src has the same size as DstTy. Both are single value types
837+
// but could be scalar or vectors of different lengths, and either can be
838+
// pointer.
839+
mlir::Value createCastsForTypeOfSameSize(mlir::Value Src, mlir::Type DstTy) {
840+
auto SrcTy = Src.getType();
841+
842+
// Case 1.
843+
if (!isa<cir::PointerType>(SrcTy) && !isa<cir::PointerType>(DstTy))
844+
return Builder.createBitcast(Src, DstTy);
845+
846+
// Case 2.
847+
if (isa<cir::PointerType>(SrcTy) && isa<cir::PointerType>(DstTy))
848+
return Builder.createPointerBitCastOrAddrSpaceCast(Src, DstTy);
849+
850+
// Case 3.
851+
if (isa<cir::PointerType>(SrcTy) && !isa<cir::PointerType>(DstTy)) {
852+
// Case 3b.
853+
if (!Builder.isInt(DstTy))
854+
llvm_unreachable("NYI");
855+
// Cases 3a and 3b.
856+
llvm_unreachable("NYI");
857+
}
858+
859+
// Case 4b.
860+
if (!Builder.isInt(SrcTy))
861+
llvm_unreachable("NYI");
862+
863+
// Cases 4a and 4b.
864+
llvm_unreachable("NYI");
865+
}
866+
867+
mlir::Value VisitAsTypeExpr(AsTypeExpr *E) {
868+
unsigned numSrcElems = 0;
869+
QualType qualSrcTy = E->getSrcExpr()->getType();
870+
mlir::Type srcTy = CGF.convertType(qualSrcTy);
871+
if (auto v = dyn_cast<cir::VectorType>(srcTy)) {
872+
assert(!cir::MissingFeatures::scalableVectors() &&
873+
"NYI: non-fixed (scalable) vector src");
874+
numSrcElems = v.getSize();
875+
}
876+
877+
unsigned numDstElems = 0;
878+
QualType qualDstTy = E->getType();
879+
mlir::Type dstTy = CGF.convertType(qualDstTy);
880+
if (auto v = dyn_cast<cir::VectorType>(dstTy)) {
881+
assert(!cir::MissingFeatures::scalableVectors() &&
882+
"NYI: non-fixed (scalable) vector dst");
883+
numDstElems = v.getSize();
884+
}
885+
886+
// Use bit vector expansion for ext_vector_type boolean vectors.
887+
if (qualDstTy->isExtVectorBoolType()) {
888+
llvm_unreachable("NYI");
889+
}
890+
891+
// Going from vec3 to non-vec3 is a special case and requires a shuffle
892+
// vector to get a vec4, then a bitcast if the target type is different.
893+
if (numSrcElems == 3 && numDstElems != 3) {
894+
llvm_unreachable("NYI");
895+
}
896+
897+
// Going from non-vec3 to vec3 is a special case and requires a bitcast
898+
// to vec4 if the original type is not vec4, then a shuffle vector to
899+
// get a vec3.
900+
if (numSrcElems != 3 && numDstElems == 3) {
901+
llvm_unreachable("NYI");
902+
}
903+
904+
// Otherwise, fallback to bitcast of same size
905+
mlir::Value src = CGF.emitScalarExpr(E->getSrcExpr());
906+
return createCastsForTypeOfSameSize(src, dstTy);
907+
}
908+
835909
mlir::Value VisitAtomicExpr(AtomicExpr *E) {
836910
return CGF.emitAtomicExpr(E).getScalarVal();
837911
}

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -662,8 +662,10 @@ LogicalResult cir::CastOp::verify() {
662662
auto resPtrTy = mlir::dyn_cast<cir::PointerType>(resType);
663663
if (!srcPtrTy || !resPtrTy)
664664
return emitOpError() << "requires !cir.ptr type for source and result";
665-
if (srcPtrTy.getPointee() != resPtrTy.getPointee())
666-
return emitOpError() << "requires two types differ in addrspace only";
665+
// Address space verification is sufficient here. The pointee types need not
666+
// be verified as they are handled by bitcast verification logic, which
667+
// ensures address space compatibility. Verifying pointee types would create
668+
// a circular dependency between address space and pointee type casting.
667669
return success();
668670
}
669671
case cir::CastKind::float_to_complex: {
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// RUN: %clang_cc1 %s -cl-std=CL2.0 -fclangir -emit-cir -triple spirv64-unknown-unknown -o %t.ll
2+
// RUN: FileCheck %s --input-file=%t.ll --check-prefix=CIR
3+
4+
// RUN: %clang_cc1 %s -cl-std=CL2.0 -fclangir -emit-llvm -triple spirv64-unknown-unknown -o %t.ll
5+
// RUN: FileCheck %s --input-file=%t.ll --check-prefix=LLVM
6+
7+
// RUN: %clang_cc1 %s -cl-std=CL2.0 -emit-llvm -triple spirv64-unknown-unknown -o %t.ll
8+
// RUN: FileCheck %s --input-file=%t.ll --check-prefix=OG-LLVM
9+
10+
typedef __attribute__(( ext_vector_type(4) )) char char4;
11+
12+
// CIR: cir.func @f4(%{{.*}}: !s32i loc({{.*}})) -> !cir.vector<!s8i x 4>
13+
// CIR: %[[x:.*]] = cir.load align(4) %{{.*}} : !cir.ptr<!s32i, addrspace(offload_private)>
14+
// CIR: cir.cast bitcast %[[x]] : !s32i -> !cir.vector<!s8i x 4>
15+
// LLVM: define spir_func <4 x i8> @f4(i32 %[[x:.*]])
16+
// LLVM: %[[astype:.*]] = bitcast i32 %[[x]] to <4 x i8>
17+
// LLVM-NOT: shufflevector
18+
// LLVM: ret <4 x i8> %[[astype]]
19+
// OG-LLVM: define spir_func noundef <4 x i8> @f4(i32 noundef %[[x:.*]])
20+
// OG-LLVM: %[[astype:.*]] = bitcast i32 %[[x]] to <4 x i8>
21+
// OG-LLVM-NOT: shufflevector
22+
// OG-LLVM: ret <4 x i8> %[[astype]]
23+
char4 f4(int x) {
24+
return __builtin_astype(x, char4);
25+
}
26+
27+
// CIR: cir.func @f6(%{{.*}}: !cir.vector<!s8i x 4> loc({{.*}})) -> !s32i
28+
// CIR: %[[x:.*]] = cir.load align(4) %{{.*}} : !cir.ptr<!cir.vector<!s8i x 4>, addrspace(offload_private)>, !cir.vector<!s8i x 4>
29+
// CIR: cir.cast bitcast %[[x]] : !cir.vector<!s8i x 4> -> !s32i
30+
// LLVM: define{{.*}} spir_func i32 @f6(<4 x i8> %[[x:.*]])
31+
// LLVM: %[[astype:.*]] = bitcast <4 x i8> %[[x]] to i32
32+
// LLVM-NOT: shufflevector
33+
// LLVM: ret i32 %[[astype]]
34+
// OG-LLVM: define{{.*}} spir_func noundef i32 @f6(<4 x i8> noundef %[[x:.*]])
35+
// OG-LLVM: %[[astype:.*]] = bitcast <4 x i8> %[[x]] to i32
36+
// OG-LLVM-NOT: shufflevector
37+
// OG-LLVM: ret i32 %[[astype]]
38+
int f6(char4 x) {
39+
return __builtin_astype(x, int);
40+
}
41+
42+
// CIR: cir.func @f4_ptr(%{{.*}}: !cir.ptr<!s32i, addrspace(offload_global)> loc({{.*}})) -> !cir.ptr<!cir.vector<!s8i x 4>, addrspace(offload_local)>
43+
// CIR: %[[x:.*]] = cir.load align(8) %{{.*}} : !cir.ptr<!cir.ptr<!s32i, addrspace(offload_global)>, addrspace(offload_private)>, !cir.ptr<!s32i, addrspace(offload_global)>
44+
// CIR: cir.cast address_space %[[x]] : !cir.ptr<!s32i, addrspace(offload_global)> -> !cir.ptr<!cir.vector<!s8i x 4>, addrspace(offload_local)>
45+
// LLVM: define spir_func ptr addrspace(3) @f4_ptr(ptr addrspace(1) readnone captures(ret: address, provenance) %[[x:.*]])
46+
// LLVM: %[[astype:.*]] = addrspacecast ptr addrspace(1) %[[x]] to ptr addrspace(3)
47+
// LLVM-NOT: shufflevector
48+
// LLVM: ret ptr addrspace(3) %[[astype]]
49+
// OG-LLVM: define spir_func ptr addrspace(3) @f4_ptr(ptr addrspace(1) noundef readnone captures(ret: address, provenance) %[[x:.*]])
50+
// OG-LLVM: %[[astype:.*]] = addrspacecast ptr addrspace(1) %[[x]] to ptr addrspace(3)
51+
// OG-LLVM-NOT: shufflevector
52+
// OG-LLVM: ret ptr addrspace(3) %[[astype]]
53+
__local char4* f4_ptr(__global int* x) {
54+
return __builtin_astype(x, __local char4*);
55+
}

clang/test/CIR/IR/invalid.cir

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -300,15 +300,6 @@ cir.func @cast24(%p : !u32i) {
300300

301301
// -----
302302

303-
!u32i = !cir.int<u, 32>
304-
!u64i = !cir.int<u, 64>
305-
cir.func @cast25(%p : !cir.ptr<!u32i, addrspace(target<1>)>) {
306-
%0 = cir.cast address_space %p : !cir.ptr<!u32i, addrspace(target<1>)> -> !cir.ptr<!u64i, addrspace(target<2>)> // expected-error {{requires two types differ in addrspace only}}
307-
cir.return
308-
}
309-
310-
// -----
311-
312303
!u64i = !cir.int<u, 64>
313304
cir.func @cast26(%p : !cir.ptr<!u64i, addrspace(target<1>)>) {
314305
%0 = cir.cast address_space %p : !cir.ptr<!u64i, addrspace(target<1>)> -> !u64i // expected-error {{requires !cir.ptr type for source and result}}

0 commit comments

Comments
 (0)