Skip to content

Commit 630a11b

Browse files
koparasyandykaylor
andauthored
[CIR][HIP] Proper Handling of address spaces in ptr-diff (#1994)
Co-authored-by: Andy Kaylor <akaylor@nvidia.com>
1 parent db506fe commit 630a11b

File tree

3 files changed

+208
-10
lines changed

3 files changed

+208
-10
lines changed

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,17 +1524,59 @@ mlir::Value ScalarExprEmitter::emitSub(const BinOpInfo &Ops) {
15241524
if (!mlir::isa<cir::PointerType>(Ops.RHS.getType()))
15251525
return emitPointerArithmetic(CGF, Ops, /*isSubtraction=*/true);
15261526

1527-
// Otherwise, this is a pointer subtraction
1528-
15291527
// Do the raw subtraction part.
1530-
//
1531-
// TODO(cir): note for LLVM lowering out of this; when expanding this into
1532-
// LLVM we shall take VLA's, division by element size, etc.
1533-
//
1534-
// See more in `EmitSub` in CGExprScalar.cpp.
1528+
mlir::Value lhs = Ops.LHS;
1529+
mlir::Value rhs = Ops.RHS;
1530+
1531+
cir::PointerType lhsPtrTy = mlir::dyn_cast<cir::PointerType>(lhs.getType());
1532+
cir::PointerType rhsPtrTy = mlir::dyn_cast<cir::PointerType>(rhs.getType());
1533+
1534+
if (lhsPtrTy && rhsPtrTy) {
1535+
cir::AddressSpace lhsAS = lhsPtrTy.getAddrSpace();
1536+
cir::AddressSpace rhsAS = rhsPtrTy.getAddrSpace();
1537+
1538+
if (lhsAS != rhsAS) {
1539+
// Different address spaces → use addrspacecast
1540+
rhs = Builder.createAddrSpaceCast(rhs, lhsPtrTy);
1541+
} else if (lhsPtrTy != rhsPtrTy) {
1542+
// Same addrspace but different pointee/type → bitcast is fine
1543+
rhs = Builder.createBitcast(rhs, lhsPtrTy);
1544+
}
1545+
}
1546+
15351547
assert(!cir::MissingFeatures::llvmLoweringPtrDiffConsidersPointee());
1536-
return cir::PtrDiffOp::create(Builder, CGF.getLoc(Ops.Loc), CGF.PtrDiffTy,
1537-
Ops.LHS, Ops.RHS);
1548+
mlir::Value diff = cir::PtrDiffOp::create(Builder, CGF.getLoc(Ops.Loc),
1549+
CGF.PtrDiffTy, lhs, rhs);
1550+
1551+
const BinaryOperator *expr = cast<BinaryOperator>(Ops.E);
1552+
QualType elementType = expr->getLHS()->getType()->getPointeeType();
1553+
1554+
mlir::Location loc = CGF.getLoc(Ops.Loc);
1555+
mlir::Value divisor;
1556+
1557+
// Check if this is a VLA pointee type.
1558+
if (const auto *vla = CGF.getContext().getAsVariableArrayType(elementType)) {
1559+
auto vlaSize = CGF.getVLASize(vla);
1560+
elementType = vlaSize.Type;
1561+
divisor = vlaSize.NumElts;
1562+
1563+
CharUnits eltSize = CGF.getContext().getTypeSizeInChars(elementType);
1564+
if (!eltSize.isOne()) {
1565+
cir::IntType cirIntTy = llvm::cast<cir::IntType>(CGF.PtrDiffTy);
1566+
cir::IntAttr eltSizeAttr =
1567+
cir::IntAttr::get(cirIntTy, eltSize.getQuantity());
1568+
1569+
if (divisor.getType() != CGF.PtrDiffTy)
1570+
divisor = Builder.createIntCast(divisor, CGF.PtrDiffTy);
1571+
}
1572+
} else {
1573+
// cir::ptrdiff correctly computes the ABI difference of 2 pointers. We
1574+
// do not need to compute anything else here. We just return it.
1575+
return diff;
1576+
}
1577+
1578+
return cir::BinOp::create(Builder, loc, CGF.PtrDiffTy, cir::BinOpKind::Div,
1579+
diff, divisor);
15381580
}
15391581

15401582
// Helper to apply OpenCL-style shift masking. It handles both vector and scalar
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#include "cuda.h"
2+
3+
// RUN: %clang_cc1 -triple=amdgcn-amd-amdhsa -x hip -fclangir \
4+
// RUN: -fcuda-is-device -fhip-new-launch-api \
5+
// RUN: -I%S/../Inputs/ -emit-cir %s -o %t.ll
6+
// RUN: FileCheck --check-prefix=CIR-DEVICE --input-file=%t.ll %s
7+
8+
// RUN: %clang_cc1 -triple=amdgcn-amd-amdhsa -x hip -fclangir \
9+
// RUN: -fcuda-is-device -fhip-new-launch-api \
10+
// RUN: -I%S/../Inputs/ -emit-llvm %s -o %t.ll
11+
// RUN: FileCheck --check-prefix=LLVM-DEVICE --input-file=%t.ll %s
12+
13+
// RUN: %clang_cc1 -triple=amdgcn-amd-amdhsa -x hip \
14+
// RUN: -fcuda-is-device -fhip-new-launch-api \
15+
// RUN: -I%S/../Inputs/ -emit-llvm %s -o %t.ll
16+
// RUN: FileCheck --check-prefix=OGCG-DEVICE --input-file=%t.ll %s
17+
18+
__device__ int ptr_diff() {
19+
const char c_str[] = "c-string";
20+
const char* len = c_str;
21+
return c_str - len;
22+
}
23+
24+
25+
// CIR-DEVICE: %[[#LenLocalAddr:]] = cir.alloca !cir.ptr<!s8i>, !cir.ptr<!cir.ptr<!s8i>>, ["len", init]
26+
// CIR-DEVICE: %[[#GlobalPtr:]] = cir.get_global @_ZZ8ptr_diffvE5c_str : !cir.ptr<!cir.array<!s8i x 9>, addrspace(offload_constant)>
27+
// CIR-DEVICE: %[[#CastDecay:]] = cir.cast array_to_ptrdecay %[[#GlobalPtr]] : !cir.ptr<!cir.array<!s8i x 9>, addrspace(offload_constant)>
28+
// CIR-DEVICE: %[[#LenLocalAddrCast:]] = cir.cast bitcast %[[#LenLocalAddr]] : !cir.ptr<!cir.ptr<!s8i>> -> !cir.ptr<!cir.ptr<!s8i, addrspace(offload_constant)>>
29+
// CIR-DEVICE: cir.store align(8) %[[#CastDecay]], %[[#LenLocalAddrCast]] : !cir.ptr<!s8i, addrspace(offload_constant)>, !cir.ptr<!cir.ptr<!s8i, addrspace(offload_constant)>>
30+
// CIR-DEVICE: %[[#CStr:]] = cir.cast array_to_ptrdecay %[[#GlobalPtr]] : !cir.ptr<!cir.array<!s8i x 9>, addrspace(offload_constant)> -> !cir.ptr<!s8i, addrspace(offload_constant)>
31+
// CIR-DEVICE: %[[#LoadedLenAddr:]] = cir.load align(8) %[[#LenLocalAddr]] : !cir.ptr<!cir.ptr<!s8i>>, !cir.ptr<!s8i> loc(#loc7)
32+
// CIR-DEVICE: %[[#AddrCast:]] = cir.cast address_space %[[#LoadedLenAddr]] : !cir.ptr<!s8i> -> !cir.ptr<!s8i, addrspace(offload_constant)>
33+
// CIR-DEVICE: %[[#DIFF:]] = cir.ptr_diff %[[#CStr]], %[[#AddrCast]] : !cir.ptr<!s8i, addrspace(offload_constant)>
34+
35+
// LLVM-DEVICE: define dso_local i32 @_Z8ptr_diffv()
36+
// LLVM-DEVICE: %[[#GlobalPtrAddr:]] = alloca i32, i64 1, align 4, addrspace(5)
37+
// LLVM-DEVICE: %[[#GlobalPtrCast:]] = addrspacecast ptr addrspace(5) %[[#GlobalPtrAddr]] to ptr
38+
// LLVM-DEVICE: %[[#LenLocalAddr:]] = alloca ptr, i64 1, align 8, addrspace(5)
39+
// LLVM-DEVICE: %[[#LenLocalAddrCast:]] = addrspacecast ptr addrspace(5) %[[#LenLocalAddr]] to ptr
40+
// LLVM-DEVICE: store ptr addrspace(4) @_ZZ8ptr_diffvE5c_str, ptr %[[#LenLocalAddrCast]], align 8
41+
// LLVM-DEVICE: %[[#LoadedAddr:]] = load ptr, ptr %[[#LenLocalAddrCast]], align 8
42+
// LLVM-DEVICE: %[[#CastedVal:]] = addrspacecast ptr %[[#LoadedAddr]] to ptr addrspace(4)
43+
// LLVM-DEVICE: %[[#IntVal:]] = ptrtoint ptr addrspace(4) %[[#CastedVal]] to i64
44+
// LLVM-DEVICE: %[[#SubVal:]] = sub i64 ptrtoint (ptr addrspace(4) @_ZZ8ptr_diffvE5c_str to i64), %[[#IntVal]]
45+
46+
// OGCG-DEVICE: define dso_local noundef i32 @_Z8ptr_diffv() #0
47+
// OGCG-DEVICE: %[[RETVAL:.*]] = alloca i32, align 4, addrspace(5)
48+
// OGCG-DEVICE: %[[C_STR:.*]] = alloca [9 x i8], align 1, addrspace(5)
49+
// OGCG-DEVICE: %[[LEN:.*]] = alloca ptr, align 8, addrspace(5)
50+
// OGCG-DEVICE: %[[RETVAL_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[RETVAL]] to ptr
51+
// OGCG-DEVICE: %[[C_STR_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[C_STR]] to ptr
52+
// OGCG-DEVICE: %[[LEN_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[LEN]] to ptr
53+
// OGCG-DEVICE: %[[ARRAYDECAY:.*]] = getelementptr inbounds [9 x i8], ptr %[[C_STR_ASCAST]], i64 0, i64 0
54+
// OGCG-DEVICE: store ptr %[[ARRAYDECAY]], ptr %[[LEN_ASCAST]], align 8
55+
// OGCG-DEVICE: %[[ARRAYDECAY1:.*]] = getelementptr inbounds [9 x i8], ptr %[[C_STR_ASCAST]], i64 0, i64 0
56+
// OGCG-DEVICE: %[[LOADED:.*]] = load ptr, ptr %[[LEN_ASCAST]], align 8
57+
// OGCG-DEVICE: %[[LHS:.*]] = ptrtoint ptr %[[ARRAYDECAY1]] to i64
58+
// OGCG-DEVICE: %[[RHS:.*]] = ptrtoint ptr %[[LOADED]] to i64
59+
// OGCG-DEVICE: %[[SUB:.*]] = sub i64 %[[LHS]], %[[RHS]]
60+
// OGCG-DEVICE: %[[CONV:.*]] = trunc i64 %[[SUB]] to i32

clang/test/CIR/CodeGen/vla.c

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o - | FileCheck %s
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
2+
// RUN: FileCheck --check-prefix=CHECK --input-file=%t.cir %s
3+
4+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t.ll
5+
// RUN: FileCheck --check-prefix=LLVM --input-file=%t.ll %s
6+
7+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-llvm %s -o %t.ll
8+
// RUN: FileCheck --check-prefix=OGCG --input-file=%t.ll %s
29

310
// CHECK: cir.func dso_local @f0(%arg0: !s32i
411
// CHECK: [[TMP0:%.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["len", init] {alignment = 4 : i64}
@@ -93,3 +100,92 @@ int f9(unsigned n, char (*p)[n][n+1][6]) {
93100

94101
return p2 - p;
95102
}
103+
104+
long f10(int n) {
105+
int (*p)[n];
106+
int (*q)[n];
107+
return q - p;
108+
}
109+
// CHECK-LABEL: cir.func dso_local @f10
110+
// CHECK: %[[Q_VAL:[0-9]+]] = cir.load {{.*}} %{{.*}} : !cir.ptr<!cir.ptr<!s32i>>, !cir.ptr<!s32i>
111+
// CHECK: %[[P_VAL:[0-9]+]] = cir.load {{.*}} %{{.*}} : !cir.ptr<!cir.ptr<!s32i>>, !cir.ptr<!s32i>
112+
// CHECK: %[[PTRDIFF:[0-9]+]] = cir.ptr_diff %[[Q_VAL]], %[[P_VAL]] : !cir.ptr<!s32i> -> !s64i
113+
// CHECK: %[[N_S64:[0-9]+]] = cir.cast integral %{{.*}} : !u64i -> !s64i
114+
// CHECK: %[[DIV:[0-9]+]] = cir.binop(div, %[[PTRDIFF]], %[[N_S64]]) : !s64i
115+
116+
// LLVM-LABEL: @f10(
117+
// LLVM: %[[QI:[0-9]+]] = ptrtoint ptr %{{.*}} to i64
118+
// LLVM: %[[PI:[0-9]+]] = ptrtoint ptr %{{.*}} to i64
119+
// LLVM: %[[DIFF_BYTES:[0-9]+]] = sub i64 %[[QI]], %[[PI]]
120+
// LLVM: %[[PTRDIFF_INTS:[0-9]+]] = sdiv i64 %[[DIFF_BYTES]], 4
121+
// LLVM: %[[RESULT:[0-9]+]] = sdiv i64 %[[PTRDIFF_INTS]], %{{.*}}
122+
123+
// OGCG-LABEL: @f10(
124+
// OGCG: %{{.*}} = ptrtoint ptr %{{.*}} to i64
125+
// OGCG: %{{.*}} = ptrtoint ptr %{{.*}} to i64
126+
// OGCG: %{{.*}} = sub i64 %{{.*}}, %{{.*}}
127+
// OGCG: %{{.*}} = mul nuw i64 4, %{{.*}}
128+
// OGCG: %{{.*}} = sdiv exact i64 %{{.*}}, %{{.*}}
129+
130+
long f11(int n, int m) {
131+
int (*p)[n][m];
132+
int (*q)[n][m];
133+
return q - p;
134+
}
135+
// CHECK-LABEL: cir.func dso_local @f11
136+
137+
// # allocas
138+
// CHECK: %[[N_ADDR:[0-9]+]] = cir.alloca !s32i, !cir.ptr<!s32i>
139+
// CHECK: %[[M_ADDR:[0-9]+]] = cir.alloca !s32i, !cir.ptr<!s32i>
140+
// CHECK: %[[RET:[0-9]+]] = cir.alloca !s64i, !cir.ptr<!s64i>
141+
// CHECK: %[[P:[0-9]+]] = cir.alloca !cir.ptr<!s32i>, !cir.ptr<!cir.ptr<!s32i>>
142+
// CHECK: %[[Q:[0-9]+]] = cir.alloca !cir.ptr<!s32i>, !cir.ptr<!cir.ptr<!s32i>>
143+
144+
// # store n, m
145+
// CHECK: cir.store %arg0, %[[N_ADDR]] : !s32i, !cir.ptr<!s32i>
146+
// CHECK: cir.store %arg1, %[[M_ADDR]] : !s32i, !cir.ptr<!s32i>
147+
148+
// # load n and cast to u64
149+
// CHECK: %[[N_LOAD:[0-9]+]] = cir.load {{.*}} %[[N_ADDR]] : !cir.ptr<!s32i>, !s32i
150+
// CHECK: %[[N_U64:[0-9]+]] = cir.cast integral %[[N_LOAD]] : !s32i -> !u64i
151+
152+
// # load m and cast to u64
153+
// CHECK: %[[M_LOAD:[0-9]+]] = cir.load {{.*}} %[[M_ADDR]] : !cir.ptr<!s32i>, !s32i
154+
// CHECK: %[[M_U64:[0-9]+]] = cir.cast integral %[[M_LOAD]] : !s32i -> !u64i
155+
156+
// # load q and p
157+
// CHECK: %[[Q_VAL:[0-9]+]] = cir.load {{.*}} %[[Q]] : !cir.ptr<!cir.ptr<!s32i>>, !cir.ptr<!s32i>
158+
// CHECK: %[[P_VAL:[0-9]+]] = cir.load {{.*}} %[[P]] : !cir.ptr<!cir.ptr<!s32i>>, !cir.ptr<!s32i>
159+
//
160+
// # ptrdiff → (byte_diff / 4)
161+
// CHECK: %[[PTRDIFF:[0-9]+]] = cir.ptr_diff %[[Q_VAL]], %[[P_VAL]] : !cir.ptr<!s32i> -> !s64i
162+
163+
// # compute n*m
164+
// CHECK: %[[NM_U64:[0-9]+]] = cir.binop(mul, %[[N_U64]], %[[M_U64]]) : !u64i
165+
// CHECK: %[[NM_S64:[0-9]+]] = cir.cast integral %[[NM_U64]] : !u64i -> !s64i
166+
167+
// # divide ptrdiff_ints by (n*m)
168+
// CHECK: %[[RESULT:[0-9]+]] = cir.binop(div, %[[PTRDIFF]], %[[NM_S64]]) : !s64i
169+
170+
// # store + return
171+
// CHECK: cir.store %[[RESULT]], %[[RET]] : !s64i, !cir.ptr<!s64i>
172+
// CHECK: %[[RETVAL:[0-9]+]] = cir.load %[[RET]] : !cir.ptr<!s64i>, !s64i
173+
// CHECK: cir.return %[[RETVAL]] : !s64i
174+
175+
176+
// LLVM-LABEL: @f11(
177+
// # load q and p
178+
// LLVM: %[[QI:[0-9]+]] = ptrtoint ptr %{{.*}} to i64
179+
// LLVM: %[[PI:[0-9]+]] = ptrtoint ptr %{{.*}} to i64
180+
// LLVM: %[[DIFF_BYTES:[0-9]+]] = sub i64 %[[QI]], %[[PI]]
181+
// LLVM: %[[PTRDIFF_INTS:[0-9]+]] = sdiv i64 %[[DIFF_BYTES]], 4
182+
// LLVM: %[[NM:[0-9]+]] = mul i64 %{{.*}}, %{{.*}}
183+
// LLVM: %[[RESULT:[0-9]+]] = sdiv i64 %[[PTRDIFF_INTS]], %[[NM]]
184+
185+
// OGCG-LABEL: @f11(
186+
// OGCG: %{{.*}} = ptrtoint ptr %{{.*}} to i64
187+
// OGCG: %{{.*}} = ptrtoint ptr %{{.*}} to i64
188+
// OGCG: %{{.*}} = sub i64 %{{.*}}, %{{.*}}
189+
// OGCG: %{{.*}} = mul nuw i64 %{{.*}}, %{{.*}}
190+
// OGCG: %{{.*}} = mul nuw i64 4, %{{.*}}
191+
// OGCG: %{{.*}} = sdiv exact i64 %{{.*}}, %{{.*}}

0 commit comments

Comments
 (0)