1111import torch
1212import torch .nn as nn
1313from torch .nn import Module
14- from torch import tensor , Tensor , int32
14+ from torch import tensor , Tensor , int32 , tanh , atanh , clamp
1515from torch .amp import autocast
1616
1717import einx
@@ -30,6 +30,9 @@ def default(*args):
3030 return arg
3131 return None
3232
33+ def identity (t ):
34+ return t
35+
3336def maybe (fn ):
3437 @wraps (fn )
3538 def inner (x , * args , ** kwargs ):
@@ -73,6 +76,7 @@ def __init__(
7376 force_quantization_f32 = True ,
7477 preserve_symmetry = False ,
7578 noise_dropout = 0. ,
79+ bound_hard_clamp = False # for residual fsq, if input is pre-softclamped to the right range
7680 ):
7781 super ().__init__ ()
7882
@@ -121,22 +125,31 @@ def __init__(
121125 self .allowed_dtypes = allowed_dtypes
122126 self .force_quantization_f32 = force_quantization_f32
123127
124- def bound (self , z , eps = 1e-3 ):
128+ # allow for a hard clamp
129+
130+ self .bound_hard_clamp = bound_hard_clamp
131+
132+ def bound (self , z , eps = 1e-3 , hard_clamp = False ):
125133 """ Bound `z`, an array of shape (..., d). """
134+ maybe_tanh = tanh if not hard_clamp else partial (clamp , min = - 1. , max = 1. )
135+ maybe_atanh = atanh if not hard_clamp else identity
136+
126137 half_l = (self ._levels - 1 ) * (1 + eps ) / 2
127138 offset = torch .where (self ._levels % 2 == 0 , 0.5 , 0.0 )
128- shift = (offset / half_l ). atanh ( )
129- bounded_z = (z + shift ). tanh ( ) * half_l - offset
139+ shift = maybe_atanh (offset / half_l )
140+ bounded_z = maybe_tanh (z + shift ) * half_l - offset
130141 half_width = self ._levels // 2
131142 return round_ste (bounded_z ) / half_width
132143
133144 # symmetry-preserving and noise-approximated quantization, section 3.2 in https://arxiv.org/abs/2411.19842
134145
135- def symmetry_preserving_bound (self , z ):
146+ def symmetry_preserving_bound (self , z , hard_clamp = False ):
136147 """ QL(x) = 2 / (L - 1) * [(L - 1) * (tanh(x) + 1) / 2 + 0.5] - 1 """
148+ maybe_tanh = tanh if not hard_clamp else partial (clamp , min = - 1. , max = 1. )
149+
137150 levels_minus_1 = (self ._levels - 1 )
138151 scale = 2. / levels_minus_1
139- bracket = (levels_minus_1 * (z . tanh ( ) + 1 ) / 2. ) + 0.5
152+ bracket = (levels_minus_1 * (maybe_tanh ( z ) + 1 ) / 2. ) + 0.5
140153 bracket = floor_ste (bracket )
141154 return scale * bracket - 1.
142155
@@ -146,7 +159,7 @@ def quantize(self, z):
146159 shape , device , noise_dropout , preserve_symmetry = z .shape [0 ], z .device , self .noise_dropout , self .preserve_symmetry
147160 bound_fn = self .symmetry_preserving_bound if preserve_symmetry else self .bound
148161
149- bounded_z = bound_fn (z )
162+ bounded_z = bound_fn (z , hard_clamp = self . bound_hard_clamp )
150163
151164 # determine where to add a random offset elementwise
152165 # if using noise dropout
0 commit comments