diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index bccffb662..c55ffe725 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -18,7 +18,7 @@ from application.extensions import api from application.llm.llm_creator import LLMCreator from application.retriever.retriever_creator import RetrieverCreator -from application.utils import check_required_fields +from application.utils import check_required_fields, limit_chat_history logger = logging.getLogger(__name__) @@ -324,8 +324,7 @@ def post(self): try: question = data["question"] - history = str(data.get("history", [])) - history = str(json.loads(history)) + history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model) conversation_id = data.get("conversation_id") prompt_id = data.get("prompt_id", "default") @@ -456,7 +455,7 @@ def post(self): try: question = data["question"] - history = data.get("history", []) + history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model) conversation_id = data.get("conversation_id") prompt_id = data.get("prompt_id", "default") chunks = int(data.get("chunks", 2)) diff --git a/application/core/settings.py b/application/core/settings.py index a7811ec78..0bace432f 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -16,7 +16,7 @@ class Settings(BaseSettings): MONGO_URI: str = "mongodb://localhost:27017/docsgpt" MODEL_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf") DEFAULT_MAX_HISTORY: int = 150 - MODEL_TOKEN_LIMITS: dict = {"gpt-3.5-turbo": 4096, "claude-2": 1e5} + MODEL_TOKEN_LIMITS: dict = {"gpt-4o-mini": 128000, "gpt-3.5-turbo": 4096, "claude-2": 1e5} UPLOAD_FOLDER: str = "inputs" PARSE_PDF_AS_IMAGE: bool = False VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" diff --git a/application/retriever/brave_search.py b/application/retriever/brave_search.py index 1fd844b26..3d9ae89e6 100644 --- a/application/retriever/brave_search.py +++ b/application/retriever/brave_search.py @@ -2,7 +2,6 @@ from application.retriever.base import BaseRetriever from application.core.settings import settings from application.llm.llm_creator import LLMCreator -from application.utils import num_tokens_from_string from langchain_community.tools import BraveSearch @@ -73,15 +72,8 @@ def gen(self): yield {"source": doc} if len(self.chat_history) > 1: - tokens_current_history = 0 - # count tokens in history for i in self.chat_history: if "prompt" in i and "response" in i: - tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string( - i["response"] - ) - if tokens_current_history + tokens_batch < self.token_limit: - tokens_current_history += tokens_batch messages_combine.append( {"role": "user", "content": i["prompt"]} ) diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 42e318d20..8de625dd8 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -3,7 +3,6 @@ from application.vectorstore.vector_creator import VectorCreator from application.llm.llm_creator import LLMCreator -from application.utils import num_tokens_from_string class ClassicRAG(BaseRetriever): @@ -73,15 +72,8 @@ def gen(self): yield {"source": doc} if len(self.chat_history) > 1: - tokens_current_history = 0 - # count tokens in history for i in self.chat_history: - if "prompt" in i and "response" in i: - tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string( - i["response"] - ) - if tokens_current_history + tokens_batch < self.token_limit: - tokens_current_history += tokens_batch + if "prompt" in i and "response" in i: messages_combine.append( {"role": "user", "content": i["prompt"]} ) @@ -89,7 +81,7 @@ def gen(self): {"role": "system", "content": i["response"]} ) messages_combine.append({"role": "user", "content": self.question}) - + llm = LLMCreator.create_llm( settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key ) diff --git a/application/retriever/duckduck_search.py b/application/retriever/duckduck_search.py index 6ae562269..fa19ead03 100644 --- a/application/retriever/duckduck_search.py +++ b/application/retriever/duckduck_search.py @@ -1,7 +1,6 @@ from application.retriever.base import BaseRetriever from application.core.settings import settings from application.llm.llm_creator import LLMCreator -from application.utils import num_tokens_from_string from langchain_community.tools import DuckDuckGoSearchResults from langchain_community.utilities import DuckDuckGoSearchAPIWrapper @@ -89,16 +88,9 @@ def gen(self): for doc in docs: yield {"source": doc} - if len(self.chat_history) > 1: - tokens_current_history = 0 - # count tokens in history + if len(self.chat_history) > 1: for i in self.chat_history: - if "prompt" in i and "response" in i: - tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string( - i["response"] - ) - if tokens_current_history + tokens_batch < self.token_limit: - tokens_current_history += tokens_batch + if "prompt" in i and "response" in i: messages_combine.append( {"role": "user", "content": i["prompt"]} ) diff --git a/application/utils.py b/application/utils.py index 1fc9e3291..7099a20a9 100644 --- a/application/utils.py +++ b/application/utils.py @@ -46,3 +46,40 @@ def check_required_fields(data, required_fields): def get_hash(data): return hashlib.md5(data.encode()).hexdigest() +def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"): + """ + Limits chat history based on token count. + Returns a list of messages that fit within the token limit. + """ + from application.core.settings import settings + + max_token_limit = ( + max_token_limit + if max_token_limit and + max_token_limit < settings.MODEL_TOKEN_LIMITS.get( + gpt_model, settings.DEFAULT_MAX_HISTORY + ) + else settings.MODEL_TOKEN_LIMITS.get( + gpt_model, settings.DEFAULT_MAX_HISTORY + ) + ) + + + if not history: + return [] + + tokens_current_history = 0 + trimmed_history = [] + + for message in reversed(history): + if "prompt" in message and "response" in message: + tokens_batch = num_tokens_from_string(message["prompt"]) + num_tokens_from_string( + message["response"] + ) + if tokens_current_history + tokens_batch < max_token_limit: + tokens_current_history += tokens_batch + trimmed_history.insert(0, message) + else: + break + + return trimmed_history diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 7b6f11d61..f96a17d40 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -1649,7 +1649,7 @@ "version": "18.3.0", "resolved": "https://registry.npmjs.org/@types/react-dom/-/react-dom-18.3.0.tgz", "integrity": "sha512-EhwApuTmMBmXuFOikhQLIBUn6uFg81SwLMOAUgodJF14SOBOCMdU04gDoYi0WOJJHD144TL32z4yDqCW3dnkQg==", - "devOptional": true, + "dev": true, "dependencies": { "@types/react": "*" }