Skip to content

Commit

Permalink
Update get_unpad_data patching for multipack (#2013)
Browse files Browse the repository at this point in the history
* Update `get_unpad_data` patching for multipack

* Update src/axolotl/utils/models.py

* Update src/axolotl/utils/models.py

* Add test case

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
  • Loading branch information
3 people authored and bursteratom committed Nov 18, 2024
1 parent 802a68c commit 6ca9151
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 58 deletions.
72 changes: 15 additions & 57 deletions src/axolotl/monkeypatch/multipack.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""multipack patching for v2 of sample packing"""

import importlib

import transformers
Expand Down Expand Up @@ -27,71 +28,28 @@
]


def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
if model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "deepseek_v2":
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
if has_remote_code:
patch_remote(model_name)
elif hasattr(transformers, "modeling_flash_attention_utils"):
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
return

# retain for legacy
if model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
elif model_type == "llama":
if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"):
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "mistral":
if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"):
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2_moe":
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "phi":
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemma":
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemma2":
transformers.models.gemma2.modeling_gemma2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "starcoder2":
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()


def patch_remote(model_name, config_name, modeling_name):
def patch_remote(model_name):
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_* to be available
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
parts = model_config.__class__.__module__.split(".")
parts[-1] = parts[-1].replace("configuration_", "modeling_", 1)
module_name = ".".join(parts)
modeling_arch = importlib.import_module(module_name)
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
if hasattr(modeling_arch, "_get_unpad_data"):
modeling_arch._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
9 changes: 8 additions & 1 deletion src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,10 +395,17 @@ def apply_patches(self) -> None:
and self.cfg.flash_attention
and self.cfg.sample_packing
):
has_remote_code = (
"auto_map" in self.model_config
and "AutoModelForCausalLM" in self.model_config["auto_map"]
)
if has_remote_code and self.cfg.trust_remote_code is False:
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
has_remote_code = self.cfg.trust_remote_code
patch_for_multipack(
self.cfg.model_config_type,
model_name=self.cfg.base_model,
is_remote_code=self.cfg.trust_remote_code,
has_remote_code=has_remote_code,
)

if self.cfg.is_llama_derived_model:
Expand Down
66 changes: 66 additions & 0 deletions tests/e2e/test_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
E2E tests for llama
"""

import logging
import os
import unittest
from pathlib import Path

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from .utils import with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"


class TestLlama(unittest.TestCase):
"""
Test case for Llama models
"""

@with_temp_dir
def test_fft_trust_remote_code(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"trust_remote_code": True,
"sequence_len": 512,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()

0 comments on commit 6ca9151

Please sign in to comment.