Skip to content

Commit 3ad24a3

Browse files
committed
fix eval mode for implicit neural codebooks and fix tests
1 parent c3f33b0 commit 3ad24a3

File tree

3 files changed

+18
-14
lines changed

3 files changed

+18
-14
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.16.0"
3+
version = "1.16.1"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_readme.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,14 @@ def test_residual_vq(
7171
from vector_quantize_pytorch import ResidualVQ
7272

7373
residual_vq = ResidualVQ(
74-
dim = 256,
74+
dim = 32,
7575
num_quantizers = 8,
76-
codebook_size = 1024,
76+
codebook_size = 128,
7777
implicit_neural_codebook = implicit_neural_codebook,
7878
use_cosine_sim = use_cosine_sim,
7979
)
8080

81-
x = torch.randn(1, 1024, 256)
81+
x = torch.randn(1, 256, 32)
8282

8383
quantized, indices, commit_loss = residual_vq(x)
8484
quantized, indices, commit_loss, all_codes = residual_vq(x, return_all_codes = True)

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -245,12 +245,6 @@ def kmeans(
245245

246246
return means, bins
247247

248-
def batched_embedding(indices, embeds):
249-
batch, dim = indices.shape[1], embeds.shape[-1]
250-
indices = repeat(indices, 'h b n -> h b n d', d = dim)
251-
embeds = repeat(embeds, 'h c d -> h b c d', b = batch)
252-
return embeds.gather(2, indices)
253-
254248
# distributed helpers
255249

256250
@cache
@@ -521,17 +515,22 @@ def forward(
521515

522516
embed_ind = unpack_one(embed_ind, 'h *')
523517

518+
if exists(codebook_transform_fn):
519+
transformed_embed = unpack_one(transformed_embed, 'h * c d')
520+
524521
if self.training:
525522
unpacked_onehot = unpack_one(embed_onehot, 'h * c')
526523

527524
if exists(codebook_transform_fn):
528-
transformed_embed = unpack_one(transformed_embed, 'h * c d')
529525
quantize = einsum('h b n c, h b n c d -> h b n d', unpacked_onehot, transformed_embed)
530526
else:
531527
quantize = einsum('h b n c, h c d -> h b n d', unpacked_onehot, embed)
532528

533529
else:
534-
quantize = batched_embedding(embed_ind, embed)
530+
if exists(codebook_transform_fn):
531+
quantize = einx.get_at('h b n [c] d, h b n -> h b n d', transformed_embed, embed_ind)
532+
else:
533+
quantize = einx.get_at('h [c] d, h b n -> h b n d', embed, embed_ind)
535534

536535
if self.training and self.ema_update and not freeze_codebook:
537536

@@ -715,17 +714,22 @@ def forward(
715714
embed_ind, embed_onehot = self.gumbel_sample(dist, dim = -1, temperature = sample_codebook_temp, training = self.training)
716715
embed_ind = unpack_one(embed_ind, 'h *')
717716

717+
if exists(codebook_transform_fn):
718+
transformed_embed = unpack_one(transformed_embed, 'h * c d')
719+
718720
if self.training:
719721
unpacked_onehot = unpack_one(embed_onehot, 'h * c')
720722

721723
if exists(codebook_transform_fn):
722-
transformed_embed = unpack_one(transformed_embed, 'h * c d')
723724
quantize = einsum('h b n c, h b n c d -> h b n d', unpacked_onehot, transformed_embed)
724725
else:
725726
quantize = einsum('h b n c, h c d -> h b n d', unpacked_onehot, embed)
726727

727728
else:
728-
quantize = batched_embedding(embed_ind, embed)
729+
if exists(codebook_transform_fn):
730+
quantize = einx.get_at('h b n [c] d, h b n -> h b n d', transformed_embed, embed_ind)
731+
else:
732+
quantize = einx.get_at('h [c] d, h b n -> h b n d', embed, embed_ind)
729733

730734
if self.training and self.ema_update and not freeze_codebook:
731735
if exists(mask):

0 commit comments

Comments
 (0)