-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add OpenAIEmbeddingEncoder to EmbeddingRetriever #3356
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks already pretty good, left a few comments with some possible improvements.
test/nodes/test_retriever.py
Outdated
if isinstance(document_store, WeaviateDocumentStore): | ||
# Weaviate sets the embedding dimension to 768 as soon as it is initialized. | ||
# We need 1024 here and therefore initialize a new WeaviateDocumentStore. | ||
document_store = WeaviateDocumentStore(index="haystack_test", embedding_dim=1024, recreate_index=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is not needed as we specify to use only InMemoryDocumentStore
in the test parameters.
test/nodes/test_retriever.py
Outdated
if isinstance(document_store, WeaviateDocumentStore): | ||
# Weaviate sets the embedding dimension to 768 as soon as it is initialized. | ||
# We need 1024 here and therefore initialize a new WeaviateDocumentStore. | ||
document_store = WeaviateDocumentStore(index="haystack_test", embedding_dim=1024, recreate_index=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
self.doc_model_encoder_engine = f"text-search-{model_class}-doc-001" | ||
self.tokenizer = AutoTokenizer.from_pretrained("gpt2") | ||
|
||
def ensure_texts_limit(self, text: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know we are inside a private class here, still I'd make this method private as it's not supposed to be used outside of that class I guess.
tokenized_payload = self.tokenizer(text) | ||
return self.tokenizer.decode(tokenized_payload["input_ids"][: self.max_seq_len]) | ||
|
||
def embed(self, model, text: str) -> np.ndarray: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add a type for the model
argument.
for doc in docs: | ||
embedding = self.embed(self.doc_model_encoder_engine, doc.content) | ||
embeddings.append(embedding) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to OpenAI Documentation, we can get embeddings for multiple inputs in a single request. My guess is that this would probably be a bit more efficient than doing one request for each Document.
Also, we should probably take care of OpenAI's rate limit given that we usually create embeddings for a large number of Documents. Timo worked in #3078 on a solution for the OpenAIAnswerGenerator
(this PR got unfortunately stale). Other than that, you might also want to take a look at this notebook by OpenAI on best practices for rate limit handling.
haystack/nodes/retriever/dense.py
Outdated
@@ -1541,6 +1543,7 @@ def __init__( | |||
This approach is also used in the TableTextRetriever paper and is likely to improve | |||
performance if your titles contain meaningful information for retrieval | |||
(topic, entities etc.). | |||
:param api_key: The OpenAI API key |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The doc string should explain that the OpenAI API key is only needed when we use a model by OpenAI and maybe link to the OpenAI page where the user can sign up for an API key.
self.max_seq_len = retriever.max_seq_len | ||
self.url = "https://api.openai.com/v1/embeddings" | ||
self.api_key = retriever.api_key | ||
model_class: str = next( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just model_class = retriever.embedding_model
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, good point but I wanted to handle the case when users accidentally specify the full name of the model. Some might specify "ada", "babbage" etc and some might specify the full name. This way we handle properly both use cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, I'm just wondering, what if the user want to use text-similarity-ada-001
model for example. In this case, we would silently use text-search-ada-doc-001
/ text-search-ada-query-001
without the user knowing that.
We should also probably adapt the docstring of the param embedding_model
of the EmbeddingRetriever
, what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bogdankostic I thought about it too, but that should not happen as the use case does not match. See https://beta.openai.com/docs/guides/embeddings/similarity-embeddings and https://beta.openai.com/docs/guides/embeddings/text-search-embeddings for recommended use-cases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our use case is definitely Text search embeddings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost good to go, just proposed some minor improvements.
self.max_seq_len = retriever.max_seq_len | ||
self.url = "https://api.openai.com/v1/embeddings" | ||
self.api_key = retriever.api_key | ||
model_class: str = next( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, I'm just wondering, what if the user want to use text-similarity-ada-001
model for example. In this case, we would silently use text-search-ada-doc-001
/ text-search-ada-query-001
without the user knowing that.
We should also probably adapt the docstring of the param embedding_model
of the EmbeddingRetriever
, what do you think?
batch_limited = [] | ||
batch = text[i : i + self.batch_size] | ||
for content in batch: | ||
batch_limited.append(self._ensure_text_limit(content)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make use of list comprehension here:
batch_limited = [] | |
batch = text[i : i + self.batch_size] | |
for content in batch: | |
batch_limited.append(self._ensure_text_limit(content)) | |
batch = text[i : i + self.batch_size] | |
batch_limited = [self._ensure_text_limit(content) for content in batch] |
self.doc_model_encoder_engine = f"text-search-{model_class}-doc-001" | ||
self.tokenizer = AutoTokenizer.from_pretrained("gpt2") | ||
|
||
def _ensure_text_limit(self, text: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add the return type here.
haystack/nodes/retriever/dense.py
Outdated
:param api_key: The OpenAI API key. Required if one wants to use OpenAI embeddings. For more | ||
details see https://beta.openai.com/account/api-keys for more details |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"for more details" is doubled here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Related Issues
Proposed Changes:
Added OpenAIEmbeddingEncoder as a method to create document and query embeddings.
How did you test it?
Added a unit test, needs to inject OpenAI API key in unit tests (as a secret)
Notes for the reviewer
LMK if anything is unclear
Checklist