Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
[NeuralChat] add customized system prompts (#1179)
Browse files Browse the repository at this point in the history
* add customized system prompts

Signed-off-by: lvliang-intel <liang1.lv@intel.com>
  • Loading branch information
Spycsh authored Feb 4, 2024
1 parent e6f87ab commit 04b2f87
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,14 @@ def prepare_prompt(self, prompt: str, task: str = ""):
self.conv_template.append_message(self.conv_template.roles[1], None)
return self.conv_template.get_prompt()

def set_customized_system_prompts(self, system_prompts, model_path: str, task: str = ""):
"""Override the system prompts of the model path and the task."""
if system_prompts is None or len(system_prompts) == 0:
raise Exception("Please check the model system prompts, should not be None!")
else:
self.get_conv_template(model_path, task)
self.conv_template.conv.system_message = system_prompts

def register_plugin_instance(self, plugin_name, instance):
"""
Register a plugin instance.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@

class UnitTest(unittest.TestCase):
def setUp(self) -> None:
config = PipelineConfig(model_name_or_path="facebook/opt-125m")
chatbot = build_chatbot(config)
router.set_chatbot(chatbot)
self.config = PipelineConfig(model_name_or_path="facebook/opt-125m")
self.chatbot = build_chatbot(self.config)
router.set_chatbot(self.chatbot)

def test_text_chat(self):
# Create a sample chat completion request object
Expand All @@ -42,6 +42,16 @@ def test_text_chat(self):
response = client.post("/v1/chat/completions", json=chat_request.dict())
assert response.status_code == 200

def test_text_chat_with_customized_prompt(self):
self.chatbot.set_customized_system_prompts(system_prompts="You cannot tell jokes",
model_path=self.chatbot.model_name,)
# Create a sample chat completion request object
chat_request = ChatCompletionRequest(
prompt="Tell me about Intel Xeon processors.",
)
response = client.post("/v1/chat/completions", json=chat_request.dict())
assert response.status_code == 200


if __name__ == "__main__":
unittest.main()
8 changes: 8 additions & 0 deletions workflows/chatbot/inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,12 @@ def parse_args():
default="v2",
help="the version of return stats format",
)
parser.add_argument(
"--system_prompt",
type=str,
default="None",
help="the customized system prompt",
)
args = parser.parse_args()
return args

Expand Down Expand Up @@ -226,6 +232,8 @@ def main():
optimization_config=MixedPrecisionConfig(dtype=args.dtype)
)
chatbot = build_chatbot(config)
if args.system_prompt:
chatbot.set_customized_system_prompts(system_prompts=args.system_prompt, model_path=base_model_path)
gen_config = GenerationConfig(
task=args.task,
temperature=args.temperature,
Expand Down

0 comments on commit 04b2f87

Please sign in to comment.