Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU] Support PyTorch/XLA FSDP via SPMD #28949

Merged
merged 10 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/transformers/integrations/tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torch.utils.data import DataLoader

from ..utils import is_torch_tpu_available


def tpu_spmd_dataloader(dataloader: DataLoader):
if is_torch_tpu_available():
import torch_xla.distributed.parallel_loader as pl

assert isinstance(
dataloader, pl.MpDeviceLoader
), "The dataloader must be a `torch_xla.distributed.parallel_loader.MpDeviceLoader`."

# This is to support PyTorch/XLA FSDP via SPMD.
# Here we shard the input data's 0th dim across the fsdp axis.
import torch_xla.distributed.spmd as xs

sharding_spec = xs.ShardingSpec(xs.get_global_mesh(), ("fsdp", None))
dataloader._parallel_loader_kwargs["input_sharding"] = sharding_spec
return dataloader
else:
return dataloader
65 changes: 58 additions & 7 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
from .integrations.tpu import tpu_spmd_dataloader
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
Expand Down Expand Up @@ -170,6 +171,8 @@
if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
Comment on lines +174 to +175
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not super fan of super short names but seems common in trainer!



if is_sagemaker_mp_enabled():
Expand Down Expand Up @@ -635,6 +638,13 @@ def __init__(
if args.torch_compile and not is_torch_compile_available():
raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.")

self.is_fsdp_xla_v2_enabled = args.fsdp_config["xla_fsdp_v2"]
if self.is_fsdp_xla_v2_enabled:
# Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper.
# Tensor axis is just a placeholder where it will not be used in FSDPv2.
num_devices = xr.global_runtime_device_count()
xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))

def _activate_neftune(self, model):
r"""
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper:
Expand Down Expand Up @@ -1385,6 +1395,11 @@ def _wrap_model(self, model, training=True, dataloader=None):
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
)

if self.is_fsdp_xla_v2_enabled:
from torch_xla.experimental.spmd_fully_sharded_data_parallel import (
SpmdFullyShardedDataParallel as FSDPv2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make this easier by importing FSDPv2 as FSDP instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask what's the benefits of doing so?

)
except ImportError:
raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
auto_wrap_policy = None
Expand Down Expand Up @@ -1416,15 +1431,40 @@ def _wrap_model(self, model, training=True, dataloader=None):
if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
# Apply gradient checkpointing to auto-wrapped sub-modules if specified
def auto_wrapper_callable(m, *args, **kwargs):
return FSDP(checkpoint_module(m), *args, **kwargs)
target_cls = FSDP if not self.is_fsdp_xla_v2_enabled else FSDPv2
return target_cls(checkpoint_module(m), *args, **kwargs)

# Wrap the base model with an outer FSDP wrapper
self.model = model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable,
**fsdp_kwargs,
)
if self.is_fsdp_xla_v2_enabled:

def shard_output(output, mesh):
from .modeling_outputs import CausalLMOutputWithPast

real_output = None
if isinstance(output, torch.Tensor):
real_output = output
elif isinstance(output, tuple):
real_output = output[0]
elif isinstance(output, CausalLMOutputWithPast):
real_output = output.logits

if real_output is None:
raise ValueError("Something went wrong, the output of the model shouldn't be `None`")
xs.mark_sharding(real_output, mesh, ("fsdp", None, None))

self.model = model = FSDPv2(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And then leave the check for down here on what to do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shard_output is not used by FSDPv1. Shouldn't we guard that with the flag too?

model,
shard_output=shard_output,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable,
)
else:
self.model = model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable,
**fsdp_kwargs,
)

# Patch `xm.optimizer_step` should not reduce gradients in this case,
# as FSDP does not need gradient reduction over sharded parameters.
Expand Down Expand Up @@ -1593,6 +1633,8 @@ def _inner_training_loop(
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()
if self.is_fsdp_xla_v2_enabled:
train_dataloader = tpu_spmd_dataloader(train_dataloader)

# Setting up training control variables:
# number of training epochs: num_train_epochs
Expand Down Expand Up @@ -1962,6 +2004,11 @@ def _inner_training_loop(
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

if self.control.should_epoch_stop or self.control.should_training_stop:
# PyTorch/XLA relies on the data loader to insert the mark_step for
# each step. Since we are breaking the loop early, we need to manually
# insert the mark_step here.
if is_torch_tpu_available():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed a bug here. cc @ArthurZucker @jonb377

xm.mark_step()
break
if step < 0:
logger.warning(
Expand Down Expand Up @@ -2945,6 +2992,7 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa

def _save_tpu(self, output_dir: Optional[str] = None):
output_dir = output_dir if output_dir is not None else self.args.output_dir

logger.info(f"Saving model checkpoint to {output_dir}")
model = self.model
model.to("cpu")
Expand Down Expand Up @@ -3143,6 +3191,9 @@ def evaluate(
self._memory_tracker.start()

eval_dataloader = self.get_eval_dataloader(eval_dataset)
if self.is_fsdp_xla_v2_enabled:
eval_dataloader = tpu_spmd_dataloader(eval_dataloader)

start_time = time.time()

eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
Expand Down
1 change: 1 addition & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,6 +1684,7 @@ def __post_init__(self):
):
raise ValueError("`min_num_params` and `transformer_layer_cls_to_wrap` are mutually exclusive.")
self.fsdp_config["xla"] = self.fsdp_config.get("xla", False)
self.fsdp_config["xla_fsdp_v2"] = self.fsdp_config.get("xla_fsdp_v2", False)
self.fsdp_config["xla_fsdp_grad_ckpt"] = self.fsdp_config.get("xla_fsdp_grad_ckpt", False)
if self.fsdp_config["xla"]:
if len(self.fsdp) > 0:
Expand Down
Loading