From b393c0967aa8ea5dbfa62d9cf592fc571d4e37b3 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Thu, 1 Aug 2024 14:40:43 +0800 Subject: [PATCH] feat(model): Support mistral nemo --- dbgpt/configs/model_config.py | 4 ++++ dbgpt/model/adapter/hf_adapter.py | 11 +++++++++++ 2 files changed, 15 insertions(+) diff --git a/dbgpt/configs/model_config.py b/dbgpt/configs/model_config.py index bd28072fc..4d02a2730 100644 --- a/dbgpt/configs/model_config.py +++ b/dbgpt/configs/model_config.py @@ -211,6 +211,9 @@ def get_device() -> str: "mixtral-8x7b-instruct-v0.1": os.path.join( MODEL_PATH, "Mixtral-8x7B-Instruct-v0.1" ), + "mistral-nemo-instruct-2407": os.path.join( + MODEL_PATH, "Mistral-Nemo-Instruct-2407" + ), # https://huggingface.co/upstage/SOLAR-10.7B-Instruct-v1.0 "solar-10.7b-instruct-v1.0": os.path.join(MODEL_PATH, "SOLAR-10.7B-Instruct-v1.0"), # https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca @@ -238,6 +241,7 @@ def get_device() -> str: "gemma-7b-it": os.path.join(MODEL_PATH, "gemma-7b-it"), # https://huggingface.co/google/gemma-2b-it "gemma-2b-it": os.path.join(MODEL_PATH, "gemma-2b-it"), + "gemma-2-2b-it": os.path.join(MODEL_PATH, "gemma-2-2b-it"), "gemma-2-9b-it": os.path.join(MODEL_PATH, "gemma-2-9b-it"), "gemma-2-27b-it": os.path.join(MODEL_PATH, "gemma-2-27b-it"), "starling-lm-7b-beta": os.path.join(MODEL_PATH, "Starling-LM-7B-beta"), diff --git a/dbgpt/model/adapter/hf_adapter.py b/dbgpt/model/adapter/hf_adapter.py index 415a16a58..3facbac24 100644 --- a/dbgpt/model/adapter/hf_adapter.py +++ b/dbgpt/model/adapter/hf_adapter.py @@ -198,6 +198,16 @@ def do_match(self, lower_model_name_or_path: Optional[str] = None): ) +class MistralNemo(NewHFChatModelAdapter): + def do_match(self, lower_model_name_or_path: Optional[str] = None): + return ( + lower_model_name_or_path + and "mistral" in lower_model_name_or_path + and "nemo" in lower_model_name_or_path + and "instruct" in lower_model_name_or_path + ) + + class SOLARAdapter(NewHFChatModelAdapter): """ https://huggingface.co/upstage/SOLAR-10.7B-Instruct-v1.0 @@ -627,6 +637,7 @@ def load(self, model_path: str, from_pretrained_kwargs: dict): register_model_adapter(YiAdapter) register_model_adapter(Yi15Adapter) register_model_adapter(Mixtral8x7BAdapter) +register_model_adapter(MistralNemo) register_model_adapter(SOLARAdapter) register_model_adapter(GemmaAdapter) register_model_adapter(Gemma2Adapter)