Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor advanced prompt core. #1350

Merged
merged 2 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions api/core/model_providers/models/entity/model_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from langchain.load.serializable import Serializable
from pydantic import BaseModel

class AppMode(enum.Enum):
COMPLETION = 'completion'
CHAT = 'chat'


class ModelMode(enum.Enum):
COMPLETION = 'completion'
Expand Down
242 changes: 172 additions & 70 deletions api/core/model_providers/models/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages, \
to_lc_messages
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules, AppMode
from core.model_providers.providers.base import BaseModelProvider
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import PromptTemplateParser
Expand Down Expand Up @@ -330,83 +330,191 @@ def get_prompt(self, mode: str,
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
return [PromptMessage(content=prompt)], stops

def get_advanced_prompt(self, app_mode: str,
app_model_config: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> List[PromptMessage]:

def get_advanced_prompt(self,
app_mode: str,
app_model_config: str,
inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> List[PromptMessage]:

model_mode = app_model_config.model_dict['mode']
conversation_histories_role = {}

raw_prompt_list = []
app_mode_enum = AppMode(app_mode)
model_mode_enum = ModelMode(model_mode)

prompt_messages = []

if app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
raw_prompt_list = [{
'role': MessageType.USER.value,
'text': prompt_text
}]
conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
elif app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
elif app_mode == 'completion' and model_mode == ModelMode.CHAT.value:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
elif app_mode == 'completion' and model_mode == ModelMode.COMPLETION.value:
prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
raw_prompt_list = [{
'role': MessageType.USER.value,
'text': prompt_text
}]
else:
raise Exception("app_mode or model_mode not support")
if app_mode_enum == AppMode.CHAT:
if model_mode_enum == ModelMode.COMPLETION:
prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query, context, memory)
elif model_mode_enum == ModelMode.CHAT:
prompt_messages = self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, context, memory)
elif app_mode_enum == AppMode.COMPLETION:
if model_mode_enum == ModelMode.CHAT:
prompt_messages = self._get_completion_app_chat_model_prompt_messages(app_model_config, inputs, context)
elif model_mode_enum == ModelMode.COMPLETION:
prompt_messages = self._get_completion_app_completion_model_prompt_messages(app_model_config, inputs, context)

return prompt_messages

def _set_context_variable(self, context, prompt_template, prompt_inputs):
if '#context#' in prompt_template.variable_keys:
if context:
prompt_inputs['#context#'] = context
else:
prompt_inputs['#context#'] = ''

def _set_query_variable(self, query, prompt_template, prompt_inputs):
if '#query#' in prompt_template.variable_keys:
if query:
prompt_inputs['#query#'] = query
else:
prompt_inputs['#query#'] = ''

def _set_histories_variable(self, memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs):
if '#histories#' in prompt_template.variable_keys:
if memory:
tmp_human_message = PromptBuilder.to_human_message(
prompt_content=raw_prompt,
inputs={ '#histories#': '', **prompt_inputs }
)

rest_tokens = self._calculate_rest_token(tmp_human_message)

memory.human_prefix = conversation_histories_role['user_prefix']
memory.ai_prefix = conversation_histories_role['assistant_prefix']
histories = self._get_history_messages_from_memory(memory, rest_tokens)
prompt_inputs['#histories#'] = histories
else:
prompt_inputs['#histories#'] = ''

def _append_chat_histories(self, memory, prompt_messages):
if memory:
rest_tokens = self._calculate_rest_token(prompt_messages)

memory.human_prefix = MessageType.USER.value
memory.ai_prefix = MessageType.ASSISTANT.value
histories = self._get_history_messages_list_from_memory(memory, rest_tokens)
prompt_messages.extend(histories)

def _calculate_rest_token(self, prompt_messages):
rest_tokens = 2000

if self.model_rules.max_tokens.max:
curr_message_tokens = self.get_num_tokens(to_prompt_messages(prompt_messages))
max_tokens = self.model_kwargs.max_tokens
rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)

return rest_tokens

def _format_prompt(self, prompt_template, prompt_inputs):
prompt = prompt_template.format(
prompt_inputs
)

prompt = re.sub(r'<\|.*?\|>', '', prompt)
return prompt

def _get_chat_app_completion_model_prompt_messages(self,
app_model_config: str,
inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> List[PromptMessage]:

raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']

prompt_messages = []
prompt = ''

prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}

self._set_context_variable(context, prompt_template, prompt_inputs)

self._set_query_variable(query, prompt_template, prompt_inputs)

self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs)

prompt = self._format_prompt(prompt_template, prompt_inputs)

prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt))

return prompt_messages

def _get_chat_app_chat_model_prompt_messages(self,
app_model_config: str,
inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> List[PromptMessage]:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']

prompt_messages = []

for prompt_item in raw_prompt_list:
prompt = prompt_item['text']
raw_prompt = prompt_item['text']
prompt = ''

# set prompt template variables
prompt_template = PromptTemplateParser(template=prompt)
prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}

if '#context#' in prompt:
if context:
prompt_inputs['#context#'] = context
else:
prompt_inputs['#context#'] = ''

if '#query#' in prompt:
if query:
prompt_inputs['#query#'] = query
else:
prompt_inputs['#query#'] = ''

if '#histories#' in prompt:
if memory and app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
memory.human_prefix = conversation_histories_role['user_prefix']
memory.ai_prefix = conversation_histories_role['assistant_prefix']
histories = self._get_history_messages_from_memory(memory, 2000)
prompt_inputs['#histories#'] = histories
else:
prompt_inputs['#histories#'] = ''

prompt = prompt_template.format(
prompt_inputs
)
self._set_context_variable(context, prompt_template, prompt_inputs)

prompt = re.sub(r'<\|.*?\|>', '', prompt)
prompt = self._format_prompt(prompt_template, prompt_inputs)

prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))

self._append_chat_histories(memory, prompt_messages)

if memory and app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
memory.human_prefix = MessageType.USER.value
memory.ai_prefix = MessageType.ASSISTANT.value
histories = self._get_history_messages_list_from_memory(memory, 2000)
prompt_messages.extend(histories)
prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))

return prompt_messages

def _get_completion_app_completion_model_prompt_messages(self,
app_model_config: str,
inputs: dict,
context: Optional[str]) -> List[PromptMessage]:
raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']

prompt_messages = []
prompt = ''

prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}

if app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))
self._set_context_variable(context, prompt_template, prompt_inputs)

prompt = self._format_prompt(prompt_template, prompt_inputs)

prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt))

return prompt_messages

def _get_completion_app_chat_model_prompt_messages(self,
app_model_config: str,
inputs: dict,
context: Optional[str]) -> List[PromptMessage]:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']

prompt_messages = []

for prompt_item in raw_prompt_list:
raw_prompt = prompt_item['text']
prompt = ''

prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}

self._set_context_variable(context, prompt_template, prompt_inputs)

prompt = self._format_prompt(prompt_template, prompt_inputs)

prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))

return prompt_messages

def prompt_file_name(self, mode: str) -> str:
Expand Down Expand Up @@ -452,13 +560,7 @@ def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict
}
)

if self.model_rules.max_tokens.max:
curr_message_tokens = self.get_num_tokens(to_prompt_messages([tmp_human_message]))
max_tokens = self.model_kwargs.max_tokens
rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
else:
rest_tokens = 2000
rest_tokens = self._calculate_rest_token(tmp_human_message)

memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
Expand Down
25 changes: 13 additions & 12 deletions api/services/advanced_prompt_template_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

import copy

from core.model_providers.models.entity.model_params import AppMode, ModelMode
from core.prompt.advanced_prompt_templates import CHAT_APP_COMPLETION_PROMPT_CONFIG, CHAT_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_COMPLETION_PROMPT_CONFIG, \
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, CONTEXT, BAICHUAN_CONTEXT

Expand All @@ -22,15 +23,15 @@ def get_prompt(cls, args: dict) -> dict:
def get_common_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict:
context_prompt = copy.deepcopy(CONTEXT)

if app_mode == 'chat':
if model_mode == 'completion':
if app_mode == AppMode.CHAT.value:
if model_mode == ModelMode.COMPLETION.value:
return cls.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
elif model_mode == 'chat':
elif model_mode == ModelMode.CHAT.value:
return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
elif app_mode == 'completion':
if model_mode == 'completion':
elif app_mode == AppMode.COMPLETION.value:
if model_mode == ModelMode.COMPLETION.value:
return cls.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
elif model_mode == 'chat':
elif model_mode == ModelMode.CHAT.value:
return cls.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)

@classmethod
Expand All @@ -51,13 +52,13 @@ def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str)
def get_baichuan_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict:
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)

if app_mode == 'chat':
if model_mode == 'completion':
if app_mode == AppMode.CHAT.value:
if model_mode == ModelMode.COMPLETION.value:
return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
elif model_mode == 'chat':
elif model_mode == ModelMode.CHAT.value:
return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)
elif app_mode == 'completion':
if model_mode == 'completion':
elif app_mode == AppMode.COMPLETION.value:
if model_mode == ModelMode.COMPLETION.value:
return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
elif model_mode == 'chat':
elif model_mode == ModelMode.CHAT.value:
return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)
11 changes: 9 additions & 2 deletions api/services/app_model_config_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from core.agent.agent_executor import PlanningStrategy
from core.model_providers.model_provider_factory import ModelProviderFactory
from core.model_providers.models.entity.model_params import ModelType, ModelMode
from core.model_providers.models.entity.model_params import ModelType, ModelMode, AppMode
from models.account import Account
from services.dataset_service import DatasetService

Expand Down Expand Up @@ -418,7 +418,7 @@ def is_advanced_prompt_valid(config: dict, app_mode: str) -> None:
if config['model']["mode"] not in ['chat', 'completion']:
raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced")

if app_mode == 'chat' and config['model']["mode"] == ModelMode.COMPLETION.value:
if app_mode == AppMode.CHAT.value and config['model']["mode"] == ModelMode.COMPLETION.value:
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']

Expand All @@ -427,3 +427,10 @@ def is_advanced_prompt_valid(config: dict, app_mode: str) -> None:

if not assistant_prefix:
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'


if config['model']["mode"] == ModelMode.CHAT.value:
prompt_list = config['chat_prompt_config']['prompt']

if len(prompt_list) > 10:
raise ValueError("prompt messages must be less than 10")