Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup obsolete code #555

Merged
merged 11 commits into from
Apr 9, 2024
139 changes: 6 additions & 133 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import collections
import contextlib
import inspect
import os
import re
import shutil
Expand Down Expand Up @@ -54,7 +53,6 @@
AutocastBackend,
ModelParallelismPlugin,
NeuronDistributedType,
NeuronFullyShardedDataParallelPlugin,
get_tied_parameters_dict,
patch_accelerate_is_tpu_available,
tie_parameters,
Expand Down Expand Up @@ -124,14 +122,8 @@ def __init__(
full_kwargs["gradient_accumulation_steps"] = gradient_accumulation_steps

fsdp_plugin = full_kwargs["fsdp_plugin"]
if fsdp_plugin is None:
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true":
fsdp_plugin = NeuronFullyShardedDataParallelPlugin()
elif not isinstance(fsdp_plugin, NeuronFullyShardedDataParallelPlugin):
raise ValueError(
"The fsdp_plugin must be an instance of NeuronFullyShardedDataParallelPlugin to use XLA FSDP with "
f"the NeuronAccelerator, but an instance of {type(fsdp_plugin)} was given here."
)
if fsdp_plugin is not None:
raise ValueError("FSDP is not supported.")
self.fsdp_plugin = fsdp_plugin

self._model_cpu_parameters_to_xla = {}
Expand All @@ -151,15 +143,9 @@ def __init__(
enabled = self.state.mixed_precision == "bf16" and autocast_backend is AutocastBackend.AMP
self.autocast_handler = AutocastKwargs(enabled=enabled)

if self.fsdp_plugin is not None and self.zero_1:
raise ValueError("Either enable XLA ZeRO Stage 1 or XLA FSDP but not both.")

if self.process_index == -1 and self.zero_1:
raise ValueError("XLA ZeRO Stage 1 can only be enabled in a distributed training setting.")

if fsdp_plugin is not None and mp_plugin is not None:
raise ValueError("It is not possible to both use neuronx_distributed Tensor Parallelism and XLA FSDP.")

if num_steps != 1:
self.gradient_accumulation_steps = num_steps

Expand Down Expand Up @@ -349,76 +335,6 @@ def patch_model_for_neuron(
model_patcher.patch()
return model

def prepare_model_for_xla_fsdp(
self, model: torch.nn.Module, device_placement: Optional[bool] = None, evaluation_mode: bool = False
):
if device_placement is None:
device_placement = self.device_placement
self._models.append(model)
# We check only for models loaded with `accelerate`

# Checks if any of the child module has the attribute `hf_device_map`.
has_hf_device_map = False
for m in model.modules():
if hasattr(m, "hf_device_map"):
has_hf_device_map = True
break

if getattr(model, "is_loaded_in_8bit", False) and getattr(model, "hf_device_map", False):
model_devices = set(model.hf_device_map.values())
if len(model_devices) > 1:
raise ValueError(
"You can't train a model that has been loaded in 8-bit precision on multiple devices."
)

current_device_index = list(model_devices)[0]
if torch.device(current_device_index) != self.device:
# if on the first device (GPU 0) we don't care
if (self.device.index is not None) or (current_device_index != 0):
raise ValueError(
"You can't train a model that has been loaded in 8-bit precision on a different device than the one "
"you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device()}"
"you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device() or device_map={'':torch.xpu.current_device()}"
)

if "cpu" in model_devices or "disk" in model_devices:
raise ValueError(
"You can't train a model that has been loaded in 8-bit precision with CPU or disk offload."
)
elif device_placement and not has_hf_device_map:
model = model.to(self.device)

try:
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
except ImportError:
raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")

if not evaluation_mode:
# Check if the model is already a FSDP model due to `Manual Wrapping` and if so,
# don't wrap it again
# TODO: validate which arguments work for XLA FSDP.
if type(model) != FSDP:
self.state.fsdp_plugin.set_auto_wrap_policy(model)
fsdp_plugin = self.state.fsdp_plugin
kwargs = {
"sharding_strategy": fsdp_plugin.sharding_strategy,
"cpu_offload": fsdp_plugin.cpu_offload,
"auto_wrap_policy": fsdp_plugin.auto_wrap_policy,
"backward_prefetch": fsdp_plugin.backward_prefetch,
"mixed_precision": fsdp_plugin.mixed_precision_policy,
"ignored_modules": fsdp_plugin.ignored_modules,
"device_id": self.device,
}
signature = inspect.signature(FSDP.__init__).parameters.keys()
if "limit_all_gathers" in signature:
kwargs["limit_all_gathers"] = fsdp_plugin.limit_all_gathers
if "use_orig_params" in signature:
kwargs["use_orig_params"] = fsdp_plugin.use_orig_params
model = FSDP(model, **kwargs)
self._models[-1] = model

return model

@requires_neuronx_distributed
def _prepare_model_for_mp(
self, model: torch.nn.Module, device_placement: Optional[bool] = None, evaluation_mode: bool = False
Expand Down Expand Up @@ -502,41 +418,22 @@ def prepare_model(
model.config.output_attentions = False
model.config.output_hidden_states = False

if self.distributed_type is NeuronDistributedType.XLA_FSDP:
return self.prepare_model_for_xla_fsdp(
model, device_placement=device_placement, evaluation_mode=evaluation_mode
)
elif self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
return self._prepare_model_for_mp(
model, device_placement=device_placement, evaluation_mode=evaluation_mode
)
move_model_to_device(model, xm.xla_device())
device_placement = False
return super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode)

def backward_for_xla_fsdp(self, loss, **kwargs):
if self.scaler is not None:
self.scaler.scale(loss).backward(**kwargs)
else:
loss.backward(**kwargs)

def backward(self, loss, **kwargs):
if self.distributed_type != DistributedType.DEEPSPEED:
loss = loss / self.gradient_accumulation_steps
if self.distributed_type is NeuronDistributedType.XLA_FSDP:
self.backward_for_xla_fsdp(loss, **kwargs)
elif self.scaler is not None:
if self.scaler is not None:
self.scaler.scale(loss).backward(**kwargs)
else:
loss.backward(**kwargs)

def clip_grad_norm_for_xla_fsdp(self, parameters, max_norm, norm_type: int = 2):
self.unscale_gradients()
parameters = list(parameters)
for model in self._models:
if parameters == list(model.parameters()):
return model.clip_grad_norm_(max_norm, norm_type)

@contextlib.contextmanager
def autocast(self, cache_enabled: bool = False, autocast_handler: Optional[AutocastKwargs] = None):
if cache_enabled:
Expand Down Expand Up @@ -575,17 +472,10 @@ def _prepare_clip_grad_norm(self, parameters, max_norm, norm_type: int = 2):
return opt.prepare_clip_grad_norm(parameters, max_norm, norm_type=norm_type)

def clip_grad_norm_(self, parameters, max_norm, norm_type=2):
if self.distributed_type is NeuronDistributedType.XLA_FSDP:
return self.clip_grad_norm_for_xla_fsdp(parameters, max_norm, norm_type=norm_type)
elif self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM or self.zero_1:
if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM or self.zero_1:
return self._prepare_clip_grad_norm(parameters, max_norm, norm_type=norm_type)
return super().clip_grad_norm_(parameters, max_norm, norm_type=norm_type)

def clip_grad_value_(self, parameters, clip_value):
if self.distributed_type is NeuronDistributedType.XLA_FSDP:
raise Exception("XLA FSDP does not support `clip_grad_value_`. Use `clip_grad_norm_` instead.")
return super().clip_grad_value_(parameters, clip_value)

def _custom_save_state(
self,
save_model_func: Optional[Callable[["Accelerator", "PreTrainedModel", Union[str, Path], int], Any]],
Expand Down Expand Up @@ -661,21 +551,6 @@ def _inner(folder):
self.project_configuration.iteration += 1
return save_location

def save_state_for_xla_fsdp(self, output_dir: Optional[str] = None, **save_model_func_kwargs):
def save_model_func(accelelerator, model, output_dir, i):
logger.info("Saving FSDP model")
self.state.fsdp_plugin.save_model(accelelerator, model, output_dir, i)
logger.info(f"FSDP Model saved to the directory {output_dir}")

def save_optimizer_func(accelerator, optimizer, model, output_dir, i):
logger.info("Saving FSDP Optimizer")
self.state.fsdp_plugin.save_optimizer(accelerator, optimizer, model, output_dir, i)
logger.info(f"FSDP Optimizer saved to the directory {output_dir}")

return self._custom_save_state(
save_model_func, save_optimizer_func, output_dir=output_dir, **save_model_func_kwargs
)

def save_state_for_mp(self, output_dir: Optional[str] = None, **save_model_func_kwargs):
# The model is saved at the same time as the optimizer.
save_model_func = None
Expand All @@ -692,9 +567,7 @@ def save_optimizer_func(accelerator, optimizer, model, output_dir, i):

@patch_within_function(("accelerate.checkpointing.xm", xm), ignore_missing_attributes=True)
def save_state(self, output_dir: Optional[str] = None, **save_model_func_kwargs) -> str:
if self.distributed_type is NeuronDistributedType.XLA_FSDP:
return self.save_state_for_xla_fsdp(output_dir=output_dir, **save_model_func_kwargs)
elif self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
return self.save_state_for_mp(output_dir=output_dir, **save_model_func_kwargs)
return super().save_state(output_dir=output_dir, **save_model_func_kwargs)

Expand Down
2 changes: 0 additions & 2 deletions optimum/neuron/accelerate/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ def step(self, closure=None):
optimizer_args = {"closure": closure} if closure is not None else {}
# By default barrier=False, but making sure it's the case here since we use ParalleLoader.
xm.optimizer_step(self.optimizer, optimizer_args=optimizer_args, barrier=False)
elif self.accelerator_state.distributed_type is NeuronDistributedType.XLA_FSDP:
self.optimizer.step(closure)
elif self.accelerator_state.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
if parallel_layers.parallel_state.get_data_parallel_size() > 1:
bucket_allreduce_gradients(xm._fetch_gradients(self.optimizer))
Expand Down
18 changes: 4 additions & 14 deletions optimum/neuron/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
parse_choice_from_env,
parse_flag_from_env,
)
from accelerate.utils.dataclasses import FullyShardedDataParallelPlugin, SageMakerDistributedType
from accelerate.utils.dataclasses import SageMakerDistributedType

from ...utils import logging
from ..utils import is_neuronx_distributed_available, is_torch_xla_available
Expand All @@ -41,7 +41,7 @@
set_common_flags,
set_neuron_cc_flags_for_torch_amp,
)
from .utils import NeuronDistributedType, NeuronFullyShardedDataParallelPlugin
from .utils import NeuronDistributedType
from .utils.dataclasses import AutocastBackend, ModelParallelismPlugin


Expand Down Expand Up @@ -201,7 +201,7 @@ def __init__(self, cpu: bool = False, **kwargs):
self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0)

def wait_for_everyone(self):
if self.distributed_type in [NeuronDistributedType.XLA_FSDP, NeuronDistributedType.MODEL_PARALLELISM]:
if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
xm.rendezvous("accelerate.utils.wait_for_everyone")
else:
super().wait_for_everyone()
Expand Down Expand Up @@ -303,17 +303,7 @@ def __init__(
pipeline_model_parallel_size=self.mp_plugin.pipeline_parallel_size,
)

if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true":
self.distributed_type = NeuronDistributedType.XLA_FSDP
if self._mixed_precision != "no":
# TODO: do we need that?
fsdp_plugin.set_mixed_precision(self._mixed_precision)
if isinstance(fsdp_plugin, FullyShardedDataParallelPlugin) and not isinstance(
fsdp_plugin, NeuronFullyShardedDataParallelPlugin
):
fsdp_plugin.__class__ = NeuronFullyShardedDataParallelPlugin
self.fsdp_plugin = fsdp_plugin
elif os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true" and not cpu:
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true" and not cpu:
self.deepspeed_plugin = deepspeed_plugin
elif self.distributed_type == DistributedType.MULTI_GPU:
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true":
Expand Down
1 change: 0 additions & 1 deletion optimum/neuron/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,5 @@
AutocastBackend,
ModelParallelismPlugin,
NeuronDistributedType,
NeuronFullyShardedDataParallelPlugin,
)
from .misc import get_tied_parameters_dict, patch_accelerate_is_tpu_available, tie_parameters
Loading
Loading