Skip to content

Commit

Permalink
reverse some checks against text_model_config
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Oct 2, 2024
1 parent c66abb2 commit 2835f0e
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,19 +358,13 @@ def load_model(
if cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper

if (
hasattr(text_model_config, "model_type")
and text_model_config.model_type == "mllama"
):
if hasattr(model_config, "model_type") and model_config.model_type == "mllama":
if cfg.flash_attention:
from axolotl.monkeypatch.attention.mllama import patch_mllama

patch_mllama()

if (
hasattr(text_model_config, "model_type")
and text_model_config.model_type == "btlm"
):
if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
if cfg.flash_attention:
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
replace_btlm_attn_with_flash_attn,
Expand All @@ -379,8 +373,8 @@ def load_model(
replace_btlm_attn_with_flash_attn(cfg.base_model)

if (
hasattr(text_model_config, "model_type")
and text_model_config.model_type == "stablelm_epoch"
hasattr(model_config, "model_type")
and model_config.model_type == "stablelm_epoch"
):
if cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.stablelm_attn_hijack_flash import (
Expand Down Expand Up @@ -702,7 +696,7 @@ def load_model(
)
skip_move_to_device = True
elif (
text_model_config.model_type == "llama"
model_config.model_type == "llama"
and not cfg.trust_remote_code
and not cfg.gptq
):
Expand Down

0 comments on commit 2835f0e

Please sign in to comment.