forked from facebookresearch/xformers
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feat] Orthoformer attention (facebookresearch#164)
* Orthoformer attention Author: Mandela Patrick et al. * only select landmarks amid the queries
- Loading branch information
1 parent
9273d7c
commit 77883ce
Showing
5 changed files
with
310 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,295 @@ | ||
import logging | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
from typing import Optional | ||
|
||
import torch | ||
import torch.autograd.profiler as profiler | ||
import torch.nn as nn | ||
import torch.nn.functional as Fn | ||
|
||
from xformers.components.attention import Attention, AttentionConfig, register_attention | ||
from xformers.components.attention.core import ( | ||
scaled_dot_product_attention, | ||
scaled_query_key_softmax, | ||
) | ||
|
||
|
||
class LandmarkSelection(str, Enum): | ||
Orthogonal = "orthogonal" | ||
KMeans = "kmeans" | ||
KMeans_Spherical = "kmeans_spherical" | ||
Random = "random" | ||
|
||
|
||
@dataclass | ||
class OrthoformerAttentionConfig(AttentionConfig): | ||
""" | ||
num_landmarks Number of landmarks to use for softmax approximation. | ||
subsample_fraction Percentage of q_samples matrix to sample per iteration | ||
landmark_selection Landmark selection strategy | ||
""" | ||
|
||
num_landmarks: Optional[int] | ||
subsample_fraction: Optional[float] | ||
landmark_selection: Optional[LandmarkSelection] | ||
|
||
|
||
@register_attention("orthoformer", OrthoformerAttentionConfig) | ||
class OrthoFormerAttention(Attention): | ||
def __init__( | ||
self, | ||
dropout: float, | ||
num_landmarks: int = 32, | ||
subsample_fraction: float = 1.0, | ||
landmark_selection: LandmarkSelection = LandmarkSelection.Orthogonal, | ||
*args, | ||
**kwargs, | ||
): | ||
""" | ||
Orthoformer attention mechanism, from | ||
" | ||
Keeping Your Eye on the Ball: Trajectory Attention in Video Transformers | ||
Patrick, M., Campbell, D., Asano, Y., Misra, I., Metze, F., Feichtenhofer, C., Vedaldi, A., Henriques, J. (2021) | ||
" | ||
ArXiv: https://arxiv.org/abs/2106.05392 | ||
Reference repository: https://github.com/facebookresearch/Motionformer | ||
""" | ||
super().__init__() | ||
|
||
self.num_landmarks = num_landmarks | ||
self.attn_drop = nn.Dropout(dropout) | ||
self.subsample_fraction = subsample_fraction | ||
self.landmark_selection = landmark_selection | ||
|
||
def forward( | ||
self, | ||
q: torch.Tensor, | ||
k: torch.Tensor, | ||
v: torch.Tensor, | ||
att_mask: Optional[torch.Tensor] = None, | ||
*args, | ||
**kwargs, | ||
): | ||
N = k.shape[1] | ||
|
||
if self.num_landmarks == N: | ||
# Default attention | ||
x = scaled_dot_product_attention(q, k, v, att_mask) | ||
else: | ||
with torch.no_grad(), profiler.record_function("select landmarks"): | ||
if self.landmark_selection == LandmarkSelection.Orthogonal: | ||
landmarks = self._compute_orthogonal_landmarks(q) | ||
elif self.landmark_selection == LandmarkSelection.Random: | ||
half_L = self.num_landmarks // 2 | ||
landmarks_q = q[:, torch.randint(q.size(1), (half_L,)), :] | ||
landmarks_k = k[:, torch.randint(k.size(1), (half_L,)), :] | ||
landmarks = torch.cat((landmarks_q, landmarks_k), dim=-2) | ||
elif self.landmark_selection == LandmarkSelection.KMeans: | ||
landmarks = self._cluster_landmarks(q) | ||
elif self.landmark_selection == LandmarkSelection.KMeans_Spherical: | ||
landmarks = self._cluster_landmarks(q, spherical=True) | ||
|
||
if att_mask is not None: | ||
logging.warning( | ||
"Orthoformer: attention mask passed alongside with using landmarks to reduce dimensions. \ | ||
The two are typically not compatible" | ||
) | ||
# FIXME: Should we still accept a mask in that case ? | ||
att_mask = None | ||
kernel_1 = scaled_query_key_softmax(q, landmarks, att_mask) | ||
kernel_2 = scaled_query_key_softmax(landmarks, k, att_mask) | ||
x = torch.matmul(kernel_1, torch.matmul(kernel_2, v)) | ||
x = self.attn_drop(x) | ||
return x | ||
|
||
def _cluster_landmarks( | ||
self, | ||
q: torch.Tensor, | ||
spherical: bool = False, | ||
num_iters: int = 6, | ||
) -> torch.Tensor: | ||
""" | ||
Construct set of landmarks by recursively selecting new landmarks | ||
that are maximally orthogonal to the existing set. | ||
Returns near orthogonal landmarks with shape (B, M, D). | ||
""" | ||
|
||
if self.subsample_fraction < 1.0: | ||
num_samples = max( | ||
int(self.subsample_fraction * q.size(-2)), self.num_landmarks | ||
) # Need at least M/2 samples of queries and keys | ||
q_samples = q[:, torch.randint(q.size(-2), (num_samples,)), :] # (B, N, D) | ||
else: | ||
q_samples = q # (B, N, D) | ||
|
||
if spherical: | ||
q_samples_normalized = Fn.normalize( | ||
q_samples, p=2, dim=-1 | ||
) # may need to change default eps to eps=1e-8 for mixed precision compatibility | ||
landmarks = self._kmeans_spherical( | ||
q_samples_normalized, self.num_landmarks, num_iters | ||
) | ||
else: | ||
landmarks = self._kmeans(q_samples, self.num_landmarks, num_iters) | ||
return landmarks # (B, M, D) | ||
|
||
def _kmeans(self, x: torch.Tensor, K: int, num_iters: int = 10): | ||
""" | ||
Arguments: | ||
x: (B, N, D) | ||
K: number of clusters | ||
num_iters: the number of kmeans updates | ||
""" | ||
|
||
B, N, D = x.size() | ||
assert K <= N | ||
c = x[ | ||
:, torch.randperm(N, device=x.device)[:K], : | ||
].clone() # initialisation for the centroids | ||
|
||
with profiler.record_function("kmeans"): | ||
x_i = x.view(B, N, 1, D) | ||
c_j = c.view(B, 1, K, D) | ||
counts = c.new_zeros(B, K) | ||
ones = x.new_ones((B, N)) | ||
|
||
for _ in range(num_iters): | ||
# E step: assign points to the nearest cluster | ||
D_ij = ((x_i - c_j) ** 2).sum(-1) # (B, N, K) squared distances | ||
cl = D_ij.argmin( | ||
dim=-1, keepdim=True | ||
).long() # (B, N, 1) index of point to nearest cluster | ||
|
||
# M step: update the centroids | ||
c.zero_() | ||
c.scatter_add_(-2, cl.repeat(1, 1, D), x) # sum of points per cluster | ||
counts.fill_(1e-6) # avoid div0 | ||
counts.scatter_add_( | ||
-1, cl.squeeze(-1), ones | ||
) # number of points per cluster | ||
c.divide_(counts.unsqueeze(-1)) # compute the average | ||
|
||
return c | ||
|
||
def _kmeans_spherical(self, x: torch.Tensor, K: int, num_iters=10): | ||
""" | ||
Arguments: | ||
x: (B, N, D) | ||
""" | ||
B, N, D = x.size() | ||
assert K <= N | ||
|
||
# initialisation for the centroids | ||
c = x[:, torch.randperm(N, device=x.device)[:K], :].clone() | ||
|
||
with profiler.record_function("kmeans_spherical"): | ||
counts = c.new_zeros(B, K) | ||
ones = x.new_ones((B, N)) | ||
|
||
for _ in range(num_iters): | ||
# E step: assign points to the nearest cluster | ||
D_ij = torch.matmul( | ||
x, c.transpose(-2, -1) | ||
) # (B, N, K) cosine similarity | ||
cl = D_ij.argmax( | ||
dim=-1, keepdim=True | ||
).long() # (B, N, 1) index of point to nearest cluster | ||
|
||
# M step: update the centroids | ||
c.zero_() | ||
c.scatter_add_(-2, cl.repeat(1, 1, D), x) # sum of points per cluster | ||
counts.fill_(1e-6) # avoid div0 | ||
counts.scatter_add_( | ||
-1, cl.squeeze(-1), ones | ||
) # number of points per cluster | ||
c.divide_(counts.unsqueeze(-1)) # compute the average | ||
c = Fn.normalize(c, p=2, dim=-1) # renormalise | ||
return c | ||
|
||
def _compute_orthogonal_landmarks(self, q: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Construct set of landmarks by recursively selecting new landmarks | ||
that are maximally orthogonal to the existing set. | ||
Returns near orthogonal landmarks with shape (B, M, D). | ||
""" | ||
|
||
if self.subsample_fraction < 1.0: | ||
# Need at least M samples of queries | ||
num_samples = max( | ||
int(self.subsample_fraction * q.size(-2)), self.num_landmarks | ||
) | ||
q_samples = q[ | ||
:, torch.randint(q.size(-2), (num_samples,), device=q.device), : | ||
] | ||
else: | ||
# (B, N, D) | ||
q_samples = q | ||
|
||
# may need to change default eps to eps=1e-8 for mixed precision compatibility | ||
q_samples_normalized = Fn.normalize(q_samples, p=2, dim=-1) | ||
B, N, D = q_samples_normalized.shape | ||
|
||
selected_mask = torch.zeros((B, N, 1), device=q_samples_normalized.device) | ||
landmark_mask = torch.ones( | ||
(B, 1, 1), dtype=selected_mask.dtype, device=q_samples_normalized.device | ||
) | ||
|
||
# Get initial random landmark | ||
random_idx = torch.randint( | ||
q_samples_normalized.size(-2), (B, 1, 1), device=q_samples_normalized.device | ||
) | ||
selected_mask.scatter_(-2, random_idx, landmark_mask) | ||
|
||
# Selected landmarks | ||
selected_landmarks = torch.empty( | ||
(B, self.num_landmarks, D), | ||
device=q_samples_normalized.device, | ||
dtype=q_samples_normalized.dtype, | ||
) | ||
selected_landmarks[:, 0, :] = q_samples_normalized[ | ||
torch.arange(q_samples_normalized.size(0)), random_idx.view(-1), : | ||
].view(B, D) | ||
|
||
# Store computed cosine similarities | ||
cos_sims = torch.empty( | ||
(B, N, self.num_landmarks), | ||
device=q_samples_normalized.device, | ||
dtype=q_samples_normalized.dtype, | ||
) | ||
|
||
for M in range(1, self.num_landmarks): | ||
with profiler.record_function("find new landmark"): | ||
# Calculate absolute cosine similarity between selected and unselected landmarks | ||
# (B, N, D) * (B, D) -> (B, N) | ||
cos_sims[:, :, M - 1] = torch.einsum( | ||
"b n d, b d -> b n", | ||
q_samples_normalized, | ||
selected_landmarks[:, M - 1, :], | ||
).abs() | ||
|
||
# (B, N, M) cosine similarities of current set of landmarks wrt all queries and keys | ||
cos_sim_set = cos_sims[:, :, :M] | ||
|
||
# Get orthogonal landmark: landmark with smallest absolute cosine similarity: | ||
# set cosine similarity for already selected landmarks to > 1 | ||
cos_sim_set.view(-1, M)[selected_mask.flatten().bool(), :] = 10 | ||
|
||
# (B,) - want max for non | ||
selected_landmark_idx = cos_sim_set.amax(-1).argmin(-1) | ||
|
||
# Add most orthogonal landmark to selected landmarks: | ||
selected_landmarks[:, M, :] = q_samples_normalized[ | ||
torch.arange(q_samples_normalized.size(0)), selected_landmark_idx, : | ||
].view(B, D) | ||
|
||
# Removed selected indices from non-selected mask: | ||
selected_mask.scatter_( | ||
-2, selected_landmark_idx.unsqueeze(-1).unsqueeze(-1), landmark_mask | ||
) | ||
|
||
# (B, M, D) | ||
landmarks = torch.masked_select(q_samples, selected_mask.bool()).reshape( | ||
B, -1, D | ||
) | ||
return landmarks # (B, M, D) |