diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 4192ffbb3..7f386bbff 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -11,7 +11,13 @@ from langchain.chat_models.base import BaseChatModel from langchain.llms.sagemaker_endpoint import LLMContentHandler from langchain.llms.utils import enforce_stop_tokens -from langchain.prompts import PromptTemplate +from langchain.prompts import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + MessagesPlaceholder, + PromptTemplate, + SystemMessagePromptTemplate, +) from langchain.pydantic_v1 import BaseModel, Extra, root_validator from langchain.schema import LLMResult from langchain.utils import get_from_dict_or_env @@ -42,6 +48,23 @@ from pydantic.main import ModelMetaclass +SYSTEM_PROMPT = """ +You are Jupyternaut, a conversational assistant living in JupyterLab to help users. +You are not a language model, but rather an application built on a foundation model from {provider_name} called {local_model_id}. +You are talkative and you provide lots of specific details from the foundation model's context. +You may use Markdown to format your response. +Code blocks must be formatted in Markdown. +Math should be rendered with inline TeX markup, surrounded by $. +If you do not know the answer to a question, answer truthfully by responding that you do not know. +The following is a friendly conversation between you and a human. +""".strip() + +DEFAULT_TEMPLATE = """Current conversation: +{history} +Human: {input} +AI:""" + + class EnvAuthStrategy(BaseModel): """Require one auth token via an environment variable.""" @@ -265,6 +288,31 @@ def get_prompt_template(self, format) -> PromptTemplate: else: return self.prompt_templates["text"] # Default to plain format + def get_chat_prompt_template(self) -> PromptTemplate: + """ + Produce a prompt template optimised for chat conversation. + The template should take two variables: history and input. + """ + if self.is_chat_provider: + return ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT).format( + provider_name=self.name, local_model_id=self.model_id + ), + MessagesPlaceholder(variable_name="history"), + HumanMessagePromptTemplate.from_template("{input}"), + ] + ) + else: + return PromptTemplate( + input_variables=["history", "input"], + template=SYSTEM_PROMPT.format( + provider_name=self.name, local_model_id=self.model_id + ) + + "\n\n" + + DEFAULT_TEMPLATE, + ) + @property def is_chat_provider(self): return isinstance(self, BaseChatModel) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 3a76fba44..584f0b33f 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -4,32 +4,9 @@ from jupyter_ai_magics.providers import BaseProvider from langchain.chains import ConversationChain from langchain.memory import ConversationBufferWindowMemory -from langchain.prompts import ( - ChatPromptTemplate, - HumanMessagePromptTemplate, - MessagesPlaceholder, - PromptTemplate, - SystemMessagePromptTemplate, -) from .base import BaseChatHandler, SlashCommandRoutingType -SYSTEM_PROMPT = """ -You are Jupyternaut, a conversational assistant living in JupyterLab to help users. -You are not a language model, but rather an application built on a foundation model from {provider_name} called {local_model_id}. -You are talkative and you provide lots of specific details from the foundation model's context. -You may use Markdown to format your response. -Code blocks must be formatted in Markdown. -Math should be rendered with inline TeX markup, surrounded by $. -If you do not know the answer to a question, answer truthfully by responding that you do not know. -The following is a friendly conversation between you and a human. -""".strip() - -DEFAULT_TEMPLATE = """Current conversation: -{history} -Human: {input} -AI:""" - class DefaultChatHandler(BaseChatHandler): id = "default" @@ -49,27 +26,10 @@ def create_llm_chain( model_parameters = self.get_model_parameters(provider, provider_params) llm = provider(**provider_params, **model_parameters) - if llm.is_chat_provider: - prompt_template = ChatPromptTemplate.from_messages( - [ - SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT).format( - provider_name=llm.name, local_model_id=llm.model_id - ), - MessagesPlaceholder(variable_name="history"), - HumanMessagePromptTemplate.from_template("{input}"), - ] - ) - self.memory = ConversationBufferWindowMemory(return_messages=True, k=2) - else: - prompt_template = PromptTemplate( - input_variables=["history", "input"], - template=SYSTEM_PROMPT.format( - provider_name=llm.name, local_model_id=llm.model_id - ) - + "\n\n" - + DEFAULT_TEMPLATE, - ) - self.memory = ConversationBufferWindowMemory(k=2) + prompt_template = llm.get_chat_prompt_template() + self.memory = ConversationBufferWindowMemory( + return_messages=llm.is_chat_provider, k=2 + ) self.llm = llm self.llm_chain = ConversationChain(