diff --git a/gemma/tensor_stats.cc b/gemma/tensor_stats.cc index 53203b6e..62dcb98c 100644 --- a/gemma/tensor_stats.cc +++ b/gemma/tensor_stats.cc @@ -107,6 +107,98 @@ void MaybeWriteRow(const std::unique_ptr& file, const MatPtr& type_erased, bytes_per_row * row_idx); } +constexpr size_t kGroupSize = 128; // subchannel + +void QuantizeGroup(const float* HWY_RESTRICT in, + TensorStatsAccumulator& my_stats) { + namespace hn = hwy::HWY_NAMESPACE; + const hn::ScalableTag df; + using VF = hn::Vec; + using MF = hn::Mask; + const hn::ScalableTag dd; + using VD = hn::Vec; + HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); + HWY_ALIGN float enc[kGroupSize]; + HWY_ALIGN float dec[kGroupSize]; + HWY_ALIGN float all_snr[kGroupSize]; + HWY_DASSERT(kGroupSize % NF == 0); // No remainder handling required. + + const VF k0 = hn::Zero(df); + const VF k1 = hn::Set(df, 1.0f); + + // Scan for min/max for quantization. + VF vmin = hn::Set(df, hwy::HighestValue()); + VF vmax = hn::Set(df, hwy::LowestValue()); + for (size_t i = 0; i < kGroupSize; i += NF) { + const VF v = hn::Load(df, in + i); + vmin = hn::Min(vmin, v); + vmax = hn::Max(vmax, v); + } + const float min = hn::ReduceMin(df, vmin); + const float max = hn::ReduceMax(df, vmax); + // Avoid division by zero during quantization. + if (max == min) return; + + // Distortion stats. + VF vsum_err = hn::Zero(df); + VD sum_log_snr0 = hn::Zero(dd); + VD sum_log_snr1 = hn::Zero(dd); + size_t num_snr = 0; + + // Unclipped asymmetric quantization (for activations). + const VF scale = hn::Set(df, 255.0f / (max - min)); + const VF inv_scale = hn::Div(k1, scale); + const VF zeropoint = hn::Sub(hn::Round(hn::Mul(hn::Set(df, -min), scale)), + hn::Set(df, 128.0f)); + const VF dq_sub = hn::Mul(zeropoint, inv_scale); // For MulSub. + for (size_t i = 0; i < kGroupSize; i += NF) { + const VF v = hn::Load(df, in + i); + const VF q = hn::Round(hn::MulAdd(v, scale, zeropoint)); + hn::Store(q, df, enc + i); + // Dequantize. + const VF d = hn::MulSub(q, inv_scale, dq_sub); + hn::Store(d, df, dec + i); + + const VF err = hn::AbsDiff(v, d); // L1 + vsum_err = hn::Add(vsum_err, err); + + // For preventing division by zero. However, we still want to + // clamp snr because it could be very high (>1E3 when most + // elements are lossless). + const MF has_err = hn::Gt(err, k0); + const VF rel = hn::MaskedDivOr(k0, has_err, hn::Abs(v), err); + // SNR = 1 + abs/L1, with cap on the latter term. + const VF snr = hn::Add(k1, hn::Min(rel, hn::Set(df, 300.f))); + hn::Store(snr, df, all_snr + i); + // Where `has_err` is false, `snr` elements are 1 and log(1) is zero, hence + // they do not affect sum_log. However, very high errors also result in + // snr=1, which drags down the average because `sum_log` is increased. + num_snr += hn::CountTrue(df, has_err); + + const VD log_snr0 = hn::Log(dd, hn::PromoteLowerTo(dd, snr)); + const VD log_snr1 = hn::Log(dd, hn::PromoteUpperTo(dd, snr)); + sum_log_snr0 = hn::Add(sum_log_snr0, log_snr0); + sum_log_snr1 = hn::Add(sum_log_snr1, log_snr1); + } + + const float sum_err = hn::ReduceSum(df, vsum_err); + const float avg_L1 = sum_err / static_cast(kGroupSize); + const double sum_log = hn::ReduceSum(dd, hn::Add(sum_log_snr0, sum_log_snr1)); + // SNR >= 1, hence log >= 0. + HWY_ASSERT(sum_log >= 0.0); + if (num_snr == 0) { // Avoid division by zero. + // It can happen that dequantization is lossless, i.e. SNR is + // infinite; skip such groups. + HWY_ASSERT(sum_err == 0.0f); + return; + } + // Signal to noise ratio (Shannon's channel capacity, NOT the + // L2-based and logarithmic PSNR) + const float snr = std::exp(sum_log / static_cast(num_snr)); + + my_stats.NotifyGroup(avg_L1, snr); +} + // First dispatch to the type, then parallel over rows, then vectorized // decompress and Notify for each value. void UpdateStatsT(TensorStats& stats, size_t layer_idx, @@ -138,29 +230,30 @@ void UpdateStatsT(TensorStats& stats, size_t layer_idx, my_stats.NotifyCond(ConditionNumber(row, cols)); namespace hn = hwy::HWY_NAMESPACE; - hn::ScalableTag df; + const hn::ScalableTag df; using VF = hn::Vec; HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); - HWY_ALIGN float buf[2 * hn::MaxLanes(df)]; + HWY_ALIGN float buf[kGroupSize]; + size_t buf_filled = 0; size_t packed_ofs = 0; if (cols >= 2 * NF) { for (; packed_ofs <= cols - 2 * NF; packed_ofs += 2 * NF) { VF v0, v1; Decompress2(df, packed, packed_ofs, v0, v1); - hn::Store(v0, df, buf); - hn::Store(v1, df, buf + NF); - const VF min_mag = hn::Min(hn::Abs(v0), hn::Abs(v1)); - const VF max_mag = hn::Max(hn::Abs(v0), hn::Abs(v1)); - const float min = hn::ReduceMin(df, min_mag); - if (min != 0.0f) { // Avoid division by zero. - my_stats.NotifyGroup(min, hn::ReduceMax(df, max_mag)); - } + hn::Store(v0, df, buf + buf_filled); + hn::Store(v1, df, buf + buf_filled + NF); + buf_filled += 2 * NF; + if (buf_filled == kGroupSize) { + QuantizeGroup(buf, my_stats); + + for (size_t i = 0; i < kGroupSize; ++i) { + my_stats.Notify(buf[i], row_idx, packed_ofs + i); + } + my_stats.NotifyCorr(Correlation(buf, kGroupSize)); - for (size_t i = 0; i < 2 * NF; ++i) { - my_stats.Notify(buf[i], row_idx, packed_ofs + i); + buf_filled = 0; } - my_stats.NotifyCorr(Correlation(buf, 2 * NF)); } } @@ -168,7 +261,7 @@ void UpdateStatsT(TensorStats& stats, size_t layer_idx, for (; packed_ofs < cols; packed_ofs += NF) { const size_t remaining = HWY_MIN(NF, cols - packed_ofs); DecompressAndZeroPad(df, packed, packed_ofs, buf, remaining); - // Skip NotifyGroup for this partial group. + // Skip QuantizeGroup because it requires full groups. for (size_t i = 0; i < remaining; ++i) { my_stats.Notify(buf[i], row_idx, packed_ofs + i); } diff --git a/gemma/tensor_stats.h b/gemma/tensor_stats.h index 6975ab5f..338f41fc 100644 --- a/gemma/tensor_stats.h +++ b/gemma/tensor_stats.h @@ -68,7 +68,12 @@ struct TensorStatsAcrossLayers { fprintf(stderr, "cor.avg %s\n", s_corr_avg.ToString(skip).c_str()); } fprintf(stderr, "cor.max %s\n", s_corr_max.ToString(skip).c_str()); - fprintf(stderr, "rng_avg %s\n", s_range_avg.ToString(skip).c_str()); + fprintf(stderr, "err_avg %s\n", s_grp_err_avg.ToString(skip).c_str()); + fprintf(stderr, "err_std %s\n", s_grp_err_std.ToString(skip).c_str()); + fprintf(stderr, "err_max %s\n", s_grp_err_max.ToString(skip).c_str()); + fprintf(stderr, "snr_1 %s\n", s_grp_snr1.ToString(skip).c_str()); + fprintf(stderr, "snr_avg %s\n", s_grp_snr_avg.ToString(skip).c_str()); + fprintf(stderr, "snr_std %s\n", s_grp_snr_std.ToString(skip).c_str()); fprintf(stderr, "exp.min %s\n", s_exp_min.ToString(skip).c_str()); fprintf(stderr, "exp.max %s\n", s_exp_max.ToString(skip).c_str()); fprintf(stderr, "exp.mod %s\n", s_exp_mode.ToString(skip).c_str()); @@ -112,7 +117,12 @@ struct TensorStatsAcrossLayers { hwy::Stats s_corr_avg; hwy::Stats s_corr_max; - hwy::Stats s_range_avg; + hwy::Stats s_grp_err_avg; + hwy::Stats s_grp_err_std; + hwy::Stats s_grp_err_max; + hwy::Stats s_grp_snr1; + hwy::Stats s_grp_snr_avg; + hwy::Stats s_grp_snr_std; hwy::Stats s_exp_min; hwy::Stats s_exp_max; @@ -151,13 +161,11 @@ class TensorStatsAccumulator { void DoNotPrint() { skip_.fetch_or(1); } bool ShouldPrint() const { return skip_.load() == 0; } - // Vector code computed the min/max of a group (= two vectors); this is - // faster than doing it in `Notify`. - void NotifyGroup(float min, float max) { - s_group_min_.Notify(min); - s_group_max_.Notify(max); - // Caller ensures min != 0. - s_group_range_.Notify(max / min); + // Computed by vector code, much faster than doing it in `Notify`. + void NotifyGroup(float avg_L1, float snr) { + s_group_err_.Notify(avg_L1); + s_group_snr_.Notify(snr); + num_snr1_ += (snr == 1.0f); } void NotifyCorr(float corr) { s_corr_.Notify(corr); } @@ -173,9 +181,9 @@ class TensorStatsAccumulator { s_val_.Assimilate(other.s_val_); s_mag_.Assimilate(other.s_mag_); s_corr_.Assimilate(other.s_corr_); - s_group_min_.Assimilate(other.s_group_min_); - s_group_max_.Assimilate(other.s_group_max_); - s_group_range_.Assimilate(other.s_group_range_); + s_group_err_.Assimilate(other.s_group_err_); + s_group_snr_.Assimilate(other.s_group_snr_); + num_snr1_ += other.num_snr1_; } // Called on the per-layer representative after reducing across threads. @@ -197,7 +205,12 @@ class TensorStatsAccumulator { s.s_corr_avg.Notify(s_corr_.Mean()); s.s_corr_max.Notify(s_corr_.Max()); - s.s_range_avg.Notify(s_group_range_.Mean()); + s.s_grp_err_avg.Notify(s_group_err_.Mean()); + s.s_grp_err_std.Notify(s_group_err_.StandardDeviation()); + s.s_grp_err_max.Notify(s_group_err_.Max()); + s.s_grp_snr1.Notify(static_cast(num_snr1_)); + s.s_grp_snr_avg.Notify(s_group_snr_.Mean()); + s.s_grp_snr_std.Notify(s_group_snr_.StandardDeviation()); const uint32_t subnormals = b_exp256_.Bin(0); // Prevent subnormals from hiding the min exponent. @@ -222,13 +235,12 @@ class TensorStatsAccumulator { void PrintAll() { fprintf(stderr, "Frob %.2E\n", std::sqrt(sum_sq_)); const int skip = hwy::Stats::kNoGeomean; - fprintf(stderr, "cnd %s\n", s_cond_.ToString(skip).c_str()); - fprintf(stderr, "val %s\n", s_val_.ToString(skip).c_str()); - fprintf(stderr, "mag %s\n", s_mag_.ToString(skip).c_str()); - fprintf(stderr, "corr %s\n", s_corr_.ToString(skip).c_str()); - fprintf(stderr, "group_min %s\n", s_group_min_.ToString(skip).c_str()); - fprintf(stderr, "group_max %s\n", s_group_max_.ToString(skip).c_str()); - fprintf(stderr, "group_range %s\n", s_group_range_.ToString(skip).c_str()); + fprintf(stderr, "cnd %s\n", s_cond_.ToString(skip).c_str()); + fprintf(stderr, "val %s\n", s_val_.ToString(skip).c_str()); + fprintf(stderr, "mag %s\n", s_mag_.ToString(skip).c_str()); + fprintf(stderr, "crr %s\n", s_corr_.ToString(skip).c_str()); + fprintf(stderr, "err %s\n", s_group_err_.ToString(skip).c_str()); + fprintf(stderr, "snr %s\n", s_group_snr_.ToString(skip).c_str()); b_exp256_.Print("exp"); PrintBinRanges(b_big_row_, "big row"); PrintBinRanges(b_big_col_, "big col"); @@ -244,30 +256,25 @@ class TensorStatsAccumulator { } if (total == 0) return; - // If all bins are at least 10% of a uniform distribution, print the range - // to vastly reduce the log size. + fprintf(stderr, "%s total %zu: \n", name, total); + // Group together runs to reduce the log size. const size_t min = HWY_MAX(1, total / (N * 10)); - size_t last = 0; - for (; last < N; ++last) { - if (b.Bin(last) < min) break; - } - if (last >= N / 2) { - // Also require all subsequent bins to be zero, otherwise we should - // print the outlier bins. - bool all_zero = true; - for (size_t i = last + 1; i < N; ++i) { - if (b.Bin(last) != 0) { - all_zero = false; - break; - } + for (size_t i = 0; i < N; ++i) { + if (b.Bin(i) == 0) continue; + if (b.Bin(i) < min) { + fprintf(stderr, " %3zu: %zu\n", i, b.Bin(i)); + continue; } - if (all_zero) { - fprintf(stderr, "%s: uniform up to %zu\n", name, last); - return; + const size_t first = i; + while (i + 1 < N && b.Bin(i + 1) >= min) { + i++; + } + if (first == i) { + fprintf(stderr, " %3zu: %zu\n", i, b.Bin(i)); + } else { + fprintf(stderr, " [%3zu, %3zu]\n", first, i); } } - - b.Print(name, /*skip_zero=*/true); } double sum_sq_ = 0.0; // for Frobenius norm @@ -278,9 +285,9 @@ class TensorStatsAccumulator { hwy::Stats s_mag_; hwy::Stats s_cond_; // condition number hwy::Stats s_corr_; // lag-1 autocorrelation - hwy::Stats s_group_min_; - hwy::Stats s_group_max_; - hwy::Stats s_group_range_; + hwy::Stats s_group_err_; + hwy::Stats s_group_snr_; + size_t num_snr1_ = 0; std::atomic skip_{0}; };