Skip to content

Commit

Permalink
add autocast patch
Browse files Browse the repository at this point in the history
  • Loading branch information
anw90 committed Dec 10, 2024
1 parent 12d6a21 commit c98dc44
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ jobs:
pip install -r requirements/requirements-test.txt && \
git config --global --add safe.directory $PWD && \
pip install -e . && \
bash tests/run_ut.sh'
make test'
env:
UT_IMAGE: ${{ secrets.UT_IMAGE }}
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@
format:
isort torchacc/
yapf -i -r *.py torchacc/ tests/ benchmarks/

test:
PJRT_USE_TORCH_ALLOCATOR=true python -m pytest -v -k 'not flash_attn' ./tests/
PJRT_USE_TORCH_ALLOCATOR=true python -m pytest -v -k -n 4 'flash_attn' ./tests/
16 changes: 0 additions & 16 deletions tests/run_ut.sh

This file was deleted.

35 changes: 35 additions & 0 deletions tests/utils/test_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import unittest
import torch
import torchacc as ta


class PatchAutocastTest(unittest.TestCase):

def _matmul_with_autocast(self, lhs, rhs, first_dtype, second_dtype):
with torch.autocast(device_type=first_dtype, dtype=torch.bfloat16):
first = torch.matmul(lhs, rhs)
with torch.autocast(device_type=second_dtype, enabled=False):
second = torch.matmul(lhs, rhs)
return (first, second)

def test_patch_autocast(self):
device = ta.lazy_device()

t1 = torch.rand([2,2], device=device)
t2 = torch.rand([2,2], device=device)

first, second = self._matmul_with_autocast(t1, t2, 'cuda', 'cuda')
assert first.dtype==torch.bfloat16
assert second.dtype==torch.float32

first, second = self._matmul_with_autocast(t1, t2, 'xla', 'xla')
assert first.dtype==torch.bfloat16
assert second.dtype==torch.float32

first, second = self._matmul_with_autocast(t1, t2, 'cuda', 'xla')
assert first.dtype==torch.bfloat16
assert second.dtype==torch.float32

first, second = self._matmul_with_autocast(t1, t2, 'xla', 'cuda')
assert first.dtype==torch.bfloat16
assert second.dtype==torch.float32
1 change: 1 addition & 0 deletions torchacc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,5 +133,6 @@ def _set_env():


patch.patch_fa()
patch.patch_autocast()
decompose.replace_decompose()
_set_env()
20 changes: 17 additions & 3 deletions torchacc/utils/patch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
from functools import wraps
import os

import torch

Expand Down Expand Up @@ -62,7 +63,9 @@ def patch_fa():
Replace `transformers.modeling_flash_attention_utils._flash_attention_forward` with
`torchacc.ops.flash_attn_xla` and `torchacc.ops.flash_attn_varlen_xla`
'''
from .logger import logger
if os.getenv('PATCH_FA', '1') not in ['1', 'true', 'True']:
return

try:
import transformers
from packaging import version
Expand Down Expand Up @@ -222,15 +225,14 @@ def update_causal_mask(
def patch_qwen(use_flash_attn):
'''
Modify the calculation of `rotary_seq_len` in `Qwen2FlashAttention2.forward` to avoid xla graph be executed.
Replace `transformers.models.qwen.modeling_qwen2.Qwen2Model._update_causal_mask` with `return None`
Replace `transformers.models.qwen.modeling_qwen2.Qwen2Model._update_causal_mask` with `return None`
and replace flash_attn with the interface in torchacc. This requires transformers>=4.41.0.
'''
import inspect

import transformers
from packaging import version

from .logger import logger

if use_flash_attn:
from transformers.cache_utils import Cache
Expand Down Expand Up @@ -275,3 +277,15 @@ def update_causal_mask(
exec(src, qwen2.__dict__)
except Exception as e:
logger.warning(f"patch qwen2 failed due to: {e}")


def patch_autocast():
if os.getenv('PATCH_TORCH_AUTOCAST', '1') in ['1', 'true', 'True']:
original_init = torch.autocast.__init__

def patched_init(self, device_type: str, *args, **kwargs):
if device_type == 'xla':
device_type = 'cuda'
original_init(self, device_type, *args, **kwargs)

torch.autocast.__init__ = patched_init

0 comments on commit c98dc44

Please sign in to comment.