Skip to content

Commit 518da5a

Browse files
authored
[CIR] Backport ComplexImagOp on scalar type (#2022)
Backporting ComplexImagOp on scalar type from the upstream
1 parent 58ac26e commit 518da5a

File tree

6 files changed

+54
-12
lines changed

6 files changed

+54
-12
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,9 +337,10 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
337337
}
338338

339339
mlir::Value createComplexImag(mlir::Location loc, mlir::Value operand) {
340-
auto operandTy = mlir::cast<cir::ComplexType>(operand.getType());
341-
return cir::ComplexImagOp::create(*this, loc, operandTy.getElementType(),
342-
operand);
340+
auto resultType = operand.getType();
341+
if (auto complexResultType = mlir::dyn_cast<cir::ComplexType>(resultType))
342+
resultType = complexResultType.getElementType();
343+
return cir::ComplexImagOp::create(*this, loc, resultType, operand);
343344
}
344345

345346
mlir::Value createComplexBinOp(mlir::Location loc, mlir::Value lhs,

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,18 +1539,20 @@ def CIR_ComplexRealOp : CIR_Op<"complex.real", [Pure]> {
15391539
def CIR_ComplexImagOp : CIR_Op<"complex.imag", [Pure]> {
15401540
let summary = "Extract the imaginary part of a complex value";
15411541
let description = [{
1542-
`cir.complex.imag` operation takes an operand of `!cir.complex` type and
1543-
yields the imaginary part of it.
1542+
`cir.complex.imag` operation takes an operand of `!cir.complex`, `!cir.int`
1543+
or `!cir.float`. If the operand is `!cir.complex`, the imag part of it will
1544+
be returned, otherwise a zero value will be returned.
15441545

15451546
Example:
15461547

15471548
```mlir
1548-
%1 = cir.complex.imag %0 : !cir.complex<!cir.float> -> !cir.float
1549+
%imag = cir.complex.imag %complex : !cir.complex<!cir.float> -> !cir.float
1550+
%imag = cir.complex.imag %scalar : !cir.float -> !cir.float
15491551
```
15501552
}];
15511553

15521554
let results = (outs CIR_AnyIntOrFloatType:$result);
1553-
let arguments = (ins CIR_ComplexType:$operand);
1555+
let arguments = (ins CIR_AnyComplexOrIntOrFloatType:$operand);
15541556

15551557
let assemblyFormat = [{
15561558
$operand `:` qualified(type($operand)) `->` qualified(type($result))

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2306,8 +2306,8 @@ mlir::Value ScalarExprEmitter::VisitImag(const UnaryOperator *E) {
23062306
// TODO(cir): handle scalar promotion.
23072307

23082308
Expr *Op = E->getSubExpr();
2309+
mlir::Location Loc = CGF.getLoc(E->getExprLoc());
23092310
if (Op->getType()->isAnyComplexType()) {
2310-
mlir::Location Loc = CGF.getLoc(E->getExprLoc());
23112311

23122312
// If it's an l-value, load through the appropriate subobject l-value.
23132313
// Note that we have to ask E because Op might be an l-value that
@@ -2321,7 +2321,7 @@ mlir::Value ScalarExprEmitter::VisitImag(const UnaryOperator *E) {
23212321
return Builder.createComplexImag(Loc, CGF.emitComplexExpr(Op));
23222322
}
23232323

2324-
return Visit(Op);
2324+
return Builder.createComplexImag(Loc, Visit(Op));
23252325
}
23262326

23272327
// Conversion from bool, integral, or floating-point to integral or

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -971,14 +971,22 @@ OpFoldResult cir::ComplexRealOp::fold(FoldAdaptor adaptor) {
971971
}
972972

973973
LogicalResult cir::ComplexImagOp::verify() {
974-
if (getType() != getOperand().getType().getElementType()) {
974+
mlir::Type operandTy = getOperand().getType();
975+
if (auto complexOperandTy = mlir::dyn_cast<cir::ComplexType>(operandTy))
976+
operandTy = complexOperandTy.getElementType();
977+
978+
if (getType() != operandTy) {
975979
emitOpError() << ": result type does not match operand type";
976980
return failure();
977981
}
982+
978983
return success();
979984
}
980985

981986
OpFoldResult cir::ComplexImagOp::fold(FoldAdaptor adaptor) {
987+
if (!mlir::isa<cir::ComplexType>(getOperand().getType()))
988+
return nullptr;
989+
982990
if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
983991
return complexCreateOp.getOperand(1);
984992

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2597,8 +2597,19 @@ mlir::LogicalResult CIRToLLVMComplexImagOpLowering::matchAndRewrite(
25972597
cir::ComplexImagOp op, OpAdaptor adaptor,
25982598
mlir::ConversionPatternRewriter &rewriter) const {
25992599
auto resultLLVMTy = getTypeConverter()->convertType(op.getType());
2600-
rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>(
2601-
op, resultLLVMTy, adaptor.getOperand(), llvm::ArrayRef<std::int64_t>{1});
2600+
mlir::Value operand = adaptor.getOperand();
2601+
mlir::Location loc = op.getLoc();
2602+
2603+
if (mlir::isa<cir::ComplexType>(op.getOperand().getType())) {
2604+
operand = mlir::LLVM::ExtractValueOp::create(
2605+
rewriter, loc, resultLLVMTy, operand, llvm::ArrayRef<std::int64_t>{1});
2606+
} else {
2607+
mlir::TypedAttr zeroAttr = rewriter.getZeroAttr(resultLLVMTy);
2608+
operand =
2609+
mlir::LLVM::ConstantOp::create(rewriter, loc, resultLLVMTy, zeroAttr);
2610+
}
2611+
2612+
rewriter.replaceOp(op, operand);
26022613
return mlir::success();
26032614
}
26042615

clang/test/CIR/CodeGen/complex.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,23 @@ void real_on_scalar_glvalue() {
390390
// OGCG: %[[B_ADDR:.*]] = alloca float, align 4
391391
// OGCG: %[[TMP_A:.*]] = load float, ptr %[[A_ADDR]], align 4
392392
// OGCG: store float %[[TMP_A]], ptr %[[B_ADDR]], align 4
393+
394+
void imag_on_scalar_glvalue() {
395+
float a;
396+
float b = __imag__ a;
397+
}
398+
399+
// CIR: %[[A_ADDR:.*]] = cir.alloca !cir.float, !cir.ptr<!cir.float>, ["a"]
400+
// CIR: %[[B_ADDR:.*]] = cir.alloca !cir.float, !cir.ptr<!cir.float>, ["b", init]
401+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[A_ADDR]] : !cir.ptr<!cir.float>, !cir.float
402+
// CIR: %[[A_IMAG:.*]] = cir.complex.imag %[[TMP_A]] : !cir.float -> !cir.float
403+
// CIR: cir.store{{.*}} %[[A_IMAG]], %[[B_ADDR]] : !cir.float, !cir.ptr<!cir.float>
404+
405+
// LLVM: %[[A_ADDR:.*]] = alloca float, i64 1, align 4
406+
// LLVM: %[[B_ADDR:.*]] = alloca float, i64 1, align 4
407+
// LLVM: %[[TMP_A:.*]] = load float, ptr %[[A_ADDR]], align 4
408+
// LLVM: store float 0.000000e+00, ptr %[[B_ADDR]], align 4
409+
410+
// OGCG: %[[A_ADDR:.*]] = alloca float, align 4
411+
// OGCG: %[[B_ADDR:.*]] = alloca float, align 4
412+
// OGCG: store float 0.000000e+00, ptr %[[B_ADDR]], align 4

0 commit comments

Comments
 (0)