From 7744608c4c455a86b21f4ce0642e6c19bc1cebf4 Mon Sep 17 00:00:00 2001 From: ebsmothers Date: Fri, 11 Oct 2024 22:32:29 -0400 Subject: [PATCH] Torchao version check changes/BC import of TensorCoreTiledLayout (#1812) --- torchtune/modules/low_precision/_utils.py | 57 ----------------------- torchtune/training/quantization.py | 10 +++- torchtune/utils/_import_guard.py | 10 +++- torchtune/utils/_version.py | 21 +++++++++ 4 files changed, 38 insertions(+), 60 deletions(-) delete mode 100644 torchtune/modules/low_precision/_utils.py diff --git a/torchtune/modules/low_precision/_utils.py b/torchtune/modules/low_precision/_utils.py deleted file mode 100644 index 30f02911e6..0000000000 --- a/torchtune/modules/low_precision/_utils.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from datetime import datetime -from importlib.metadata import PackageNotFoundError, version -from typing import Optional, Tuple - -import torch - -import torchao - - -def _is_fbcode(): - return not hasattr(torch.version, "git_version") - - -def _nightly_version_ge(ao_version_str: str, date: str) -> bool: - """ - Compare a torchao nightly version to a date of the form - %Y-%m-%d. - - Returns True if the nightly version is greater than or equal to - the date, False otherwise - """ - ao_datetime = datetime.strptime(ao_version_str.split("+")[0], "%Y.%m.%d") - return ao_datetime >= datetime.strptime(date, "%Y-%m-%d") - - -def _get_torchao_version() -> Tuple[Optional[str], Optional[bool]]: - """ - Get torchao version. Returns a tuple of two elements, the first element - is the version string, the second element is whether it's a nightly version. - For fbcode usage, return None, None. - - Checks: - 1) is_fbcode, then - 3) torchao.__version__ (only defined for torchao >= 0.3.0), then - 4) importlib's version(torchao) - - - If none of these work, raise an error. - - """ - if _is_fbcode(): - return None, None - try: - ao_version = torchao.__version__ - except AttributeError: - try: - ao_version = version("torchao") - except Exception as e: - raise PackageNotFoundError("Could not find torchao version") from e - is_nightly = "dev" in ao_version - return ao_version, is_nightly diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index debe49ab15..465e987981 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -6,7 +6,13 @@ from typing import Callable, Optional -from torchao.dtypes import TensorCoreTiledLayoutType +from torchtune.utils._import_guard import _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API + +if _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API: + from torchao.dtypes import TensorCoreTiledLayout +else: + from torchao.dtypes import TensorCoreTiledLayoutType as TensorCoreTiledLayout + from torchao.quantization import ( int4_weight_only, int8_dynamic_activation_int4_weight, @@ -88,7 +94,7 @@ def __init__(self, groupsize: int = 128, inner_k_tiles: int = 8): self.inner_k_tiles = inner_k_tiles def quantize(self, model): - layout_type = TensorCoreTiledLayoutType(self.inner_k_tiles) + layout_type = TensorCoreTiledLayout(self.inner_k_tiles) quantize_fn = int4_weight_only(self.groupsize, layout_type) quantize_(model, quantize_fn) return model diff --git a/torchtune/utils/_import_guard.py b/torchtune/utils/_import_guard.py index c0779271fb..fd1ba7c6c2 100644 --- a/torchtune/utils/_import_guard.py +++ b/torchtune/utils/_import_guard.py @@ -5,7 +5,8 @@ # LICENSE file in the root directory of this source tree. import torch -from torchtune.utils._version import torch_version_ge +import torchao +from torchtune.utils._version import _is_fbcode, _nightly_version_ge, torch_version_ge # We can only use flex attention / BlockMask if torch version >= 2.5.0 and GPU is Turing / SM75 and above _SUPPORTS_FLEX_ATTENTION = ( @@ -13,3 +14,10 @@ and torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 5) ) + +torchao_version = torchao.__version__ + +_USE_NEW_TENSOR_CORE_TILED_LAYOUT_API = not _is_fbcode() and ( + ("dev" not in torchao_version and torchao_version >= "0.6.0") + or ("dev" in torchao_version and _nightly_version_ge(torchao_version, "2024-10-10")) +) diff --git a/torchtune/utils/_version.py b/torchtune/utils/_version.py index 830a8ba079..9dcbd8e450 100644 --- a/torchtune/utils/_version.py +++ b/torchtune/utils/_version.py @@ -3,6 +3,9 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +from datetime import datetime + import torch @@ -23,3 +26,21 @@ def torch_version_ge(version: str) -> bool: True """ return version in torch.__version__ or torch.__version__ >= version + + +def _is_fbcode(): + return not hasattr(torch.version, "git_version") + + +def _nightly_version_ge(ao_version_str: str, date: str) -> bool: + """ + Compare a torchao nightly version to a date of the form + %Y-%m-%d. + + Returns True if the nightly version is greater than or equal to + the date, False otherwise + """ + ao_datetime = datetime.strptime( + ao_version_str.split("+")[0].split("dev")[1], "%Y%m%d" + ) + return ao_datetime >= datetime.strptime(date, "%Y-%m-%d")