Skip to content

Commit

Permalink
chore(training): Allow training on torch xla > 2.3.0, add warning (#48)
Browse files Browse the repository at this point in the history
* chore(training): Allow training on torch xla > 2.3.0, add warning

When fine-tuning Gemma-7B on Pytorch XLA 2.3.0, we saw and reported an
issue. This seems to have been fixed on nightly. This commit relaxes
dependency versions and displays a warning when getting FSDP training
args for Gemma on 2.3.0.

* chore(build): remove warning about subpackage
  • Loading branch information
tengomucho authored Jun 4, 2024
1 parent 292bd41 commit df7884a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
7 changes: 7 additions & 0 deletions optimum/tpu/fsdp_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""
Utility functions to provide FSDPv2 configuration for TPU training.
"""
import logging
from typing import Dict, List, Union

from transformers.modeling_utils import PreTrainedModel
Expand Down Expand Up @@ -84,6 +85,12 @@ def get_fsdp_training_args(model: PreTrainedModel) -> Dict:
from .modeling_gemma import GemmaForCausalLM

if isinstance(model, GemmaForCausalLM):
logger = logging.get_logger(__name__)
from torch_xla import __version__ as xla_version
if xla_version == "2.3.0":
logger.log_once("Fine-tuning Gemma on Pytorch XLA 2.3.0 might raise some issues. In case of any issues "
"consider using the nightly version, and report the issue on the optimum-tpu GitHub "
"repository: https://github.com/huggingface/optimum-tpu/issues/new.")
cls_to_wrap = "GemmaDecoderLayer"
matched_model = True
elif model_type == "llama":
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ keywords = [

dependencies = [
"transformers == 4.41.1",
"torch ~= 2.3.0",
"torch-xla[tpu] ~= 2.3.0",
"torch >= 2.3.0, <= 2.4.0",
"torch-xla[tpu] >= 2.3.0, <= 2.4.0",
"loguru == 0.6.0"
]

Expand All @@ -63,7 +63,7 @@ Repository = "https://github.com/huggingface/optimum-tpu"
Issues = "https://github.com/huggingface/optimum-tpu/issues"

[tool.setuptools.packages.find]
include = ["optimum.tpu"]
include = ["optimum.tpu*"]

[tool.black]
line-length = 119
Expand Down

0 comments on commit df7884a

Please sign in to comment.