Skip to content

Commit b4d513e

Browse files
committed
update residual fsq, always use preserve symmetry
1 parent f570433 commit b4d513e

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.25.2"
3+
version = "1.26.0"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -23,7 +23,7 @@ classifiers=[
2323
]
2424

2525
dependencies = [
26-
"torch>=2.0",
26+
"torch>=2.4",
2727
"einops>=0.8.0",
2828
"einx>=0.3.0",
2929
]

vector_quantize_pytorch/residual_fsq.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
from __future__ import annotations
2+
13
import random
24
from math import ceil
35
from functools import partial
46

5-
from typing import List
6-
77
import torch
8-
from torch import nn
8+
from torch import nn, tensor
99
from torch.nn import Module, ModuleList
1010
import torch.nn.functional as F
1111
from torch.amp import autocast
@@ -52,14 +52,15 @@ class ResidualFSQ(Module):
5252
def __init__(
5353
self,
5454
*,
55-
levels: List[int],
55+
levels: list[int],
5656
num_quantizers,
5757
dim = None,
5858
is_channel_first = False,
5959
quantize_dropout = False,
6060
quantize_dropout_cutoff_index = 0,
6161
quantize_dropout_multiple_of = 1,
62-
soft_clamp_input_value = None,
62+
soft_clamp_input_value: float | list[float] | Tensor | None = None,
63+
bound_hard_clamp = True,
6364
**kwargs
6465
):
6566
super().__init__()
@@ -74,25 +75,23 @@ def __init__(
7475
self.is_channel_first = is_channel_first
7576
self.num_quantizers = num_quantizers
7677

77-
# soft clamping the input value
78-
79-
self.soft_clamp_input_value = soft_clamp_input_value
80-
8178
# layers
8279

8380
self.levels = levels
8481
self.layers = nn.ModuleList([])
8582

86-
levels_tensor = torch.Tensor(levels)
83+
levels_tensor = tensor(levels)
84+
assert (levels_tensor > 1).all()
8785

8886
scales = []
8987

9088
for ind in range(num_quantizers):
91-
scales.append(levels_tensor ** -ind)
89+
scales.append(levels_tensor.float() ** -ind)
9290

9391
fsq = FSQ(
9492
levels = levels,
9593
dim = codebook_dim,
94+
preserve_symmetry = True,
9695
**kwargs
9796
)
9897

@@ -111,6 +110,17 @@ def __init__(
111110
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
112111
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
113112

113+
# soft clamping the input value
114+
115+
if bound_hard_clamp:
116+
assert not exists(soft_clamp_input_value)
117+
soft_clamp_input_value = 1 + (1 / (levels_tensor - 1))
118+
119+
if isinstance(soft_clamp_input_value, (list, float)):
120+
soft_clamp_input_value = tensor(soft_clamp_input_value)
121+
122+
self.register_buffer('soft_clamp_input_value', soft_clamp_input_value, persistent = False)
123+
114124
@property
115125
def codebooks(self):
116126
codebooks = [layer.implicit_codebook for layer in self.layers]

0 commit comments

Comments
 (0)