Skip to content

Commit

Permalink
Patch the torch autocast (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
anw90 authored Dec 12, 2024
1 parent cd67d31 commit d9ec909
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 4 deletions.
37 changes: 37 additions & 0 deletions tests/utils/test_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import unittest

import torch

import torchacc as ta


class PatchAutocastTest(unittest.TestCase):

def _matmul_with_autocast(self, lhs, rhs, first_device, second_device):
with torch.autocast(device_type=first_device, dtype=torch.bfloat16):
first = torch.matmul(lhs, rhs)
with torch.autocast(device_type=second_device, enabled=False):
second = torch.matmul(lhs, rhs)
return (first, second)

def test_patch_autocast(self):
device = ta.lazy_device()

lhs = torch.rand([2, 2], device=device)
rhs = torch.rand([2, 2], device=device)

first, second = self._matmul_with_autocast(lhs, rhs, 'cuda', 'cuda')
self.assertEqual(first.dtype, torch.bfloat16)
self.assertEqual(second.dtype, torch.float32)

first, second = self._matmul_with_autocast(lhs, rhs, 'xla', 'xla')
self.assertEqual(first.dtype, torch.bfloat16)
self.assertEqual(second.dtype, torch.float32)

first, second = self._matmul_with_autocast(lhs, rhs, 'cuda', 'xla')
self.assertEqual(first.dtype, torch.bfloat16)
self.assertEqual(second.dtype, torch.float32)

first, second = self._matmul_with_autocast(lhs, rhs, 'xla', 'cuda')
self.assertEqual(first.dtype, torch.bfloat16)
self.assertEqual(second.dtype, torch.float32)
1 change: 1 addition & 0 deletions torchacc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,5 +133,6 @@ def _set_env():


patch.patch_fa()
patch.patch_autocast()
decompose.replace_decompose()
_set_env()
21 changes: 17 additions & 4 deletions torchacc/utils/patch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import os
from functools import wraps

import torch
Expand Down Expand Up @@ -62,7 +63,9 @@ def patch_fa():
Replace `transformers.modeling_flash_attention_utils._flash_attention_forward` with
`torchacc.ops.flash_attn_xla` and `torchacc.ops.flash_attn_varlen_xla`
'''
from .logger import logger
if os.getenv('TORCHACC_PATCH_FA', '1') not in ['1', 'true', 'True']:
return

try:
import transformers
from packaging import version
Expand Down Expand Up @@ -222,16 +225,14 @@ def update_causal_mask(
def patch_qwen(use_flash_attn):
'''
Modify the calculation of `rotary_seq_len` in `Qwen2FlashAttention2.forward` to avoid xla graph be executed.
Replace `transformers.models.qwen.modeling_qwen2.Qwen2Model._update_causal_mask` with `return None`
Replace `transformers.models.qwen.modeling_qwen2.Qwen2Model._update_causal_mask` with `return None`
and replace flash_attn with the interface in torchacc. This requires transformers>=4.41.0.
'''
import inspect

import transformers
from packaging import version

from .logger import logger

if use_flash_attn:
from transformers.cache_utils import Cache
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
Expand Down Expand Up @@ -275,3 +276,15 @@ def update_causal_mask(
exec(src, qwen2.__dict__)
except Exception as e:
logger.warning(f"patch qwen2 failed due to: {e}")


def patch_autocast():
if os.getenv('TORCHACC_PATCH_TORCH_AUTOCAST', '1') in ['1', 'true', 'True']:
original_init = torch.autocast.__init__

def patched_init(self, device_type: str, *args, **kwargs):
if device_type == 'xla':
device_type = 'cuda'
original_init(self, device_type, *args, **kwargs)

torch.autocast.__init__ = patched_init

0 comments on commit d9ec909

Please sign in to comment.