Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2.6.3 is faster than 2.7.0 for flash-attn v2 CUDA fwd/bwd #1335

Open
ds-kczerski opened this issue Nov 14, 2024 · 3 comments
Open

2.6.3 is faster than 2.7.0 for flash-attn v2 CUDA fwd/bwd #1335

ds-kczerski opened this issue Nov 14, 2024 · 3 comments

Comments

@ds-kczerski
Copy link

ds-kczerski commented Nov 14, 2024

Hey, I have observed in my timing tests that version 2.6.3 is faster than some later commits (including 2.7.0.post2) for below input sizes. For example, for small batch sizes (==2) and relatively small sequences, 2.6.3 is even 2x faster for me in the forward pass.

My setup: 4070 Laptop (CUDA 12) and A100 (CUDA 11), Torch 2.4. Both flash-attn versions were installed via pip install directly from PyPI. Below are results measured with a custom Python script with proper CUDA synchronization.

Screenshot from 2024-11-14 16-42-57

Minimal instructions to replicate:

# set-up environment for flash attention, install torch
pip install loguru
pip install pytest
pip install flash-attn==2.6.3 --no-build-isolation
pytest -s test_min_example.py 
pip uninstall flash-attn==2.6.3
pip install flash-attn==2.7.0.post2 --no-build-isolation
pytest -s test_min_example.py 

test_min_example.py

import time
from abc import ABC, abstractmethod
from math import sqrt
from typing import Optional, Tuple

import pytest
from loguru import logger

import torch
from torch import Tensor
from torch.nn import Module

try:
    from flash_attn import flash_attn_func
except ImportError as e:
    logger.error(f"ImportError: {e}")
    flash_attn_func = None

# Tensors dimensions (batch_size, seq_len (same for q and k), num_heads, embed_dim)
dimensions = [
    (512, 50, 16, 256),
    (512, 150, 16, 256),
    (512, 300, 16, 256),
    (2, 50, 16, 256),
    (2, 150, 16, 256),
    (2, 300, 16, 256),
]

dtypes = [
    torch.float16,
    #torch.bfloat16
]

class AttentionBackend(Module, ABC):
    def __init__(self, embed_dim: int, num_heads: int, dropout: float = .1):
        """
        :param embed_dim: the size of each embedding vector
        :param num_heads: number of heads
        :param dropout: attention dropout
        """
        assert not embed_dim % num_heads, 'embed_dim must be divisible by num_heads'
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self._scale = 1 / sqrt(embed_dim / num_heads)

    def unflatten(self, q: Tensor, k: Tensor, v: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        q = q.unflatten(2, (self.num_heads, -1))
        k = k.unflatten(2, (self.num_heads, -1))
        v = v.unflatten(2, (self.num_heads, -1))

        return q, k, v

    @abstractmethod
    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tuple[
        Tensor, Optional[Tensor]]:
        raise NotImplementedError("Forward method not implemented in subclass.")
    
    
class DummyFlashAttentionBackend(AttentionBackend):
    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tuple[
        Tensor, Optional[Tensor]]:
        q, k, v = self.unflatten(q, k, v)
        o = flash_attn_func(q, k, v, softmax_scale=self._scale, dropout_p=self.dropout)
        o = o.flatten(2)
        return o, None

def measure_forward_time(q,k,v, backend, num_runs = 10):
    times = []
    for _ in range(2):  # Warm-up runs
        output, _ = backend(q, k, v)
        torch.cuda.synchronize()

    for _ in range(num_runs):  # Timed runs
        torch.cuda.synchronize()
        start_time = time.time()
        output, _ = backend(q, k, v)
        torch.cuda.synchronize()
        end_time = time.time()
        times.append(end_time - start_time)

    secs_to_microseconds = 1000000
    return (sum(times) / num_runs) * secs_to_microseconds

def measure_backward_time(q,k,v, backend, num_runs = 10):
    times = []
    for _ in range(2):  # Warm-up runs
        output, _ = backend(q, k, v)
        loss = torch.sum(output)  # Dummy loss for backward pass
        loss.backward(retain_graph=True)
        torch.cuda.synchronize()

    for _ in range(num_runs):  # Timed runs
        torch.cuda.synchronize()
        output, _ = backend(q, k, v)
        loss = torch.sum(output)  # Dummy loss for backward pass
        torch.cuda.synchronize()
        start_time = time.time()
        loss.backward(retain_graph=True)
        torch.cuda.synchronize()
        end_time = time.time()
        times.append(end_time - start_time)

    secs_to_microseconds = 1000000
    return (sum(times) / num_runs) * secs_to_microseconds

def create_random_tensors_with_embeddings(batch_size, seq_len, embed_dim, num_heads, device, dtype):
    q = torch.randn((batch_size, seq_len, embed_dim), device=device, dtype=dtype, requires_grad=True)
    k = torch.randn((batch_size, seq_len, embed_dim), device=device, dtype=dtype, requires_grad=True)
    v = torch.randn((batch_size, seq_len, embed_dim), device=device, dtype=dtype, requires_grad=True)
    return q, k, v

@pytest.mark.parametrize("batch_size, seq_len, num_heads, embed_dim", dimensions)
@pytest.mark.parametrize("dtype", dtypes)
def test_timing_forward_pass(batch_size, seq_len, num_heads, embed_dim, dtype):
    # Compare flash attention timings
    if not torch.cuda.is_available():
        pytest.skip("CUDA is not available")

    torch.manual_seed(42)
    device = torch.device('cuda')

    dropout = 0  # disable dropout

    q, k, v = create_random_tensors_with_embeddings(
        batch_size, seq_len, embed_dim, num_heads, device, dtype
    )

    dummy_kernel_backend = DummyFlashAttentionBackend(embed_dim, num_heads, dropout=dropout).to(device)
    avg_dummy_time = measure_forward_time(q, k, v, dummy_kernel_backend)
    del dummy_kernel_backend

    # Log the results
    logger.info(f"Configuration: batch_size={batch_size}, seq_len={seq_len}, "
                f"num_heads={num_heads}, embed_dim={embed_dim}, dtype={dtype}")
    logger.info(f"[Forward] Average time for Dummy backend: {avg_dummy_time:.2f} microsecs")

    # Pass the test since it's for timing comparison, no correctness checking
    assert True  # As noted, the output comparison is not relevant here

@pytest.mark.parametrize("batch_size, seq_len, num_heads, embed_dim", dimensions)
@pytest.mark.parametrize("dtype", dtypes)
def test_timing_backward_pass(batch_size, seq_len, num_heads, embed_dim, dtype):
    # Compare flash attention timings
    if not torch.cuda.is_available():
        pytest.skip("CUDA is not available")

    torch.manual_seed(42)
    device = torch.device('cuda')

    dropout = 0  # disable dropout

    q, k, v = create_random_tensors_with_embeddings(
        batch_size, seq_len, embed_dim, num_heads, device, dtype
    )

    dummy_kernel_backend = DummyFlashAttentionBackend(embed_dim, num_heads, dropout=dropout).to(device)
    avg_dummy_time = measure_backward_time(q, k, v, dummy_kernel_backend)
    del dummy_kernel_backend

    # Log the results
    logger.info(f"Configuration: batch_size={batch_size}, seq_len={seq_len}, "
                f"num_heads={num_heads}, embed_dim={embed_dim}, dtype={dtype}")
    logger.info(f"[Backward] Average time for Dummy backend: {avg_dummy_time:.2f} nanosecs")

    # Pass the test since it's for timing comparison, no correctness checking
    assert True  # As noted, the output comparison is not relevant here

Could you please help me understand what might be the source of these timing differences? When going through the source code, it seems to me that the kernel code is the same, the CUTLASS submodule repo pointer is the same, and the only changes are in the API in C++/Python, which relate to head, head_size_og, and padding. Also, my embedding sizes and head numbers are divisible by 8.

@tridao
Copy link
Contributor

tridao commented Nov 14, 2024

I'm guessing it's because we moved some of the checks and padding (i.e. checking if headdim not a multiple 8) from C++ to Python for compatibility with torch compile. This might add a bit more Python overhead so it's noticable for small batch and short sequences (since the kernel will be very fast there).
You can try torch compiling it to reduce the overhead in this case.

@tridao
Copy link
Contributor

tridao commented Nov 14, 2024

What would be helpful is to get the profiler result (e.g. pytorch profiler or nsight systems) to see the kernel time. e.g. if the kernel time stays the same then we can say it's because of Python overhead. If the kernel time is very different then we'll need to investigate.

@ds-kczerski
Copy link
Author

ds-kczerski commented Nov 15, 2024

Hey, thanks for the quick reply!

I’ve been profiling with nsys on the A100 and can conclude that it’s likely Python overhead, as the kernel times appear identical for both versions 2.6.3 and 2.7.0post2. I’m checking forward/backward passes for the same dimensions as mentioned earlier. Unfortunately, it seems that Python overhead becomes quite significant, especially when targeting smaller Q/K lengths and/or batch sizes.

You can try torch compiling it to reduce the overhead in this case.

Yeah, we should introduce it as a baseline I guess. Will test it soon. ATM, this thread can be closed :) Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants