Skip to content

Commit

Permalink
Fix gaussian distribution and sampling for large number of categories (
Browse files Browse the repository at this point in the history
  • Loading branch information
fmassa authored May 2, 2021
1 parent 2816f1b commit 9f2110a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
4 changes: 2 additions & 2 deletions docs/source/2d_attention_patterns.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@
}
],
"source": [
"gaus_2d_dist = AP.local_2d_gausian_distribution(H, W, sigma=0.5)\n",
"gaus_2d_dist = AP.local_2d_gausian_distribution(H, W, sigma=2)\n",
"\n",
"fig, axs = plt.subplots(1, 2, figsize=(15, 7))\n",
"# full sparse matrix mask between every two points\n",
Expand Down Expand Up @@ -404,7 +404,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.8.8"
}
},
"nbformat": 4,
Expand Down
21 changes: 18 additions & 3 deletions xformers/components/attention/attention_patterns.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import torch


Expand Down Expand Up @@ -65,8 +66,8 @@ def local_2d_distance(H, W, p=2.0):


def local_2d_gausian_distribution(H, W, sigma=1):
d = local_2d_distance(H, W, p=2.0)
d = torch.exp(-0.5 * sigma ** (-0.5) * d)
d = local_2d_distance(H, W, p=2.0) ** 2
d = torch.exp(-0.5 * sigma ** (-2.0) * d)
return d


Expand All @@ -83,6 +84,20 @@ def axial_2d_pattern(H, W):

def random_pattern_from_probability_matrix(dist_matrix, nnz):
att = torch.zeros_like(dist_matrix, dtype=torch.bool)
idxs = torch.multinomial(dist_matrix.flatten(), nnz, replacement=False)
# PyTorch multinomial wrongly doesn't support sampling when number of categories
# is > 2^24, arguing that it's because it's the max representable consecutive element
# in fp32 and that the kernels use float32. This is actually not true, and the kernels
# should work fine if double tensor is passed on CPU. This is a bug that was introduced
# in https://github.com/pytorch/pytorch/commit/bf04c2ca2f591d98ce57816f0ef0cd20a21bbf66
# when unifying the checks between CPU and CUDA. For now, just fall-back to numpy
if dist_matrix.numel() > 2 ** 24:
dist_matrix = dist_matrix.double()
dist_matrix /= dist_matrix.sum()
idxs = np.random.choice(
dist_matrix.numel(), nnz, p=dist_matrix.flatten(), replace=False
)
idxs = torch.as_tensor(idxs)
else:
idxs = torch.multinomial(dist_matrix.flatten(), nnz, replacement=False)
att.view(-1)[idxs] = True
return att

0 comments on commit 9f2110a

Please sign in to comment.