2222def exists (val ):
2323 return val is not None
2424
25+ def first (it ):
26+ return it [0 ]
27+
2528def default (val , d ):
2629 return val if exists (val ) else d
2730
@@ -34,6 +37,64 @@ def round_up_multiple(num, mult):
3437def is_distributed ():
3538 return dist .is_initialized () and dist .get_world_size () > 1
3639
40+ # the mlp for generating the neural implicit codebook
41+ # from Huijben et al. https://arxiv.org/abs/2401.14732
42+
43+ class MLP (Module ):
44+ def __init__ (
45+ self ,
46+ dim ,
47+ dim_hidden = None ,
48+ depth = 4 , # they used 4 layers in the paper
49+ l2norm_output = False
50+ ):
51+ super ().__init__ ()
52+ dim_hidden = default (dim_hidden , dim )
53+
54+ self .proj_in = nn .Linear (2 * dim , dim )
55+
56+ layers = ModuleList ([])
57+
58+ for _ in range (depth ):
59+ layers .append (nn .Sequential (
60+ nn .Linear (dim , dim_hidden ),
61+ nn .SiLU (),
62+ nn .Linear (dim_hidden , dim )
63+ ))
64+
65+ self .layers = layers
66+ self .l2norm_output = l2norm_output
67+
68+ def forward (
69+ self ,
70+ codes ,
71+ * ,
72+ condition
73+ ):
74+ one_headed = codes .ndim == 2
75+
76+ if one_headed :
77+ codes = rearrange (codes , 'c d -> 1 c d' )
78+
79+ heads , num_codes , batch , seq_len = codes .shape [0 ], codes .shape [- 2 ], condition .shape [0 ], condition .shape [- 2 ]
80+
81+ codes = repeat (codes , 'h c d -> h b n c d' , n = seq_len , b = batch )
82+ condition = repeat (condition , 'b n d -> h b n c d' , c = num_codes , h = heads )
83+
84+ x = torch .cat ((condition , codes ), dim = - 1 )
85+ x = self .proj_in (x )
86+
87+ for layer in self .layers :
88+ x = layer (x ) + x
89+
90+ if self .l2norm_output :
91+ x = F .normalize (x , dim = - 1 )
92+
93+ if not one_headed :
94+ return x
95+
96+ return rearrange (x , '1 ... -> ...' )
97+
3798# main class
3899
39100class ResidualVQ (Module ):
@@ -50,7 +111,9 @@ def __init__(
50111 quantize_dropout_cutoff_index = 0 ,
51112 quantize_dropout_multiple_of = 1 ,
52113 accept_image_fmap = False ,
53- ** kwargs
114+ implicit_neural_codebook = False , # QINCo from https://arxiv.org/abs/2401.14732
115+ mlp_kwargs : dict = dict (),
116+ ** vq_kwargs
54117 ):
55118 super ().__init__ ()
56119 assert heads == 1 , 'residual vq is not compatible with multi-headed codes'
@@ -65,7 +128,16 @@ def __init__(
65128 self .num_quantizers = num_quantizers
66129
67130 self .accept_image_fmap = accept_image_fmap
68- self .layers = ModuleList ([VectorQuantize (dim = codebook_dim , codebook_dim = codebook_dim , accept_image_fmap = accept_image_fmap , ** kwargs ) for _ in range (num_quantizers )])
131+
132+ self .implicit_neural_codebook = implicit_neural_codebook
133+
134+ if implicit_neural_codebook :
135+ vq_kwargs .update (
136+ learnable_codebook = True ,
137+ ema_update = False
138+ )
139+
140+ self .layers = ModuleList ([VectorQuantize (dim = codebook_dim , codebook_dim = codebook_dim , accept_image_fmap = accept_image_fmap , ** vq_kwargs ) for _ in range (num_quantizers )])
69141
70142 assert all ([not vq .has_projections for vq in self .layers ])
71143
@@ -76,6 +148,12 @@ def __init__(
76148 self .quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
77149 self .quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
78150
151+ # setting up the MLPs for implicit neural codebooks
152+
153+ self .mlps = ModuleList ([MLP (dim = codebook_dim , l2norm_output = first (self .layers ).use_cosine_sim , ** mlp_kwargs ) for _ in range (num_quantizers - 1 )])
154+
155+ # sharing codebook logic
156+
79157 if not shared_codebook :
80158 return
81159
@@ -120,7 +198,31 @@ def get_codes_from_indices(self, indices):
120198 mask = indices == - 1.
121199 indices = indices .masked_fill (mask , 0 ) # have it fetch a dummy code to be masked out later
122200
123- all_codes = get_at ('q [c] d, b n q -> q b n d' , self .codebooks , indices )
201+ if not self .implicit_neural_codebook :
202+ # gather all the codes
203+
204+ all_codes = get_at ('q [c] d, b n q -> q b n d' , self .codebooks , indices )
205+
206+ else :
207+ # else if using implicit neural codebook, codes will need to be derived layer by layer
208+
209+ code_transform_mlps = (None , * self .mlps )
210+
211+ all_codes = []
212+ quantized_out = 0.
213+
214+ for codes , indices , maybe_transform_mlp in zip (self .codebooks , indices .unbind (dim = - 1 ), code_transform_mlps ):
215+
216+ if exists (maybe_transform_mlp ):
217+ codes = maybe_transform_mlp (codes , condition = quantized_out )
218+ layer_codes = get_at ('b n [c] d, b n -> b n d' , codes , indices )
219+ else :
220+ layer_codes = get_at ('[c] d, b n -> b n d' , codes , indices )
221+
222+ all_codes .append (layer_codes )
223+ quantized_out += layer_codes
224+
225+ all_codes = torch .stack (all_codes )
124226
125227 # mask out any codes that were dropout-ed
126228
@@ -195,9 +297,16 @@ def forward(
195297 null_indices = torch .full (null_indices_shape , - 1. , device = device , dtype = torch .long )
196298 null_loss = torch .full ((1 ,), 0. , device = device , dtype = x .dtype )
197299
300+ # setup the mlps for implicit neural codebook
301+
302+ maybe_code_transforms = (None ,) * len (self .layers )
303+
304+ if self .implicit_neural_codebook :
305+ maybe_code_transforms = (None , * self .mlps )
306+
198307 # go through the layers
199308
200- for quantizer_index , layer in enumerate (self .layers ):
309+ for quantizer_index , ( vq , maybe_mlp ) in enumerate (zip ( self .layers , maybe_code_transforms ) ):
201310
202311 if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index :
203312 all_indices .append (null_indices )
@@ -208,12 +317,20 @@ def forward(
208317 if return_loss :
209318 layer_indices = indices [..., quantizer_index ]
210319
211- quantized , * rest = layer (
320+ # setup the transform code function to be passed into VectorQuantize forward
321+
322+ if exists (maybe_mlp ):
323+ maybe_mlp = partial (maybe_mlp , condition = quantized_out )
324+
325+ # vector quantize forward
326+
327+ quantized , * rest = vq (
212328 residual ,
213329 mask = mask ,
214330 indices = layer_indices ,
215331 sample_codebook_temp = sample_codebook_temp ,
216- freeze_codebook = freeze_codebook
332+ freeze_codebook = freeze_codebook ,
333+ codebook_transform_fn = maybe_mlp
217334 )
218335
219336 residual = residual - quantized .detach ()
0 commit comments