diff --git a/examples/microGPT.py b/examples/microGPT.py index 2a9d9b38a2..048ae962a6 100644 --- a/examples/microGPT.py +++ b/examples/microGPT.py @@ -21,7 +21,7 @@ class GPT(pl.LightningModule): - """ the full GPT language model, with a context size of block_size """ + """the full GPT language model, with a context size of block_size""" def __init__( self, diff --git a/tests/test_triton_blocksparse.py b/tests/test_triton_blocksparse.py index b76a3478fb..a0ffe78595 100644 --- a/tests/test_triton_blocksparse.py +++ b/tests/test_triton_blocksparse.py @@ -98,7 +98,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=32, H=2, M=512, N=384, K @pytest.mark.skipif(not _triton_available, reason="Triton requires a recent CUDA gpu") -@pytest.mark.parametrize("BLOCK", [32]) +@pytest.mark.parametrize("BLOCK", [32, 128]) @pytest.mark.parametrize("WIDTH", [256, 576, 1024, 1792]) @pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32]) def test_softmax(BLOCK, WIDTH, DTYPE): @@ -127,12 +127,12 @@ def test_softmax(BLOCK, WIDTH, DTYPE): @pytest.mark.skipif(not _triton_available, reason="Triton requires a recent CUDA gpu") -@pytest.mark.parametrize("block", [32, 43]) # 16, 32, +@pytest.mark.parametrize("block", [32, 43, 128]) # 16, 32, def test_attention_fwd_bwd( block, input_scale=1.0, scale=1 / 8.0, - n_ctx=256, + n_ctx=384, dtype=torch.float16, batch_size=2, n_heads=2, @@ -152,12 +152,14 @@ def loss_fn(x): # Triton: n_blocks = n_ctx // block - layout = torch.tril(torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long)) + layout = torch.tril( + torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long), diagonal=-1 + ) query, key, value = [x.clone() for x in qkvs] query.retain_grad() key.retain_grad() value.retain_grad() - if block not in [16, 32, 64]: + if block not in [16, 32, 64, 128]: # Check that unsupported dimensions are caught with pytest.raises(AssertionError): _ = BlockSparseAttention(layout, block) diff --git a/tests/test_triton_fused_linear.py b/tests/test_triton_fused_linear.py index dadb254b5a..a7789824d5 100644 --- a/tests/test_triton_fused_linear.py +++ b/tests/test_triton_fused_linear.py @@ -38,7 +38,7 @@ "dtype", [torch.float32] ) # Triton use tensor cores, which return slightly different results to pytorch mm def test_fused_matmul(shape, dtype): - """ Check that the matrix multiply kernel and Pytorch's give the same results""" + """Check that the matrix multiply kernel and Pytorch's give the same results""" torch.random.manual_seed(0) # Raw fused matrix multiply first, to catch gross errors diff --git a/xformers/benchmarks/utils.py b/xformers/benchmarks/utils.py index 5d8fd263ec..4598bc20ab 100644 --- a/xformers/benchmarks/utils.py +++ b/xformers/benchmarks/utils.py @@ -26,7 +26,7 @@ def pretty_print(results, title, units): - """ Printout the contents of a dict as a human-readable and Markdown compatible array""" + """Printout the contents of a dict as a human-readable and Markdown compatible array""" print(title) header = " Units: {:<45}".format(units) print("| " + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys())) diff --git a/xformers/components/attention/blocksparse.py b/xformers/components/attention/blocksparse.py index 041fc6d19c..12094fdd57 100644 --- a/xformers/components/attention/blocksparse.py +++ b/xformers/components/attention/blocksparse.py @@ -82,7 +82,8 @@ def __init__( 16, 32, 64, - ), "Only block sizes in [16, 32, 64] are supported" + 128, + ), "Only block sizes in [16, 32, 64, 128] are supported" super().__init__() @@ -112,6 +113,32 @@ def update_mask_type(self, mask: torch.Tensor): ) mask = bool_mask_to_additive(mask) + def create_triton_kernels(self, device): + # blocksparse operators + self.sparse_dot_sdd = blocksparse_matmul( + self.layout, + self.block_size, + "sdd", + trans_a=False, + trans_b=True, + device=device, + ) + + self.sparse_dot_dsd = blocksparse_matmul( + self.layout, + self.block_size, + "dsd", + trans_a=False, + trans_b=False, + device=device, + ) + + self.sparse_softmax = blocksparse_softmax( + self.layout, + self.block_size, + device=device, + ) + def forward( self, q: torch.Tensor, @@ -132,31 +159,9 @@ def forward( """ # Delayed triton init, to make sure that we get the right device + # Infer device from query if not hasattr(self, "sparse_dot_sdd"): - # blocksparse operators - self.sparse_dot_sdd = blocksparse_matmul( - self.layout, - self.block_size, - "sdd", - trans_a=False, - trans_b=True, - device=q.device, - ) - - self.sparse_dot_dsd = blocksparse_matmul( - self.layout, - self.block_size, - "dsd", - trans_a=False, - trans_b=False, - device=q.device, - ) - - self.sparse_softmax = blocksparse_softmax( - self.layout, - self.block_size, - device=q.device, - ) + self.create_triton_kernels(q.device) assert ( q.shape[-2] == k.shape[-2] @@ -169,10 +174,10 @@ def forward( k.shape[-2] == self.layout.shape[-2] * self.block_size ), "Actual sequence size and layout are inconsistent" - assert math.log( - q.shape[-2], 2 - ).is_integer(), ( - "For now blocksparse only works on power-of-two sequence lengths" + assert ( + q.shape[-2] % self.block_size + ) == 0, "Sequence length {} must be a multiple of block size {}".format( + q.shape[-2], self.block_size ) # Blocksparse only works on fp16 diff --git a/xformers/utils.py b/xformers/utils.py index f2db7e808e..7a7dd44bae 100644 --- a/xformers/utils.py +++ b/xformers/utils.py @@ -91,7 +91,7 @@ def rmf(filename: str) -> None: @contextlib.contextmanager def temp_files_ctx(num: int) -> Generator: - """ A context to get tempfiles and ensure they are cleaned up. """ + """A context to get tempfiles and ensure they are cleaned up.""" files = [tempfile.mkstemp()[1] for _ in range(num)] yield tuple(files)