Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: galore optimizer #1370

Closed
wants to merge 5 commits into from
Closed

Conversation

maximegmd
Copy link
Contributor

Adds support for Galore optimizers

Still a WIP, untested.

@fakerybakery
Copy link

@maximegmd any chance you could provide an example config file on how to use this?

@casper-hansen
Copy link
Collaborator

@maximegmd any chance you could provide an example config file on how to use this?

Set the optimizer argument in the axolotl config to one of [galore_adamw, galore_adamw8bit, galore_ada_factor]. Probably galore_adamw8bit will give the biggest optimization.

@younesbelkada
Copy link
Contributor

younesbelkada commented Mar 11, 2024

Hi !
I tried to upstream these changes into transformers so that you guys can directly leverage that in axolotl: huggingface/transformers#29588 I am running some quick experiments so far it seems the training is quite slow, here is how I am running the training using Galore:

import torch
import datasets
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
import trl

train_dataset = datasets.load_dataset('imdb', split='train')

args = TrainingArguments(
    output_dir="./test-galore",
    max_steps=100,
    per_device_train_batch_size=2,
    optim="galore_adamw",
    galore_target_modules=["attn", "mlp"],
    gradient_checkpointing=True,
)

# model_id = "mistralai/Mistral-7B-v0.1"
model_id = "google/gemma-2b"

config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)

trainer = trl.SFTTrainer(
    model=model, 
    args=args,
    train_dataset=train_dataset,
    dataset_text_field='text',
    max_seq_length=512,
)

trainer.train()

@younesbelkada
Copy link
Contributor

Got it working on Gemma-2b !

import torch
import datasets
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
import trl

train_dataset = datasets.load_dataset('imdb', split='train')

args = TrainingArguments(
    output_dir="./test-galore-new",
    max_steps=100,
    per_device_train_batch_size=2,
    optim="galore_adamw",
    galore_target_modules=["attn", "mlp"],
    gradient_checkpointing=True,
    logging_strategy="steps",
    logging_steps=5,
    learning_rate=2e-3,
    save_strategy="no",
    run_name="galore-imdb"
)

model_id = "google/gemma-2b"

config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)

trainer = trl.SFTTrainer(
    model=model, 
    args=args,
    train_dataset=train_dataset,
    dataset_text_field='text',
    max_seq_length=512,
)

trainer.train()

After ~50 steps:

{'loss': 11.8705, 'grad_norm': 13.43569564819336, 'learning_rate': 0.0019, 'epoch': 0.0}                                                                
{'loss': 9.8208, 'grad_norm': 7.467105865478516, 'learning_rate': 0.0018000000000000002, 'epoch': 0.0}                                                  
{'loss': 8.606, 'grad_norm': 6.2992963790893555, 'learning_rate': 0.0017, 'epoch': 0.0}                                                                 
{'loss': 7.8436, 'grad_norm': 5.3465986251831055, 'learning_rate': 0.0016, 'epoch': 0.0}                                                                
{'loss': 7.6177, 'grad_norm': 6.2392964363098145, 'learning_rate': 0.0015, 'epoch': 0.0}                                                                
{'loss': 7.5346, 'grad_norm': 4.487287998199463, 'learning_rate': 0.0014, 'epoch': 0.0}                                                                 
{'loss': 7.6909, 'grad_norm': 4.615128517150879, 'learning_rate': 0.0013000000000000002, 'epoch': 0.0}                                                  
{'loss': 7.0826, 'grad_norm': 5.807451248168945, 'learning_rate': 0.0012, 'epoch': 0.0}                                                                 
{'loss': 7.1936, 'grad_norm': 3.470165729522705, 'learning_rate': 0.0011, 'epoch': 0.0}                                                                 
{'loss': 7.1926, 'grad_norm': 4.511063575744629, 'learning_rate': 0.001, 'epoch': 0.0}  

Using a single A100 80GB, the loss seems to converge nicely. It is expected that at init the optimizer takes some time to initialize itself

@savanth14
Copy link

@younesbelkada I tried your gemma code and faced the following error:

image

@winglian
Copy link
Collaborator

Thanks @younesbelkada! I'll open up another PR with just the validation and training args pieces and wait for the upstream integration. Much appreciated!

@younesbelkada
Copy link
Contributor

thanks so much @winglian !

@winglian
Copy link
Collaborator

Superseded by #1409. Thanks for getting this rolling @maximegmd. Props to @younesbelkada for getting this working upstream in transformers.

@winglian winglian closed this Mar 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants