diff --git a/intel_extension_for_transformers/neural_chat/README.md b/intel_extension_for_transformers/neural_chat/README.md index e5aba9f997b..df8d7099afc 100644 --- a/intel_extension_for_transformers/neural_chat/README.md +++ b/intel_extension_for_transformers/neural_chat/README.md @@ -144,6 +144,7 @@ The table below displays the validated model list in NeuralChat for both inferen |LLaMA2 series| ✅| ✅|✅| ✅ | |MPT series| ✅| ✅|✅| ✅ | |Mistral| ✅| ✅|✅| ✅ | +|Mixtral-8x7b-v0.1| ✅| ✅|✅| ✅ | |ChatGLM series| ✅| ✅|✅| ✅ | |Qwen series| ✅| ✅|✅| ✅ | |StarCoder series| | | | ✅ | diff --git a/intel_extension_for_transformers/neural_chat/chatbot.py b/intel_extension_for_transformers/neural_chat/chatbot.py index 57d7daefcbc..855c6837ade 100644 --- a/intel_extension_for_transformers/neural_chat/chatbot.py +++ b/intel_extension_for_transformers/neural_chat/chatbot.py @@ -96,7 +96,8 @@ def build_chatbot(config: PipelineConfig=None): "bloom" in config.model_name_or_path.lower() or \ "starcoder" in config.model_name_or_path.lower() or \ "codegen" in config.model_name_or_path.lower() or \ - "magicoder" in config.model_name_or_path.lower(): + "magicoder" in config.model_name_or_path.lower() or \ + "mixtral" in config.model_name_or_path.lower(): from .models.base_model import BaseModel adapter = BaseModel() else: diff --git a/intel_extension_for_transformers/neural_chat/models/model_utils.py b/intel_extension_for_transformers/neural_chat/models/model_utils.py index 014e68e38bf..ed132dbd90f 100644 --- a/intel_extension_for_transformers/neural_chat/models/model_utils.py +++ b/intel_extension_for_transformers/neural_chat/models/model_utils.py @@ -499,6 +499,7 @@ def load_model( or config.model_type == "mpt" or config.model_type == "llama" or config.model_type == "mistral" + or config.model_type == "mixtral" ) and not ipex_int8) or config.model_type == "opt": with smart_context_manager(use_deepspeed=use_deepspeed): model = AutoModelForCausalLM.from_pretrained( @@ -553,7 +554,7 @@ def load_model( ) else: raise ValueError(f"unsupported model name or path {model_name}, \ - only supports FLAN-T5/LLAMA/MPT/GPT/BLOOM/OPT/QWEN/NEURAL-CHAT/MISTRAL/CODELLAMA/STARCODER/CODEGEN now.") + only supports t5/llama/mpt/gptj/bloom/opt/qwen/mistral/mixtral/gpt_bigcode model type now.") except EnvironmentError as e: logging.error(f"Exception: {e}") if "not a local folder and is not a valid model identifier" in str(e):