From 517a59a0d5b02f24f294bd5fe65acf0ade25ad65 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 6 Aug 2023 03:55:50 +0900 Subject: [PATCH] Feat: Add rope scaling --- README.md | 4 ++++ src/axolotl/utils/models.py | 1 + 2 files changed, 5 insertions(+) diff --git a/README.md b/README.md index fe22bbc31b..4238362462 100644 --- a/README.md +++ b/README.md @@ -472,6 +472,10 @@ landmark_attention: # xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py # llama only xpos_rope: +# RoPE Scaling https://github.com/huggingface/transformers/pull/24653 +rope_scaling: + type: # linear | dynamic + factor: # float # resume from a specific checkpoint dir resume_from_checkpoint: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 253bdcbd84..f4edcf9757 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -223,6 +223,7 @@ def load_model( load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, torch_dtype=torch_dtype, device_map="auto" if cfg.world_size == 1 else cfg.device_map, + rope_scaling=cfg.rope_scaling, **model_kwargs, ) # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention: