Skip to content

Commit d6d4215

Browse files
authored
[CIR] Backport ComplexRealOp on scalar type (#2014)
Backporting ComplexRealOp on scalar type from the upstream
1 parent 0485f44 commit d6d4215

File tree

7 files changed

+55
-15
lines changed

7 files changed

+55
-15
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,10 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
330330
}
331331

332332
mlir::Value createComplexReal(mlir::Location loc, mlir::Value operand) {
333-
auto operandTy = mlir::cast<cir::ComplexType>(operand.getType());
334-
return cir::ComplexRealOp::create(*this, loc, operandTy.getElementType(),
335-
operand);
333+
auto resultType = operand.getType();
334+
if (auto complexResultType = mlir::dyn_cast<cir::ComplexType>(resultType))
335+
resultType = complexResultType.getElementType();
336+
return cir::ComplexRealOp::create(*this, loc, resultType, operand);
336337
}
337338

338339
mlir::Value createComplexImag(mlir::Location loc, mlir::Value operand) {

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1525,7 +1525,7 @@ def CIR_ComplexRealOp : CIR_Op<"complex.real", [Pure]> {
15251525
}];
15261526

15271527
let results = (outs CIR_AnyIntOrFloatType:$result);
1528-
let arguments = (ins CIR_ComplexType:$operand);
1528+
let arguments = (ins CIR_AnyComplexOrIntOrFloatType:$operand);
15291529

15301530
let assemblyFormat = [{
15311531
$operand `:` qualified(type($operand)) `->` qualified(type($result))

clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,12 @@ def CIR_AnyIntOrFloatType : AnyTypeOf<[CIR_AnyFloatType, CIR_AnyIntType],
169169

170170
def CIR_AnyComplexType : CIR_TypeBase<"::cir::ComplexType", "complex type">;
171171

172+
def CIR_AnyComplexOrIntOrFloatType : AnyTypeOf<[
173+
CIR_AnyComplexType, CIR_AnyFloatType, CIR_AnyIntType
174+
], "complex, integer or floating point type"> {
175+
let cppFunctionName = "isComplexOrIntegerOrFloatingPointType";
176+
}
177+
172178
//===----------------------------------------------------------------------===//
173179
// Record Type predicates
174180
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2105,8 +2105,8 @@ mlir::Value ScalarExprEmitter::VisitReal(const UnaryOperator *E) {
21052105
// TODO(cir): handle scalar promotion.
21062106

21072107
Expr *Op = E->getSubExpr();
2108+
mlir::Location Loc = CGF.getLoc(E->getExprLoc());
21082109
if (Op->getType()->isAnyComplexType()) {
2109-
mlir::Location Loc = CGF.getLoc(E->getExprLoc());
21102110

21112111
// If it's an l-value, load through the appropriate subobject l-value.
21122112
// Note that we have to ask E because Op might be an l-value that
@@ -2120,7 +2120,7 @@ mlir::Value ScalarExprEmitter::VisitReal(const UnaryOperator *E) {
21202120
return Builder.createComplexReal(Loc, CGF.emitComplexExpr(Op));
21212121
}
21222122

2123-
return Visit(Op);
2123+
return Builder.createComplexReal(Loc, Visit(Op));
21242124
}
21252125

21262126
mlir::Value ScalarExprEmitter::VisitImag(const UnaryOperator *E) {

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,8 @@ void cir::ConditionOp::getSuccessorRegions(
360360
regions.emplace_back(&await.getSuspend(), await.getSuspend().getArguments());
361361
}
362362

363-
MutableOperandRange cir::ConditionOp::getMutableSuccessorOperands(
364-
RegionSuccessor /*successor*/) {
363+
MutableOperandRange
364+
cir::ConditionOp::getMutableSuccessorOperands(RegionSuccessor /*successor*/) {
365365
// No values are yielded to the successor region.
366366
return MutableOperandRange(getOperation(), 0, 0);
367367
}
@@ -944,14 +944,23 @@ OpFoldResult cir::ComplexCreateOp::fold(FoldAdaptor adaptor) {
944944
//===----------------------------------------------------------------------===//
945945

946946
LogicalResult cir::ComplexRealOp::verify() {
947-
if (getType() != getOperand().getType().getElementType()) {
947+
mlir::Type operandTy = getOperand().getType();
948+
if (auto complexOperandTy = mlir::dyn_cast<cir::ComplexType>(operandTy)) {
949+
operandTy = complexOperandTy.getElementType();
950+
}
951+
952+
if (getType() != operandTy) {
948953
emitOpError() << ": result type does not match operand type";
949954
return failure();
950955
}
956+
951957
return success();
952958
}
953959

954960
OpFoldResult cir::ComplexRealOp::fold(FoldAdaptor adaptor) {
961+
if (!mlir::isa<cir::ComplexType>(getOperand().getType()))
962+
return nullptr;
963+
955964
if (auto complexCreateOp = getOperand().getDefiningOp<cir::ComplexCreateOp>())
956965
return complexCreateOp.getOperand(0);
957966

@@ -1525,8 +1534,7 @@ void cir::ScopeOp::getSuccessorRegions(
15251534
mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
15261535
// The only region always branch back to the parent operation.
15271536
if (!point.isParent()) {
1528-
regions.push_back(
1529-
RegionSuccessor(getOperation(), this->getODSResults(0)));
1537+
regions.push_back(RegionSuccessor(getOperation(), this->getODSResults(0)));
15301538
return;
15311539
}
15321540

@@ -1787,8 +1795,8 @@ void cir::TernaryOp::build(
17871795
// YieldOp
17881796
//===----------------------------------------------------------------------===//
17891797

1790-
MutableOperandRange cir::YieldOp::getMutableSuccessorOperands(
1791-
RegionSuccessor successor) {
1798+
MutableOperandRange
1799+
cir::YieldOp::getMutableSuccessorOperands(RegionSuccessor successor) {
17921800
Operation *op = getOperation();
17931801
if (auto loop = dyn_cast<LoopOpInterface>(op->getParentOp())) {
17941802
if (op->getParentRegion() == &loop.getCond())

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2583,8 +2583,13 @@ mlir::LogicalResult CIRToLLVMComplexRealOpLowering::matchAndRewrite(
25832583
cir::ComplexRealOp op, OpAdaptor adaptor,
25842584
mlir::ConversionPatternRewriter &rewriter) const {
25852585
auto resultLLVMTy = getTypeConverter()->convertType(op.getType());
2586-
rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>(
2587-
op, resultLLVMTy, adaptor.getOperand(), llvm::ArrayRef<std::int64_t>{0});
2586+
mlir::Value operand = adaptor.getOperand();
2587+
if (mlir::isa<cir::ComplexType>(op.getOperand().getType())) {
2588+
operand = mlir::LLVM::ExtractValueOp::create(
2589+
rewriter, op.getLoc(), resultLLVMTy, operand,
2590+
llvm::ArrayRef<std::int64_t>{0});
2591+
}
2592+
rewriter.replaceOp(op, operand);
25882593
return mlir::success();
25892594
}
25902595

clang/test/CIR/CodeGen/complex.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,23 @@ void calling_function_with_default_arg() {
370370
// OGCG: %[[TMP_DEFAULT_ARG:.*]] = load <2 x float>, ptr %[[DEFAULT_ARG_ADDR]], align 4
371371
// OGCG: call void @_Z33function_with_complex_default_argCf(<2 x float> {{.*}} %[[TMP_DEFAULT_ARG]])
372372

373+
void real_on_scalar_glvalue() {
374+
float a;
375+
float b = __real__ a;
376+
}
377+
378+
// CIR: %[[A_ADDR:.*]] = cir.alloca !cir.float, !cir.ptr<!cir.float>, ["a"]
379+
// CIR: %[[B_ADDR:.*]] = cir.alloca !cir.float, !cir.ptr<!cir.float>, ["b", init]
380+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[A_ADDR]] : !cir.ptr<!cir.float>, !cir.float
381+
// CIR: %[[A_REAL:.*]] = cir.complex.real %[[TMP_A]] : !cir.float -> !cir.float
382+
// CIR: cir.store{{.*}} %[[A_REAL]], %[[B_ADDR]] : !cir.float, !cir.ptr<!cir.float>
383+
384+
// LLVM: %[[A_ADDR:.*]] = alloca float, i64 1, align 4
385+
// LLVM: %[[B_ADDR:.*]] = alloca float, i64 1, align 4
386+
// LLVM: %[[TMP_A:.*]] = load float, ptr %[[A_ADDR]], align 4
387+
// LLVM: store float %[[TMP_A]], ptr %[[B_ADDR]], align 4
388+
389+
// OGCG: %[[A_ADDR:.*]] = alloca float, align 4
390+
// OGCG: %[[B_ADDR:.*]] = alloca float, align 4
391+
// OGCG: %[[TMP_A:.*]] = load float, ptr %[[A_ADDR]], align 4
392+
// OGCG: store float %[[TMP_A]], ptr %[[B_ADDR]], align 4

0 commit comments

Comments
 (0)