Skip to content

Commit

Permalink
Update docs for sdpa_kernel (#35410)
Browse files Browse the repository at this point in the history
update: sdp_kernel -> sdpa_kernel
  • Loading branch information
jla524 authored Dec 30, 2024
1 parent 5cabc75 commit b5f9797
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -332,10 +332,11 @@ In that case, you should see a warning message and we will fall back to the (slo

</Tip>

By default, SDPA selects the most performant kernel available but you can check whether a backend is available in a given setting (hardware, problem size) with [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager:
By default, SDPA selects the most performant kernel available but you can check whether a backend is available in a given setting (hardware, problem size) with [`torch.nn.attention.sdpa_kernel`](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) as a context manager:

```diff
import torch
+ from torch.nn.attention import SDPBackend, sdpa_kernel
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
Expand All @@ -344,7 +345,7 @@ model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=to
input_text = "Hello my dog is cute and"
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")

+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
+ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
outputs = model.generate(**inputs)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Expand Down Expand Up @@ -518,6 +519,7 @@ It is often possible to combine several of the optimization techniques described

```py
import torch
from torch.nn.attention import SDPBackend, sdpa_kernel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# load model in 4-bit
Expand All @@ -536,7 +538,7 @@ input_text = "Hello my dog is cute and"
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")

# enable FlashAttention
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
outputs = model.generate(**inputs)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Expand Down

0 comments on commit b5f9797

Please sign in to comment.