diff --git a/README.md b/README.md index e1391e39b..f8f226963 100644 --- a/README.md +++ b/README.md @@ -171,6 +171,9 @@ base_model_ignore_patterns: # if the base_model repo on hf hub doesn't include configuration .json files, # you can set that here, or leave this empty to default to base_model base_model_config: ./llama-7b-hf +# Optional tokenizer configuration override in case you want to use a different tokenizer +# than the one defined in the base model +tokenizer_config: # If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too model_type: AutoModelForCausalLM # Corresponding tokenizer for the model AutoTokenizer is a good choice diff --git a/scripts/finetune.py b/scripts/finetune.py index 6c42b3061..e1b0b2e59 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -171,8 +171,9 @@ def train( validate_config(cfg) # load the tokenizer first - logging.info("loading tokenizer...") - tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg) + tokenizer_config = cfg.tokenizer_config or cfg.base_model_config + logging.info(f"loading tokenizer... {tokenizer_config}") + tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg) if check_not_in( ["inference", "shard", "merge_lora"], kwargs diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 0737d0f12..cf351a78d 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -10,9 +10,14 @@ import bitsandbytes as bnb import torch import transformers -from transformers import AutoModelForCausalLM # noqa: F401 from transformers import PreTrainedModel # noqa: F401 -from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig +from transformers import ( # noqa: F401 + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + LlamaConfig, +) try: from transformers import LlamaForCausalLM @@ -25,24 +30,23 @@ if TYPE_CHECKING: from peft import PeftConfig # noqa: F401 - from transformers import PreTrainedTokenizer # noqa: F401 from axolotl.utils.dict import DictDefault # noqa: F401 def load_tokenizer( - base_model_config, + tokenizer_config, tokenizer_type, cfg, ): if tokenizer_type: tokenizer = getattr(transformers, tokenizer_type).from_pretrained( - base_model_config, + tokenizer_config, trust_remote_code=cfg.trust_remote_code or False, ) else: tokenizer = AutoTokenizer.from_pretrained( - base_model_config, + tokenizer_config, trust_remote_code=cfg.trust_remote_code or False, ) @@ -172,8 +176,10 @@ def load_model( ) load_in_8bit = False elif is_llama_derived_model and "LlamaForCausalLM" in globals(): + config = LlamaConfig.from_pretrained(base_model_config) model = LlamaForCausalLM.from_pretrained( base_model, + config=config, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype,