Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT / Trainer: LOMO optimizer support #30178

Merged
merged 23 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e7d7bbe
add V1 - adalomo not working yet
younesbelkada Apr 11, 2024
8cdc21e
add todo docs + refactor from comments
younesbelkada Apr 11, 2024
029a9c9
adjust LR
younesbelkada Apr 11, 2024
62b5e0e
add docs
younesbelkada Apr 11, 2024
629413c
Merge remote-tracking branch 'upstream/main' into add-lomo
younesbelkada Apr 17, 2024
4907531
add more elaborated test
younesbelkada Apr 17, 2024
51c8e9e
Apply suggestions from code review
younesbelkada Apr 22, 2024
d9499c5
fix
younesbelkada Apr 22, 2024
afaabfc
push
younesbelkada Apr 22, 2024
a57dd5e
Merge remote-tracking branch 'upstream/main' into add-lomo
younesbelkada May 3, 2024
beb7edc
add accelerate check
younesbelkada May 3, 2024
5184057
Merge remote-tracking branch 'upstream/main' into add-lomo
younesbelkada May 7, 2024
ac007ee
fix DDP case
younesbelkada May 7, 2024
741a1a4
Merge remote-tracking branch 'origin/main' into add-lomo
younesbelkada May 16, 2024
80105e1
Apply suggestions from code review
younesbelkada May 16, 2024
49ce45e
fix
younesbelkada May 16, 2024
8d008a5
Merge branch 'add-lomo' of https://github.com/younesbelkada/transform…
younesbelkada May 16, 2024
40db2fa
init kwargs
younesbelkada May 16, 2024
5a536bf
safely add attribute
younesbelkada May 16, 2024
c1ac8bf
Merge remote-tracking branch 'origin/main' into add-lomo
younesbelkada May 16, 2024
9d547be
revert to enum logic
younesbelkada May 17, 2024
efe04a5
Update src/transformers/trainer.py
younesbelkada May 17, 2024
6cadb75
Merge remote-tracking branch 'origin/main' into add-lomo
younesbelkada May 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions docs/source/en/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,56 @@ trainer.train()

Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.

## LOMO optimizer

The LOMO optimizers have been introduced in [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://hf.co/papers/2306.09782) and [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://hf.co/papers/2310.10195).
They both consist of an efficient full-parameter fine-tuning method. These optimizers fuse the gradient computation and the parameter update in one step to reduce memory usage. Supported optimizers for LOMO are `"lomo"` and `"adalomo"`. First either install LOMO from pypi `pip install lomo-optim` or install it from source with `pip install git+https://github.com/OpenLMLab/LOMO.git`.

<Tip>

According to the authors, it is recommended to use `AdaLomo` without `grad_norm` to get better performance and higher throughput.

</Tip>

Below is a simple script to demonstrate how to fine-tune [google/gemma-2b](https://huggingface.co/google/gemma-2b) on IMDB dataset in full precision:

```python
import torch
import datasets
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM
import trl

train_dataset = datasets.load_dataset('imdb', split='train')

args = TrainingArguments(
output_dir="./test-lomo",
max_steps=1000,
per_device_train_batch_size=4,
optim="adalomo",
gradient_checkpointing=True,
logging_strategy="steps",
logging_steps=1,
learning_rate=2e-6,
save_strategy="no",
run_name="lomo-imdb",
)

model_id = "google/gemma-2b"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0)

trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=1024,
)

trainer.train()
```

## Accelerate and Trainer

The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
is_keras_nlp_available,
is_levenshtein_available,
is_librosa_available,
is_lomo_available,
is_natten_available,
is_nltk_available,
is_onnx_available,
Expand Down Expand Up @@ -337,6 +338,14 @@ def require_galore_torch(test_case):
return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)


def require_lomo(test_case):
"""
Decorator marking a test that requires LOMO. These tests are skipped when LOMO-optim isn't installed.
https://github.com/OpenLMLab/LOMO
"""
return unittest.skipUnless(is_lomo_available(), "test requires LOMO")(test_case)


def require_cv2(test_case):
"""
Decorator marking a test that requires OpenCV.
Expand Down
65 changes: 61 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@
is_galore_torch_available,
is_in_notebook,
is_ipex_available,
is_lomo_available,
is_peft_available,
is_safetensors_available,
is_sagemaker_dp_enabled,
Expand Down Expand Up @@ -236,9 +237,14 @@
from accelerate.utils import DeepSpeedSchedulerWrapper

if is_accelerate_available("0.28.0"):
from accelerate.optimizer import AcceleratedOptimizer
from accelerate.utils import DataLoaderConfiguration


if is_lomo_available():
from lomo_optim import AdaLomo, Lomo


def _is_peft_model(model):
if is_peft_available():
classes_to_check = (PeftModel,) if is_peft_available() else ()
Expand All @@ -251,6 +257,12 @@ def _is_peft_model(model):
return False


def _unwrap_optimizer(optimizer):
if isinstance(optimizer, AcceleratedOptimizer):
optimizer = optimizer.optimizer
return optimizer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it guaranteed to only ever be one level of wrapping?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



def _get_fsdp_ckpt_kwargs():
# TODO: @AjayP13, @younesbelkada replace this check with version check at the next `accelerate` release
if is_accelerate_available() and "adapter_only" in list(inspect.signature(save_fsdp_model).parameters):
Expand Down Expand Up @@ -1059,12 +1071,23 @@ def create_optimizer(self):
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")

# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for LOMO optimizer.
if "model" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("model")

# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")

self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)

# LOMO has a slightly different optimizer API, see: https://github.com/OpenLMLab/LOMO/issues/73#issuecomment-2049612639
self._is_lomo_optimizer = is_lomo_available() and isinstance(
_unwrap_optimizer(self.optimizer), (Lomo, AdaLomo)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can certainly do this, or just do optimizer.optimizer (since we know it'll be wrapped by accelerate).

This is a bit safer so seem good to me :)


if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes

Expand Down Expand Up @@ -1382,6 +1405,26 @@ def optimizer_hook(param):

if args.optim == OptimizerNames.GALORE_ADAFACTOR:
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
elif args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
if not is_lomo_available():
raise ImportError(
"You need to install `lomo_optim` in order to use LOMO optimizers"
" install it with `pip install lomo-optim`"
)
if not is_accelerate_available("0.30.0"):
raise ImportError("You need to have `accelerate>=0.30.0` to be able to use LOMO optimizers")

if model is None:
raise ValueError("You need to pass a `model` in order to correctly initialize a LOMO optimizer.")

from lomo_optim import AdaLomo, Lomo

if "ada" in args.optim:
optimizer_cls = AdaLomo
else:
optimizer_cls = Lomo

optimizer_kwargs.update({"model": model})
else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs
Expand Down Expand Up @@ -2045,6 +2088,9 @@ def _inner_training_loop(
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)
elif self._is_lomo_optimizer:
# In this case we are in DDP + LOMO, which should be supported
self.optimizer = self.accelerator.prepare(self.optimizer)

if self.is_fsdp_enabled:
self.model = self.model_wrapped = model
Expand Down Expand Up @@ -2143,7 +2189,6 @@ def _inner_training_loop(
self._globalstep_last_logged = self.state.global_step
model.zero_grad()
grad_norm: Optional[float] = None

self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

total_batched_samples = 0
Expand Down Expand Up @@ -2275,8 +2320,8 @@ def _inner_training_loop(
else:
grad_norm = _grad_norm

# Optimizer step
self.optimizer.step()

optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
if optimizer_was_run:
# Delay optimizer scheduling until metrics are generated
Expand Down Expand Up @@ -3187,7 +3232,7 @@ def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):

return ctx_manager

def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], **kwargs) -> torch.Tensor:
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
"""
Perform a training step on a batch of inputs.

Expand All @@ -3201,12 +3246,24 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,

The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.
kwargs:
Additional key-word arguments to pass along for custom optimizers
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

Return:
`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
inputs = self._prepare_inputs(inputs)
_is_lomo = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmmm..... needing to have this in the training_step is a good indication the abstractions here are leaky. Once we have the optimizer created, we shouldn't really need to know what type of optimizer it is in the rest of the code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch .. i think it was an old code, now should be much cleaner !


if is_lomo_available():
from lomo_optim import AdaLomo, Lomo

_is_lomo = isinstance(_unwrap_optimizer(self.optimizer), (Lomo, AdaLomo))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd move out the common logic to something like _is_lomo_optimizer from here and L1086, which handles importing Lomo and AdaLomo and unwrapping the optimizer


# For LOMO optimizers you need to explicitly use the learnign rate
if _is_lomo:
kwargs["learning_rate"] = self._get_learning_rate()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have the optimizer set, don't we also have self._is_lomo_optimizer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes indeed ! changed that


if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
Expand All @@ -3225,7 +3282,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
self.accelerator.backward(loss)
self.accelerator.backward(loss, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if we pass the learning rate through when lomo isn't being used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will break .. 😢 but we:
1- raise an error if users do not have the correct accelerate version with init-ing the trainer with lomo
2- pass learning_rate only if the optimizer is a lomo optimizer
3- removed kwargs in training step
So hopefully this should be safe enough 🙏


return loss.detach() / self.args.gradient_accumulation_steps

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ class OptimizerNames(ExplicitEnum):
GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
LOMO = "lomo"
ADALOMO = "adalomo"


# Sometimes users will pass in a `str` repr of a dict in the CLI
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
is_keras_nlp_available,
is_levenshtein_available,
is_librosa_available,
is_lomo_available,
is_mlx_available,
is_natten_available,
is_ninja_available,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_bitsandbytes_available = _is_package_available("bitsandbytes")
_eetq_available = _is_package_available("eetq")
_galore_torch_available = _is_package_available("galore_torch")
_lomo_available = _is_package_available("lomo_optim")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs")
Expand Down Expand Up @@ -322,6 +323,10 @@ def is_galore_torch_available():
return _galore_torch_available


def is_lomo_available():
return _lomo_available


def is_pyctcdecode_available():
return _pyctcdecode_available

Expand Down
44 changes: 44 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
require_deepspeed,
require_galore_torch,
require_intel_extension_for_pytorch,
require_lomo,
require_optuna,
require_peft,
require_ray,
Expand Down Expand Up @@ -1202,6 +1203,49 @@ def test_dataloader_without_dataset(self):
trainer.train()
trainer.evaluate()

@require_lomo
@require_torch_gpu
def test_lomo(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)

previous_params = {n: p.clone() for n, p in tiny_llama.named_parameters()}

x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(tmpdir, learning_rate=1e-2, logging_steps=5, optim="lomo", max_steps=20)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)

# Check this works
_ = trainer.train()

for name, param in tiny_llama.named_parameters():
self.assertFalse(torch.allclose(param, previous_params[name].to(param.device), rtol=1e-12, atol=1e-12))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tolerance is super small, do we expect optimizers to make changes on this order?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is ok to put it higher, I decided to put it low so that even small changes would be captured by the test (sometimes higher tolerances would fail even though the weights are properly updated + with a high learning rate, so just to be on the safe zone)


@require_lomo
@require_torch_gpu
def test_adalomo(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="adalomo",
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)

# Check this works
_ = trainer.train()

def test_galore_matched_modules(self):
regex_patterns = [r".*.attn.*", r".*.mlp.*"]

Expand Down
Loading