Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limiting Conversational history #1500

Merged
merged 6 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions application/api/answer/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from application.extensions import api
from application.llm.llm_creator import LLMCreator
from application.retriever.retriever_creator import RetrieverCreator
from application.utils import check_required_fields
from application.utils import check_required_fields, limit_chat_history

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -324,8 +324,7 @@

try:
question = data["question"]
history = str(data.get("history", []))
history = str(json.loads(history))
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)

Check warning on line 327 in application/api/answer/routes.py

View check run for this annotation

Codecov / codecov/patch

application/api/answer/routes.py#L327

Added line #L327 was not covered by tests
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")

Expand Down Expand Up @@ -456,7 +455,7 @@

try:
question = data["question"]
history = data.get("history", [])
history = limit_chat_history(json.loads(data.get("history", [])), gpt_model=gpt_model)

Check warning on line 458 in application/api/answer/routes.py

View check run for this annotation

Codecov / codecov/patch

application/api/answer/routes.py#L458

Added line #L458 was not covered by tests
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
chunks = int(data.get("chunks", 2))
Expand Down
2 changes: 1 addition & 1 deletion application/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Settings(BaseSettings):
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
MODEL_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
DEFAULT_MAX_HISTORY: int = 150
MODEL_TOKEN_LIMITS: dict = {"gpt-3.5-turbo": 4096, "claude-2": 1e5}
MODEL_TOKEN_LIMITS: dict = {"gpt-4o-mini": 128000, "gpt-3.5-turbo": 4096, "claude-2": 1e5}
UPLOAD_FOLDER: str = "inputs"
PARSE_PDF_AS_IMAGE: bool = False
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
Expand Down
8 changes: 0 additions & 8 deletions application/retriever/brave_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.utils import num_tokens_from_string
from langchain_community.tools import BraveSearch


Expand Down Expand Up @@ -73,15 +72,8 @@ def gen(self):
yield {"source": doc}

if len(self.chat_history) > 1:
tokens_current_history = 0
# count tokens in history
for i in self.chat_history:
if "prompt" in i and "response" in i:
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
i["response"]
)
if tokens_current_history + tokens_batch < self.token_limit:
tokens_current_history += tokens_batch
messages_combine.append(
{"role": "user", "content": i["prompt"]}
)
Expand Down
12 changes: 2 additions & 10 deletions application/retriever/classic_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from application.vectorstore.vector_creator import VectorCreator
from application.llm.llm_creator import LLMCreator

from application.utils import num_tokens_from_string


class ClassicRAG(BaseRetriever):
Expand Down Expand Up @@ -73,23 +72,16 @@
yield {"source": doc}

if len(self.chat_history) > 1:
tokens_current_history = 0
# count tokens in history
for i in self.chat_history:
if "prompt" in i and "response" in i:
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
i["response"]
)
if tokens_current_history + tokens_batch < self.token_limit:
tokens_current_history += tokens_batch
if "prompt" in i and "response" in i:

Check warning on line 76 in application/retriever/classic_rag.py

View check run for this annotation

Codecov / codecov/patch

application/retriever/classic_rag.py#L76

Added line #L76 was not covered by tests
messages_combine.append(
{"role": "user", "content": i["prompt"]}
)
messages_combine.append(
{"role": "system", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question})

llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
)
Expand Down
12 changes: 2 additions & 10 deletions application/retriever/duckduck_search.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.utils import num_tokens_from_string
from langchain_community.tools import DuckDuckGoSearchResults
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper

Expand Down Expand Up @@ -89,16 +88,9 @@
for doc in docs:
yield {"source": doc}

if len(self.chat_history) > 1:
tokens_current_history = 0
# count tokens in history
if len(self.chat_history) > 1:

Check warning on line 91 in application/retriever/duckduck_search.py

View check run for this annotation

Codecov / codecov/patch

application/retriever/duckduck_search.py#L91

Added line #L91 was not covered by tests
for i in self.chat_history:
if "prompt" in i and "response" in i:
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
i["response"]
)
if tokens_current_history + tokens_batch < self.token_limit:
tokens_current_history += tokens_batch
if "prompt" in i and "response" in i:

Check warning on line 93 in application/retriever/duckduck_search.py

View check run for this annotation

Codecov / codecov/patch

application/retriever/duckduck_search.py#L93

Added line #L93 was not covered by tests
messages_combine.append(
{"role": "user", "content": i["prompt"]}
)
Expand Down
37 changes: 37 additions & 0 deletions application/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,40 @@
def get_hash(data):
return hashlib.md5(data.encode()).hexdigest()

def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
"""
Limits chat history based on token count.
Returns a list of messages that fit within the token limit.
"""
from application.core.settings import settings

Check warning on line 54 in application/utils.py

View check run for this annotation

Codecov / codecov/patch

application/utils.py#L54

Added line #L54 was not covered by tests

max_token_limit = (

Check warning on line 56 in application/utils.py

View check run for this annotation

Codecov / codecov/patch

application/utils.py#L56

Added line #L56 was not covered by tests
max_token_limit
if max_token_limit and
max_token_limit < settings.MODEL_TOKEN_LIMITS.get(
gpt_model, settings.DEFAULT_MAX_HISTORY
)
else settings.MODEL_TOKEN_LIMITS.get(
gpt_model, settings.DEFAULT_MAX_HISTORY
)
)


if not history:
return []

Check warning on line 69 in application/utils.py

View check run for this annotation

Codecov / codecov/patch

application/utils.py#L68-L69

Added lines #L68 - L69 were not covered by tests

tokens_current_history = 0
trimmed_history = []

Check warning on line 72 in application/utils.py

View check run for this annotation

Codecov / codecov/patch

application/utils.py#L71-L72

Added lines #L71 - L72 were not covered by tests

for message in reversed(history):
if "prompt" in message and "response" in message:
tokens_batch = num_tokens_from_string(message["prompt"]) + num_tokens_from_string(

Check warning on line 76 in application/utils.py

View check run for this annotation

Codecov / codecov/patch

application/utils.py#L74-L76

Added lines #L74 - L76 were not covered by tests
message["response"]
)
if tokens_current_history + tokens_batch < max_token_limit:
tokens_current_history += tokens_batch
trimmed_history.insert(0, message)

Check warning on line 81 in application/utils.py

View check run for this annotation

Codecov / codecov/patch

application/utils.py#L79-L81

Added lines #L79 - L81 were not covered by tests
else:
break

Check warning on line 83 in application/utils.py

View check run for this annotation

Codecov / codecov/patch

application/utils.py#L83

Added line #L83 was not covered by tests

return trimmed_history

Check warning on line 85 in application/utils.py

View check run for this annotation

Codecov / codecov/patch

application/utils.py#L85

Added line #L85 was not covered by tests
2 changes: 1 addition & 1 deletion frontend/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading