diff --git a/.gitignore b/.gitignore index 682b682cda..1cfc0cc7d8 100644 --- a/.gitignore +++ b/.gitignore @@ -43,3 +43,6 @@ my_runs.md # Watchman config files .watchmanconfig + +# examples demo files +examples/input.txt diff --git a/examples/microGPT.py b/examples/microGPT.py index 583a2906b6..6f546f002d 100644 --- a/examples/microGPT.py +++ b/examples/microGPT.py @@ -273,7 +273,7 @@ def top_k_logits(logits, k): REF_BATCH = 512 BATCH = 256 # adjust depending on the avaiable memory on your machine WORKERS = 8 - EPOCHS = 2 + EPOCHS = 1 BLOCK = 128 WARMUP = 20 @@ -298,10 +298,11 @@ def top_k_logits(logits, k): model = GPT( vocab_size=train_dataset.vocab_size, block_size=train_dataset.block_size, - attention="scaled_dot_product", + attention="nystrom", warmup_tokens=REF_BATCH * WARMUP, final_tokens=EPOCHS * len(train_dataset) * BLOCK, ) + print(model) trainer = Trainer( gpus=1, diff --git a/xformers/components/attention/_sputnik_sparse.py b/xformers/components/attention/_sputnik_sparse.py index 3de2baabc0..9737e90ca5 100644 --- a/xformers/components/attention/_sputnik_sparse.py +++ b/xformers/components/attention/_sputnik_sparse.py @@ -201,6 +201,10 @@ def __init__(self, matrix, device=None): def device(self): return self.values.device + @property + def ndim(self): + return len(self.shape) + @property def dtype(self): return self.values.dtype diff --git a/xformers/components/attention/core.py b/xformers/components/attention/core.py index c42204e486..187cc38bb1 100644 --- a/xformers/components/attention/core.py +++ b/xformers/components/attention/core.py @@ -17,7 +17,7 @@ from ._sputnik_sparse import SparseCS # NOTE: Could do with a better option on when to use triton and not -_use_triton = True +_use_triton = torch.cuda.is_available() if _use_triton: try: from xformers.triton.softmax import softmax as triton_softmax @@ -195,8 +195,12 @@ def scaled_query_key_softmax( # Self-attend: (N, S, hs) x (N, hs, S) -> (N, S, S) q = q / math.sqrt(k.size(-1)) - att = _matmul_with_mask(q, k.transpose(-2, -1), att_mask) + # Matmul with mask, if boolean + is_bool_mask = att_mask is not None and att_mask.dtype == torch.bool + att = _matmul_with_mask(q, k.transpose(-2, -1), att_mask if is_bool_mask else None) + + # Could also be that the mask was additive if att_mask is not None and att_mask.dtype != torch.bool: att = att + att_mask @@ -219,20 +223,27 @@ def scaled_dot_product_attention( or (att_mask is not None and att_mask.is_sparse) ) - if att_mask is not None and q.shape[-2] < att_mask.shape[0]: - # The sequence is smaller than the mask + # Try to handle a case where the sequence is smaller than the mask + if ( + att_mask is not None + and q.shape[-2] == k.shape[-2] + and q.shape[-2] < att_mask.shape[0] + ): seq = q.shape[-2] + if att_mask.ndim == 2: + att_mask = att_mask.unsqueeze(0) if not att_mask.is_sparse: - att_mask = att_mask[:seq, :seq] + att_mask = att_mask[:, :seq, :seq] else: logging.warning( "Mismatching attention mask and sequence length. On the fly correction but this will be slow" ) # Loosing sparsity on purpose, # expectation is that moving back and forth dense/sparse will negate the speedup - att_mask = att_mask.to_dense().squeeze(0)[:seq, :seq] + att_mask = att_mask.to_dense().squeeze(0)[:, :seq, :seq] + # The actual attention with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext(): if autocast_disabled: q, k, v = q.float(), k.float(), v.float() diff --git a/xformers/components/attention/nystrom.py b/xformers/components/attention/nystrom.py index aacd8327f0..d83eb1113f 100644 --- a/xformers/components/attention/nystrom.py +++ b/xformers/components/attention/nystrom.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. +import logging from dataclasses import dataclass from typing import Optional @@ -15,7 +16,11 @@ scaled_dot_product_attention, scaled_query_key_softmax, ) -from xformers.components.attention.utils import iterative_pinv, reshape_key_padding_mask +from xformers.components.attention.utils import ( + bool_mask_to_additive, + iterative_pinv, + reshape_key_padding_mask, +) @dataclass @@ -163,18 +168,26 @@ def forward( (batch size * num_heads, 1, sequence length). If dimensions are not correct, the mask will be ignored. """ + batched_dim = k.size(0) seq_len = k.size(-2) + tt = {"dtype": q.dtype, "device": q.device} if key_padding_mask is not None: - assert key_padding_mask.dtype == torch.bool + if key_padding_mask.dtype == torch.bool: + logging.warning( + "Bool mask found, but an additive mask is expected. Converting but this is slow" + ) + key_padding_mask = bool_mask_to_additive(key_padding_mask) + + assert key_padding_mask is not None # mypy is drunk if key_padding_mask.ndim == 2: key_padding_mask = reshape_key_padding_mask( key_padding_mask, batched_dim ) - assert key_padding_mask.size() == (batched_dim, 1, seq_len,), ( + assert key_padding_mask.size() == (batched_dim, 1, seq_len), ( f"key_padding_mask has invalid dimensions {key_padding_mask.size()}." f" Must have dimensions {batched_dim, 1, seq_len} or (batch_size, {seq_len})." ) @@ -182,13 +195,9 @@ def forward( if self.num_landmarks >= seq_len: mask: Optional[torch.Tensor] = None if self.causal: - mask = self._tril_mask(batched_dim, seq_len, seq_len) + mask = self._tril_mask(batched_dim, seq_len, seq_len, **tt) if key_padding_mask is not None: - mask = ( - key_padding_mask - if mask is None - else mask.logical_and(key_padding_mask) - ) + mask = key_padding_mask if mask is None else mask + key_padding_mask x = scaled_dot_product_attention(q=q, k=k, v=v, att_mask=mask) else: @@ -198,17 +207,17 @@ def forward( if self.causal and ( self.causal_mask_1 is None or (batched_dim, seq_len, self.num_landmarks) - != self.causal_mask_1.size() + != self.causal_mask_1.size()[1:] ): self.causal_mask_1 = self._tril_mask( - batched_dim, seq_len, self.num_landmarks - ).to(q.device) + batched_dim, seq_len, self.num_landmarks, **tt + ) self.causal_mask_2 = self._tril_mask( - batched_dim, self.num_landmarks, self.num_landmarks - ).to(q.device) + batched_dim, self.num_landmarks, self.num_landmarks, **tt + ) self.causal_mask_3 = self._tril_mask( - batched_dim, self.num_landmarks, seq_len - ).to(q.device) + batched_dim, self.num_landmarks, seq_len, **tt + ) mask_1: Optional[torch.Tensor] = self.causal_mask_1 mask_2: Optional[torch.Tensor] = self.causal_mask_2 @@ -217,12 +226,10 @@ def forward( mask_1 = ( key_padding_mask.transpose(-2, -1) if mask_1 is None - else mask_1.logical_and(key_padding_mask.transpose(-2, -1)) + else mask_1 + key_padding_mask.transpose(-2, -1) ) mask_3 = ( - key_padding_mask - if mask_3 is None - else mask_3.logical_and(key_padding_mask) + key_padding_mask if mask_3 is None else mask_3 + key_padding_mask ) kernel_1 = scaled_query_key_softmax(q=q, k=k_landmarks, att_mask=mask_1) @@ -258,5 +265,15 @@ def forward( x = self.attn_drop(x) return x - def _tril_mask(self, dim_1: int, dim_2: int, dim_3: int): - return torch.tril(torch.ones(dim_1, dim_2, dim_3, dtype=torch.bool), diagonal=0) + def _tril_mask(self, dim_1: int, dim_2: int, dim_3: int, **kwargs) -> torch.Tensor: + device = kwargs["device"] + dtype = kwargs["dtype"] + + return ( + torch.tril( + torch.ones(dim_3, dim_2, dtype=dtype, device=device) * float("-inf"), + diagonal=-1, + ) + .transpose(0, 1) + .expand(dim_1, -1, -1) # micro optim, save memory on the batch dimension + ) diff --git a/xformers/components/attention/utils.py b/xformers/components/attention/utils.py index c07c1c8e27..f97e3e4b94 100644 --- a/xformers/components/attention/utils.py +++ b/xformers/components/attention/utils.py @@ -87,3 +87,15 @@ def iterative_pinv(softmax_mat: torch.Tensor, n_iter=6, pinverse_original_init=F 13 * i - torch.matmul(kv, 15 * i - torch.matmul(kv, 7 * i - kv)), ) return v + + +def bool_mask_to_additive( + mask: torch.Tensor, dtype: Optional[torch.dtype] = torch.float32 +): + assert ( + mask.dtype == torch.bool + ), "This util is meant to convert in between bool masks and additive ones" + + mask_ = torch.zeros_like(mask, dtype=dtype) + mask_[~mask] = float("-inf") + return mask_