Skip to content

Commit

Permalink
Now running
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Nov 4, 2021
1 parent 962db66 commit 6f626d7
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 30 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,6 @@ my_runs.md

# Watchman config files
.watchmanconfig

# examples demo files
examples/input.txt
5 changes: 3 additions & 2 deletions examples/microGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def top_k_logits(logits, k):
REF_BATCH = 512
BATCH = 256 # adjust depending on the avaiable memory on your machine
WORKERS = 8
EPOCHS = 2
EPOCHS = 1
BLOCK = 128
WARMUP = 20

Expand All @@ -298,10 +298,11 @@ def top_k_logits(logits, k):
model = GPT(
vocab_size=train_dataset.vocab_size,
block_size=train_dataset.block_size,
attention="scaled_dot_product",
attention="nystrom",
warmup_tokens=REF_BATCH * WARMUP,
final_tokens=EPOCHS * len(train_dataset) * BLOCK,
)
print(model)

trainer = Trainer(
gpus=1,
Expand Down
4 changes: 4 additions & 0 deletions xformers/components/attention/_sputnik_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ def __init__(self, matrix, device=None):
def device(self):
return self.values.device

@property
def ndim(self):
return len(self.shape)

@property
def dtype(self):
return self.values.dtype
Expand Down
23 changes: 17 additions & 6 deletions xformers/components/attention/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ._sputnik_sparse import SparseCS

# NOTE: Could do with a better option on when to use triton and not
_use_triton = True
_use_triton = torch.cuda.is_available()
if _use_triton:
try:
from xformers.triton.softmax import softmax as triton_softmax
Expand Down Expand Up @@ -195,8 +195,12 @@ def scaled_query_key_softmax(

# Self-attend: (N, S, hs) x (N, hs, S) -> (N, S, S)
q = q / math.sqrt(k.size(-1))
att = _matmul_with_mask(q, k.transpose(-2, -1), att_mask)

# Matmul with mask, if boolean
is_bool_mask = att_mask is not None and att_mask.dtype == torch.bool
att = _matmul_with_mask(q, k.transpose(-2, -1), att_mask if is_bool_mask else None)

# Could also be that the mask was additive
if att_mask is not None and att_mask.dtype != torch.bool:
att = att + att_mask

Expand All @@ -219,20 +223,27 @@ def scaled_dot_product_attention(
or (att_mask is not None and att_mask.is_sparse)
)

if att_mask is not None and q.shape[-2] < att_mask.shape[0]:
# The sequence is smaller than the mask
# Try to handle a case where the sequence is smaller than the mask
if (
att_mask is not None
and q.shape[-2] == k.shape[-2]
and q.shape[-2] < att_mask.shape[0]
):
seq = q.shape[-2]
if att_mask.ndim == 2:
att_mask = att_mask.unsqueeze(0)

if not att_mask.is_sparse:
att_mask = att_mask[:seq, :seq]
att_mask = att_mask[:, :seq, :seq]
else:
logging.warning(
"Mismatching attention mask and sequence length. On the fly correction but this will be slow"
)
# Loosing sparsity on purpose,
# expectation is that moving back and forth dense/sparse will negate the speedup
att_mask = att_mask.to_dense().squeeze(0)[:seq, :seq]
att_mask = att_mask.to_dense().squeeze(0)[:, :seq, :seq]

# The actual attention
with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext():
if autocast_disabled:
q, k, v = q.float(), k.float(), v.float()
Expand Down
61 changes: 39 additions & 22 deletions xformers/components/attention/nystrom.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.


import logging
from dataclasses import dataclass
from typing import Optional

Expand All @@ -15,7 +16,11 @@
scaled_dot_product_attention,
scaled_query_key_softmax,
)
from xformers.components.attention.utils import iterative_pinv, reshape_key_padding_mask
from xformers.components.attention.utils import (
bool_mask_to_additive,
iterative_pinv,
reshape_key_padding_mask,
)


@dataclass
Expand Down Expand Up @@ -163,32 +168,36 @@ def forward(
(batch size * num_heads, 1, sequence length). If dimensions are not correct, the mask will
be ignored.
"""

batched_dim = k.size(0)
seq_len = k.size(-2)
tt = {"dtype": q.dtype, "device": q.device}

if key_padding_mask is not None:
assert key_padding_mask.dtype == torch.bool
if key_padding_mask.dtype == torch.bool:
logging.warning(
"Bool mask found, but an additive mask is expected. Converting but this is slow"
)
key_padding_mask = bool_mask_to_additive(key_padding_mask)

assert key_padding_mask is not None # mypy is drunk

if key_padding_mask.ndim == 2:
key_padding_mask = reshape_key_padding_mask(
key_padding_mask, batched_dim
)

assert key_padding_mask.size() == (batched_dim, 1, seq_len,), (
assert key_padding_mask.size() == (batched_dim, 1, seq_len), (
f"key_padding_mask has invalid dimensions {key_padding_mask.size()}."
f" Must have dimensions {batched_dim, 1, seq_len} or (batch_size, {seq_len})."
)

if self.num_landmarks >= seq_len:
mask: Optional[torch.Tensor] = None
if self.causal:
mask = self._tril_mask(batched_dim, seq_len, seq_len)
mask = self._tril_mask(batched_dim, seq_len, seq_len, **tt)
if key_padding_mask is not None:
mask = (
key_padding_mask
if mask is None
else mask.logical_and(key_padding_mask)
)
mask = key_padding_mask if mask is None else mask + key_padding_mask
x = scaled_dot_product_attention(q=q, k=k, v=v, att_mask=mask)

else:
Expand All @@ -198,17 +207,17 @@ def forward(
if self.causal and (
self.causal_mask_1 is None
or (batched_dim, seq_len, self.num_landmarks)
!= self.causal_mask_1.size()
!= self.causal_mask_1.size()[1:]
):
self.causal_mask_1 = self._tril_mask(
batched_dim, seq_len, self.num_landmarks
).to(q.device)
batched_dim, seq_len, self.num_landmarks, **tt
)
self.causal_mask_2 = self._tril_mask(
batched_dim, self.num_landmarks, self.num_landmarks
).to(q.device)
batched_dim, self.num_landmarks, self.num_landmarks, **tt
)
self.causal_mask_3 = self._tril_mask(
batched_dim, self.num_landmarks, seq_len
).to(q.device)
batched_dim, self.num_landmarks, seq_len, **tt
)

mask_1: Optional[torch.Tensor] = self.causal_mask_1
mask_2: Optional[torch.Tensor] = self.causal_mask_2
Expand All @@ -217,12 +226,10 @@ def forward(
mask_1 = (
key_padding_mask.transpose(-2, -1)
if mask_1 is None
else mask_1.logical_and(key_padding_mask.transpose(-2, -1))
else mask_1 + key_padding_mask.transpose(-2, -1)
)
mask_3 = (
key_padding_mask
if mask_3 is None
else mask_3.logical_and(key_padding_mask)
key_padding_mask if mask_3 is None else mask_3 + key_padding_mask
)

kernel_1 = scaled_query_key_softmax(q=q, k=k_landmarks, att_mask=mask_1)
Expand Down Expand Up @@ -258,5 +265,15 @@ def forward(
x = self.attn_drop(x)
return x

def _tril_mask(self, dim_1: int, dim_2: int, dim_3: int):
return torch.tril(torch.ones(dim_1, dim_2, dim_3, dtype=torch.bool), diagonal=0)
def _tril_mask(self, dim_1: int, dim_2: int, dim_3: int, **kwargs) -> torch.Tensor:
device = kwargs["device"]
dtype = kwargs["dtype"]

return (
torch.tril(
torch.ones(dim_3, dim_2, dtype=dtype, device=device) * float("-inf"),
diagonal=-1,
)
.transpose(0, 1)
.expand(dim_1, -1, -1) # micro optim, save memory on the batch dimension
)
12 changes: 12 additions & 0 deletions xformers/components/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,15 @@ def iterative_pinv(softmax_mat: torch.Tensor, n_iter=6, pinverse_original_init=F
13 * i - torch.matmul(kv, 15 * i - torch.matmul(kv, 7 * i - kv)),
)
return v


def bool_mask_to_additive(
mask: torch.Tensor, dtype: Optional[torch.dtype] = torch.float32
):
assert (
mask.dtype == torch.bool
), "This util is meant to convert in between bool masks and additive ones"

mask_ = torch.zeros_like(mask, dtype=dtype)
mask_[~mask] = float("-inf")
return mask_

0 comments on commit 6f626d7

Please sign in to comment.