diff --git a/optimum/neuron/accelerate/accelerator.py b/optimum/neuron/accelerate/accelerator.py index a25a23a26..624803157 100644 --- a/optimum/neuron/accelerate/accelerator.py +++ b/optimum/neuron/accelerate/accelerator.py @@ -16,7 +16,6 @@ import collections import contextlib -import inspect import os import re import shutil @@ -54,7 +53,6 @@ AutocastBackend, ModelParallelismPlugin, NeuronDistributedType, - NeuronFullyShardedDataParallelPlugin, get_tied_parameters_dict, patch_accelerate_is_tpu_available, tie_parameters, @@ -124,14 +122,8 @@ def __init__( full_kwargs["gradient_accumulation_steps"] = gradient_accumulation_steps fsdp_plugin = full_kwargs["fsdp_plugin"] - if fsdp_plugin is None: - if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true": - fsdp_plugin = NeuronFullyShardedDataParallelPlugin() - elif not isinstance(fsdp_plugin, NeuronFullyShardedDataParallelPlugin): - raise ValueError( - "The fsdp_plugin must be an instance of NeuronFullyShardedDataParallelPlugin to use XLA FSDP with " - f"the NeuronAccelerator, but an instance of {type(fsdp_plugin)} was given here." - ) + if fsdp_plugin is not None: + raise ValueError("FSDP is not supported.") self.fsdp_plugin = fsdp_plugin self._model_cpu_parameters_to_xla = {} @@ -151,15 +143,9 @@ def __init__( enabled = self.state.mixed_precision == "bf16" and autocast_backend is AutocastBackend.AMP self.autocast_handler = AutocastKwargs(enabled=enabled) - if self.fsdp_plugin is not None and self.zero_1: - raise ValueError("Either enable XLA ZeRO Stage 1 or XLA FSDP but not both.") - if self.process_index == -1 and self.zero_1: raise ValueError("XLA ZeRO Stage 1 can only be enabled in a distributed training setting.") - if fsdp_plugin is not None and mp_plugin is not None: - raise ValueError("It is not possible to both use neuronx_distributed Tensor Parallelism and XLA FSDP.") - if num_steps != 1: self.gradient_accumulation_steps = num_steps @@ -349,76 +335,6 @@ def patch_model_for_neuron( model_patcher.patch() return model - def prepare_model_for_xla_fsdp( - self, model: torch.nn.Module, device_placement: Optional[bool] = None, evaluation_mode: bool = False - ): - if device_placement is None: - device_placement = self.device_placement - self._models.append(model) - # We check only for models loaded with `accelerate` - - # Checks if any of the child module has the attribute `hf_device_map`. - has_hf_device_map = False - for m in model.modules(): - if hasattr(m, "hf_device_map"): - has_hf_device_map = True - break - - if getattr(model, "is_loaded_in_8bit", False) and getattr(model, "hf_device_map", False): - model_devices = set(model.hf_device_map.values()) - if len(model_devices) > 1: - raise ValueError( - "You can't train a model that has been loaded in 8-bit precision on multiple devices." - ) - - current_device_index = list(model_devices)[0] - if torch.device(current_device_index) != self.device: - # if on the first device (GPU 0) we don't care - if (self.device.index is not None) or (current_device_index != 0): - raise ValueError( - "You can't train a model that has been loaded in 8-bit precision on a different device than the one " - "you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device()}" - "you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device() or device_map={'':torch.xpu.current_device()}" - ) - - if "cpu" in model_devices or "disk" in model_devices: - raise ValueError( - "You can't train a model that has been loaded in 8-bit precision with CPU or disk offload." - ) - elif device_placement and not has_hf_device_map: - model = model.to(self.device) - - try: - from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP - except ImportError: - raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.") - - if not evaluation_mode: - # Check if the model is already a FSDP model due to `Manual Wrapping` and if so, - # don't wrap it again - # TODO: validate which arguments work for XLA FSDP. - if type(model) != FSDP: - self.state.fsdp_plugin.set_auto_wrap_policy(model) - fsdp_plugin = self.state.fsdp_plugin - kwargs = { - "sharding_strategy": fsdp_plugin.sharding_strategy, - "cpu_offload": fsdp_plugin.cpu_offload, - "auto_wrap_policy": fsdp_plugin.auto_wrap_policy, - "backward_prefetch": fsdp_plugin.backward_prefetch, - "mixed_precision": fsdp_plugin.mixed_precision_policy, - "ignored_modules": fsdp_plugin.ignored_modules, - "device_id": self.device, - } - signature = inspect.signature(FSDP.__init__).parameters.keys() - if "limit_all_gathers" in signature: - kwargs["limit_all_gathers"] = fsdp_plugin.limit_all_gathers - if "use_orig_params" in signature: - kwargs["use_orig_params"] = fsdp_plugin.use_orig_params - model = FSDP(model, **kwargs) - self._models[-1] = model - - return model - @requires_neuronx_distributed def _prepare_model_for_mp( self, model: torch.nn.Module, device_placement: Optional[bool] = None, evaluation_mode: bool = False @@ -502,11 +418,7 @@ def prepare_model( model.config.output_attentions = False model.config.output_hidden_states = False - if self.distributed_type is NeuronDistributedType.XLA_FSDP: - return self.prepare_model_for_xla_fsdp( - model, device_placement=device_placement, evaluation_mode=evaluation_mode - ) - elif self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM: + if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM: return self._prepare_model_for_mp( model, device_placement=device_placement, evaluation_mode=evaluation_mode ) @@ -514,29 +426,14 @@ def prepare_model( device_placement = False return super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode) - def backward_for_xla_fsdp(self, loss, **kwargs): - if self.scaler is not None: - self.scaler.scale(loss).backward(**kwargs) - else: - loss.backward(**kwargs) - def backward(self, loss, **kwargs): if self.distributed_type != DistributedType.DEEPSPEED: loss = loss / self.gradient_accumulation_steps - if self.distributed_type is NeuronDistributedType.XLA_FSDP: - self.backward_for_xla_fsdp(loss, **kwargs) - elif self.scaler is not None: + if self.scaler is not None: self.scaler.scale(loss).backward(**kwargs) else: loss.backward(**kwargs) - def clip_grad_norm_for_xla_fsdp(self, parameters, max_norm, norm_type: int = 2): - self.unscale_gradients() - parameters = list(parameters) - for model in self._models: - if parameters == list(model.parameters()): - return model.clip_grad_norm_(max_norm, norm_type) - @contextlib.contextmanager def autocast(self, cache_enabled: bool = False, autocast_handler: Optional[AutocastKwargs] = None): if cache_enabled: @@ -575,17 +472,10 @@ def _prepare_clip_grad_norm(self, parameters, max_norm, norm_type: int = 2): return opt.prepare_clip_grad_norm(parameters, max_norm, norm_type=norm_type) def clip_grad_norm_(self, parameters, max_norm, norm_type=2): - if self.distributed_type is NeuronDistributedType.XLA_FSDP: - return self.clip_grad_norm_for_xla_fsdp(parameters, max_norm, norm_type=norm_type) - elif self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM or self.zero_1: + if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM or self.zero_1: return self._prepare_clip_grad_norm(parameters, max_norm, norm_type=norm_type) return super().clip_grad_norm_(parameters, max_norm, norm_type=norm_type) - def clip_grad_value_(self, parameters, clip_value): - if self.distributed_type is NeuronDistributedType.XLA_FSDP: - raise Exception("XLA FSDP does not support `clip_grad_value_`. Use `clip_grad_norm_` instead.") - return super().clip_grad_value_(parameters, clip_value) - def _custom_save_state( self, save_model_func: Optional[Callable[["Accelerator", "PreTrainedModel", Union[str, Path], int], Any]], @@ -661,21 +551,6 @@ def _inner(folder): self.project_configuration.iteration += 1 return save_location - def save_state_for_xla_fsdp(self, output_dir: Optional[str] = None, **save_model_func_kwargs): - def save_model_func(accelelerator, model, output_dir, i): - logger.info("Saving FSDP model") - self.state.fsdp_plugin.save_model(accelelerator, model, output_dir, i) - logger.info(f"FSDP Model saved to the directory {output_dir}") - - def save_optimizer_func(accelerator, optimizer, model, output_dir, i): - logger.info("Saving FSDP Optimizer") - self.state.fsdp_plugin.save_optimizer(accelerator, optimizer, model, output_dir, i) - logger.info(f"FSDP Optimizer saved to the directory {output_dir}") - - return self._custom_save_state( - save_model_func, save_optimizer_func, output_dir=output_dir, **save_model_func_kwargs - ) - def save_state_for_mp(self, output_dir: Optional[str] = None, **save_model_func_kwargs): # The model is saved at the same time as the optimizer. save_model_func = None @@ -692,9 +567,7 @@ def save_optimizer_func(accelerator, optimizer, model, output_dir, i): @patch_within_function(("accelerate.checkpointing.xm", xm), ignore_missing_attributes=True) def save_state(self, output_dir: Optional[str] = None, **save_model_func_kwargs) -> str: - if self.distributed_type is NeuronDistributedType.XLA_FSDP: - return self.save_state_for_xla_fsdp(output_dir=output_dir, **save_model_func_kwargs) - elif self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM: + if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM: return self.save_state_for_mp(output_dir=output_dir, **save_model_func_kwargs) return super().save_state(output_dir=output_dir, **save_model_func_kwargs) diff --git a/optimum/neuron/accelerate/optimizer.py b/optimum/neuron/accelerate/optimizer.py index d62709179..4700a846e 100644 --- a/optimum/neuron/accelerate/optimizer.py +++ b/optimum/neuron/accelerate/optimizer.py @@ -109,8 +109,6 @@ def step(self, closure=None): optimizer_args = {"closure": closure} if closure is not None else {} # By default barrier=False, but making sure it's the case here since we use ParalleLoader. xm.optimizer_step(self.optimizer, optimizer_args=optimizer_args, barrier=False) - elif self.accelerator_state.distributed_type is NeuronDistributedType.XLA_FSDP: - self.optimizer.step(closure) elif self.accelerator_state.distributed_type is NeuronDistributedType.MODEL_PARALLELISM: if parallel_layers.parallel_state.get_data_parallel_size() > 1: bucket_allreduce_gradients(xm._fetch_gradients(self.optimizer)) diff --git a/optimum/neuron/accelerate/state.py b/optimum/neuron/accelerate/state.py index a03a53707..1f38dead8 100644 --- a/optimum/neuron/accelerate/state.py +++ b/optimum/neuron/accelerate/state.py @@ -32,7 +32,7 @@ parse_choice_from_env, parse_flag_from_env, ) -from accelerate.utils.dataclasses import FullyShardedDataParallelPlugin, SageMakerDistributedType +from accelerate.utils.dataclasses import SageMakerDistributedType from ...utils import logging from ..utils import is_neuronx_distributed_available, is_torch_xla_available @@ -41,7 +41,7 @@ set_common_flags, set_neuron_cc_flags_for_torch_amp, ) -from .utils import NeuronDistributedType, NeuronFullyShardedDataParallelPlugin +from .utils import NeuronDistributedType from .utils.dataclasses import AutocastBackend, ModelParallelismPlugin @@ -201,7 +201,7 @@ def __init__(self, cpu: bool = False, **kwargs): self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0) def wait_for_everyone(self): - if self.distributed_type in [NeuronDistributedType.XLA_FSDP, NeuronDistributedType.MODEL_PARALLELISM]: + if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM: xm.rendezvous("accelerate.utils.wait_for_everyone") else: super().wait_for_everyone() @@ -303,17 +303,7 @@ def __init__( pipeline_model_parallel_size=self.mp_plugin.pipeline_parallel_size, ) - if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true": - self.distributed_type = NeuronDistributedType.XLA_FSDP - if self._mixed_precision != "no": - # TODO: do we need that? - fsdp_plugin.set_mixed_precision(self._mixed_precision) - if isinstance(fsdp_plugin, FullyShardedDataParallelPlugin) and not isinstance( - fsdp_plugin, NeuronFullyShardedDataParallelPlugin - ): - fsdp_plugin.__class__ = NeuronFullyShardedDataParallelPlugin - self.fsdp_plugin = fsdp_plugin - elif os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true" and not cpu: + if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true" and not cpu: self.deepspeed_plugin = deepspeed_plugin elif self.distributed_type == DistributedType.MULTI_GPU: if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true": diff --git a/optimum/neuron/accelerate/utils/__init__.py b/optimum/neuron/accelerate/utils/__init__.py index 49cea8cf6..5a111e360 100644 --- a/optimum/neuron/accelerate/utils/__init__.py +++ b/optimum/neuron/accelerate/utils/__init__.py @@ -17,6 +17,5 @@ AutocastBackend, ModelParallelismPlugin, NeuronDistributedType, - NeuronFullyShardedDataParallelPlugin, ) from .misc import get_tied_parameters_dict, patch_accelerate_is_tpu_available, tie_parameters diff --git a/optimum/neuron/accelerate/utils/dataclasses.py b/optimum/neuron/accelerate/utils/dataclasses.py index 325b7a088..ca37c12fd 100644 --- a/optimum/neuron/accelerate/utils/dataclasses.py +++ b/optimum/neuron/accelerate/utils/dataclasses.py @@ -15,24 +15,16 @@ """Custom dataclasses for Neuron.""" import enum -import os from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union import torch -from accelerate.utils.constants import MODEL_NAME, OPTIMIZER_NAME -from accelerate.utils.dataclasses import FullyShardedDataParallelPlugin from ...distributed import ParallelizersManager from ...utils import is_torch_xla_available -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP - from torch_xla.distributed.fsdp.state_dict_utils import consolidate_sharded_model_checkpoints - if TYPE_CHECKING: from transformers import PreTrainedModel @@ -42,10 +34,9 @@ class NeuronDistributedType(str, enum.Enum): Represents a type of distributed environment specific to Neuron. Values: - - **XLA_FSDP** -- Fully Shareded Data Parallelism on Neuron cores using `torch_xla`. + - **MODEL_PARALLELISM** -- Tensor and Pipeline Parallelisms using `torch_xla` and `neuronx_distributed`. """ - XLA_FSDP = "XLA_FSDP" MODEL_PARALLELISM = "MODEL_PARALLELISM" @@ -58,96 +49,6 @@ class AutocastBackend(str, enum.Enum): AMP = "amp" -@dataclass -class NeuronFullyShardedDataParallelPlugin(FullyShardedDataParallelPlugin): - # TODO: redefine the post init to do checks on which option is supported. - def save_model(self, accelerator, model, output_dir, model_index=0): - from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType - - state_dict = {"model": model.state_dict(), "shard_metadata": model.get_shard_metadata()} - weights_name = ( - f"{MODEL_NAME}_rank{accelerator.process_index}.pth" - if model_index == 0 - else f"{MODEL_NAME}_{model_index}_rank{accelerator.process_index}.pth" - ) - output_model_file = os.path.join(output_dir, weights_name) - xm.save(state_dict, output_model_file, master_only=False) - xm.rendezvous("saved sharded model checkpoint") - - if self.state_dict_type == StateDictType.FULL_STATE_DICT and accelerator.process_index == 0: - weights_name = f"{MODEL_NAME}.bin" if model_index == 0 else f"{MODEL_NAME}_{model_index}.bin" - output_model_file = os.path.join(output_dir, weights_name) - if accelerator.process_index == 0: - full_state_dict, _ = consolidate_sharded_model_checkpoints( - f"{output_dir}/{MODEL_NAME}_rank", - save_model=False, - ) - torch.save(full_state_dict, output_model_file) - print(f"Model saved to {output_model_file}") - - def load_model(self, accelerator, model, input_dir, model_index=0): - from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType - - accelerator.wait_for_everyone() - if self.state_dict_type == StateDictType.FULL_STATE_DICT: - if type(model) is FSDP: - raise ValueError("Only sharded model weights can be loaded with XLA FSDP.") - if accelerator.process_index == 0: - weights_name = f"{MODEL_NAME}.bin" if model_index == 0 else f"{MODEL_NAME}_{model_index}.bin" - input_model_file = os.path.join(input_dir, weights_name) - accelerator.print(f"Loading model from {input_model_file}") - state_dict = torch.load(input_model_file) - accelerator.print(f"Model loaded from {input_model_file}") - model.load_state_dict(state_dict, False) - else: - weights_name = ( - f"{MODEL_NAME}_rank{accelerator.process_index}.pth" - if model_index == 0 - else f"{MODEL_NAME}_{model_index}_rank{accelerator.process_index}.pth" - ) - input_model_file = os.path.join(input_dir, weights_name) - state_dict = torch.load(input_model_file) - model.load_state_dict(state_dict["model"], False) - - def save_optimizer(self, accelerator, optimizer, model, output_dir, optimizer_index=0, optim_input=None): - # from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP - # optim_state = FSDP.full_optim_state_dict(model, optimizer, optim_input=optim_input) - optim_state = {"optimizer": optimizer.state_dict(), "shard_metadata": model.get_shard_metadata()} - optimizer_path = os.path.join(output_dir, f"{OPTIMIZER_NAME}_rank{accelerator.process_index}.bin") - xm.save(optim_state, optimizer_path, master_only=False) - xm.rendezvous("saved sharded optimizer checkpoint") - - # TODO: save the full optimizer state if possible. - # if accelerator.process_index == 0: - # optim_state_name = ( - # f"{OPTIMIZER_NAME}.bin" if optimizer_index == 0 else f"{OPTIMIZER_NAME}_{optimizer_index}.bin" - # ) - # output_optimizer_file = os.path.join(output_dir, optim_state_name) - # print(f"Saving Optimizer state to {output_optimizer_file}") - # torch.save(optim_state, output_optimizer_file) - # print(f"Optimizer state saved in {output_optimizer_file}") - - def load_optimizer(self, accelerator, optimizer, model, input_dir, optimizer_index=0): - accelerator.wait_for_everyone() - # TODO: load full osd support. - # full_osd = None - # if accelerator.process_index == 0: - # optimizer_name = ( - # f"{OPTIMIZER_NAME}.bin" if optimizer_index == 0 else f"{OPTIMIZER_NAME}_{optimizer_index}.bin" - # ) - # input_optimizer_file = os.path.join(input_dir, optimizer_name) - # print(f"Loading Optimizer state from {input_optimizer_file}") - # full_osd = torch.load(input_optimizer_file) - # print(f"Optimizer state loaded from {input_optimizer_file}") - # # called from all ranks, though only rank0 has a valid param for full_osd - # sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model) - # optimizer.load_state_dict(sharded_osd) - optimizer_path = os.path.join(input_dir, f"{OPTIMIZER_NAME}_rank{accelerator.process_index}.bin") - optim_state = torch.load(optimizer_path) - xm.send_cpu_data_to_device(optim_state, accelerator.device) - optimizer.load_state_dict(optim_state["optimizer"]) - - @dataclass class ModelParallelismPlugin: tensor_parallel_size: int = 1 diff --git a/optimum/neuron/distributed/base.py b/optimum/neuron/distributed/base.py index d0c73ce4c..68df0837b 100644 --- a/optimum/neuron/distributed/base.py +++ b/optimum/neuron/distributed/base.py @@ -31,10 +31,9 @@ from ...utils import logging from ..utils import is_neuronx_distributed_available, is_torch_xla_available -from ..utils.misc import is_main_worker +from ..utils.misc import is_main_worker, is_precompilation from ..utils.patching import Patcher from ..utils.require_utils import requires_neuronx_distributed, requires_torch_xla -from ..utils.training_utils import is_precompilation from .parallel_layers import ( IOSequenceParallelizer, LayerNormSequenceParallelizer, @@ -73,19 +72,6 @@ logger = logging.get_logger() -class SavedModelInTemporaryDirectory: - def __init__(self, model: "PreTrainedModel"): - self.tmpdir = TemporaryDirectory() - self.model = model - - def __enter__(self): - self.model.save_pretrained(self.tmpdir.name) - return self.tmpdir.name - - def __exit__(self, *exc): - self.tmpdir.cleanup() - - class SequenceParallelismSpecs: SEQUENCE_PARALLEL_LAYERNORM_PATTERNS: Optional[List[str]] = None LAYERNORM_TYPE: LayerNormType = LayerNormType.REGULAR @@ -280,6 +266,7 @@ def _parallelize( Returns: `PreTrainedModel`: The parallelized model. """ + pass @classmethod @requires_neuronx_distributed diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index cfef542d9..a72477954 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -22,6 +22,7 @@ import os from dataclasses import dataclass from pathlib import Path +from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Optional, Set, Tuple, Type, Union import torch @@ -125,27 +126,6 @@ def __post_init__(self): self.qualified_name = self.qualified_name[len(prefix) :] -@dataclass -class GroupedQueryAttentionInfo: - """ - Describes the information about Grouped Query Attention. - - Attributes: - - num_attention_heads (`int`) -- The number of query heads in the layer. - - num_key_value_heads (`int`) -- The number of key value heads in the layer. - """ - - num_attention_heads: int - num_key_value_heads: int - - def __post_init__(self): - if self.num_attention_heads % self.num_key_value_heads != 0: - raise ValueError( - f"The number of key value heads ({self.num_key_value_heads}) does not divide the number of query heads" - f"({self.num_attention_heads})" - ) - - class FakeProj(torch.nn.Module): """ Dummy layer that replaces a Linear projection by gathering the result from its associated merged @@ -995,91 +975,6 @@ def linear_to_parallel_linear( return parallel_linear_layer -@requires_neuronx_distributed -def gqa_key_value_slicing_when_tp_size_greater_than_num_key_value_heads( - gqa_info: GroupedQueryAttentionInfo, - linear_layer: "torch.nn.Linear", - linear_layer_weight_info: Optional[WeightInformation] = None, - linear_layer_bias_weight_info: Optional[WeightInformation] = None, - device: Optional["torch.device"] = None, -) -> "torch.nn.Linear": - """ - Helper function that splits key and value projections when performing Grouped Query Attention with the TP size is - smaller than the number of key value heads. - - Args: - gqa_info (`GroupedQueryAttentionInfo`): - The dataclass containing the information related to Grouped Query Attention. - linear_layer (`torch.nn.Linear`): - The linear layer to split. - linear_layer_weight_info (`Optional[torch.nn.Linear]`, defaults to `None`): - Information about which checkpoint file the linear layer weights are stored in. - linear_layer_bias_weight_info (`Optional[WeightInformation]`, defaults to `None`): - Information about which checkpoint file the linear layer bias is stored in. - device (`Optional[torch.device]`, defaults to `None`): - The device where the new split layer should be put. - - Returns: - `torch.nn.Linear`: The split linear layer. - """ - from neuronx_distributed.parallel_layers.parallel_state import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_size, - ) - - tp_size = get_tensor_model_parallel_size() - tp_rank = get_tensor_model_parallel_rank() - if tp_size < gqa_info.num_key_value_heads: - raise ValueError( - f"This function can only be used in the case where the TP size ({tp_size}) is smalled than thue number of " - f"key value heads ({gqa_info.num_key_value_heads})." - ) - num_key_value_heads_x_head_dim, hidden_size = linear_layer.weight.shape - head_dim = num_key_value_heads_x_head_dim // gqa_info.num_key_value_heads - if device is None: - device = linear_layer.weight.device - sliced_linear_layer = torch.nn.Linear( - hidden_size, head_dim, device=device, dtype=linear_layer.weight.dtype, bias=linear_layer.bias is not None - ) - key_value_head_index = gqa_info.num_key_value_heads * tp_rank // tp_size - with torch.no_grad(): - if linear_layer_weight_info is not None: - weight_data = load_tensor_for_weight( - linear_layer_weight_info, - tensor_slices=( - (key_value_head_index * head_dim, (key_value_head_index + 1) * head_dim), - None, - ), - ) - sliced_linear_layer.weight.copy_(weight_data) - mark_parameter_init_status_during_parallelization(sliced_linear_layer.weight, True) - - elif linear_layer.weight.device != torch.device("meta"): - sliced_linear_layer.weight.copy_( - linear_layer.weight[key_value_head_index * head_dim : (key_value_head_index + 1) * head_dim, :] - ) - mark_parameter_init_status_during_parallelization(sliced_linear_layer.weight, True) - else: - mark_parameter_init_status_during_parallelization(sliced_linear_layer.weight, False) - - if linear_layer.bias is not None: - if linear_layer_bias_weight_info is not None: - bias_weight_data = load_tensor_for_weight( - linear_layer_bias_weight_info, - tensor_slices=((key_value_head_index * head_dim, (key_value_head_index + 1) * head_dim),), - ) - sliced_linear_layer.bias.copy_(bias_weight_data) - mark_parameter_init_status_during_parallelization(sliced_linear_layer.bias, True) - elif sliced_linear_layer.bias.device != torch.device("meta"): - sliced_linear_layer.bias.copy_( - linear_layer.bias[key_value_head_index * head_dim : (key_value_head_index + 1) * head_dim] - ) - mark_parameter_init_status_during_parallelization(sliced_linear_layer.bias, True) - else: - mark_parameter_init_status_during_parallelization(sliced_linear_layer.bias, False) - return sliced_linear_layer - - @requires_neuronx_distributed def delete_tensor_model_parallel_attributes(tensor: torch.Tensor): from neuronx_distributed.parallel_layers.utils import _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS @@ -1496,3 +1391,16 @@ def is_sharded(self): class OptimumNeuronFXTracer(HFTracerWrapper): def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: return super().is_leaf_module(m, module_qualified_name) or isinstance(m, FakeProj) + + +class SavedModelInTemporaryDirectory: + def __init__(self, model: "PreTrainedModel"): + self.tmpdir = TemporaryDirectory() + self.model = model + + def __enter__(self): + self.model.save_pretrained(self.tmpdir.name) + return self.tmpdir.name + + def __exit__(self, *exc): + self.tmpdir.cleanup() diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 014e229ad..4793d0923 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -15,7 +15,6 @@ """Defines Trainer subclasses to perform training on AWS Neuron instances.""" import copy -import glob import math import os import random @@ -64,7 +63,7 @@ from transformers.training_args import ParallelMode from transformers.utils import WEIGHTS_NAME, is_apex_available, is_sagemaker_mp_enabled -from ..utils import check_if_transformers_greater, logging +from ..utils import logging from .accelerate import NeuronAccelerator, NeuronDistributedType from .distributed import Parallelizer, ParallelizersManager from .distributed.utils import make_optimizer_constructor_lazy @@ -81,19 +80,16 @@ has_write_access_to_repo, ) from .utils.hub_neuronx_cache import ModelCacheEntry, hub_neuronx_cache, patch_neuron_cc_wrapper, synchronize_hub_cache -from .utils.misc import is_main_worker +from .utils.misc import is_main_worker, is_precompilation, torch_xla_safe_save_file from .utils.patching import patch_everywhere from .utils.require_utils import requires_neuronx_distributed, requires_torch_neuronx from .utils.training_utils import ( - TRANSFORMERS_MIN_VERSION_USE_ACCELERATE, get_model_param_count, is_main_worker_for_metrics, is_main_worker_for_metrics_method, - is_precompilation, is_topology_supported, patch_generation_mixin_to_neuron_generation_mixin, skip_first_batches, - torch_xla_safe_save_file, ) from .utils.version_utils import get_neuronxcc_version @@ -153,15 +149,9 @@ def __init__(self, *args, **kwargs): if training_args.half_precision_backend == "amp": self.use_amp = True - self.validate_args(training_args) if is_precompilation(): self.prepare_args_for_precompilation(training_args) - if check_if_transformers_greater(TRANSFORMERS_MIN_VERSION_USE_ACCELERATE): - import transformers - - transformers.trainer.Accelerator = NeuronAccelerator - super().__init__(*args, **kwargs) # We need to change which process can be seen as "world process zero" to make sure the proper metrics @@ -171,13 +161,6 @@ def __init__(self, *args, **kwargs): is_world_process_zero=is_main_worker_for_metrics(), ) - # That's the case for Transformers < 4.30.0 - if not hasattr(self, "is_fsdp_enabled"): - self.is_fsdp_enabled = False - - if self.is_fsdp_enabled and self.args.do_eval: - raise ValueError("Evaluation is not supported with XLA FSDP yet.") - if self.args.local_rank <= 0: logger.setLevel(logging.INFO) @@ -240,9 +223,6 @@ def prepare_args_for_precompilation(self, args: "TrainingArguments"): logger.info("Disabling prediction during precompilation as this is not well supported yet.") args.do_predict = False - def validate_args(self, args: "TrainingArguments"): - pass - def create_accelerator_and_postprocess(self): # create accelerator object self.accelerator = NeuronAccelerator( @@ -307,7 +287,7 @@ def _get_eval_sampler(self, eval_dataset: torch.utils.data.Dataset) -> Optional[ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: optimizer_cls, optimizer_kwargs = transformers_get_optimizer_cls_and_kwargs(args) lazy_load = args.mp_plugin.should_parallelize or args.zero_1 - if check_if_transformers_greater("4.30.0") and lazy_load: + if lazy_load: optimizer_cls = make_optimizer_constructor_lazy(optimizer_cls) return optimizer_cls, optimizer_kwargs @@ -615,34 +595,14 @@ def _save_checkpoint(self, model, trial, metrics=None): def _load_from_checkpoint(self, resume_from_checkpoint, model=None): # It has been handled during model parallelization. - # TODO: how to handle pp? if self.accelerator.distributed_type is NeuronDistributedType.MODEL_PARALLELISM: return super()._load_from_checkpoint(resume_from_checkpoint, model=model) - def _load_optimizer_and_scheduler_for_xla_fsdp(self, checkpoint): - checkpoint_file_exists = ( - glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*") - if is_sagemaker_mp_enabled() - else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) - ) - if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): - self.accelerator.state.fsdp_plugin.load_optimizer(self.accelerator, self.optimizer, self.model, checkpoint) - - with warnings.catch_warnings(record=True) as caught_warnings: - lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu") - reissue_pt_warnings(caught_warnings) - xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device) - self.lr_scheduler.load_state_dict(lr_scheduler_state) - - # TODO: load grad scaling? - def _load_optimizer_and_scheduler(self, checkpoint): if checkpoint is None: return - if self.accelerator.distributed_type is NeuronDistributedType.XLA_FSDP: - return self._load_optimizer_and_scheduler_for_xla_fsdp(checkpoint) - elif self.accelerator.distributed_type is NeuronDistributedType.MODEL_PARALLELISM: + if self.accelerator.distributed_type is NeuronDistributedType.MODEL_PARALLELISM: lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu") xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device) self.lr_scheduler.load_state_dict(lr_scheduler_state) diff --git a/optimum/neuron/training_args.py b/optimum/neuron/training_args.py index 051b8289c..0d5b131fb 100644 --- a/optimum/neuron/training_args.py +++ b/optimum/neuron/training_args.py @@ -15,29 +15,23 @@ """Defines a TrainingArguments class compatible with Neuron.""" import os -import warnings from dataclasses import dataclass, field -from datetime import timedelta from typing import Optional import torch -from accelerate.utils import DistributedType -from packaging import version from transformers.trainer_utils import get_last_checkpoint -from transformers.training_args import ParallelMode, TrainingArguments +from transformers.training_args import TrainingArguments from transformers.training_args_seq2seq import Seq2SeqTrainingArguments from transformers.utils import ( cached_property, - is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, - requires_backends, ) from ..utils import logging from .accelerate import NeuronAcceleratorState, NeuronPartialState from .accelerate.utils import ModelParallelismPlugin, patch_accelerate_is_tpu_available -from .utils import is_accelerate_available, is_main_worker, is_torch_xla_available -from .utils.patching import Patcher +from .utils import is_main_worker +from .utils.patching import Patcher, patch_within_function from .utils.torch_xla_and_neuronx_initialization import set_neuron_cc_optlevel @@ -120,8 +114,7 @@ def __post_init__(self): patch_accelerate_is_tpu_available() if self.fsdp != "": - # Disabling FSDP until next release because it is still very experimental and not validated. - raise RuntimeError("FSDP is not supported yet.") + raise RuntimeError("FSDP is not supported.") if self.fp16: raise ValueError("The fp16 data type is not supported in Neuron, please use bf16 instead.") @@ -179,99 +172,14 @@ def __post_init__(self): super().__post_init__() @cached_property + @patch_within_function( + [ + ("transformers.training_args.PartialState", NeuronPartialState), + ("transformers.training_args.AcceleratorState", NeuronAcceleratorState), + ] + ) def _setup_devices(self) -> "torch.device": - - requires_backends(self, ["torch"]) - logger.info("PyTorch: setting up devices") - NeuronAcceleratorState._reset_state() - NeuronPartialState._reset_state() - if not is_sagemaker_mp_enabled() and not is_accelerate_available(): - raise ImportError( - "Using the `Trainer` with `PyTorch` requires `accelerate>=0.20.1`: Please run `pip install " - "transformers[torch]` or `pip install accelerate -U`" - ) - self.distributed_state = None - if self.no_cuda: - self.distributed_state = NeuronPartialState(cpu=True, backend=self.ddp_backend) - self._n_gpu = 0 - elif is_sagemaker_mp_enabled(): - local_rank = smp.local_rank() - device = torch.device("cuda", local_rank) - self._n_gpu = 1 - torch.cuda.set_device(device) - elif is_sagemaker_dp_enabled(): - self.distributed_state = NeuronPartialState(_use_sagemaker_dp=True) - self._n_gpu = 1 - elif self.deepspeed: - # Need to do similar for Accelerator init - os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" - self.distributed_state = NeuronPartialState(timeout=timedelta(seconds=self.ddp_timeout)) - del os.environ["ACCELERATE_USE_DEEPSPEED"] - self._n_gpu = 1 - else: - self.distributed_state = NeuronPartialState(backend=self.ddp_backend) - self._n_gpu = 1 - if not is_sagemaker_mp_enabled(): - device = self.distributed_state.device - self.local_rank = self.distributed_state.local_process_index - if ( - torch.distributed.is_available() - and torch.distributed.is_initialized() - and self.parallel_mode not in [ParallelMode.DISTRIBUTED, ParallelMode.TPU] - ): - logger.warning( - "torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED and " - "parallel_mode != ParallelMode.TPU. " - "In order to use Torch DDP / XLA FSDP, launch your script with `python -m torch.distributed.launch" - ) - if is_torch_xla_available(): - device = self.distributed_state.device - self._n_gpu = 0 - elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled(): - # Already set _n_gpu - pass - elif self.distributed_state.distributed_type == DistributedType.NO: - if self.use_mps_device: - if not torch.backends.mps.is_available(): - if not torch.backends.mps.is_built(): - raise AssertionError( - "MPS not available because the current PyTorch install was not " - "built with MPS enabled. Please install torch version >=1.12.0 on " - "your Apple silicon Mac running macOS 12.3 or later with a native " - "version (arm64) of Python" - ) - else: - raise AssertionError( - "MPS not available because the current MacOS version is not 12.3+ " - "and/or you do not have an MPS-enabled device on this machine." - ) - else: - if not version.parse(version.parse(torch.__version__).base_version) > version.parse("1.12.0"): - warnings.warn( - "We strongly recommend to install PyTorch >= 1.13 (nightly version at the time of writing)" - " on your MacOS machine. It has major fixes related to model correctness and performance" - " improvements for transformer based models. Please refer to" - " https://github.com/pytorch/pytorch/issues/82707 for more details." - ) - device = torch.device("mps") - self._n_gpu = 1 - elif self.no_cuda: - device = torch.device("cpu") - self._n_gpu = 0 - else: - # if n_gpu is > 1 we'll use nn.DataParallel. - # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` - # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will - # trigger an error that a device index is missing. Index 0 takes into account the - # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` - # will use the first GPU in that env, i.e. GPU#1 - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at - # the default value. - self._n_gpu = torch.cuda.device_count() - if device.type == "cuda": - torch.cuda.set_device(device) - return device + return super()._setup_devices @property def place_model_on_device(self): diff --git a/optimum/neuron/utils/__init__.py b/optimum/neuron/utils/__init__.py index c3ac7920e..148d6d0a0 100644 --- a/optimum/neuron/utils/__init__.py +++ b/optimum/neuron/utils/__init__.py @@ -40,13 +40,12 @@ check_if_weights_replacable, get_stable_diffusion_configs, is_main_worker, + is_precompilation, replace_weights, ) from .optimization_utils import get_attention_scores_sd, get_attention_scores_sdxl from .patching import DynamicPatch, ModelPatcher, Patcher, patch_everywhere, patch_within_function from .training_utils import ( - FirstAndLastDataset, is_model_officially_supported, - is_precompilation, patch_transformers_for_neuron_sdk, ) diff --git a/optimum/neuron/utils/argument_utils.py b/optimum/neuron/utils/argument_utils.py index ebc5b9b52..ecf57e621 100644 --- a/optimum/neuron/utils/argument_utils.py +++ b/optimum/neuron/utils/argument_utils.py @@ -145,7 +145,7 @@ def store_compilation_config( inline_weights_to_neff: bool, optlevel: str, model_type: Optional[str] = None, - task: str = None, + task: Optional[str] = None, input_names: Optional[List[str]] = None, output_names: Optional[List[str]] = None, output_attentions: bool = False, diff --git a/optimum/neuron/utils/misc.py b/optimum/neuron/utils/misc.py index de9ee3383..a69eae70c 100644 --- a/optimum/neuron/utils/misc.py +++ b/optimum/neuron/utils/misc.py @@ -42,7 +42,7 @@ from ...utils import is_diffusers_available, logging from .import_utils import is_torch_neuronx_available, is_torch_xla_available -from .require_utils import requires_safetensors, requires_torch_xla +from .require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla if is_torch_neuronx_available(): @@ -58,6 +58,10 @@ logger = logging.get_logger() +def is_precompilation() -> bool: + return os.environ.get("NEURON_PARALLEL_COMPILE") == "1" + + def is_main_worker(global_main: bool = True) -> bool: if torch.distributed.is_initialized() and is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -200,6 +204,28 @@ def convert_checkpoint_to_safetensors( return safetensors_path +@requires_neuronx_distributed +@requires_safetensors +def torch_xla_safe_save_file( + tensors: Dict[str, torch.Tensor], + filename: Union[str, os.PathLike], + metadata: Optional[Dict[str, str]] = None, + master_only: bool = True, + global_master: bool = False, +): + """ + Torch XLA compatible implementation of `safetensors.torch.save_file`. + """ + from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu + from safetensors.torch import save_file + from torch_xla.core.xla_model import is_master_ordinal + + should_write_data = not master_only or is_master_ordinal(local=not global_master) + cpu_data = move_all_tensor_to_cpu(tensors, convert=should_write_data) + if should_write_data: + save_file(cpu_data, filename, metadata=metadata) + + @requires_torch_xla @functools.wraps(cached_file) def distributed_friendly_cached_file(*args, **kwargs): diff --git a/optimum/neuron/utils/training_utils.py b/optimum/neuron/utils/training_utils.py index 64158534a..7866dc689 100644 --- a/optimum/neuron/utils/training_utils.py +++ b/optimum/neuron/utils/training_utils.py @@ -14,14 +14,11 @@ # limitations under the License. """Training utilities""" -import os -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Type, Union import torch import transformers from accelerate import skip_first_batches as accelerate_skip_first_batches -from torch.utils._pytree import tree_map -from torch.utils.data import DataLoader, Dataset, IterableDataset from transformers import GenerationMixin from transformers.models.auto.modeling_auto import ( MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, @@ -47,8 +44,8 @@ from ...utils.logging import set_verbosity as set_verbosity_optimum from ..generation import GeneralNeuronGenerationMixin, NeuronGenerationMixin -from . import is_neuronx_distributed_available, is_torch_xla_available -from .require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla +from . import is_neuronx_distributed_available +from .require_utils import requires_neuronx_distributed, requires_torch_xla if is_neuronx_distributed_available(): @@ -59,10 +56,6 @@ from transformers import PreTrainedModel -TRANSFORMERS_MIN_VERSION_FOR_XLA_FSDP = "4.30.0.dev0" -TRANSFORMERS_MIN_VERSION_USE_ACCELERATE = "4.30.0.dev0" - - def _generate_supported_model_class_names( model_type: str, supported_tasks: Optional[Union[str, List[str]]] = None, @@ -126,24 +119,8 @@ def _generate_supported_model_class_names( _SUPPORTED_MODEL_NAMES.update(_generate_supported_model_class_names(*model_type)) -def is_precompilation() -> bool: - return os.environ.get("NEURON_PARALLEL_COMPILE") == "1" - - def is_model_officially_supported(model: "PreTrainedModel") -> bool: - # In theory the type annotation is not correct since we can have also a XlaFullyShardedDataParallel - # but let's ignore it here. - if not is_torch_xla_available(): - raise RuntimeError( - "is_model_officially_supported requires torch_xla to run, please install it by running: " - "pip install torch_xla" - ) - from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel - - if isinstance(model, XlaFullyShardedDataParallel): - class_name = model.module.__class__.__name__ - else: - class_name = model.__class__.__name__ + class_name = model.__class__.__name__ return class_name in _SUPPORTED_MODEL_NAMES @@ -156,77 +133,11 @@ def is_topology_supported() -> bool: return num_devices in allowed_number_of_devices or num_devices % 32 == 0 -class FirstAndLastDataset(Dataset): - def __init__( - self, dataloader: DataLoader, num_repeat: int = 10, gradient_accumulation_steps: int = 1, world_size: int = 1 - ): - self.dataloader = dataloader - self.num_repeat = num_repeat * gradient_accumulation_steps * world_size - self.samples = self.create_samples() - - def _create_samples_for_map_style_dataset(self): - samples = [] - num_samples = len(self.dataloader.dataset) - batch_size = self.dataloader.batch_size - if batch_size is None and self.dataloader.batch_sampler is not None: - batch_size = self.dataloader.batch_sampler.batch_size - - # TODO: validate that. - if batch_size is None: - samples = [self.dataloader.dataset[0]] * self.num_repeat + [self.dataloader.dataset[-1]] * self.num_repeat - return samples - - num_batches = num_samples // batch_size - remaining = num_samples % batch_size - - iterator = iter(self.dataloader) - first_batch = next(iterator) - samples = [first_batch] * self.num_repeat - - if num_batches >= 1 and remaining != 0: - - def map_fn(example): - if isinstance(example, torch.Tensor): - return example[:remaining] - else: - return example - - last_batch = tree_map(map_fn, first_batch) - samples += [last_batch] * self.num_repeat - - return samples - - def _create_samples_for_iterable_dataset(self): - # Will not work if the iterable dataset yields dynamic batch sizes. - iterator = iter(self.dataloader) - first_batch = next(iterator) - samples = [first_batch] * self.num_repeat - last_batch = None - while True: - try: - last_batch = next(iterator) - except StopIteration: - if last_batch is not None: - samples += [last_batch] * self.num_repeat - break - return samples - - def create_samples(self): - if isinstance(self.dataloader.dataset, IterableDataset): - return self._create_samples_for_iterable_dataset() - else: - return self._create_samples_for_map_style_dataset() - - def __getitem__(self, idx: int): - return self.samples[idx] - - def __len__(self): - return len(self.samples) - - -def patch_generation_mixin_to_neuron_generation_mixin(model: "PreTrainedModel"): +def patch_generation_mixin_to_neuron_generation_mixin( + model: "PreTrainedModel", neuron_generation_mixin_cls: Type = NeuronGenerationMixin +): """ - Changes the vanilla `GenerationMixin` class from Transformers to `NeuronGenerationMixin` in the model's + Changes the vanilla `GenerationMixin` class from Transformers to `neuron_generation_mixin_cls` in the model's inheritance. This allows to make the model Neuron-compatible for generation without much hassle. """ to_visit = [model.__class__] @@ -240,9 +151,9 @@ def patch_generation_mixin_to_neuron_generation_mixin(model: "PreTrainedModel"): for base in bases: to_visit.append(base) if base == GenerationMixin: - new_bases.append(NeuronGenerationMixin) + new_bases.append(neuron_generation_mixin_cls) should_stop = True - elif base == NeuronGenerationMixin: + elif base == neuron_generation_mixin_cls: should_stop = True new_bases.append(base) else: @@ -255,25 +166,9 @@ def patch_generation_mixin_to_general_neuron_generation_mixin(model: "PreTrained Changes the vanilla `GenerationMixin` class from Transformers to `GeneralNeuronGenerationMixin` in the model's inheritance. This allows to make the model Neuron-compatible for generation without much hassle. """ - to_visit = [model.__class__] - should_stop = False - while to_visit and not should_stop: - cls = to_visit.pop(0) - if cls is object: - continue - bases = cls.__bases__ - new_bases = [] - for base in bases: - to_visit.append(base) - if base == GenerationMixin: - new_bases.append(GeneralNeuronGenerationMixin) - should_stop = True - elif base == GeneralNeuronGenerationMixin: - should_stop = True - new_bases.append(base) - else: - new_bases.append(base) - cls.__bases__ = tuple(new_bases) + return patch_generation_mixin_to_neuron_generation_mixin( + model, neuron_generation_mixin_cls=GeneralNeuronGenerationMixin + ) def set_verbosity(verbosity: int): @@ -303,28 +198,6 @@ def skip_first_batches(dataloader, num_batches=0): return dataloader -@requires_neuronx_distributed -@requires_safetensors -def torch_xla_safe_save_file( - tensors: Dict[str, torch.Tensor], - filename: Union[str, os.PathLike], - metadata: Optional[Dict[str, str]] = None, - master_only: bool = True, - global_master: bool = False, -): - """ - Torch XLA compatible implementation of `safetensors.torch.save_file`. - """ - from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu - from safetensors.torch import save_file - from torch_xla.core.xla_model import is_master_ordinal - - should_write_data = not master_only or is_master_ordinal(local=not global_master) - cpu_data = move_all_tensor_to_cpu(tensors, convert=should_write_data) - if should_write_data: - save_file(cpu_data, filename, metadata=metadata) - - @requires_neuronx_distributed def get_model_param_count(model: Union[torch.nn.Module, "NxDPPModel"], trainable_only: bool = False): """Counts the number of parameters of the model.""" diff --git a/tests/test_utils.py b/tests/test_utils.py index d10082ccf..b5664ae6c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -14,17 +14,13 @@ # limitations under the License. """Tests for utility functions and classes.""" -from typing import Dict, Literal, Union -from unittest import TestCase -import torch -from torch.utils.data import DataLoader, Dataset, IterableDataset from transformers import BertConfig, BertForSequenceClassification, PreTrainedModel, Wav2Vec2Config, Wav2Vec2Model from optimum.neuron.accelerate.accelerator import MODEL_PATCHING_SPECS from optimum.neuron.utils import ModelPatcher from optimum.neuron.utils.testing_utils import is_trainium_test -from optimum.neuron.utils.training_utils import FirstAndLastDataset, is_model_officially_supported +from optimum.neuron.utils.training_utils import is_model_officially_supported @is_trainium_test @@ -45,125 +41,6 @@ class Child(BertForSequenceClassification): assert is_model_officially_supported(bert_model) is True -class FirstAndLastDatasetTest(TestCase): - def _create_dataset(self, num_samples: int, dataset_type: Union[Literal["map"], Literal["iterable"]]) -> Dataset: - random_sample = {"my_sample": torch.rand(4, 3, 24, 24)} - - class MapStyle(Dataset): - def __init__(self, num_samples: int): - self.num_samples = num_samples - - def __getitem__(self, key) -> Dict[str, torch.Tensor]: - return random_sample - - def __len__(self) -> int: - return self.num_samples - - class IterableStyle(IterableDataset): - def __init__(self, num_samples: int): - self.num_samples = num_samples - - def __iter__(self): - count = 0 - while count < self.num_samples: - yield random_sample - count += 1 - - dataset_class = MapStyle if dataset_type == "map" else IterableStyle - return dataset_class(num_samples) - - def test_map_style_dataset(self): - batch_size = 16 - gradient_accumulation_steps = 4 - world_size = 2 - non_divisible_num_samples = batch_size * 200 + 1 - divisible_num_samples = batch_size * 200 - num_repeat = 10 - - # Case 1: the batch size does not divide the number of samples. - dataloader = DataLoader(self._create_dataset(non_divisible_num_samples, "map"), batch_size=batch_size) - first_and_last = FirstAndLastDataset(dataloader, num_repeat=num_repeat) - self.assertEqual(len(first_and_last), num_repeat * 2) - - # Case 2: the batch size divides the number of samples. - dataloader = DataLoader(self._create_dataset(divisible_num_samples, "map"), batch_size=batch_size) - first_and_last = FirstAndLastDataset(dataloader, num_repeat=num_repeat) - self.assertEqual(len(first_and_last), num_repeat) - - # Case 3: the batch size does not divide the number of samples and we have gradient accumulation / multiple processes. - dataloader = DataLoader(self._create_dataset(non_divisible_num_samples, "map"), batch_size=batch_size) - first_and_last = FirstAndLastDataset( - dataloader, - num_repeat=num_repeat, - gradient_accumulation_steps=gradient_accumulation_steps, - world_size=world_size, - ) - self.assertEqual(len(first_and_last) / (gradient_accumulation_steps * world_size), num_repeat * 2) - - # Case 4: the batch size divides the number of samples and we have gradient accumulation / multiple processes. - dataloader = DataLoader(self._create_dataset(divisible_num_samples, "map"), batch_size=batch_size) - first_and_last = FirstAndLastDataset( - dataloader, - num_repeat=num_repeat, - gradient_accumulation_steps=gradient_accumulation_steps, - world_size=world_size, - ) - self.assertEqual(len(first_and_last) / (gradient_accumulation_steps * world_size), num_repeat) - - def test_iterable_style_dataset(self): - batch_size = 16 - gradient_accumulation_steps = 4 - world_size = 2 - non_divisible_num_samples = batch_size * 200 + 1 - divisible_num_samples = batch_size * 200 - num_repeat = 10 - - # Case 1: the batch size does not divide the number of samples. - dataloader = DataLoader(self._create_dataset(non_divisible_num_samples, "iterable"), batch_size=batch_size) - first_and_last = FirstAndLastDataset(dataloader, num_repeat=num_repeat) - self.assertEqual(len(first_and_last), num_repeat * 2) - - # Case 2: the batch size divides the number of samples. - dataloader = DataLoader(self._create_dataset(divisible_num_samples, "iterable"), batch_size=batch_size) - first_and_last = FirstAndLastDataset(dataloader, num_repeat=num_repeat) - self.assertEqual(len(first_and_last), num_repeat * 2) - - # Case 3: the batch size does not divide the number of samples and we have gradient accumulation / multiple processes. - dataloader = DataLoader(self._create_dataset(non_divisible_num_samples, "iterable"), batch_size=batch_size) - first_and_last = FirstAndLastDataset( - dataloader, - num_repeat=num_repeat, - gradient_accumulation_steps=gradient_accumulation_steps, - world_size=world_size, - ) - self.assertEqual(len(first_and_last) / (gradient_accumulation_steps * world_size), num_repeat * 2) - - # Case 4: the batch size divides the number of samples and we have gradient accumulation / multiple processes. - dataloader = DataLoader(self._create_dataset(divisible_num_samples, "iterable"), batch_size=batch_size) - first_and_last = FirstAndLastDataset( - dataloader, - num_repeat=num_repeat, - gradient_accumulation_steps=gradient_accumulation_steps, - world_size=world_size, - ) - self.assertEqual(len(first_and_last) / (gradient_accumulation_steps * world_size), num_repeat * 2) - - # Case 5: only one batch. - dataloader = DataLoader(self._create_dataset(batch_size, "iterable"), batch_size=batch_size) - first_and_last = FirstAndLastDataset(dataloader, num_repeat=num_repeat) - self.assertEqual(len(first_and_last), num_repeat) - - # Case 6: only one batch with gradient accumulation / multiple processes. - dataloader = DataLoader(self._create_dataset(batch_size, "iterable"), batch_size=batch_size) - first_and_last = FirstAndLastDataset( - dataloader, - num_repeat=num_repeat, - gradient_accumulation_steps=gradient_accumulation_steps, - world_size=world_size, - ) - self.assertEqual(len(first_and_last) / (gradient_accumulation_steps * world_size), num_repeat) - - def test_patch_model(): bert_model = BertForSequenceClassification(BertConfig()) patching_specs = []