Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions gemma/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ namespace gcpp {
namespace HWY_NAMESPACE {

static constexpr size_t kNFx8HTileSize = 8;

static constexpr float kNegInf = -std::numeric_limits<float>::max() / 64.0f;
// Transposes q into q_t.
// Both are 4D tensors stuffed into a 2-D MatPtrT.
// q has shape [batch, qbatch][head, qkv_dim].
Expand Down Expand Up @@ -467,7 +467,7 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
const DF4 df4;
using VF4 = hn::Vec<DF4>;
static_assert(kNumQueries >= 1 && kNumQueries <= 4);
VF4 new_max = hn::Set(df4, -std::numeric_limits<float>::max() / 2.0f);
VF4 new_max = hn::Set(df4, kNegInf);
VF max_0, max_1, max_2, max_3 = hn::Zero(df);
max_0 = hn::Max(x_0_p0, x_0_p1);
if constexpr (kNumQueries >= 2) {
Expand All @@ -490,38 +490,36 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
VF4 one_over_cap = hn::Set(df4, one_over_att_cap);
new_max = hn::Mul(cap, hn::Tanh(df4, hn::Mul(new_max, one_over_cap)));
}
VF4 old_max_vf = hn::Set(df4, -std::numeric_limits<float>::max() / 2.0f);
VF4 old_max_vf = hn::Set(df4, kNegInf);
old_max_vf = hn::LoadU(df4, old_max);
new_max = hn::Max(new_max, old_max_vf);
auto changed_max = hn::Gt(new_max, hn::Set(df4, kNegInf));
// TODO figure out what was wrong with broadcasts and change to that.
HWY_ALIGN float tmp_max[4];
hn::Store(new_max, df4, tmp_max);
hn::StoreU(new_max, df4, old_max);
if constexpr (kNumQueries >= 1) {
const VF new_max_0 = hn::Set(df, tmp_max[0]);
const VF new_max_0 = hn::Set(df, old_max[0]);
x_0_p0 = hn::Exp(df, hn::Sub(x_0_p0, new_max_0));
x_0_p1 = hn::Exp(df, hn::Sub(x_0_p1, new_max_0));
}
if constexpr (kNumQueries >= 2) {
const VF new_max_0 = hn::Set(df, tmp_max[1]);
const VF new_max_0 = hn::Set(df, old_max[1]);
x_1_p0 = hn::Exp(df, hn::Sub(x_1_p0, new_max_0));
x_1_p1 = hn::Exp(df, hn::Sub(x_1_p1, new_max_0));
}
if constexpr (kNumQueries >= 3) {
const VF new_max_0 = hn::Set(df, tmp_max[2]);
const VF new_max_0 = hn::Set(df, old_max[2]);
x_2_p0 = hn::Exp(df, hn::Sub(x_2_p0, new_max_0));
x_2_p1 = hn::Exp(df, hn::Sub(x_2_p1, new_max_0));
}
if constexpr (kNumQueries >= 4) {
const VF new_max_0 = hn::Set(df, tmp_max[3]);
const VF new_max_0 = hn::Set(df, old_max[3]);
x_3_p0 = hn::Exp(df, hn::Sub(x_3_p0, new_max_0));
x_3_p1 = hn::Exp(df, hn::Sub(x_3_p1, new_max_0));
}
VF4 old_d_vf = hn::Set(df4, 0.0f);
old_d_vf = hn::LoadU(df4, old_d);
VF4 scale = hn::Mul(old_d_vf, hn::Exp(df4, hn::Sub(old_max_vf, new_max)));

hn::StoreU(new_max, df4, old_max);

VF4 x_sum = hn::Zero(df4);
if constexpr (kNumQueries == 1) {
x_sum = hn::Set(df4, hn::ReduceSum(df, x_0_p0) + hn::ReduceSum(df, x_0_p1));
Expand All @@ -539,12 +537,12 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
const VF4 zero4 = hn::Zero(df4);
const VF4 one_over_d =
hn::MaskedDivOr(zero4, non_zero_mask, hn::Set(df4, 1.0f), old_d_vf);
float tmp_one_over_d[4];
HWY_ALIGN float tmp_one_over_d[4];
hn::Store(one_over_d, df4, tmp_one_over_d);
hn::Store(old_d_vf, df4, old_d);
hn::BlendedStore(old_d_vf, changed_max, df4, old_d);
scale = hn::Mul(scale, one_over_d);
hn::Store(scale, df4, scales);
if (hn::ExtractLane(old_d_vf, 0) > 0.0f) {
hn::BlendedStore(scale, changed_max, df4, scales);
if (hn::ExtractLane(old_d_vf, 0) > 0.0f && scales[0] != 1.0f) {
const VF one_over_d_0 = hn::Set(df, tmp_one_over_d[0]);
x_0_p0 = hn::Mul(x_0_p0, one_over_d_0);
x_0_p1 = hn::Mul(x_0_p1, one_over_d_0);
Expand All @@ -553,7 +551,7 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
x_0_p1 = zero;
}
if constexpr (kNumQueries >= 2) {
if (hn::ExtractLane(old_d_vf, 1) > 0.0f) {
if (hn::ExtractLane(old_d_vf, 1) > 0.0f && scales[1] != 1.0f) {
const VF one_over_d_1 = hn::Set(df, tmp_one_over_d[1]);
x_1_p0 = hn::Mul(x_1_p0, one_over_d_1);
x_1_p1 = hn::Mul(x_1_p1, one_over_d_1);
Expand All @@ -563,7 +561,7 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
}
}
if constexpr (kNumQueries >= 3) {
if (hn::ExtractLane(old_d_vf, 2) > 0.0f) {
if (hn::ExtractLane(old_d_vf, 2) > 0.0f && scales[2] != 1.0f) {
const VF one_over_d_2 = hn::Set(df, tmp_one_over_d[2]);
x_2_p0 = hn::Mul(x_2_p0, one_over_d_2);
x_2_p1 = hn::Mul(x_2_p1, one_over_d_2);
Expand All @@ -573,7 +571,7 @@ static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
}
}
if constexpr (kNumQueries >= 4) {
if (hn::ExtractLane(old_d_vf, 3) > 0.0f) {
if (hn::ExtractLane(old_d_vf, 3) > 0.0f && scales[3] != 1.0f) {
const VF one_over_d_3 = hn::Set(df, tmp_one_over_d[3]);
x_3_p0 = hn::Mul(x_3_p0, one_over_d_3);
x_3_p1 = hn::Mul(x_3_p1, one_over_d_3);
Expand Down
Loading