diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index e4907ccb6e..c18af9760f 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -358,19 +358,13 @@ def load_model( if cfg.gradient_checkpointing == "unsloth": transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper - if ( - hasattr(text_model_config, "model_type") - and text_model_config.model_type == "mllama" - ): + if hasattr(model_config, "model_type") and model_config.model_type == "mllama": if cfg.flash_attention: from axolotl.monkeypatch.attention.mllama import patch_mllama patch_mllama() - if ( - hasattr(text_model_config, "model_type") - and text_model_config.model_type == "btlm" - ): + if hasattr(model_config, "model_type") and model_config.model_type == "btlm": if cfg.flash_attention: from axolotl.monkeypatch.btlm_attn_hijack_flash import ( replace_btlm_attn_with_flash_attn, @@ -379,8 +373,8 @@ def load_model( replace_btlm_attn_with_flash_attn(cfg.base_model) if ( - hasattr(text_model_config, "model_type") - and text_model_config.model_type == "stablelm_epoch" + hasattr(model_config, "model_type") + and model_config.model_type == "stablelm_epoch" ): if cfg.flash_attention and cfg.sample_packing: from axolotl.monkeypatch.stablelm_attn_hijack_flash import ( @@ -702,7 +696,7 @@ def load_model( ) skip_move_to_device = True elif ( - text_model_config.model_type == "llama" + model_config.model_type == "llama" and not cfg.trust_remote_code and not cfg.gptq ):