Skip to content

Commit c5e9f48

Browse files
committed
allow for hard clamp in fsq, to ready for residual fsq pre-softclamping to the right ranges
1 parent 976c3f2 commit c5e9f48

File tree

2 files changed

+24
-9
lines changed

2 files changed

+24
-9
lines changed

tests/test_readme.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,13 +247,15 @@ def test_directional_reparam():
247247
quantized, indices, _ = rq(x)
248248

249249
@pytest.mark.parametrize('preserve_symmetry', (True, False))
250+
@pytest.mark.parametrize('bound_hard_clamp', (True, False))
250251
def test_fsq(
251-
preserve_symmetry
252+
preserve_symmetry,
253+
bound_hard_clamp
252254
):
253255
from vector_quantize_pytorch import FSQ
254256

255257
levels = [8,5,5,5] # see 4.1 and A.4.1 in the paper
256-
quantizer = FSQ(levels, preserve_symmetry = preserve_symmetry)
258+
quantizer = FSQ(levels, preserve_symmetry = preserve_symmetry, bound_hard_clamp = bound_hard_clamp)
257259

258260
x = torch.randn(1, 1024, 4) # 4 since there are 4 levels
259261
xhat, indices = quantizer(x)

vector_quantize_pytorch/finite_scalar_quantization.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
import torch.nn as nn
1313
from torch.nn import Module
14-
from torch import tensor, Tensor, int32
14+
from torch import tensor, Tensor, int32, tanh, atanh, clamp
1515
from torch.amp import autocast
1616

1717
import einx
@@ -30,6 +30,9 @@ def default(*args):
3030
return arg
3131
return None
3232

33+
def identity(t):
34+
return t
35+
3336
def maybe(fn):
3437
@wraps(fn)
3538
def inner(x, *args, **kwargs):
@@ -73,6 +76,7 @@ def __init__(
7376
force_quantization_f32 = True,
7477
preserve_symmetry = False,
7578
noise_dropout = 0.,
79+
bound_hard_clamp = False # for residual fsq, if input is pre-softclamped to the right range
7680
):
7781
super().__init__()
7882

@@ -121,22 +125,31 @@ def __init__(
121125
self.allowed_dtypes = allowed_dtypes
122126
self.force_quantization_f32 = force_quantization_f32
123127

124-
def bound(self, z, eps = 1e-3):
128+
# allow for a hard clamp
129+
130+
self.bound_hard_clamp = bound_hard_clamp
131+
132+
def bound(self, z, eps = 1e-3, hard_clamp = False):
125133
""" Bound `z`, an array of shape (..., d). """
134+
maybe_tanh = tanh if not hard_clamp else partial(clamp, min = -1., max = 1.)
135+
maybe_atanh = atanh if not hard_clamp else identity
136+
126137
half_l = (self._levels - 1) * (1 + eps) / 2
127138
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
128-
shift = (offset / half_l).atanh()
129-
bounded_z = (z + shift).tanh() * half_l - offset
139+
shift = maybe_atanh(offset / half_l)
140+
bounded_z = maybe_tanh(z + shift) * half_l - offset
130141
half_width = self._levels // 2
131142
return round_ste(bounded_z) / half_width
132143

133144
# symmetry-preserving and noise-approximated quantization, section 3.2 in https://arxiv.org/abs/2411.19842
134145

135-
def symmetry_preserving_bound(self, z):
146+
def symmetry_preserving_bound(self, z, hard_clamp = False):
136147
""" QL(x) = 2 / (L - 1) * [(L - 1) * (tanh(x) + 1) / 2 + 0.5] - 1 """
148+
maybe_tanh = tanh if not hard_clamp else partial(clamp, min = -1., max = 1.)
149+
137150
levels_minus_1 = (self._levels - 1)
138151
scale = 2. / levels_minus_1
139-
bracket = (levels_minus_1 * (z.tanh() + 1) / 2.) + 0.5
152+
bracket = (levels_minus_1 * (maybe_tanh(z) + 1) / 2.) + 0.5
140153
bracket = floor_ste(bracket)
141154
return scale * bracket - 1.
142155

@@ -146,7 +159,7 @@ def quantize(self, z):
146159
shape, device, noise_dropout, preserve_symmetry = z.shape[0], z.device, self.noise_dropout, self.preserve_symmetry
147160
bound_fn = self.symmetry_preserving_bound if preserve_symmetry else self.bound
148161

149-
bounded_z = bound_fn(z)
162+
bounded_z = bound_fn(z, hard_clamp = self.bound_hard_clamp)
150163

151164
# determine where to add a random offset elementwise
152165
# if using noise dropout

0 commit comments

Comments
 (0)