From 04b2f87e0582d95ec8030061bd0c91a1b9d87174 Mon Sep 17 00:00:00 2001 From: Sihan Chen <39623753+Spycsh@users.noreply.github.com> Date: Mon, 5 Feb 2024 06:49:10 +0800 Subject: [PATCH] [NeuralChat] add customized system prompts (#1179) * add customized system prompts Signed-off-by: lvliang-intel --- .../neural_chat/models/base_model.py | 8 ++++++++ .../tests/ci/server/test_textchat_server.py | 16 +++++++++++++--- workflows/chatbot/inference/generate.py | 8 ++++++++ 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/intel_extension_for_transformers/neural_chat/models/base_model.py b/intel_extension_for_transformers/neural_chat/models/base_model.py index 57bbb1f8e45..5d2303fd4a8 100644 --- a/intel_extension_for_transformers/neural_chat/models/base_model.py +++ b/intel_extension_for_transformers/neural_chat/models/base_model.py @@ -455,6 +455,14 @@ def prepare_prompt(self, prompt: str, task: str = ""): self.conv_template.append_message(self.conv_template.roles[1], None) return self.conv_template.get_prompt() + def set_customized_system_prompts(self, system_prompts, model_path: str, task: str = ""): + """Override the system prompts of the model path and the task.""" + if system_prompts is None or len(system_prompts) == 0: + raise Exception("Please check the model system prompts, should not be None!") + else: + self.get_conv_template(model_path, task) + self.conv_template.conv.system_message = system_prompts + def register_plugin_instance(self, plugin_name, instance): """ Register a plugin instance. diff --git a/intel_extension_for_transformers/neural_chat/tests/ci/server/test_textchat_server.py b/intel_extension_for_transformers/neural_chat/tests/ci/server/test_textchat_server.py index 0ce11697a32..66c262931ee 100644 --- a/intel_extension_for_transformers/neural_chat/tests/ci/server/test_textchat_server.py +++ b/intel_extension_for_transformers/neural_chat/tests/ci/server/test_textchat_server.py @@ -29,9 +29,9 @@ class UnitTest(unittest.TestCase): def setUp(self) -> None: - config = PipelineConfig(model_name_or_path="facebook/opt-125m") - chatbot = build_chatbot(config) - router.set_chatbot(chatbot) + self.config = PipelineConfig(model_name_or_path="facebook/opt-125m") + self.chatbot = build_chatbot(self.config) + router.set_chatbot(self.chatbot) def test_text_chat(self): # Create a sample chat completion request object @@ -42,6 +42,16 @@ def test_text_chat(self): response = client.post("/v1/chat/completions", json=chat_request.dict()) assert response.status_code == 200 + def test_text_chat_with_customized_prompt(self): + self.chatbot.set_customized_system_prompts(system_prompts="You cannot tell jokes", + model_path=self.chatbot.model_name,) + # Create a sample chat completion request object + chat_request = ChatCompletionRequest( + prompt="Tell me about Intel Xeon processors.", + ) + response = client.post("/v1/chat/completions", json=chat_request.dict()) + assert response.status_code == 200 + if __name__ == "__main__": unittest.main() diff --git a/workflows/chatbot/inference/generate.py b/workflows/chatbot/inference/generate.py index 45654d31a63..2838b156f0e 100644 --- a/workflows/chatbot/inference/generate.py +++ b/workflows/chatbot/inference/generate.py @@ -168,6 +168,12 @@ def parse_args(): default="v2", help="the version of return stats format", ) + parser.add_argument( + "--system_prompt", + type=str, + default="None", + help="the customized system prompt", + ) args = parser.parse_args() return args @@ -226,6 +232,8 @@ def main(): optimization_config=MixedPrecisionConfig(dtype=args.dtype) ) chatbot = build_chatbot(config) + if args.system_prompt: + chatbot.set_customized_system_prompts(system_prompts=args.system_prompt, model_path=base_model_path) gen_config = GenerationConfig( task=args.task, temperature=args.temperature,