Skip to content

Commit

Permalink
Merge pull request #120 from OpenAccess-AI-Collective/model-from-path
Browse files Browse the repository at this point in the history
split up llama model loading so config can be loaded from base config and models can be loaded from a path
  • Loading branch information
winglian authored May 31, 2023
2 parents 876edd8 + e3c494c commit c7021e1
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ base_model_ignore_patterns:
# if the base_model repo on hf hub doesn't include configuration .json files,
# you can set that here, or leave this empty to default to base_model
base_model_config: ./llama-7b-hf
# Optional tokenizer configuration override in case you want to use a different tokenizer
# than the one defined in the base model
tokenizer_config:
# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
model_type: AutoModelForCausalLM
# Corresponding tokenizer for the model AutoTokenizer is a good choice
Expand Down
5 changes: 3 additions & 2 deletions scripts/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ def train(
cfg.bf16 = False

# load the tokenizer first
logging.info("loading tokenizer...")
tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg)
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
logging.info(f"loading tokenizer... {tokenizer_config}")
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)

if check_not_in(
["inference", "shard", "merge_lora"], kwargs
Expand Down
18 changes: 12 additions & 6 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@
import bitsandbytes as bnb
import torch
import transformers
from transformers import AutoModelForCausalLM # noqa: F401
from transformers import PreTrainedModel # noqa: F401
from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
from transformers import ( # noqa: F401
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
LlamaConfig,
)

try:
from transformers import LlamaForCausalLM
Expand All @@ -25,24 +30,23 @@

if TYPE_CHECKING:
from peft import PeftConfig # noqa: F401
from transformers import PreTrainedTokenizer # noqa: F401

from axolotl.utils.dict import DictDefault # noqa: F401


def load_tokenizer(
base_model_config,
tokenizer_config,
tokenizer_type,
cfg,
):
if tokenizer_type:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
base_model_config,
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
base_model_config,
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
)

Expand Down Expand Up @@ -172,8 +176,10 @@ def load_model(
)
load_in_8bit = False
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
config = LlamaConfig.from_pretrained(base_model_config)
model = LlamaForCausalLM.from_pretrained(
base_model,
config=config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
Expand Down

0 comments on commit c7021e1

Please sign in to comment.