Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Prevent going past token limit in OpenAI calls in PromptNode #4179

Merged
merged 45 commits into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
05780e1
Refactoring to remove duplicate code when using OpenAI API
sjrl Feb 16, 2023
7b425c4
Adding docstrings
sjrl Feb 16, 2023
f4a8e58
Fix mypy issue
sjrl Feb 16, 2023
5e851ab
Moved retry mechanism to openai_request function in openai_utils
sjrl Feb 16, 2023
6b8dbb0
Migrate OpenAI embedding encoder to use the openai_request util funct…
sjrl Feb 16, 2023
c274885
Adding docstrings.
sjrl Feb 16, 2023
4479f41
pylint import errors
sjrl Feb 16, 2023
34a7943
More pylint import errors
sjrl Feb 16, 2023
31f707a
Move construction of headers into openai_request and api_key as input…
sjrl Feb 16, 2023
feb1a24
Made _openai_text_completion_tokenization_details so can be resued in…
sjrl Feb 16, 2023
952214e
Add prompt truncation to the PromptNode.
sjrl Feb 16, 2023
68e18fe
Removed commented out test.
sjrl Feb 16, 2023
7453b9a
Bump version of tiktoken to 0.2.0 so we can use MODEL_TO_ENCODING to …
sjrl Feb 16, 2023
8e017b4
Merge branch 'main' of github.com:deepset-ai/haystack into openai/tok…
sjrl Feb 17, 2023
88444a0
Change one method back to public
sjrl Feb 17, 2023
2793ffb
Fixed bug in token length truncation. Included answer length into tru…
sjrl Feb 17, 2023
017acb9
Merge branch 'main' of github.com:deepset-ai/haystack into openai/tok…
sjrl Feb 17, 2023
8f728d9
Pylint error
sjrl Feb 17, 2023
a88bbc9
Merge branch 'main' of github.com:deepset-ai/haystack into openai/tok…
sjrl Feb 20, 2023
23c4b13
Merge branch 'main' of github.com:deepset-ai/haystack into openai/tok…
sjrl Feb 21, 2023
80f6aed
Improved warning message
sjrl Feb 21, 2023
4e3d1c7
Added _ensure_token_limit for HFLocalInvocationLayer. Had to remove m…
sjrl Feb 21, 2023
d9d17a3
Merge branch 'main' of github.com:deepset-ai/haystack into openai/tok…
sjrl Feb 21, 2023
6200b2e
Adding tests
sjrl Feb 21, 2023
6fc0fc2
Expanded on doc strings
sjrl Feb 21, 2023
5383f77
Updated tests
sjrl Feb 21, 2023
8238703
Update docstrings
sjrl Feb 21, 2023
86557e7
Update tests, and go back to how USE_TIKTOKEN was used before.
sjrl Feb 21, 2023
0d9f308
Update haystack/nodes/prompt/prompt_node.py
sjrl Feb 23, 2023
e736717
Update haystack/nodes/prompt/prompt_node.py
sjrl Feb 23, 2023
3af1b19
Update haystack/nodes/prompt/prompt_node.py
sjrl Feb 23, 2023
2b490b4
Update haystack/nodes/retriever/_openai_encoder.py
sjrl Feb 23, 2023
e44931f
Update haystack/utils/openai_utils.py
sjrl Feb 23, 2023
922d141
Update haystack/utils/openai_utils.py
sjrl Feb 23, 2023
1dd9485
Updated docstrings, and added integration marks
sjrl Feb 27, 2023
2261bf9
Merge branch 'main' of github.com:deepset-ai/haystack into openai/tok…
sjrl Feb 27, 2023
bf37ff2
Remove comment
sjrl Feb 27, 2023
ff51d35
Update test
sjrl Feb 28, 2023
aa80052
Fix test
sjrl Mar 1, 2023
8134d76
Update test
sjrl Mar 1, 2023
e47de9f
Merge branch 'main' of github.com:deepset-ai/haystack into openai/tok…
sjrl Mar 2, 2023
1185ce0
Merge branch 'main' of github.com:deepset-ai/haystack into openai/tok…
sjrl Mar 2, 2023
ae5d933
Updated openai_request function to work with the azure api
sjrl Mar 2, 2023
5a50ce5
Fixed error in _openai_encodery.py
sjrl Mar 2, 2023
3c8db04
Merge branch 'main' into openai/token_limit
vblagoje Mar 3, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 21 additions & 85 deletions haystack/nodes/answer_generator/openai.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,22 @@
import json
import logging
import os
import platform
import sys
from typing import List, Optional, Tuple, Union

import requests

from haystack import Document
from haystack.environment import (
HAYSTACK_REMOTE_API_BACKOFF_SEC,
HAYSTACK_REMOTE_API_MAX_RETRIES,
HAYSTACK_REMOTE_API_TIMEOUT_SEC,
)
from haystack.errors import OpenAIError, OpenAIRateLimitError
from haystack.environment import HAYSTACK_REMOTE_API_TIMEOUT_SEC
from haystack.nodes.answer_generator import BaseGenerator
from haystack.utils.reflection import retry_with_exponential_backoff
from haystack.nodes.prompt import PromptTemplate
from haystack.utils.openai_utils import (
load_openai_tokenizer,
openai_request,
count_openai_tokens,
_openai_text_completion_tokenization_details,
_check_openai_text_completion_answers,
)

logger = logging.getLogger(__name__)

machine = platform.machine().lower()
system = platform.system()

USE_TIKTOKEN = False
if sys.version_info >= (3, 8) and (machine in ["amd64", "x86_64"] or (machine == "arm64" and system == "Darwin")):
USE_TIKTOKEN = True

if USE_TIKTOKEN:
import tiktoken # pylint: disable=import-error
else:
logger.warning(
"OpenAI tiktoken module is not available for Python < 3.8,Linux ARM64 and AARCH64. Falling back to GPT2TokenizerFast."
)
from transformers import GPT2TokenizerFast, PreTrainedTokenizerFast


OPENAI_TIMEOUT = float(os.environ.get(HAYSTACK_REMOTE_API_TIMEOUT_SEC, 30))
OPENAI_BACKOFF = float(os.environ.get(HAYSTACK_REMOTE_API_BACKOFF_SEC, 10))
OPENAI_MAX_RETRIES = int(os.environ.get(HAYSTACK_REMOTE_API_MAX_RETRIES, 5))


class OpenAIAnswerGenerator(BaseGenerator):
Expand Down Expand Up @@ -182,27 +160,13 @@ def __init__(
self.stop_words = stop_words
self.prompt_template = prompt_template
self.context_join_str = context_join_str

tokenizer = "gpt2"
if "davinci" in self.model:
self.MAX_TOKENS_LIMIT = 4000
if self.model.endswith("-003") and USE_TIKTOKEN:
tokenizer = "cl100k_base"
else:
self.MAX_TOKENS_LIMIT = 2048

self.using_azure = self.azure_deployment_name is not None and self.azure_base_url is not None

if USE_TIKTOKEN:
logger.debug("Using tiktoken %s tokenizer", tokenizer)
self._tk_tokenizer: tiktoken.Encoding = tiktoken.get_encoding(tokenizer)
else:
logger.debug("Using GPT2TokenizerFast")
self._hf_tokenizer: PreTrainedTokenizerFast = GPT2TokenizerFast.from_pretrained(tokenizer)
tokenizer_name, max_tokens_limit = _openai_text_completion_tokenization_details(model_name=self.model)

self.MAX_TOKENS_LIMIT = max_tokens_limit
self._tokenizer = load_openai_tokenizer(tokenizer_name=tokenizer_name)

@retry_with_exponential_backoff(
backoff_in_seconds=OPENAI_BACKOFF, max_retries=OPENAI_MAX_RETRIES, errors=(OpenAIRateLimitError, OpenAIError)
)
def predict(
self,
query: str,
Expand Down Expand Up @@ -253,41 +217,19 @@ def predict(
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
}
url = "https://api.openai.com/v1/completions"
if self.using_azure:
url = f"{self.azure_base_url}/openai/deployments/{self.azure_deployment_name}/completions?api-version={self.api_version}"
else:
url = "https://api.openai.com/v1/completions"

headers = {"Content-Type": "application/json"}
if self.using_azure:
headers = {"api-key": self.api_key, **headers}
headers["api-key"] = self.api_key
else:
headers = {"Authorization": f"Bearer {self.api_key}", **headers}

response = requests.post(url, headers=headers, data=json.dumps(payload), timeout=timeout)
res = json.loads(response.text)

if response.status_code != 200 or "choices" not in res:
openai_error: OpenAIError
if response.status_code == 429:
openai_error = OpenAIRateLimitError(f"API rate limit exceeded: {response.text}")
else:
openai_error = OpenAIError(
f"OpenAI returned an error.\n"
f"Status code: {response.status_code}\n"
f"Response body: {response.text}",
status_code=response.status_code,
)
raise openai_error

number_of_truncated_answers = sum(1 for ans in res["choices"] if ans["finish_reason"] == "length")
if number_of_truncated_answers > 0:
logger.warning(
"%s out of the %s answers have been truncated before reaching a natural stopping point."
"Consider increasing the max_tokens parameter to allow for longer answers.",
number_of_truncated_answers,
top_k,
)
headers["Authorization"] = f"Bearer {self.api_key}"

res = openai_request(url=url, headers=headers, payload=payload, timeout=timeout)
_check_openai_text_completion_answers(result=res, payload=payload)
generated_answers = [ans["text"] for ans in res["choices"]]
answers = self._create_answers(generated_answers, input_docs)
result = {"query": query, "answers": answers}
Expand Down Expand Up @@ -323,7 +265,7 @@ def _build_prompt_within_max_length(self, query: str, documents: List[Document])
construct the context) are thrown away until the prompt length fits within the MAX_TOKENS_LIMIT.
"""
full_prompt = self._fill_prompt(query, documents)
n_full_prompt_tokens = self._count_tokens(full_prompt)
n_full_prompt_tokens = count_openai_tokens(text=full_prompt, tokenizer=self._tokenizer)

# for length restrictions of prompt see: https://platform.openai.com/docs/api-reference/completions/create#completions/create-max_tokens
leftover_token_len = self.MAX_TOKENS_LIMIT - n_full_prompt_tokens - self.max_tokens
Expand All @@ -333,7 +275,7 @@ def _build_prompt_within_max_length(self, query: str, documents: List[Document])
skipped_docs = 0
# If leftover_token_len is negative we have gone past the MAX_TOKENS_LIMIT and the prompt must be trimmed
if leftover_token_len < 0:
n_docs_tokens = [self._count_tokens(doc.content) for doc in documents]
n_docs_tokens = [count_openai_tokens(text=doc.content, tokenizer=self._tokenizer) for doc in documents]
sjrl marked this conversation as resolved.
Show resolved Hide resolved
logger.debug("Number of tokens in documents: %s", n_docs_tokens)

# Reversing the order of documents b/c we want to throw away less relevant docs first
Expand All @@ -349,7 +291,7 @@ def _build_prompt_within_max_length(self, query: str, documents: List[Document])
# Throw away least relevant docs
input_docs = documents[:-skipped_docs]
full_prompt = self._fill_prompt(query, input_docs)
n_full_prompt_tokens = self._count_tokens(full_prompt)
n_full_prompt_tokens = count_openai_tokens(text=full_prompt, tokenizer=self._tokenizer)

if len(input_docs) == 0:
logger.warning(
Expand All @@ -367,9 +309,3 @@ def _build_prompt_within_max_length(self, query: str, documents: List[Document])
logger.debug("Number of tokens in full prompt: %s", n_full_prompt_tokens)
logger.debug("Full prompt: %s", full_prompt)
return full_prompt, input_docs

def _count_tokens(self, text: str) -> int:
if USE_TIKTOKEN:
return len(self._tk_tokenizer.encode(text))
else:
return len(self._hf_tokenizer.tokenize(text))
Loading