-
Notifications
You must be signed in to change notification settings - Fork 26.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT / Trainer: LOMO optimizer support #30178
Changes from 13 commits
e7d7bbe
8cdc21e
029a9c9
62b5e0e
629413c
4907531
51c8e9e
d9499c5
afaabfc
a57dd5e
beb7edc
5184057
ac007ee
741a1a4
80105e1
49ce45e
8d008a5
40db2fa
5a536bf
c1ac8bf
9d547be
efe04a5
6cadb75
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -153,6 +153,7 @@ | |
is_galore_torch_available, | ||
is_in_notebook, | ||
is_ipex_available, | ||
is_lomo_available, | ||
is_peft_available, | ||
is_safetensors_available, | ||
is_sagemaker_dp_enabled, | ||
|
@@ -236,9 +237,14 @@ | |
from accelerate.utils import DeepSpeedSchedulerWrapper | ||
|
||
if is_accelerate_available("0.28.0"): | ||
from accelerate.optimizer import AcceleratedOptimizer | ||
from accelerate.utils import DataLoaderConfiguration | ||
|
||
|
||
if is_lomo_available(): | ||
from lomo_optim import AdaLomo, Lomo | ||
|
||
|
||
def _is_peft_model(model): | ||
if is_peft_available(): | ||
classes_to_check = (PeftModel,) if is_peft_available() else () | ||
|
@@ -251,6 +257,12 @@ def _is_peft_model(model): | |
return False | ||
|
||
|
||
def _unwrap_optimizer(optimizer): | ||
if isinstance(optimizer, AcceleratedOptimizer): | ||
optimizer = optimizer.optimizer | ||
return optimizer | ||
|
||
|
||
def _get_fsdp_ckpt_kwargs(): | ||
# TODO: @AjayP13, @younesbelkada replace this check with version check at the next `accelerate` release | ||
if is_accelerate_available() and "adapter_only" in list(inspect.signature(save_fsdp_model).parameters): | ||
|
@@ -1059,12 +1071,23 @@ def create_optimizer(self): | |
if "params" in optimizer_kwargs: | ||
optimizer_grouped_parameters = optimizer_kwargs.pop("params") | ||
|
||
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs` | ||
# e.g. for LOMO optimizer. | ||
if "model" in optimizer_kwargs: | ||
optimizer_grouped_parameters = optimizer_kwargs.pop("model") | ||
|
||
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict` | ||
# to avoid arguments conflicts. | ||
if "optimizer_dict" in optimizer_kwargs: | ||
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict") | ||
|
||
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) | ||
|
||
# LOMO has a slightly different optimizer API, see: https://github.com/OpenLMLab/LOMO/issues/73#issuecomment-2049612639 | ||
self._is_lomo_optimizer = is_lomo_available() and isinstance( | ||
_unwrap_optimizer(self.optimizer), (Lomo, AdaLomo) | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can certainly do this, or just do This is a bit safer so seem good to me :) |
||
|
||
if optimizer_cls.__name__ == "Adam8bit": | ||
import bitsandbytes | ||
|
||
|
@@ -1382,6 +1405,26 @@ def optimizer_hook(param): | |
|
||
if args.optim == OptimizerNames.GALORE_ADAFACTOR: | ||
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False}) | ||
elif args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: | ||
if not is_lomo_available(): | ||
raise ImportError( | ||
"You need to install `lomo_optim` in order to use LOMO optimizers" | ||
" install it with `pip install lomo-optim`" | ||
) | ||
if not is_accelerate_available("0.30.0"): | ||
raise ImportError("You need to have `accelerate>=0.30.0` to be able to use LOMO optimizers") | ||
|
||
if model is None: | ||
raise ValueError("You need to pass a `model` in order to correctly initialize a LOMO optimizer.") | ||
|
||
from lomo_optim import AdaLomo, Lomo | ||
|
||
if "ada" in args.optim: | ||
optimizer_cls = AdaLomo | ||
else: | ||
optimizer_cls = Lomo | ||
|
||
optimizer_kwargs.update({"model": model}) | ||
else: | ||
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") | ||
return optimizer_cls, optimizer_kwargs | ||
|
@@ -2045,6 +2088,9 @@ def _inner_training_loop( | |
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( | ||
self.model, self.optimizer, self.lr_scheduler | ||
) | ||
elif self._is_lomo_optimizer: | ||
# In this case we are in DDP + LOMO, which should be supported | ||
self.optimizer = self.accelerator.prepare(self.optimizer) | ||
|
||
if self.is_fsdp_enabled: | ||
self.model = self.model_wrapped = model | ||
|
@@ -2143,7 +2189,6 @@ def _inner_training_loop( | |
self._globalstep_last_logged = self.state.global_step | ||
model.zero_grad() | ||
grad_norm: Optional[float] = None | ||
|
||
self.control = self.callback_handler.on_train_begin(args, self.state, self.control) | ||
|
||
total_batched_samples = 0 | ||
|
@@ -2275,8 +2320,8 @@ def _inner_training_loop( | |
else: | ||
grad_norm = _grad_norm | ||
|
||
# Optimizer step | ||
self.optimizer.step() | ||
|
||
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped | ||
if optimizer_was_run: | ||
# Delay optimizer scheduling until metrics are generated | ||
|
@@ -3187,7 +3232,7 @@ def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True): | |
|
||
return ctx_manager | ||
|
||
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: | ||
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], **kwargs) -> torch.Tensor: | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Perform a training step on a batch of inputs. | ||
|
||
|
@@ -3201,12 +3246,24 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, | |
|
||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the | ||
argument `labels`. Check your model's documentation for all accepted arguments. | ||
kwargs: | ||
Additional key-word arguments to pass along for custom optimizers | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Return: | ||
`torch.Tensor`: The tensor with training loss on this batch. | ||
""" | ||
model.train() | ||
inputs = self._prepare_inputs(inputs) | ||
_is_lomo = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmmmm..... needing to have this in the training_step is a good indication the abstractions here are leaky. Once we have the optimizer created, we shouldn't really need to know what type of optimizer it is in the rest of the code There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice catch .. i think it was an old code, now should be much cleaner ! |
||
|
||
if is_lomo_available(): | ||
from lomo_optim import AdaLomo, Lomo | ||
|
||
_is_lomo = isinstance(_unwrap_optimizer(self.optimizer), (Lomo, AdaLomo)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd move out the common logic to something like |
||
|
||
# For LOMO optimizers you need to explicitly use the learnign rate | ||
if _is_lomo: | ||
kwargs["learning_rate"] = self._get_learning_rate() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we have the optimizer set, don't we also have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes indeed ! changed that |
||
|
||
if is_sagemaker_mp_enabled(): | ||
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) | ||
|
@@ -3225,7 +3282,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, | |
with amp.scale_loss(loss, self.optimizer) as scaled_loss: | ||
scaled_loss.backward() | ||
else: | ||
self.accelerator.backward(loss) | ||
self.accelerator.backward(loss, **kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens if we pass the learning rate through when lomo isn't being used? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will break .. 😢 but we: |
||
|
||
return loss.detach() / self.args.gradient_accumulation_steps | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,6 +63,7 @@ | |
require_deepspeed, | ||
require_galore_torch, | ||
require_intel_extension_for_pytorch, | ||
require_lomo, | ||
require_optuna, | ||
require_peft, | ||
require_ray, | ||
|
@@ -1202,6 +1203,49 @@ def test_dataloader_without_dataset(self): | |
trainer.train() | ||
trainer.evaluate() | ||
|
||
@require_lomo | ||
@require_torch_gpu | ||
def test_lomo(self): | ||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) | ||
tiny_llama = LlamaForCausalLM(config) | ||
|
||
previous_params = {n: p.clone() for n, p in tiny_llama.named_parameters()} | ||
|
||
x = torch.randint(0, 100, (128,)) | ||
train_dataset = RepeatDataset(x) | ||
|
||
with tempfile.TemporaryDirectory() as tmpdir: | ||
# Trainer without inf/nan filter | ||
args = TrainingArguments(tmpdir, learning_rate=1e-2, logging_steps=5, optim="lomo", max_steps=20) | ||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) | ||
|
||
# Check this works | ||
_ = trainer.train() | ||
|
||
for name, param in tiny_llama.named_parameters(): | ||
self.assertFalse(torch.allclose(param, previous_params[name].to(param.device), rtol=1e-12, atol=1e-12)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This tolerance is super small, do we expect optimizers to make changes on this order? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is ok to put it higher, I decided to put it low so that even small changes would be captured by the test (sometimes higher tolerances would fail even though the weights are properly updated + with a high learning rate, so just to be on the safe zone) |
||
|
||
@require_lomo | ||
@require_torch_gpu | ||
def test_adalomo(self): | ||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) | ||
tiny_llama = LlamaForCausalLM(config) | ||
x = torch.randint(0, 100, (128,)) | ||
train_dataset = RepeatDataset(x) | ||
|
||
with tempfile.TemporaryDirectory() as tmpdir: | ||
# Trainer without inf/nan filter | ||
args = TrainingArguments( | ||
tmpdir, | ||
learning_rate=1e-9, | ||
logging_steps=5, | ||
optim="adalomo", | ||
) | ||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) | ||
|
||
# Check this works | ||
_ = trainer.train() | ||
|
||
def test_galore_matched_modules(self): | ||
regex_patterns = [r".*.attn.*", r".*.mlp.*"] | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it guaranteed to only ever be one level of wrapping?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes indeed ! https://github.com/huggingface/accelerate/blob/4ad4d28c49a9818e985ea12d66a89fe73fe73c87/src/accelerate/optimizer.py#L56