1+ from __future__ import annotations
2+
13import random
24from math import ceil
35from functools import partial
46
5- from typing import List
6-
77import torch
8- from torch import nn
8+ from torch import nn , tensor
99from torch .nn import Module , ModuleList
1010import torch .nn .functional as F
1111from torch .amp import autocast
@@ -52,14 +52,15 @@ class ResidualFSQ(Module):
5252 def __init__ (
5353 self ,
5454 * ,
55- levels : List [int ],
55+ levels : list [int ],
5656 num_quantizers ,
5757 dim = None ,
5858 is_channel_first = False ,
5959 quantize_dropout = False ,
6060 quantize_dropout_cutoff_index = 0 ,
6161 quantize_dropout_multiple_of = 1 ,
62- soft_clamp_input_value = None ,
62+ soft_clamp_input_value : float | list [float ] | Tensor | None = None ,
63+ bound_hard_clamp = True ,
6364 ** kwargs
6465 ):
6566 super ().__init__ ()
@@ -74,25 +75,23 @@ def __init__(
7475 self .is_channel_first = is_channel_first
7576 self .num_quantizers = num_quantizers
7677
77- # soft clamping the input value
78-
79- self .soft_clamp_input_value = soft_clamp_input_value
80-
8178 # layers
8279
8380 self .levels = levels
8481 self .layers = nn .ModuleList ([])
8582
86- levels_tensor = torch .Tensor (levels )
83+ levels_tensor = tensor (levels )
84+ assert (levels_tensor > 1 ).all ()
8785
8886 scales = []
8987
9088 for ind in range (num_quantizers ):
91- scales .append (levels_tensor ** - ind )
89+ scales .append (levels_tensor . float () ** - ind )
9290
9391 fsq = FSQ (
9492 levels = levels ,
9593 dim = codebook_dim ,
94+ preserve_symmetry = True ,
9695 ** kwargs
9796 )
9897
@@ -111,6 +110,17 @@ def __init__(
111110 self .quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
112111 self .quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
113112
113+ # soft clamping the input value
114+
115+ if bound_hard_clamp :
116+ assert not exists (soft_clamp_input_value )
117+ soft_clamp_input_value = 1 + (1 / (levels_tensor - 1 ))
118+
119+ if isinstance (soft_clamp_input_value , (list , float )):
120+ soft_clamp_input_value = tensor (soft_clamp_input_value )
121+
122+ self .register_buffer ('soft_clamp_input_value' , soft_clamp_input_value , persistent = False )
123+
114124 @property
115125 def codebooks (self ):
116126 codebooks = [layer .implicit_codebook for layer in self .layers ]
0 commit comments