-
Notifications
You must be signed in to change notification settings - Fork 15.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ElasticsearchEmbeddings class for generating embeddings using Ela…
…sticsearch models (#3401) This PR introduces a new module, `elasticsearch_embeddings.py`, which provides a wrapper around Elasticsearch embedding models. The new ElasticsearchEmbeddings class allows users to generate embeddings for documents and query texts using a [model deployed in an Elasticsearch cluster](https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-model-ref.html#ml-nlp-model-ref-text-embedding). ### Main features: 1. The ElasticsearchEmbeddings class initializes with an Elasticsearch connection object and a model_id, providing an interface to interact with the Elasticsearch ML client through [infer_trained_model](https://elasticsearch-py.readthedocs.io/en/v8.7.0/api.html?highlight=trained%20model%20infer#elasticsearch.client.MlClient.infer_trained_model) . 2. The `embed_documents()` method generates embeddings for a list of documents, and the `embed_query()` method generates an embedding for a single query text. 3. The class supports custom input text field names in case the deployed model expects a different field name than the default `text_field`. 4. The implementation is compatible with any model deployed in Elasticsearch that generates embeddings as output. ### Benefits: 1. Simplifies the process of generating embeddings using Elasticsearch models. 2. Provides a clean and intuitive interface to interact with the Elasticsearch ML client. 3. Allows users to easily integrate Elasticsearch-generated embeddings. Related issue #3400 --------- Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
- Loading branch information
1 parent
fd99f3c
commit 1fdd086
Showing
4 changed files
with
324 additions
and
0 deletions.
There are no files selected for viewing
137 changes: 137 additions & 0 deletions
137
docs/modules/models/text_embedding/examples/elasticsearch.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"colab": { | ||
"provenance": [] | ||
}, | ||
"kernelspec": { | ||
"name": "python3", | ||
"display_name": "Python 3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"!pip install elasticsearch langchain" | ||
], | ||
"metadata": { | ||
"id": "OOiBBjc0Kd-6" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"%env ES_CLOUDID=<cloud id from cloud.elastic.co>\n", | ||
"%env ES_USER=<user>\n", | ||
"%env ES_PASS=<password>\n", | ||
"\n", | ||
"es_cloudid = os.environ.get(\"ES_CLOUDID\")\n", | ||
"es_user = os.environ.get(\"ES_USER\")\n", | ||
"es_pass = os.environ.get(\"ES_PASS\")" | ||
], | ||
"metadata": { | ||
"id": "Wr8unljAKdCh" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Connect to Elasticsearch\n", | ||
"es_connection = Elasticsearch(cloud_id=es_cloudid, basic_auth=(es_user, es_pass))" | ||
], | ||
"metadata": { | ||
"id": "YIDsrBqTKs85" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Define the model ID and input field name (if different from default)\n", | ||
"model_id = \"your_model_id\"\n", | ||
"input_field = \"your_input_field\" # Optional, only if different from 'text_field'" | ||
], | ||
"metadata": { | ||
"id": "sfFhnFHOKvbM" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Initialize the ElasticsearchEmbeddings instance\n", | ||
"embeddings_generator = ElasticsearchEmbeddings(es_connection, model_id, input_field)" | ||
], | ||
"metadata": { | ||
"id": "V-pCgqLCKvYs" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Generate embeddings for a list of documents\n", | ||
"documents = [\n", | ||
" \"This is an example document.\",\n", | ||
" \"Another example document to generate embeddings for.\",\n", | ||
" ]\n", | ||
"document_embeddings = embeddings_generator.embed_documents(documents)" | ||
], | ||
"metadata": { | ||
"id": "lJg2iRDWKvV_" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Print the generated document embeddings\n", | ||
"for i, doc_embedding in enumerate(document_embeddings):\n", | ||
" print(f\"Embedding for document {i + 1}: {doc_embedding}\")" | ||
], | ||
"metadata": { | ||
"id": "R3sYQlh3KvTQ" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Generate an embedding for a single query text\n", | ||
"query_text = \"What is the meaning of life?\"\n", | ||
"query_embedding = embeddings_generator.embed_query(query_text)" | ||
], | ||
"metadata": { | ||
"id": "n0un5Vc0KvQd" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"# Print the generated query embedding\n", | ||
"print(f\"Embedding for query: {query_embedding}\")" | ||
], | ||
"metadata": { | ||
"id": "PANph6pmKvLD" | ||
}, | ||
"execution_count": null, | ||
"outputs": [] | ||
} | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, List, Optional | ||
|
||
from langchain.utils import get_from_env | ||
|
||
if TYPE_CHECKING: | ||
from elasticsearch.client import MlClient | ||
|
||
from langchain.embeddings.base import Embeddings | ||
|
||
|
||
class ElasticsearchEmbeddings(Embeddings): | ||
""" | ||
Wrapper around Elasticsearch embedding models. | ||
This class provides an interface to generate embeddings using a model deployed | ||
in an Elasticsearch cluster. It requires an Elasticsearch connection object | ||
and the model_id of the model deployed in the cluster. | ||
In Elasticsearch you need to have an embedding model loaded and deployed. | ||
- https://www.elastic.co/guide/en/elasticsearch/reference/current/infer-trained-model.html | ||
- https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html | ||
""" # noqa: E501 | ||
|
||
def __init__( | ||
self, | ||
client: MlClient, | ||
model_id: str, | ||
*, | ||
input_field: str = "text_field", | ||
): | ||
""" | ||
Initialize the ElasticsearchEmbeddings instance. | ||
Args: | ||
client (MlClient): An Elasticsearch ML client object. | ||
model_id (str): The model_id of the model deployed in the Elasticsearch | ||
cluster. | ||
input_field (str): The name of the key for the input text field in the | ||
document. Defaults to 'text_field'. | ||
""" | ||
self.client = client | ||
self.model_id = model_id | ||
self.input_field = input_field | ||
|
||
@classmethod | ||
def from_credentials( | ||
cls, | ||
model_id: str, | ||
*, | ||
es_cloud_id: Optional[str] = None, | ||
es_user: Optional[str] = None, | ||
es_password: Optional[str] = None, | ||
input_field: str = "text_field", | ||
) -> ElasticsearchEmbeddings: | ||
"""Instantiate embeddings from Elasticsearch credentials. | ||
Args: | ||
model_id (str): The model_id of the model deployed in the Elasticsearch | ||
cluster. | ||
input_field (str): The name of the key for the input text field in the | ||
document. Defaults to 'text_field'. | ||
es_cloud_id: (str, optional): The Elasticsearch cloud ID to connect to. | ||
es_user: (str, optional): Elasticsearch username. | ||
es_password: (str, optional): Elasticsearch password. | ||
Example Usage: | ||
from langchain.embeddings import ElasticsearchEmbeddings | ||
# Define the model ID and input field name (if different from default) | ||
model_id = "your_model_id" | ||
# Optional, only if different from 'text_field' | ||
input_field = "your_input_field" | ||
# Credentials can be passed in two ways. Either set the env vars | ||
# ES_CLOUD_ID, ES_USER, ES_PASSWORD and they will be automatically pulled | ||
# in, or pass them in directly as kwargs. | ||
embeddings = ElasticsearchEmbeddings.from_credentials( | ||
model_id, | ||
input_field=input_field, | ||
# es_cloud_id="foo", | ||
# es_user="bar", | ||
# es_password="baz", | ||
) | ||
documents = [ | ||
"This is an example document.", | ||
"Another example document to generate embeddings for.", | ||
] | ||
embeddings_generator.embed_documents(documents) | ||
""" | ||
try: | ||
from elasticsearch import Elasticsearch | ||
from elasticsearch.client import MlClient | ||
except ImportError: | ||
raise ImportError( | ||
"elasticsearch package not found, please install with 'pip install " | ||
"elasticsearch'" | ||
) | ||
|
||
es_cloud_id = es_cloud_id or get_from_env("es_cloud_id", "ES_CLOUD_ID") | ||
es_user = es_user or get_from_env("es_user", "ES_USER") | ||
es_password = es_password or get_from_env("es_password", "ES_PASSWORD") | ||
|
||
# Connect to Elasticsearch | ||
es_connection = Elasticsearch( | ||
cloud_id=es_cloud_id, basic_auth=(es_user, es_password) | ||
) | ||
client = MlClient(es_connection) | ||
return cls(client, model_id, input_field=input_field) | ||
|
||
def _embedding_func(self, texts: List[str]) -> List[List[float]]: | ||
""" | ||
Generate embeddings for the given texts using the Elasticsearch model. | ||
Args: | ||
texts (List[str]): A list of text strings to generate embeddings for. | ||
Returns: | ||
List[List[float]]: A list of embeddings, one for each text in the input | ||
list. | ||
""" | ||
response = self.client.infer_trained_model( | ||
model_id=self.model_id, docs=[{self.input_field: text} for text in texts] | ||
) | ||
|
||
embeddings = [doc["predicted_value"] for doc in response["inference_results"]] | ||
return embeddings | ||
|
||
def embed_documents(self, texts: List[str]) -> List[List[float]]: | ||
""" | ||
Generate embeddings for a list of documents. | ||
Args: | ||
texts (List[str]): A list of document text strings to generate embeddings | ||
for. | ||
Returns: | ||
List[List[float]]: A list of embeddings, one for each document in the input | ||
list. | ||
""" | ||
return self._embedding_func(texts) | ||
|
||
def embed_query(self, text: str) -> List[float]: | ||
""" | ||
Generate an embedding for a single query text. | ||
Args: | ||
text (str): The query text to generate an embedding for. | ||
Returns: | ||
List[float]: The embedding for the input query text. | ||
""" | ||
return self._embedding_func([text])[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
"""Test elasticsearch_embeddings embeddings.""" | ||
|
||
import pytest | ||
|
||
from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings | ||
|
||
|
||
@pytest.fixture | ||
def model_id() -> str: | ||
# Replace with your actual model_id | ||
return "your_model_id" | ||
|
||
|
||
def test_elasticsearch_embedding_documents(model_id: str) -> None: | ||
"""Test Elasticsearch embedding documents.""" | ||
documents = ["foo bar", "bar foo", "foo"] | ||
embedding = ElasticsearchEmbeddings.from_credentials(model_id) | ||
output = embedding.embed_documents(documents) | ||
assert len(output) == 3 | ||
assert len(output[0]) == 768 # Change 768 to the expected embedding size | ||
assert len(output[1]) == 768 # Change 768 to the expected embedding size | ||
assert len(output[2]) == 768 # Change 768 to the expected embedding size | ||
|
||
|
||
def test_elasticsearch_embedding_query(model_id: str) -> None: | ||
"""Test Elasticsearch embedding query.""" | ||
document = "foo bar" | ||
embedding = ElasticsearchEmbeddings.from_credentials(model_id) | ||
output = embedding.embed_query(document) | ||
assert len(output) == 768 # Change 768 to the expected embedding size |