diff --git a/src/transformers/integrations/tpu.py b/src/transformers/integrations/tpu.py new file mode 100644 index 00000000000000..f2943dcf12df3e --- /dev/null +++ b/src/transformers/integrations/tpu.py @@ -0,0 +1,36 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch.utils.data import DataLoader + +from ..utils import is_torch_tpu_available + + +def tpu_spmd_dataloader(dataloader: DataLoader): + if is_torch_tpu_available(): + import torch_xla.distributed.parallel_loader as pl + + assert isinstance( + dataloader, pl.MpDeviceLoader + ), "The dataloader must be a `torch_xla.distributed.parallel_loader.MpDeviceLoader`." + + # This is to support PyTorch/XLA FSDP via SPMD. + # Here we shard the input data's 0th dim across the fsdp axis. + import torch_xla.distributed.spmd as xs + + sharding_spec = xs.ShardingSpec(xs.get_global_mesh(), ("fsdp", None)) + dataloader._parallel_loader_kwargs["input_sharding"] = sharding_spec + return dataloader + else: + return dataloader diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index bbf5d4abf8a924..4667d141ede999 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -60,6 +60,7 @@ from .debug_utils import DebugOption, DebugUnderflowOverflow from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available +from .integrations.tpu import tpu_spmd_dataloader from .modelcard import TrainingSummary from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES @@ -170,6 +171,8 @@ if is_torch_tpu_available(check_device=False): import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met + import torch_xla.distributed.spmd as xs + import torch_xla.runtime as xr if is_sagemaker_mp_enabled(): @@ -635,6 +638,13 @@ def __init__( if args.torch_compile and not is_torch_compile_available(): raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.") + self.is_fsdp_xla_v2_enabled = args.fsdp_config["xla_fsdp_v2"] + if self.is_fsdp_xla_v2_enabled: + # Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper. + # Tensor axis is just a placeholder where it will not be used in FSDPv2. + num_devices = xr.global_runtime_device_count() + xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor"))) + def _activate_neftune(self, model): r""" Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper: @@ -1385,6 +1395,11 @@ def _wrap_model(self, model, training=True, dataloader=None): size_based_auto_wrap_policy, transformer_auto_wrap_policy, ) + + if self.is_fsdp_xla_v2_enabled: + from torch_xla.experimental.spmd_fully_sharded_data_parallel import ( + SpmdFullyShardedDataParallel as FSDPv2, + ) except ImportError: raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.") auto_wrap_policy = None @@ -1416,15 +1431,40 @@ def _wrap_model(self, model, training=True, dataloader=None): if self.args.fsdp_config["xla_fsdp_grad_ckpt"]: # Apply gradient checkpointing to auto-wrapped sub-modules if specified def auto_wrapper_callable(m, *args, **kwargs): - return FSDP(checkpoint_module(m), *args, **kwargs) + target_cls = FSDP if not self.is_fsdp_xla_v2_enabled else FSDPv2 + return target_cls(checkpoint_module(m), *args, **kwargs) # Wrap the base model with an outer FSDP wrapper - self.model = model = FSDP( - model, - auto_wrap_policy=auto_wrap_policy, - auto_wrapper_callable=auto_wrapper_callable, - **fsdp_kwargs, - ) + if self.is_fsdp_xla_v2_enabled: + + def shard_output(output, mesh): + from .modeling_outputs import CausalLMOutputWithPast + + real_output = None + if isinstance(output, torch.Tensor): + real_output = output + elif isinstance(output, tuple): + real_output = output[0] + elif isinstance(output, CausalLMOutputWithPast): + real_output = output.logits + + if real_output is None: + raise ValueError("Something went wrong, the output of the model shouldn't be `None`") + xs.mark_sharding(real_output, mesh, ("fsdp", None, None)) + + self.model = model = FSDPv2( + model, + shard_output=shard_output, + auto_wrap_policy=auto_wrap_policy, + auto_wrapper_callable=auto_wrapper_callable, + ) + else: + self.model = model = FSDP( + model, + auto_wrap_policy=auto_wrap_policy, + auto_wrapper_callable=auto_wrapper_callable, + **fsdp_kwargs, + ) # Patch `xm.optimizer_step` should not reduce gradients in this case, # as FSDP does not need gradient reduction over sharded parameters. @@ -1593,6 +1633,8 @@ def _inner_training_loop( logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") # Data loader and number of training steps train_dataloader = self.get_train_dataloader() + if self.is_fsdp_xla_v2_enabled: + train_dataloader = tpu_spmd_dataloader(train_dataloader) # Setting up training control variables: # number of training epochs: num_train_epochs @@ -1962,6 +2004,11 @@ def _inner_training_loop( self.control = self.callback_handler.on_substep_end(args, self.state, self.control) if self.control.should_epoch_stop or self.control.should_training_stop: + # PyTorch/XLA relies on the data loader to insert the mark_step for + # each step. Since we are breaking the loop early, we need to manually + # insert the mark_step here. + if is_torch_tpu_available(): + xm.mark_step() break if step < 0: logger.warning( @@ -2945,6 +2992,7 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa def _save_tpu(self, output_dir: Optional[str] = None): output_dir = output_dir if output_dir is not None else self.args.output_dir + logger.info(f"Saving model checkpoint to {output_dir}") model = self.model model.to("cpu") @@ -3143,6 +3191,9 @@ def evaluate( self._memory_tracker.start() eval_dataloader = self.get_eval_dataloader(eval_dataset) + if self.is_fsdp_xla_v2_enabled: + eval_dataloader = tpu_spmd_dataloader(eval_dataloader) + start_time = time.time() eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index e51cf41106ee80..4ec9424396178f 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1684,6 +1684,7 @@ def __post_init__(self): ): raise ValueError("`min_num_params` and `transformer_layer_cls_to_wrap` are mutually exclusive.") self.fsdp_config["xla"] = self.fsdp_config.get("xla", False) + self.fsdp_config["xla_fsdp_v2"] = self.fsdp_config.get("xla_fsdp_v2", False) self.fsdp_config["xla_fsdp_grad_ckpt"] = self.fsdp_config.get("xla_fsdp_grad_ckpt", False) if self.fsdp_config["xla"]: if len(self.fsdp) > 0: