@@ -245,12 +245,6 @@ def kmeans(
245245
246246 return means , bins
247247
248- def batched_embedding (indices , embeds ):
249- batch , dim = indices .shape [1 ], embeds .shape [- 1 ]
250- indices = repeat (indices , 'h b n -> h b n d' , d = dim )
251- embeds = repeat (embeds , 'h c d -> h b c d' , b = batch )
252- return embeds .gather (2 , indices )
253-
254248# distributed helpers
255249
256250@cache
@@ -521,17 +515,22 @@ def forward(
521515
522516 embed_ind = unpack_one (embed_ind , 'h *' )
523517
518+ if exists (codebook_transform_fn ):
519+ transformed_embed = unpack_one (transformed_embed , 'h * c d' )
520+
524521 if self .training :
525522 unpacked_onehot = unpack_one (embed_onehot , 'h * c' )
526523
527524 if exists (codebook_transform_fn ):
528- transformed_embed = unpack_one (transformed_embed , 'h * c d' )
529525 quantize = einsum ('h b n c, h b n c d -> h b n d' , unpacked_onehot , transformed_embed )
530526 else :
531527 quantize = einsum ('h b n c, h c d -> h b n d' , unpacked_onehot , embed )
532528
533529 else :
534- quantize = batched_embedding (embed_ind , embed )
530+ if exists (codebook_transform_fn ):
531+ quantize = einx .get_at ('h b n [c] d, h b n -> h b n d' , transformed_embed , embed_ind )
532+ else :
533+ quantize = einx .get_at ('h [c] d, h b n -> h b n d' , embed , embed_ind )
535534
536535 if self .training and self .ema_update and not freeze_codebook :
537536
@@ -715,17 +714,22 @@ def forward(
715714 embed_ind , embed_onehot = self .gumbel_sample (dist , dim = - 1 , temperature = sample_codebook_temp , training = self .training )
716715 embed_ind = unpack_one (embed_ind , 'h *' )
717716
717+ if exists (codebook_transform_fn ):
718+ transformed_embed = unpack_one (transformed_embed , 'h * c d' )
719+
718720 if self .training :
719721 unpacked_onehot = unpack_one (embed_onehot , 'h * c' )
720722
721723 if exists (codebook_transform_fn ):
722- transformed_embed = unpack_one (transformed_embed , 'h * c d' )
723724 quantize = einsum ('h b n c, h b n c d -> h b n d' , unpacked_onehot , transformed_embed )
724725 else :
725726 quantize = einsum ('h b n c, h c d -> h b n d' , unpacked_onehot , embed )
726727
727728 else :
728- quantize = batched_embedding (embed_ind , embed )
729+ if exists (codebook_transform_fn ):
730+ quantize = einx .get_at ('h b n [c] d, h b n -> h b n d' , transformed_embed , embed_ind )
731+ else :
732+ quantize = einx .get_at ('h [c] d, h b n -> h b n d' , embed , embed_ind )
729733
730734 if self .training and self .ema_update and not freeze_codebook :
731735 if exists (mask ):
0 commit comments