Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

various batch of fixes #1785

Merged
merged 6 commits into from
Jul 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions examples/llama-3/qlora-fsdp-405b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
base_model: meta-llama/Meta-Llama-3.1-405B
tokenizer_type: AutoTokenizer

load_in_4bit: true
strict: false

datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out/qlora-llama3_1-405b

adapter: qlora

sequence_len: 1024
sample_packing: true
pad_to_sequence_len: true

lora_r: 16
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true

gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001

train_on_inputs: false
group_by_length: false
bf16: true
tf32: true

gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
logging_steps: 1
flash_attention: true

warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
pad_token: <|finetune_right_pad_id|>
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.11.1
transformers==4.43.1
transformers==4.43.3
tokenizers==0.19.1
bitsandbytes==0.43.1
accelerate==0.32.0
Expand Down Expand Up @@ -32,6 +32,7 @@ fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e59
gradio==3.50.2
tensorboard
python-dotenv==1.0.1
autoawq>=0.2.5

mamba-ssm==1.2.0.post1

Expand Down
9 changes: 7 additions & 2 deletions src/axolotl/cli/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
CLI to run training on a model
"""
import logging
import warnings
from pathlib import Path
from typing import Union

Expand Down Expand Up @@ -76,8 +77,12 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):

if parsed_cli_args.download:
model_name = parsed_cfg.base_model
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
with warnings.catch_warnings():
# there are a bunch of useless UserWarnings about
# "copying from a non-meta parameter in the checkpoint to a meta parameter in the current model"
warnings.simplefilter("ignore")
with init_empty_weights(include_buffers=True):
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)

LOG.info(
Fore.GREEN
Expand Down
14 changes: 14 additions & 0 deletions src/axolotl/common/architectures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
Common architecture specific constants
"""

MOE_ARCH_BLOCK = {
"dbrx": "DbrxFFN",
"jamba": "JambaSparseMoeBlock",
"jetmoe": [
"JetMoeMoA",
"JetMoeMoE",
],
"mixtral": "MixtralSparseMoeBlock",
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
}
154 changes: 86 additions & 68 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import importlib.util
import logging
import math
import os
import sys
from abc import abstractmethod
from collections import defaultdict
Expand All @@ -28,7 +29,7 @@
TrainerCallback,
TrainingArguments,
)
from transformers.trainer_utils import seed_worker
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import (
CPOConfig,
Expand Down Expand Up @@ -286,7 +287,77 @@ class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
)


class AxolotlTrainer(Trainer):
class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
"""

args = None # type: AxolotlTrainingArguments

def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.

Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
use_cosine_quadratic = (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
)

use_cosine_min_lr = (
self.args.lr_scheduler_type == "cosine"
and self.args.cosine_min_lr_ratio is not None
)

# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")

self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")

if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")

return self.lr_scheduler


class AxolotlTrainer(SchedulerMixin, Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
Expand Down Expand Up @@ -404,68 +475,6 @@ def create_optimizer(self):

return self.optimizer

def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.

Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
use_cosine_quadratic = (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
)

use_cosine_min_lr = (
self.args.lr_scheduler_type == "cosine"
and self.args.cosine_min_lr_ratio is not None
)

# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")

self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")

if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")

return self.lr_scheduler

def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining:
if self.args.multipack_real_batches:
Expand Down Expand Up @@ -830,6 +839,14 @@ def store_metrics(
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)

def _save_checkpoint(self, model, trial, metrics=None):
# make sure the checkpoint dir exists, since trainer is flakey
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, metrics=metrics)


class AxolotlMambaTrainer(AxolotlTrainer):
"""
Expand Down Expand Up @@ -929,7 +946,7 @@ def create_scheduler(
return self.lr_scheduler


class AxolotlDPOTrainer(DPOTrainer):
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
Expand Down Expand Up @@ -990,23 +1007,23 @@ def tokenize_row(
return res


class AxolotlORPOTrainer(ORPOTrainer):
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""

tag_names = ["axolotl", "orpo"]


class AxolotlKTOTrainer(KTOTrainer):
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""

tag_names = ["axolotl", "kto"]


class AxolotlCPOTrainer(CPOTrainer):
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""
Expand Down Expand Up @@ -1750,6 +1767,7 @@ def build_training_arguments(self, total_num_steps):
if self.cfg.rl == "simpo":
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/prompt_strategies/dpo/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def transform_fn(sample, tokenizer=None):
tokenize=False,
)
chosen_strip_index = result["chosen"].find(chosen["content"])
result["chosen"] = result["chosen"][chosen_strip_index:]
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()

result["rejected"] = tokenizer.apply_chat_template(
[rejected],
Expand All @@ -71,7 +71,7 @@ def transform_fn(sample, tokenizer=None):
tokenize=False,
)
rejected_strip_index = result["rejected"].find(rejected["content"])
result["rejected"] = result["rejected"][rejected_strip_index:]
result["rejected"] = result["rejected"][rejected_strip_index:].rstrip()

return result

Expand Down
29 changes: 13 additions & 16 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,26 +212,23 @@ def terminate_handler(_, __, model_weakref):
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
trainer.accelerator.wait_for_everyone()
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
trainer.save_model(cfg.output_dir)

# the trainer saved a model.safetensors file in the output directory,
# but it is a proxy model and should be deleted
if os.path.exists(os.path.join(cfg.output_dir, "model.safetensors")):
# but it is most likely a proxy model and if so, should be deleted
maybe_proxy = os.path.exists(os.path.join(cfg.output_dir, "model.safetensors"))
maybe_sharded = os.path.exists(
os.path.join(cfg.output_dir, "model.safetensors.index.json")
)

if maybe_proxy and maybe_sharded:
LOG.info(f"Deleting {os.path.join(cfg.output_dir, 'model.safetensors')}")
LOG.info("This is a proxy model and should be deleted")
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))

# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
# For Zero Stages 1 and 2, models are saved as usual in the output directory.
# The model name saved is `pytorch_model.bin`
unwrapped_model.save_pretrained(
cfg.output_dir,
is_main_process=trainer.accelerator.is_main_process,
save_function=trainer.accelerator.save,
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
)
try:
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
except FileNotFoundError:
pass

elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
Expand Down
Loading
Loading