diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 964b41f70..fbcaf7a66 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -1,6 +1,9 @@ """multipack patching for v2 of sample packing""" +import importlib import transformers +from accelerate import init_empty_weights +from transformers import AutoConfig, AutoModelForCausalLM from transformers.integrations import is_deepspeed_zero3_enabled from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3 @@ -12,11 +15,12 @@ "falcon", "phi", "gemma", + "gemmoe", "starcoder2", ] -def patch_for_multipack(model_type): +def patch_for_multipack(model_type, model_name=None): if model_type == "mixtral": transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data @@ -43,3 +47,15 @@ def patch_for_multipack(model_type): transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data ) + elif model_type == "gemmoe": + model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + # we need to load the model here in order for modeling_gemmoe to be available + with init_empty_weights(): + AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) + module_name = model_config.__class__.__module__.replace( + ".configuration_gemmoe", ".modeling_gemmoe" + ) + modeling_gemmoe = importlib.import_module(module_name) + modeling_gemmoe._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 53201c996..fce7b20a7 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -429,7 +429,7 @@ def load_model( and cfg.flash_attention and cfg.sample_packing ): - patch_for_multipack(cfg.model_config_type) + patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model) elif cfg.is_llama_derived_model: # Modify all llama derived models in one block