Skip to content

Commit

Permalink
document the new SimVQ and ResidualSimVQ
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 12, 2024
1 parent f883a07 commit 919a1b8
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,47 @@ indices = quantizer(x)

This repository should also automatically synchronizing the codebooks in a multi-process setting. If somehow it isn't, please open an issue. You can override whether to synchronize codebooks or not by setting `sync_codebook = True | False`

### Sim VQ

<img src="./images/simvq.png" width="400px"></img>

A <a href="https://arxiv.org/abs/2411.02038">new ICLR 2025 paper</a> proposes a scheme where the codebook is frozen, and the codebook is implicitly generated through a linear projection. The author claims this setup leads to less codebook collapse as well as easier convergence. I have found this to perform even better when paired with <a href="https://arxiv.org/abs/2410.06424">rotation trick</a> from Fifty et al., and expanding the linear projection to a small one layer MLP. You can experiment with it as so

```python
import torch
from vector_quantize_pytorch import SimVQ

sim_vq = SimVQ(
dim = 512,
codebook_size = 1024
)

x = torch.randn(1, 1024, 512)
quantized, indices, commit_loss = sim_vq(x)

assert x.shape == quantized.shape
assert torch.allclose(quantized, sim_vq.indices_to_codes(indices), atol = 1e-6)
```

For the residual flavor, just import `ResidualSimVQ` instead

```python
import torch
from vector_quantize_pytorch import ResidualSimVQ

residual_sim_vq = ResidualSimVQ(
dim = 512,
num_quantizers = 4,
codebook_size = 1024
)

x = torch.randn(1, 1024, 512)
quantized, indices, commit_loss = residual_sim_vq(x)

assert x.shape == quantized.shape
assert torch.allclose(quantized, residual_sim_vq.get_output_from_indices(indices), atol = 1e-6)
```

### Finite Scalar Quantization

<img src="./images/fsq.png" width="500px"></img>
Expand Down
Binary file added images/simvq.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 919a1b8

Please sign in to comment.