From d9ec9099198d9b2175dca978bc40f9886baff6f7 Mon Sep 17 00:00:00 2001 From: Ang Wang Date: Thu, 12 Dec 2024 15:50:19 +0800 Subject: [PATCH] Patch the torch autocast (#38) --- tests/utils/test_patch.py | 37 +++++++++++++++++++++++++++++++++++++ torchacc/__init__.py | 1 + torchacc/utils/patch.py | 21 +++++++++++++++++---- 3 files changed, 55 insertions(+), 4 deletions(-) create mode 100644 tests/utils/test_patch.py diff --git a/tests/utils/test_patch.py b/tests/utils/test_patch.py new file mode 100644 index 0000000..33d3064 --- /dev/null +++ b/tests/utils/test_patch.py @@ -0,0 +1,37 @@ +import unittest + +import torch + +import torchacc as ta + + +class PatchAutocastTest(unittest.TestCase): + + def _matmul_with_autocast(self, lhs, rhs, first_device, second_device): + with torch.autocast(device_type=first_device, dtype=torch.bfloat16): + first = torch.matmul(lhs, rhs) + with torch.autocast(device_type=second_device, enabled=False): + second = torch.matmul(lhs, rhs) + return (first, second) + + def test_patch_autocast(self): + device = ta.lazy_device() + + lhs = torch.rand([2, 2], device=device) + rhs = torch.rand([2, 2], device=device) + + first, second = self._matmul_with_autocast(lhs, rhs, 'cuda', 'cuda') + self.assertEqual(first.dtype, torch.bfloat16) + self.assertEqual(second.dtype, torch.float32) + + first, second = self._matmul_with_autocast(lhs, rhs, 'xla', 'xla') + self.assertEqual(first.dtype, torch.bfloat16) + self.assertEqual(second.dtype, torch.float32) + + first, second = self._matmul_with_autocast(lhs, rhs, 'cuda', 'xla') + self.assertEqual(first.dtype, torch.bfloat16) + self.assertEqual(second.dtype, torch.float32) + + first, second = self._matmul_with_autocast(lhs, rhs, 'xla', 'cuda') + self.assertEqual(first.dtype, torch.bfloat16) + self.assertEqual(second.dtype, torch.float32) diff --git a/torchacc/__init__.py b/torchacc/__init__.py index 187f3b7..af787cd 100644 --- a/torchacc/__init__.py +++ b/torchacc/__init__.py @@ -133,5 +133,6 @@ def _set_env(): patch.patch_fa() +patch.patch_autocast() decompose.replace_decompose() _set_env() diff --git a/torchacc/utils/patch.py b/torchacc/utils/patch.py index bb04aa2..9fee5ab 100644 --- a/torchacc/utils/patch.py +++ b/torchacc/utils/patch.py @@ -1,4 +1,5 @@ import inspect +import os from functools import wraps import torch @@ -62,7 +63,9 @@ def patch_fa(): Replace `transformers.modeling_flash_attention_utils._flash_attention_forward` with `torchacc.ops.flash_attn_xla` and `torchacc.ops.flash_attn_varlen_xla` ''' - from .logger import logger + if os.getenv('TORCHACC_PATCH_FA', '1') not in ['1', 'true', 'True']: + return + try: import transformers from packaging import version @@ -222,7 +225,7 @@ def update_causal_mask( def patch_qwen(use_flash_attn): ''' Modify the calculation of `rotary_seq_len` in `Qwen2FlashAttention2.forward` to avoid xla graph be executed. - Replace `transformers.models.qwen.modeling_qwen2.Qwen2Model._update_causal_mask` with `return None` + Replace `transformers.models.qwen.modeling_qwen2.Qwen2Model._update_causal_mask` with `return None` and replace flash_attn with the interface in torchacc. This requires transformers>=4.41.0. ''' import inspect @@ -230,8 +233,6 @@ def patch_qwen(use_flash_attn): import transformers from packaging import version - from .logger import logger - if use_flash_attn: from transformers.cache_utils import Cache from transformers.models.qwen2.modeling_qwen2 import Qwen2Model @@ -275,3 +276,15 @@ def update_causal_mask( exec(src, qwen2.__dict__) except Exception as e: logger.warning(f"patch qwen2 failed due to: {e}") + + +def patch_autocast(): + if os.getenv('TORCHACC_PATCH_TORCH_AUTOCAST', '1') in ['1', 'true', 'True']: + original_init = torch.autocast.__init__ + + def patched_init(self, device_type: str, *args, **kwargs): + if device_type == 'xla': + device_type = 'cuda' + original_init(self, device_type, *args, **kwargs) + + torch.autocast.__init__ = patched_init