Skip to content

Commit 32ff710

Browse files
authored
[AArch64] Lower v8bf16 FMUL to BFMLAL top/bottom with +sve (#169655)
Assuming the predicate is hoisted, this should have a slightly better throughput: https://godbolt.org/z/jb7aP7Efc Note: SVE must be used to convert back to bf16 as the bfmlalb/t instructions operate on even/odd lanes, but the neon bfcvtn/2 process the top/bottom halves of vectors.
1 parent 5e3ffd6 commit 32ff710

File tree

3 files changed

+73
-37
lines changed

3 files changed

+73
-37
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1834,6 +1834,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
18341834
else
18351835
setOperationPromotedToType(ISD::FMUL, VT, PromotedVT);
18361836
}
1837+
1838+
if (Subtarget->hasBF16() && Subtarget->isNeonAvailable())
1839+
setOperationAction(ISD::FMUL, MVT::v8bf16, Custom);
18371840
}
18381841

18391842
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
@@ -7742,7 +7745,8 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
77427745
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
77437746

77447747
assert(Subtarget->hasBF16() && "Expected +bf16 for custom FMUL lowering");
7745-
assert((VT == MVT::nxv4bf16 || VT == MVT::nxv8bf16) && "Unexpected FMUL VT");
7748+
assert((VT == MVT::nxv4bf16 || VT == MVT::nxv8bf16 || VT == MVT::v8bf16) &&
7749+
"Unexpected FMUL VT");
77467750

77477751
auto MakeGetIntrinsic = [&](Intrinsic::ID IID) {
77487752
return [&, IID](EVT VT, auto... Ops) {
@@ -7751,37 +7755,56 @@ SDValue AArch64TargetLowering::LowerFMUL(SDValue Op, SelectionDAG &DAG) const {
77517755
};
77527756
};
77537757

7754-
auto ReinterpretCast = [&](SDValue Value, EVT VT) {
7755-
if (VT == Value.getValueType())
7758+
auto Reinterpret = [&](SDValue Value, EVT VT) {
7759+
EVT SrcVT = Value.getValueType();
7760+
if (VT == SrcVT)
77567761
return Value;
7762+
if (SrcVT.isFixedLengthVector())
7763+
return convertToScalableVector(DAG, VT, Value);
7764+
if (VT.isFixedLengthVector())
7765+
return convertFromScalableVector(DAG, VT, Value);
77577766
return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Value);
77587767
};
77597768

7760-
// Create helpers for building intrinsic calls.
7761-
auto BFMLALB = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalb);
7762-
auto BFMLALT = MakeGetIntrinsic(Intrinsic::aarch64_sve_bfmlalt);
7769+
bool UseSVEBFMLAL = VT.isScalableVector();
77637770
auto FCVT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvt_bf16f32_v2);
77647771
auto FCVTNT = MakeGetIntrinsic(Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2);
77657772

7766-
// All intrinsics expect to operate on full bf16 vector types.
7767-
SDValue LHS = ReinterpretCast(Op.getOperand(0), MVT::nxv8bf16);
7768-
SDValue RHS = ReinterpretCast(Op.getOperand(1), MVT::nxv8bf16);
7773+
// Note: The NEON BFMLAL[BT] reads even/odd lanes like the SVE variant.
7774+
// This does not match BFCVTN[2], so we use SVE to convert back to bf16.
7775+
auto BFMLALB =
7776+
MakeGetIntrinsic(UseSVEBFMLAL ? Intrinsic::aarch64_sve_bfmlalb
7777+
: Intrinsic::aarch64_neon_bfmlalb);
7778+
auto BFMLALT =
7779+
MakeGetIntrinsic(UseSVEBFMLAL ? Intrinsic::aarch64_sve_bfmlalt
7780+
: Intrinsic::aarch64_neon_bfmlalt);
77697781

7770-
SDValue Zero =
7771-
DAG.getNeutralElement(ISD::FADD, DL, MVT::nxv4f32, Op->getFlags());
7772-
SDValue Pg = DAG.getConstant(1, DL, MVT::nxv4i1);
7782+
EVT AccVT = UseSVEBFMLAL ? MVT::nxv4f32 : MVT::v4f32;
7783+
SDValue Zero = DAG.getNeutralElement(ISD::FADD, DL, AccVT, Op->getFlags());
7784+
SDValue Pg = getPredicateForVector(DAG, DL, AccVT);
77737785

7774-
// Lower bf16 FMUL as a pair (VT == nxv8bf16) of BFMLAL top/bottom
7786+
// Lower bf16 FMUL as a pair (VT == [nx]v8bf16) of BFMLAL top/bottom
77757787
// instructions. These result in two f32 vectors, which can be converted back
77767788
// to bf16 with FCVT and FCVTNT.
7777-
SDValue BottomF32 = BFMLALB(MVT::nxv4f32, Zero, LHS, RHS);
7789+
SDValue LHS = Op.getOperand(0);
7790+
SDValue RHS = Op.getOperand(1);
7791+
7792+
// All SVE intrinsics expect to operate on full bf16 vector types.
7793+
if (UseSVEBFMLAL) {
7794+
LHS = Reinterpret(LHS, MVT::nxv8bf16);
7795+
RHS = Reinterpret(RHS, MVT::nxv8bf16);
7796+
}
7797+
7798+
SDValue BottomF32 = Reinterpret(BFMLALB(AccVT, Zero, LHS, RHS), MVT::nxv4f32);
77787799
SDValue BottomBF16 =
77797800
FCVT(MVT::nxv8bf16, DAG.getPOISON(MVT::nxv8bf16), Pg, BottomF32);
77807801
// Note: nxv4bf16 only uses even lanes.
77817802
if (VT == MVT::nxv4bf16)
7782-
return ReinterpretCast(BottomBF16, VT);
7783-
SDValue TopF32 = BFMLALT(MVT::nxv4f32, Zero, LHS, RHS);
7784-
return FCVTNT(VT, BottomBF16, Pg, TopF32);
7803+
return Reinterpret(BottomBF16, VT);
7804+
7805+
SDValue TopF32 = Reinterpret(BFMLALT(AccVT, Zero, LHS, RHS), MVT::nxv4f32);
7806+
SDValue TopBF16 = FCVTNT(MVT::nxv8bf16, BottomBF16, Pg, TopF32);
7807+
return Reinterpret(TopBF16, VT);
77857808
}
77867809

77877810
SDValue AArch64TargetLowering::LowerFMA(SDValue Op, SelectionDAG &DAG) const {

llvm/test/CodeGen/AArch64/bf16-v8-instructions.ll

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
22
; RUN: llc < %s -mtriple=aarch64 -mattr=-bf16 | FileCheck %s --check-prefixes=CHECK,CHECK-CVT
3-
; RUN: llc < %s -mtriple=aarch64 -mattr=+bf16 | FileCheck %s --check-prefixes=CHECK,CHECK-BF16
3+
; RUN: llc < %s -mtriple=aarch64 -mattr=+bf16 | FileCheck %s --check-prefixes=CHECK,CHECK-BF16,CHECK-NOSVE-BF16
4+
; RUN: llc < %s -mtriple=aarch64 -mattr=+bf16,+sve | FileCheck %s --check-prefixes=CHECK,CHECK-BF16,CHECK-SVE-BF16
45

56
define <8 x bfloat> @add_h(<8 x bfloat> %a, <8 x bfloat> %b) {
67
; CHECK-CVT-LABEL: add_h:
@@ -117,17 +118,29 @@ define <8 x bfloat> @mul_h(<8 x bfloat> %a, <8 x bfloat> %b) {
117118
; CHECK-CVT-NEXT: uzp2 v0.8h, v0.8h, v2.8h
118119
; CHECK-CVT-NEXT: ret
119120
;
120-
; CHECK-BF16-LABEL: mul_h:
121-
; CHECK-BF16: // %bb.0: // %entry
122-
; CHECK-BF16-NEXT: shll v2.4s, v1.4h, #16
123-
; CHECK-BF16-NEXT: shll v3.4s, v0.4h, #16
124-
; CHECK-BF16-NEXT: shll2 v1.4s, v1.8h, #16
125-
; CHECK-BF16-NEXT: shll2 v0.4s, v0.8h, #16
126-
; CHECK-BF16-NEXT: fmul v2.4s, v3.4s, v2.4s
127-
; CHECK-BF16-NEXT: fmul v1.4s, v0.4s, v1.4s
128-
; CHECK-BF16-NEXT: bfcvtn v0.4h, v2.4s
129-
; CHECK-BF16-NEXT: bfcvtn2 v0.8h, v1.4s
130-
; CHECK-BF16-NEXT: ret
121+
; CHECK-NOSVE-BF16-LABEL: mul_h:
122+
; CHECK-NOSVE-BF16: // %bb.0: // %entry
123+
; CHECK-NOSVE-BF16-NEXT: shll v2.4s, v1.4h, #16
124+
; CHECK-NOSVE-BF16-NEXT: shll v3.4s, v0.4h, #16
125+
; CHECK-NOSVE-BF16-NEXT: shll2 v1.4s, v1.8h, #16
126+
; CHECK-NOSVE-BF16-NEXT: shll2 v0.4s, v0.8h, #16
127+
; CHECK-NOSVE-BF16-NEXT: fmul v2.4s, v3.4s, v2.4s
128+
; CHECK-NOSVE-BF16-NEXT: fmul v1.4s, v0.4s, v1.4s
129+
; CHECK-NOSVE-BF16-NEXT: bfcvtn v0.4h, v2.4s
130+
; CHECK-NOSVE-BF16-NEXT: bfcvtn2 v0.8h, v1.4s
131+
; CHECK-NOSVE-BF16-NEXT: ret
132+
;
133+
; CHECK-SVE-BF16-LABEL: mul_h:
134+
; CHECK-SVE-BF16: // %bb.0: // %entry
135+
; CHECK-SVE-BF16-NEXT: movi v2.4s, #128, lsl #24
136+
; CHECK-SVE-BF16-NEXT: movi v3.4s, #128, lsl #24
137+
; CHECK-SVE-BF16-NEXT: ptrue p0.s, vl4
138+
; CHECK-SVE-BF16-NEXT: bfmlalb v2.4s, v0.8h, v1.8h
139+
; CHECK-SVE-BF16-NEXT: bfmlalt v3.4s, v0.8h, v1.8h
140+
; CHECK-SVE-BF16-NEXT: bfcvt z2.h, p0/m, z2.s
141+
; CHECK-SVE-BF16-NEXT: bfcvtnt z2.h, p0/m, z3.s
142+
; CHECK-SVE-BF16-NEXT: mov v0.16b, v2.16b
143+
; CHECK-SVE-BF16-NEXT: ret
131144
entry:
132145
%0 = fmul <8 x bfloat> %a, %b
133146
ret <8 x bfloat> %0

llvm/test/CodeGen/AArch64/fixed-length-bf16-arith.ll

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -761,14 +761,14 @@ define <4 x bfloat> @fmul_v4bf16(<4 x bfloat> %a, <4 x bfloat> %b) {
761761
define <8 x bfloat> @fmul_v8bf16(<8 x bfloat> %a, <8 x bfloat> %b) {
762762
; NOB16B16-LABEL: fmul_v8bf16:
763763
; NOB16B16: // %bb.0:
764-
; NOB16B16-NEXT: shll v2.4s, v1.4h, #16
765-
; NOB16B16-NEXT: shll v3.4s, v0.4h, #16
766-
; NOB16B16-NEXT: shll2 v1.4s, v1.8h, #16
767-
; NOB16B16-NEXT: shll2 v0.4s, v0.8h, #16
768-
; NOB16B16-NEXT: fmul v2.4s, v3.4s, v2.4s
769-
; NOB16B16-NEXT: fmul v1.4s, v0.4s, v1.4s
770-
; NOB16B16-NEXT: bfcvtn v0.4h, v2.4s
771-
; NOB16B16-NEXT: bfcvtn2 v0.8h, v1.4s
764+
; NOB16B16-NEXT: movi v2.4s, #128, lsl #24
765+
; NOB16B16-NEXT: movi v3.4s, #128, lsl #24
766+
; NOB16B16-NEXT: ptrue p0.s, vl4
767+
; NOB16B16-NEXT: bfmlalb v2.4s, v0.8h, v1.8h
768+
; NOB16B16-NEXT: bfmlalt v3.4s, v0.8h, v1.8h
769+
; NOB16B16-NEXT: bfcvt z2.h, p0/m, z2.s
770+
; NOB16B16-NEXT: bfcvtnt z2.h, p0/m, z3.s
771+
; NOB16B16-NEXT: mov v0.16b, v2.16b
772772
; NOB16B16-NEXT: ret
773773
;
774774
; B16B16-LABEL: fmul_v8bf16:

0 commit comments

Comments
 (0)