Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 107 additions & 14 deletions gemma/tensor_stats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,98 @@ void MaybeWriteRow(const std::unique_ptr<File>& file, const MatPtr& type_erased,
bytes_per_row * row_idx);
}

constexpr size_t kGroupSize = 128; // subchannel

void QuantizeGroup(const float* HWY_RESTRICT in,
TensorStatsAccumulator& my_stats) {
namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<float> df;
using VF = hn::Vec<decltype(df)>;
using MF = hn::Mask<decltype(df)>;
const hn::ScalableTag<double> dd;
using VD = hn::Vec<decltype(dd)>;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
HWY_ALIGN float enc[kGroupSize];
HWY_ALIGN float dec[kGroupSize];
HWY_ALIGN float all_snr[kGroupSize];
HWY_DASSERT(kGroupSize % NF == 0); // No remainder handling required.

const VF k0 = hn::Zero(df);
const VF k1 = hn::Set(df, 1.0f);

// Scan for min/max for quantization.
VF vmin = hn::Set(df, hwy::HighestValue<float>());
VF vmax = hn::Set(df, hwy::LowestValue<float>());
for (size_t i = 0; i < kGroupSize; i += NF) {
const VF v = hn::Load(df, in + i);
vmin = hn::Min(vmin, v);
vmax = hn::Max(vmax, v);
}
const float min = hn::ReduceMin(df, vmin);
const float max = hn::ReduceMax(df, vmax);
// Avoid division by zero during quantization.
if (max == min) return;

// Distortion stats.
VF vsum_err = hn::Zero(df);
VD sum_log_snr0 = hn::Zero(dd);
VD sum_log_snr1 = hn::Zero(dd);
size_t num_snr = 0;

// Unclipped asymmetric quantization (for activations).
const VF scale = hn::Set(df, 255.0f / (max - min));
const VF inv_scale = hn::Div(k1, scale);
const VF zeropoint = hn::Sub(hn::Round(hn::Mul(hn::Set(df, -min), scale)),
hn::Set(df, 128.0f));
const VF dq_sub = hn::Mul(zeropoint, inv_scale); // For MulSub.
for (size_t i = 0; i < kGroupSize; i += NF) {
const VF v = hn::Load(df, in + i);
const VF q = hn::Round(hn::MulAdd(v, scale, zeropoint));
hn::Store(q, df, enc + i);
// Dequantize.
const VF d = hn::MulSub(q, inv_scale, dq_sub);
hn::Store(d, df, dec + i);

const VF err = hn::AbsDiff(v, d); // L1
vsum_err = hn::Add(vsum_err, err);

// For preventing division by zero. However, we still want to
// clamp snr because it could be very high (>1E3 when most
// elements are lossless).
const MF has_err = hn::Gt(err, k0);
const VF rel = hn::MaskedDivOr(k0, has_err, hn::Abs(v), err);
// SNR = 1 + abs/L1, with cap on the latter term.
const VF snr = hn::Add(k1, hn::Min(rel, hn::Set(df, 300.f)));
hn::Store(snr, df, all_snr + i);
// Where `has_err` is false, `snr` elements are 1 and log(1) is zero, hence
// they do not affect sum_log. However, very high errors also result in
// snr=1, which drags down the average because `sum_log` is increased.
num_snr += hn::CountTrue(df, has_err);

const VD log_snr0 = hn::Log(dd, hn::PromoteLowerTo(dd, snr));
const VD log_snr1 = hn::Log(dd, hn::PromoteUpperTo(dd, snr));
sum_log_snr0 = hn::Add(sum_log_snr0, log_snr0);
sum_log_snr1 = hn::Add(sum_log_snr1, log_snr1);
}

const float sum_err = hn::ReduceSum(df, vsum_err);
const float avg_L1 = sum_err / static_cast<float>(kGroupSize);
const double sum_log = hn::ReduceSum(dd, hn::Add(sum_log_snr0, sum_log_snr1));
// SNR >= 1, hence log >= 0.
HWY_ASSERT(sum_log >= 0.0);
if (num_snr == 0) { // Avoid division by zero.
// It can happen that dequantization is lossless, i.e. SNR is
// infinite; skip such groups.
HWY_ASSERT(sum_err == 0.0f);
return;
}
// Signal to noise ratio (Shannon's channel capacity, NOT the
// L2-based and logarithmic PSNR)
const float snr = std::exp(sum_log / static_cast<double>(num_snr));

my_stats.NotifyGroup(avg_L1, snr);
}

// First dispatch to the type, then parallel over rows, then vectorized
// decompress and Notify for each value.
void UpdateStatsT(TensorStats& stats, size_t layer_idx,
Expand Down Expand Up @@ -138,37 +230,38 @@ void UpdateStatsT(TensorStats& stats, size_t layer_idx,
my_stats.NotifyCond(ConditionNumber(row, cols));

namespace hn = hwy::HWY_NAMESPACE;
hn::ScalableTag<float> df;
const hn::ScalableTag<float> df;
using VF = hn::Vec<decltype(df)>;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
HWY_ALIGN float buf[2 * hn::MaxLanes(df)];
HWY_ALIGN float buf[kGroupSize];
size_t buf_filled = 0;

size_t packed_ofs = 0;
if (cols >= 2 * NF) {
for (; packed_ofs <= cols - 2 * NF; packed_ofs += 2 * NF) {
VF v0, v1;
Decompress2(df, packed, packed_ofs, v0, v1);
hn::Store(v0, df, buf);
hn::Store(v1, df, buf + NF);
const VF min_mag = hn::Min(hn::Abs(v0), hn::Abs(v1));
const VF max_mag = hn::Max(hn::Abs(v0), hn::Abs(v1));
const float min = hn::ReduceMin(df, min_mag);
if (min != 0.0f) { // Avoid division by zero.
my_stats.NotifyGroup(min, hn::ReduceMax(df, max_mag));
}
hn::Store(v0, df, buf + buf_filled);
hn::Store(v1, df, buf + buf_filled + NF);
buf_filled += 2 * NF;
if (buf_filled == kGroupSize) {
QuantizeGroup(buf, my_stats);

for (size_t i = 0; i < kGroupSize; ++i) {
my_stats.Notify(buf[i], row_idx, packed_ofs + i);
}
my_stats.NotifyCorr(Correlation(buf, kGroupSize));

for (size_t i = 0; i < 2 * NF; ++i) {
my_stats.Notify(buf[i], row_idx, packed_ofs + i);
buf_filled = 0;
}
my_stats.NotifyCorr(Correlation(buf, 2 * NF));
}
}

// Zero to two vectors remaining.
for (; packed_ofs < cols; packed_ofs += NF) {
const size_t remaining = HWY_MIN(NF, cols - packed_ofs);
DecompressAndZeroPad(df, packed, packed_ofs, buf, remaining);
// Skip NotifyGroup for this partial group.
// Skip QuantizeGroup because it requires full groups.
for (size_t i = 0; i < remaining; ++i) {
my_stats.Notify(buf[i], row_idx, packed_ofs + i);
}
Expand Down
93 changes: 50 additions & 43 deletions gemma/tensor_stats.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,12 @@ struct TensorStatsAcrossLayers {
fprintf(stderr, "cor.avg %s\n", s_corr_avg.ToString(skip).c_str());
}
fprintf(stderr, "cor.max %s\n", s_corr_max.ToString(skip).c_str());
fprintf(stderr, "rng_avg %s\n", s_range_avg.ToString(skip).c_str());
fprintf(stderr, "err_avg %s\n", s_grp_err_avg.ToString(skip).c_str());
fprintf(stderr, "err_std %s\n", s_grp_err_std.ToString(skip).c_str());
fprintf(stderr, "err_max %s\n", s_grp_err_max.ToString(skip).c_str());
fprintf(stderr, "snr_1 %s\n", s_grp_snr1.ToString(skip).c_str());
fprintf(stderr, "snr_avg %s\n", s_grp_snr_avg.ToString(skip).c_str());
fprintf(stderr, "snr_std %s\n", s_grp_snr_std.ToString(skip).c_str());
fprintf(stderr, "exp.min %s\n", s_exp_min.ToString(skip).c_str());
fprintf(stderr, "exp.max %s\n", s_exp_max.ToString(skip).c_str());
fprintf(stderr, "exp.mod %s\n", s_exp_mode.ToString(skip).c_str());
Expand Down Expand Up @@ -112,7 +117,12 @@ struct TensorStatsAcrossLayers {
hwy::Stats s_corr_avg;
hwy::Stats s_corr_max;

hwy::Stats s_range_avg;
hwy::Stats s_grp_err_avg;
hwy::Stats s_grp_err_std;
hwy::Stats s_grp_err_max;
hwy::Stats s_grp_snr1;
hwy::Stats s_grp_snr_avg;
hwy::Stats s_grp_snr_std;

hwy::Stats s_exp_min;
hwy::Stats s_exp_max;
Expand Down Expand Up @@ -151,13 +161,11 @@ class TensorStatsAccumulator {
void DoNotPrint() { skip_.fetch_or(1); }
bool ShouldPrint() const { return skip_.load() == 0; }

// Vector code computed the min/max of a group (= two vectors); this is
// faster than doing it in `Notify`.
void NotifyGroup(float min, float max) {
s_group_min_.Notify(min);
s_group_max_.Notify(max);
// Caller ensures min != 0.
s_group_range_.Notify(max / min);
// Computed by vector code, much faster than doing it in `Notify`.
void NotifyGroup(float avg_L1, float snr) {
s_group_err_.Notify(avg_L1);
s_group_snr_.Notify(snr);
num_snr1_ += (snr == 1.0f);
}

void NotifyCorr(float corr) { s_corr_.Notify(corr); }
Expand All @@ -173,9 +181,9 @@ class TensorStatsAccumulator {
s_val_.Assimilate(other.s_val_);
s_mag_.Assimilate(other.s_mag_);
s_corr_.Assimilate(other.s_corr_);
s_group_min_.Assimilate(other.s_group_min_);
s_group_max_.Assimilate(other.s_group_max_);
s_group_range_.Assimilate(other.s_group_range_);
s_group_err_.Assimilate(other.s_group_err_);
s_group_snr_.Assimilate(other.s_group_snr_);
num_snr1_ += other.num_snr1_;
}

// Called on the per-layer representative after reducing across threads.
Expand All @@ -197,7 +205,12 @@ class TensorStatsAccumulator {
s.s_corr_avg.Notify(s_corr_.Mean());
s.s_corr_max.Notify(s_corr_.Max());

s.s_range_avg.Notify(s_group_range_.Mean());
s.s_grp_err_avg.Notify(s_group_err_.Mean());
s.s_grp_err_std.Notify(s_group_err_.StandardDeviation());
s.s_grp_err_max.Notify(s_group_err_.Max());
s.s_grp_snr1.Notify(static_cast<float>(num_snr1_));
s.s_grp_snr_avg.Notify(s_group_snr_.Mean());
s.s_grp_snr_std.Notify(s_group_snr_.StandardDeviation());

const uint32_t subnormals = b_exp256_.Bin(0);
// Prevent subnormals from hiding the min exponent.
Expand All @@ -222,13 +235,12 @@ class TensorStatsAccumulator {
void PrintAll() {
fprintf(stderr, "Frob %.2E\n", std::sqrt(sum_sq_));
const int skip = hwy::Stats::kNoGeomean;
fprintf(stderr, "cnd %s\n", s_cond_.ToString(skip).c_str());
fprintf(stderr, "val %s\n", s_val_.ToString(skip).c_str());
fprintf(stderr, "mag %s\n", s_mag_.ToString(skip).c_str());
fprintf(stderr, "corr %s\n", s_corr_.ToString(skip).c_str());
fprintf(stderr, "group_min %s\n", s_group_min_.ToString(skip).c_str());
fprintf(stderr, "group_max %s\n", s_group_max_.ToString(skip).c_str());
fprintf(stderr, "group_range %s\n", s_group_range_.ToString(skip).c_str());
fprintf(stderr, "cnd %s\n", s_cond_.ToString(skip).c_str());
fprintf(stderr, "val %s\n", s_val_.ToString(skip).c_str());
fprintf(stderr, "mag %s\n", s_mag_.ToString(skip).c_str());
fprintf(stderr, "crr %s\n", s_corr_.ToString(skip).c_str());
fprintf(stderr, "err %s\n", s_group_err_.ToString(skip).c_str());
fprintf(stderr, "snr %s\n", s_group_snr_.ToString(skip).c_str());
b_exp256_.Print("exp");
PrintBinRanges(b_big_row_, "big row");
PrintBinRanges(b_big_col_, "big col");
Expand All @@ -244,30 +256,25 @@ class TensorStatsAccumulator {
}
if (total == 0) return;

// If all bins are at least 10% of a uniform distribution, print the range
// to vastly reduce the log size.
fprintf(stderr, "%s total %zu: \n", name, total);
// Group together runs to reduce the log size.
const size_t min = HWY_MAX(1, total / (N * 10));
size_t last = 0;
for (; last < N; ++last) {
if (b.Bin(last) < min) break;
}
if (last >= N / 2) {
// Also require all subsequent bins to be zero, otherwise we should
// print the outlier bins.
bool all_zero = true;
for (size_t i = last + 1; i < N; ++i) {
if (b.Bin(last) != 0) {
all_zero = false;
break;
}
for (size_t i = 0; i < N; ++i) {
if (b.Bin(i) == 0) continue;
if (b.Bin(i) < min) {
fprintf(stderr, " %3zu: %zu\n", i, b.Bin(i));
continue;
}
if (all_zero) {
fprintf(stderr, "%s: uniform up to %zu\n", name, last);
return;
const size_t first = i;
while (i + 1 < N && b.Bin(i + 1) >= min) {
i++;
}
if (first == i) {
fprintf(stderr, " %3zu: %zu\n", i, b.Bin(i));
} else {
fprintf(stderr, " [%3zu, %3zu]\n", first, i);
}
}

b.Print(name, /*skip_zero=*/true);
}

double sum_sq_ = 0.0; // for Frobenius norm
Expand All @@ -278,9 +285,9 @@ class TensorStatsAccumulator {
hwy::Stats s_mag_;
hwy::Stats s_cond_; // condition number
hwy::Stats s_corr_; // lag-1 autocorrelation
hwy::Stats s_group_min_;
hwy::Stats s_group_max_;
hwy::Stats s_group_range_;
hwy::Stats s_group_err_;
hwy::Stats s_group_snr_;
size_t num_snr1_ = 0;
std::atomic<int> skip_{0};
};

Expand Down
Loading