Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

beta support for multipack with gemmoe #1402

Merged
merged 1 commit into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/axolotl/monkeypatch/multipack.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""multipack patching for v2 of sample packing"""
import importlib

import transformers
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.integrations import is_deepspeed_zero3_enabled

from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
Expand All @@ -12,11 +15,12 @@
"falcon",
"phi",
"gemma",
"gemmoe",
"starcoder2",
]


def patch_for_multipack(model_type):
def patch_for_multipack(model_type, model_name=None):
if model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
Expand All @@ -43,3 +47,15 @@ def patch_for_multipack(model_type):
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemmoe":
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_gemmoe to be available
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(
".configuration_gemmoe", ".modeling_gemmoe"
)
modeling_gemmoe = importlib.import_module(module_name)
modeling_gemmoe._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
2 changes: 1 addition & 1 deletion src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def load_model(
and cfg.flash_attention
and cfg.sample_packing
):
patch_for_multipack(cfg.model_config_type)
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
elif cfg.is_llama_derived_model:
# Modify all llama derived models in one block

Expand Down
Loading