diff --git a/README.md b/README.md index 3d88369baf..cd15a5c689 100644 --- a/README.md +++ b/README.md @@ -188,17 +188,19 @@ Some examples, generated with `python3 benchmarks/benchmark_encoder.py --activat ### LRA The code for this benchmark has been adapted from [this repository](https://github.com/mlpen/Nystromformer/tree/main/LRA). [A dedicated README is available here](benchmarks/LRA/README.md) -Some results: - -| Attention | ListOps | Text | Retrieval | Image | Pathfinder | *Avg* | *Est. Gflops* | *Peak mem (mb)* | -| --------------------------- | -------- | --------- | --------- | --------- | ---------- | ------ | ------------- | --------------- | -| _Chance_ | _10_ | _50_ | _50_ | _10_ | _50_ | _34_ | _0_ | _0_ | -| Standard | **37.5** | 62.66 | 79.24 | 38.69 | **70.37** | 57.69 | 1.21 | 2291 | -| Nystromformer-128 | 36.29 | **63.24** | 78.18 | **42.86** | 67.49 | 57.61 | 0.62 | 383 | -| Favor-256 (redraw) | 19.56 | 62.76 | **81.1** | 36.09 | 67.23 | 53.35 | 0.49 | 445 | -| FourierMix | 36.29 | 60.72 | 76.41 | 36.53 | 54.07 | 52.8 | **0.17** | **87** | -| Linformer-seq/4 (no redraw) | 36.69 | 57.39 | 76.41 | 35.57 | 65.12 | 54.2 | 0.67 | 719 | -| Lambda | 19.76 | 62.47 | 79.11 | 35.04 | 49.74 | 49.224 | x | 1023 | + +__Some results:__ + +| Attention | ListOps | Text | Retrieval | Image | Pathfinder | *Avg* | *Est. Gflops* | *Peak mem (mb)* | +| --------------------------- | -------- | --------- | --------- | --------- | ---------- | --------- | ------------- | --------------- | +| _Chance_ | _10_ | _50_ | _50_ | _10_ | _50_ | _34_ | _0_ | _0_ | +| Standard | **37.5** | 62.66 | 79.24 | 38.69 | **70.37** | **57.69** | 1.21 | 2291 | +| Nystromformer-128 | 36.29 | 63.24 | 78.18 | **42.86** | 67.49 | 57.61 | 0.62 | 383 | +| Favor-256 (redraw) | 19.56 | 62.76 | **81.1** | 36.09 | 67.23 | 53.35 | 0.49 | 445 | +| FourierMix | 36.29 | 60.72 | 76.41 | 36.53 | 54.07 | 52.8 | **0.17** | **87** | +| Linformer-seq/4 (no redraw) | 36.69 | 57.39 | 76.41 | 35.57 | 65.12 | 54.2 | 0.67 | 719 | +| Lambda | 19.76 | 62.47 | 79.11 | 35.04 | 49.74 | 49.224 | x | 1023 | +| Orthoformer-32 | 27.42 | **63.96** | 77.96 | 34.5 | 67.11 | 54.19 | 0.187 | 155 | - Contrary to the initial LRA proposal, __we use the same model architecture for all tasks (2 layers).__ - The training schedule for ListOps has been lengthened, while keeping it the fastest of all tasks, which reduces the seed dependence in the final accuracy figure. diff --git a/docs/plots/memory_vs_attention.png b/docs/plots/memory_vs_attention.png index b77afc23e5..6ee800e7ee 100644 Binary files a/docs/plots/memory_vs_attention.png and b/docs/plots/memory_vs_attention.png differ diff --git a/docs/plots/runtime_vs_attention.png b/docs/plots/runtime_vs_attention.png index 54bba059d1..16bd0eae66 100644 Binary files a/docs/plots/runtime_vs_attention.png and b/docs/plots/runtime_vs_attention.png differ diff --git a/xformers/components/attention/__init__.py b/xformers/components/attention/__init__.py index bcee2e8eb7..db3f39191b 100644 --- a/xformers/components/attention/__init__.py +++ b/xformers/components/attention/__init__.py @@ -84,6 +84,7 @@ def sparsify(matrix): from .linformer import LinformerAttention # noqa from .local import LocalAttention # noqa from .nystrom import NystromAttention # noqa +from .ortho import OrthoFormerAttention # noqa from .random import RandomAttention # noqa from .scaled_dot_product import ScaledDotProduct # noqa @@ -93,6 +94,7 @@ def sparsify(matrix): "LinformerAttention", "NystromAttention", "RandomAttention", + "OrthoFormerAttention", "GlobalAttention", "FavorAttention", "Attention", diff --git a/xformers/components/attention/ortho.py b/xformers/components/attention/ortho.py new file mode 100644 index 0000000000..2ea0ee2711 --- /dev/null +++ b/xformers/components/attention/ortho.py @@ -0,0 +1,295 @@ +import logging +from dataclasses import dataclass +from enum import Enum +from typing import Optional + +import torch +import torch.autograd.profiler as profiler +import torch.nn as nn +import torch.nn.functional as Fn + +from xformers.components.attention import Attention, AttentionConfig, register_attention +from xformers.components.attention.core import ( + scaled_dot_product_attention, + scaled_query_key_softmax, +) + + +class LandmarkSelection(str, Enum): + Orthogonal = "orthogonal" + KMeans = "kmeans" + KMeans_Spherical = "kmeans_spherical" + Random = "random" + + +@dataclass +class OrthoformerAttentionConfig(AttentionConfig): + """ + num_landmarks Number of landmarks to use for softmax approximation. + subsample_fraction Percentage of q_samples matrix to sample per iteration + landmark_selection Landmark selection strategy + """ + + num_landmarks: Optional[int] + subsample_fraction: Optional[float] + landmark_selection: Optional[LandmarkSelection] + + +@register_attention("orthoformer", OrthoformerAttentionConfig) +class OrthoFormerAttention(Attention): + def __init__( + self, + dropout: float, + num_landmarks: int = 32, + subsample_fraction: float = 1.0, + landmark_selection: LandmarkSelection = LandmarkSelection.Orthogonal, + *args, + **kwargs, + ): + """ + Orthoformer attention mechanism, from + " + Keeping Your Eye on the Ball: Trajectory Attention in Video Transformers + Patrick, M., Campbell, D., Asano, Y., Misra, I., Metze, F., Feichtenhofer, C., Vedaldi, A., Henriques, J. (2021) + " + ArXiv: https://arxiv.org/abs/2106.05392 + Reference repository: https://github.com/facebookresearch/Motionformer + """ + super().__init__() + + self.num_landmarks = num_landmarks + self.attn_drop = nn.Dropout(dropout) + self.subsample_fraction = subsample_fraction + self.landmark_selection = landmark_selection + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_mask: Optional[torch.Tensor] = None, + *args, + **kwargs, + ): + N = k.shape[1] + + if self.num_landmarks == N: + # Default attention + x = scaled_dot_product_attention(q, k, v, att_mask) + else: + with torch.no_grad(), profiler.record_function("select landmarks"): + if self.landmark_selection == LandmarkSelection.Orthogonal: + landmarks = self._compute_orthogonal_landmarks(q) + elif self.landmark_selection == LandmarkSelection.Random: + half_L = self.num_landmarks // 2 + landmarks_q = q[:, torch.randint(q.size(1), (half_L,)), :] + landmarks_k = k[:, torch.randint(k.size(1), (half_L,)), :] + landmarks = torch.cat((landmarks_q, landmarks_k), dim=-2) + elif self.landmark_selection == LandmarkSelection.KMeans: + landmarks = self._cluster_landmarks(q) + elif self.landmark_selection == LandmarkSelection.KMeans_Spherical: + landmarks = self._cluster_landmarks(q, spherical=True) + + if att_mask is not None: + logging.warning( + "Orthoformer: attention mask passed alongside with using landmarks to reduce dimensions. \ + The two are typically not compatible" + ) + # FIXME: Should we still accept a mask in that case ? + att_mask = None + kernel_1 = scaled_query_key_softmax(q, landmarks, att_mask) + kernel_2 = scaled_query_key_softmax(landmarks, k, att_mask) + x = torch.matmul(kernel_1, torch.matmul(kernel_2, v)) + x = self.attn_drop(x) + return x + + def _cluster_landmarks( + self, + q: torch.Tensor, + spherical: bool = False, + num_iters: int = 6, + ) -> torch.Tensor: + """ + Construct set of landmarks by recursively selecting new landmarks + that are maximally orthogonal to the existing set. + Returns near orthogonal landmarks with shape (B, M, D). + """ + + if self.subsample_fraction < 1.0: + num_samples = max( + int(self.subsample_fraction * q.size(-2)), self.num_landmarks + ) # Need at least M/2 samples of queries and keys + q_samples = q[:, torch.randint(q.size(-2), (num_samples,)), :] # (B, N, D) + else: + q_samples = q # (B, N, D) + + if spherical: + q_samples_normalized = Fn.normalize( + q_samples, p=2, dim=-1 + ) # may need to change default eps to eps=1e-8 for mixed precision compatibility + landmarks = self._kmeans_spherical( + q_samples_normalized, self.num_landmarks, num_iters + ) + else: + landmarks = self._kmeans(q_samples, self.num_landmarks, num_iters) + return landmarks # (B, M, D) + + def _kmeans(self, x: torch.Tensor, K: int, num_iters: int = 10): + """ + Arguments: + x: (B, N, D) + K: number of clusters + num_iters: the number of kmeans updates + """ + + B, N, D = x.size() + assert K <= N + c = x[ + :, torch.randperm(N, device=x.device)[:K], : + ].clone() # initialisation for the centroids + + with profiler.record_function("kmeans"): + x_i = x.view(B, N, 1, D) + c_j = c.view(B, 1, K, D) + counts = c.new_zeros(B, K) + ones = x.new_ones((B, N)) + + for _ in range(num_iters): + # E step: assign points to the nearest cluster + D_ij = ((x_i - c_j) ** 2).sum(-1) # (B, N, K) squared distances + cl = D_ij.argmin( + dim=-1, keepdim=True + ).long() # (B, N, 1) index of point to nearest cluster + + # M step: update the centroids + c.zero_() + c.scatter_add_(-2, cl.repeat(1, 1, D), x) # sum of points per cluster + counts.fill_(1e-6) # avoid div0 + counts.scatter_add_( + -1, cl.squeeze(-1), ones + ) # number of points per cluster + c.divide_(counts.unsqueeze(-1)) # compute the average + + return c + + def _kmeans_spherical(self, x: torch.Tensor, K: int, num_iters=10): + """ + Arguments: + x: (B, N, D) + """ + B, N, D = x.size() + assert K <= N + + # initialisation for the centroids + c = x[:, torch.randperm(N, device=x.device)[:K], :].clone() + + with profiler.record_function("kmeans_spherical"): + counts = c.new_zeros(B, K) + ones = x.new_ones((B, N)) + + for _ in range(num_iters): + # E step: assign points to the nearest cluster + D_ij = torch.matmul( + x, c.transpose(-2, -1) + ) # (B, N, K) cosine similarity + cl = D_ij.argmax( + dim=-1, keepdim=True + ).long() # (B, N, 1) index of point to nearest cluster + + # M step: update the centroids + c.zero_() + c.scatter_add_(-2, cl.repeat(1, 1, D), x) # sum of points per cluster + counts.fill_(1e-6) # avoid div0 + counts.scatter_add_( + -1, cl.squeeze(-1), ones + ) # number of points per cluster + c.divide_(counts.unsqueeze(-1)) # compute the average + c = Fn.normalize(c, p=2, dim=-1) # renormalise + return c + + def _compute_orthogonal_landmarks(self, q: torch.Tensor) -> torch.Tensor: + """ + Construct set of landmarks by recursively selecting new landmarks + that are maximally orthogonal to the existing set. + Returns near orthogonal landmarks with shape (B, M, D). + """ + + if self.subsample_fraction < 1.0: + # Need at least M samples of queries + num_samples = max( + int(self.subsample_fraction * q.size(-2)), self.num_landmarks + ) + q_samples = q[ + :, torch.randint(q.size(-2), (num_samples,), device=q.device), : + ] + else: + # (B, N, D) + q_samples = q + + # may need to change default eps to eps=1e-8 for mixed precision compatibility + q_samples_normalized = Fn.normalize(q_samples, p=2, dim=-1) + B, N, D = q_samples_normalized.shape + + selected_mask = torch.zeros((B, N, 1), device=q_samples_normalized.device) + landmark_mask = torch.ones( + (B, 1, 1), dtype=selected_mask.dtype, device=q_samples_normalized.device + ) + + #  Get initial random landmark + random_idx = torch.randint( + q_samples_normalized.size(-2), (B, 1, 1), device=q_samples_normalized.device + ) + selected_mask.scatter_(-2, random_idx, landmark_mask) + + #  Selected landmarks + selected_landmarks = torch.empty( + (B, self.num_landmarks, D), + device=q_samples_normalized.device, + dtype=q_samples_normalized.dtype, + ) + selected_landmarks[:, 0, :] = q_samples_normalized[ + torch.arange(q_samples_normalized.size(0)), random_idx.view(-1), : + ].view(B, D) + + # Store computed cosine similarities + cos_sims = torch.empty( + (B, N, self.num_landmarks), + device=q_samples_normalized.device, + dtype=q_samples_normalized.dtype, + ) + + for M in range(1, self.num_landmarks): + with profiler.record_function("find new landmark"): + #  Calculate absolute cosine similarity between selected and unselected landmarks + # (B, N, D) * (B, D) -> (B, N) + cos_sims[:, :, M - 1] = torch.einsum( + "b n d, b d -> b n", + q_samples_normalized, + selected_landmarks[:, M - 1, :], + ).abs() + + # (B, N, M) cosine similarities of current set of landmarks wrt all queries and keys + cos_sim_set = cos_sims[:, :, :M] + + #  Get orthogonal landmark: landmark with smallest absolute cosine similarity: + # set cosine similarity for already selected landmarks to > 1 + cos_sim_set.view(-1, M)[selected_mask.flatten().bool(), :] = 10 + + # (B,) - want max for non + selected_landmark_idx = cos_sim_set.amax(-1).argmin(-1) + + #  Add most orthogonal landmark to selected landmarks: + selected_landmarks[:, M, :] = q_samples_normalized[ + torch.arange(q_samples_normalized.size(0)), selected_landmark_idx, : + ].view(B, D) + + #  Removed selected indices from non-selected mask: + selected_mask.scatter_( + -2, selected_landmark_idx.unsqueeze(-1).unsqueeze(-1), landmark_mask + ) + + # (B, M, D) + landmarks = torch.masked_select(q_samples, selected_mask.bool()).reshape( + B, -1, D + ) + return landmarks #  (B, M, D)