Skip to content

Commit 5f17769

Browse files
authored
[CIR][HIP] Introduce createPointerBitCastOrAddrSpaceCast and apply on device function calls (#2008)
Create a `createPointerBitCastOrAddrSpaceCast` and apply it now on function calls. OG has such functionality and it applies it sporadically in their code. I expect this to be used more in the future.
1 parent 630a11b commit 5f17769

File tree

3 files changed

+266
-6
lines changed

3 files changed

+266
-6
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,30 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
385385
/*tbaa=*/cir::TBAAAttr{});
386386
}
387387

388+
mlir::Value createPointerBitCastOrAddrSpaceCast(mlir::Value src,
389+
mlir::Type dstTy) {
390+
auto srcTy = src.getType();
391+
392+
if (auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcTy))
393+
if (auto dstPtrTy = mlir::dyn_cast<cir::PointerType>(dstTy))
394+
if (srcPtrTy.getAddrSpace() != dstPtrTy.getAddrSpace())
395+
return createAddrSpaceCast(src, dstTy);
396+
397+
return createBitcast(src, dstTy);
398+
}
399+
400+
mlir::Value createPointerBitCastOrAddrSpaceCast(mlir::Location loc,
401+
mlir::Value src,
402+
mlir::Type dstTy) {
403+
auto srcTy = src.getType();
404+
if (auto srcPtrTy = mlir::dyn_cast<cir::PointerType>(srcTy))
405+
if (auto dstPtrTy = mlir::dyn_cast<cir::PointerType>(dstTy))
406+
if (srcPtrTy.getAddrSpace() != dstPtrTy.getAddrSpace())
407+
return createAddrSpaceCast(loc, src, dstTy);
408+
409+
return createBitcast(loc, src, dstTy);
410+
}
411+
388412
mlir::Value createAlloca(mlir::Location loc, cir::PointerType addrType,
389413
mlir::Type type, llvm::StringRef name,
390414
mlir::IntegerAttr alignment,
@@ -544,9 +568,10 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
544568
return createCast(cir::CastKind::ptr_to_bool, v, getBoolTy());
545569
}
546570

547-
// TODO(cir): the following function was introduced to keep in sync with LLVM
548-
// codegen. CIR does not have "zext" operations. It should eventually be
549-
// renamed or removed. For now, we just add whatever cast is required here.
571+
// TODO(cir): the following function was introduced to keep in sync with
572+
// LLVM codegen. CIR does not have "zext" operations. It should eventually
573+
// be renamed or removed. For now, we just add whatever cast is required
574+
// here.
550575
mlir::Value createZExtOrBitCast(mlir::Location loc, mlir::Value src,
551576
mlir::Type newTy) {
552577
auto srcTy = src.getType();
@@ -610,8 +635,9 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
610635
// Alignment and size helpers
611636
//
612637

613-
// Note that mlir::IntegerType is used instead of cir::IntType here because we
614-
// don't need sign information for these to be useful, so keep it simple.
638+
// Note that mlir::IntegerType is used instead of cir::IntType here
639+
// because we don't need sign information for these to be useful, so keep
640+
// it simple.
615641

616642
// For 0 alignment, any overload of `getAlignmentAttr` returns an empty
617643
// attribute.

clang/lib/CIR/CodeGen/CIRGenCall.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,8 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &CallInfo,
491491
// can happen due to trivial type mismatches.
492492
if (ArgNo < CIRFuncTy.getNumInputs() &&
493493
V.getType() != CIRFuncTy.getInput(ArgNo))
494-
V = builder.createBitcast(V, CIRFuncTy.getInput(ArgNo));
494+
V = builder.createPointerBitCastOrAddrSpaceCast(
495+
V, CIRFuncTy.getInput(ArgNo));
495496

496497
CIRCallArgs.push_back(V);
497498
} else {

clang/test/CIR/CodeGen/HIP/addrspace-lowering.cpp

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
#include "cuda.h"
22

3+
// RUN: %clang_cc1 -triple=amdgcn-amd-amdhsa -x hip -fclangir \
4+
// RUN: -fcuda-is-device -fhip-new-launch-api -emit-cir \
5+
// RUN: -I%S/../Inputs/ %s -o %t.cir
6+
// RUN: FileCheck --check-prefix=CIR-DEVICE --input-file=%t.cir %s
7+
38
// RUN: %clang_cc1 -triple=amdgcn-amd-amdhsa -x hip -fclangir \
49
// RUN: -fcuda-is-device -fhip-new-launch-api \
510
// RUN: -I%S/../Inputs/ -emit-llvm %s -o %t.ll
@@ -11,6 +16,11 @@
1116
// RUN: FileCheck --check-prefix=OGCG-DEVICE --input-file=%t.ll %s
1217

1318

19+
20+
// ------------------------------------------------------------
21+
// CHECK POINTER ARGUMENT LOWERING (bitcast or addrspacecast)
22+
// ------------------------------------------------------------
23+
1424
__shared__ int a;
1525
// LLVM-DEVICE: @a = addrspace(3) global i32 undef, align 4
1626
// OGCG-DEVICE: @a = addrspace(3) global i32 undef, align 4
@@ -23,3 +33,226 @@ __constant__ int c;
2333
// LLVM-DEVICE: @c = addrspace(4) externally_initialized constant i32 0, align 4
2434
// OGCG-DEVICE: @c = addrspace(4) externally_initialized constant i32 0, align 4
2535

36+
// Forward decls in various address spaces.
37+
extern "C" __device__ void bar(const char *p);
38+
extern "C" __device__ void takes_global(float *p);
39+
extern "C" __device__ void takes_shared(int *p);
40+
extern "C" __device__ void takes_void(void *p);
41+
extern "C" __device__ void nullfun(int *p);
42+
extern "C" __device__ void takeS(struct S s);
43+
extern "C" __device__ void call_fp(void (*f)(int));
44+
45+
__constant__ int CC[12];
46+
__device__ float GArr[8];
47+
__device__ void fp_target(int);
48+
49+
// A struct that contains a pointer
50+
struct S { int *p; };
51+
52+
// ------------------------------------------------------------
53+
// 1. local → generic: expected bitcast or AS0 match
54+
// ------------------------------------------------------------
55+
__device__ void test_local() {
56+
int x = 42;
57+
bar((const char*)&x);
58+
}
59+
// CIR-DEVICE-LABEL: @_Z10test_localv
60+
// CIR-DEVICE: cir.alloca
61+
// CIR-DEVICE: cir.store
62+
// CIR-DEVICE: cir.cast bitcast {{.*}} -> !cir.ptr<!s8i>
63+
// CIR-DEVICE: cir.call @bar
64+
// CIR-DEVICE: cir.return
65+
66+
// LLVM-DEVICE-LABEL: @_Z10test_localv
67+
// LLVM-DEVICE: alloca i32
68+
// LLVM-DEVICE: addrspacecast ptr addrspace(5) {{.*}} to ptr
69+
// LLVM-DEVICE: store i32 42
70+
// LLVM-DEVICE: call void @bar(ptr {{.*}})
71+
// LLVM-DEVICE: ret void
72+
73+
// OGCG-DEVICE-LABEL: @_Z10test_localv
74+
// OGCG-DEVICE: alloca i32, align 4, addrspace(5)
75+
// OGCG-DEVICE: addrspacecast ptr addrspace(5) {{.*}} to ptr
76+
// OGCG-DEVICE: store i32 42
77+
// OGCG-DEVICE: call void @bar(ptr noundef {{.*}})
78+
// OGCG-DEVICE: ret void
79+
80+
// ------------------------------------------------------------
81+
// 2. global AS → generic param
82+
// ------------------------------------------------------------
83+
__device__ void test_global() {
84+
takes_global(GArr);
85+
}
86+
// CIR-DEVICE-LABEL: @_Z11test_globalv
87+
// CIR-DEVICE: cir.get_global @GArr
88+
// CIR-DEVICE: cir.cast array_to_ptrdecay
89+
// CIR-DEVICE: cir.cast address_space
90+
// CIR-DEVICE: cir.call @takes_global
91+
// CIR-DEVICE: cir.return
92+
93+
// LLVM-DEVICE-LABEL: @_Z11test_globalv
94+
// LLVM-DEVICE: call void @takes_global(ptr addrspacecast (ptr addrspace(1) @GArr to ptr))
95+
// LLVM-DEVICE: ret void
96+
97+
// OGCG-DEVICE-LABEL: @_Z11test_globalv
98+
// OGCG-DEVICE: call void @takes_global(ptr noundef addrspacecast (ptr addrspace(1) @GArr to ptr))
99+
// OGCG-DEVICE: ret void
100+
101+
// ------------------------------------------------------------
102+
// 3. shared AS(3) → generic param (requires addrspacecast)
103+
// ------------------------------------------------------------
104+
__device__ void test_shared() {
105+
__shared__ int s[2];
106+
takes_shared(s);
107+
}
108+
// CIR-DEVICE-LABEL: @_Z11test_sharedv
109+
// CIR-DEVICE: cir.get_global @_ZZ11test_sharedvE1s
110+
// CIR-DEVICE: cir.cast array_to_ptrdecay
111+
// CIR-DEVICE: cir.cast address_space
112+
// CIR-DEVICE: cir.call @takes_shared
113+
// CIR-DEVICE: cir.return
114+
115+
// LLVM-DEVICE-LABEL: @_Z11test_sharedv
116+
// LLVM-DEVICE: call void @takes_shared(ptr addrspacecast (ptr addrspace(3) @_ZZ11test_sharedvE1s to ptr))
117+
// LLVM-DEVICE: ret void
118+
119+
// OGCG-DEVICE-LABEL: @_Z11test_sharedv
120+
// OGCG-DEVICE: call void @takes_shared(ptr noundef addrspacecast (ptr addrspace(3) @_ZZ11test_sharedvE1s to ptr))
121+
// OGCG-DEVICE: ret void
122+
123+
// ------------------------------------------------------------
124+
// 4. mismatched pointee types but same AS: bitcast only
125+
// ------------------------------------------------------------
126+
__device__ void test_void_bitcast() {
127+
int x = 7;
128+
takes_void((void*)&x);
129+
}
130+
// CIR-DEVICE-LABEL: @_Z17test_void_bitcastv
131+
// CIR-DEVICE: cir.alloca
132+
// CIR-DEVICE: cir.store
133+
// CIR-DEVICE: cir.cast bitcast {{.*}} -> !cir.ptr<!void>
134+
// CIR-DEVICE: cir.call @takes_void
135+
// CIR-DEVICE: cir.return
136+
137+
// LLVM-DEVICE-LABEL: @_Z17test_void_bitcastv
138+
// LLVM-DEVICE: alloca i32
139+
// LLVM-DEVICE: addrspacecast ptr addrspace(5) {{.*}} to ptr
140+
// LLVM-DEVICE: store i32 7
141+
// LLVM-DEVICE: call void @takes_void(ptr {{.*}})
142+
// LLVM-DEVICE: ret void
143+
144+
// OGCG-DEVICE-LABEL: @_Z17test_void_bitcastv
145+
// OGCG-DEVICE: alloca i32, align 4, addrspace(5)
146+
// OGCG-DEVICE: addrspacecast ptr addrspace(5) {{.*}} to ptr
147+
// OGCG-DEVICE: store i32 7
148+
// OGCG-DEVICE: call void @takes_void(ptr noundef {{.*}})
149+
// OGCG-DEVICE: ret void
150+
151+
// ------------------------------------------------------------
152+
// 5. nullptr: ensure correct null pointer cast is emitted
153+
// ------------------------------------------------------------
154+
__device__ void test_null() {
155+
nullfun(nullptr);
156+
}
157+
// CIR-DEVICE-LABEL: @_Z9test_nullv
158+
// CIR-DEVICE: cir.const #cir.ptr<null>
159+
// CIR-DEVICE: cir.call @nullfun
160+
// CIR-DEVICE: cir.return
161+
162+
// LLVM-DEVICE-LABEL: @_Z9test_nullv
163+
// LLVM-DEVICE: call void @nullfun(ptr null)
164+
// LLVM-DEVICE: ret void
165+
166+
// OGCG-DEVICE-LABEL: @_Z9test_nullv
167+
// OGCG-DEVICE: call void @nullfun(ptr noundef null)
168+
// OGCG-DEVICE: ret void
169+
170+
// ------------------------------------------------------------
171+
// 6. Struct containing a pointer
172+
// ------------------------------------------------------------
173+
__device__ void test_struct() {
174+
int x = 5;
175+
S s{&x};
176+
takeS(s);
177+
}
178+
// CIR-DEVICE-LABEL: @_Z11test_structv
179+
// CIR-DEVICE: cir.alloca !s32i
180+
// CIR-DEVICE: cir.alloca !rec_S
181+
// CIR-DEVICE: cir.get_member {{.*}} "p"
182+
// CIR-DEVICE: cir.store {{.*}} : !cir.ptr<!s32i>, !cir.ptr<!cir.ptr<!s32i>>
183+
// CIR-DEVICE: cir.copy
184+
// CIR-DEVICE: cir.call @takeS
185+
// CIR-DEVICE: cir.return
186+
187+
// LLVM-DEVICE-LABEL: @_Z11test_structv
188+
// LLVM-DEVICE: alloca i32
189+
// LLVM-DEVICE: alloca %struct.S
190+
// LLVM-DEVICE: getelementptr %struct.S
191+
// LLVM-DEVICE: store ptr {{.*}}, ptr {{.*}}
192+
// LLVM-DEVICE: call void @llvm.memcpy
193+
// LLVM-DEVICE: load %struct.S
194+
// LLVM-DEVICE: call void @takeS(%struct.S {{.*}})
195+
// LLVM-DEVICE: ret void
196+
197+
// OGCG-DEVICE-LABEL: @_Z11test_structv
198+
// OGCG-DEVICE: alloca i32, align 4, addrspace(5)
199+
// OGCG-DEVICE: alloca %struct.S, align 8, addrspace(5)
200+
// OGCG-DEVICE: alloca %struct.S, align 8, addrspace(5)
201+
// OGCG-DEVICE: addrspacecast ptr addrspace(5) {{.*}} to ptr
202+
// OGCG-DEVICE: store i32 5
203+
// OGCG-DEVICE: getelementptr inbounds nuw %struct.S
204+
// OGCG-DEVICE: store ptr {{.*}}, ptr {{.*}}
205+
// OGCG-DEVICE: call void @llvm.memcpy.p0.p0.i64
206+
// OGCG-DEVICE: load ptr
207+
// OGCG-DEVICE: call void @takeS(ptr {{.*}})
208+
// OGCG-DEVICE: ret void
209+
210+
// ------------------------------------------------------------
211+
// 7. Function pointer argument
212+
// ------------------------------------------------------------
213+
__device__ void test_fp() {
214+
call_fp(fp_target);
215+
}
216+
// CIR-DEVICE-LABEL: @_Z7test_fpv
217+
// CIR-DEVICE: cir.get_global @_Z9fp_targeti
218+
// CIR-DEVICE: cir.call @call_fp
219+
// CIR-DEVICE: cir.return
220+
221+
// LLVM-DEVICE-LABEL: @_Z7test_fpv
222+
// LLVM-DEVICE: call void @call_fp(ptr @_Z9fp_targeti)
223+
// LLVM-DEVICE: ret void
224+
225+
// OGCG-DEVICE-LABEL: @_Z7test_fpv
226+
// OGCG-DEVICE: call void @call_fp(ptr noundef @_Z9fp_targeti)
227+
// OGCG-DEVICE: ret void
228+
229+
// ------------------------------------------------------------
230+
// 8. Original test from previous patch: string literal → char*
231+
// ------------------------------------------------------------
232+
__device__ void foo() {
233+
char cchar[] = "const char.\n";
234+
bar(cchar);
235+
}
236+
// CIR-DEVICE-LABEL: @_Z3foov
237+
// CIR-DEVICE: cir.alloca
238+
// CIR-DEVICE: cir.get_global @__const._Z3foov.cchar
239+
// CIR-DEVICE: cir.copy
240+
// CIR-DEVICE: cir.cast array_to_ptrdecay
241+
// CIR-DEVICE: cir.call @bar
242+
// CIR-DEVICE: cir.return
243+
244+
// LLVM-DEVICE-LABEL: @_Z3foov
245+
// LLVM-DEVICE: alloca [13 x i8]
246+
// LLVM-DEVICE: addrspacecast ptr addrspace(5) {{.*}} to ptr
247+
// LLVM-DEVICE: call void @llvm.memcpy.p0.p0.i32
248+
// LLVM-DEVICE: getelementptr i8
249+
// LLVM-DEVICE: call void @bar(ptr {{.*}})
250+
// LLVM-DEVICE: ret void
251+
252+
// OGCG-DEVICE-LABEL: @_Z3foov
253+
// OGCG-DEVICE: alloca [13 x i8], align 1, addrspace(5)
254+
// OGCG-DEVICE: addrspacecast ptr addrspace(5) {{.*}} to ptr
255+
// OGCG-DEVICE: call void @llvm.memcpy.p0.p4.i64
256+
// OGCG-DEVICE: getelementptr inbounds [13 x i8]
257+
// OGCG-DEVICE: call void @bar(ptr noundef {{.*}})
258+
// OGCG-DEVICE: ret void

0 commit comments

Comments
 (0)