Skip to content

Commit

Permalink
Add ElasticsearchEmbeddings class for generating embeddings using Ela…
Browse files Browse the repository at this point in the history
…sticsearch models (#3401)

This PR introduces a new module, `elasticsearch_embeddings.py`, which
provides a wrapper around Elasticsearch embedding models. The new
ElasticsearchEmbeddings class allows users to generate embeddings for
documents and query texts using a [model deployed in an Elasticsearch
cluster](https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-model-ref.html#ml-nlp-model-ref-text-embedding).

### Main features:

1. The ElasticsearchEmbeddings class initializes with an Elasticsearch
connection object and a model_id, providing an interface to interact
with the Elasticsearch ML client through
[infer_trained_model](https://elasticsearch-py.readthedocs.io/en/v8.7.0/api.html?highlight=trained%20model%20infer#elasticsearch.client.MlClient.infer_trained_model)
.
2. The `embed_documents()` method generates embeddings for a list of
documents, and the `embed_query()` method generates an embedding for a
single query text.
3. The class supports custom input text field names in case the deployed
model expects a different field name than the default `text_field`.
4. The implementation is compatible with any model deployed in
Elasticsearch that generates embeddings as output.

### Benefits:

1. Simplifies the process of generating embeddings using Elasticsearch
models.
2. Provides a clean and intuitive interface to interact with the
Elasticsearch ML client.
3. Allows users to easily integrate Elasticsearch-generated embeddings.

Related issue #3400

---------

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
  • Loading branch information
2 people authored and vowelparrot committed May 24, 2023
1 parent fd99f3c commit 1fdd086
Show file tree
Hide file tree
Showing 4 changed files with 324 additions and 0 deletions.
137 changes: 137 additions & 0 deletions docs/modules/models/text_embedding/examples/elasticsearch.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"source": [
"!pip install elasticsearch langchain"
],
"metadata": {
"id": "OOiBBjc0Kd-6"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"%env ES_CLOUDID=<cloud id from cloud.elastic.co>\n",
"%env ES_USER=<user>\n",
"%env ES_PASS=<password>\n",
"\n",
"es_cloudid = os.environ.get(\"ES_CLOUDID\")\n",
"es_user = os.environ.get(\"ES_USER\")\n",
"es_pass = os.environ.get(\"ES_PASS\")"
],
"metadata": {
"id": "Wr8unljAKdCh"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Connect to Elasticsearch\n",
"es_connection = Elasticsearch(cloud_id=es_cloudid, basic_auth=(es_user, es_pass))"
],
"metadata": {
"id": "YIDsrBqTKs85"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Define the model ID and input field name (if different from default)\n",
"model_id = \"your_model_id\"\n",
"input_field = \"your_input_field\" # Optional, only if different from 'text_field'"
],
"metadata": {
"id": "sfFhnFHOKvbM"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Initialize the ElasticsearchEmbeddings instance\n",
"embeddings_generator = ElasticsearchEmbeddings(es_connection, model_id, input_field)"
],
"metadata": {
"id": "V-pCgqLCKvYs"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Generate embeddings for a list of documents\n",
"documents = [\n",
" \"This is an example document.\",\n",
" \"Another example document to generate embeddings for.\",\n",
" ]\n",
"document_embeddings = embeddings_generator.embed_documents(documents)"
],
"metadata": {
"id": "lJg2iRDWKvV_"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Print the generated document embeddings\n",
"for i, doc_embedding in enumerate(document_embeddings):\n",
" print(f\"Embedding for document {i + 1}: {doc_embedding}\")"
],
"metadata": {
"id": "R3sYQlh3KvTQ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Generate an embedding for a single query text\n",
"query_text = \"What is the meaning of life?\"\n",
"query_embedding = embeddings_generator.embed_query(query_text)"
],
"metadata": {
"id": "n0un5Vc0KvQd"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Print the generated query embedding\n",
"print(f\"Embedding for query: {query_embedding}\")"
],
"metadata": {
"id": "PANph6pmKvLD"
},
"execution_count": null,
"outputs": []
}
]
}
2 changes: 2 additions & 0 deletions langchain/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
AlephAlphaSymmetricSemanticEmbedding,
)
from langchain.embeddings.cohere import CohereEmbeddings
from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings
from langchain.embeddings.fake import FakeEmbeddings
from langchain.embeddings.google_palm import GooglePalmEmbeddings
from langchain.embeddings.huggingface import (
Expand All @@ -32,6 +33,7 @@
"OpenAIEmbeddings",
"HuggingFaceEmbeddings",
"CohereEmbeddings",
"ElasticsearchEmbeddings",
"JinaEmbeddings",
"LlamaCppEmbeddings",
"HuggingFaceHubEmbeddings",
Expand Down
155 changes: 155 additions & 0 deletions langchain/embeddings/elasticsearch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from __future__ import annotations

from typing import TYPE_CHECKING, List, Optional

from langchain.utils import get_from_env

if TYPE_CHECKING:
from elasticsearch.client import MlClient

from langchain.embeddings.base import Embeddings


class ElasticsearchEmbeddings(Embeddings):
"""
Wrapper around Elasticsearch embedding models.
This class provides an interface to generate embeddings using a model deployed
in an Elasticsearch cluster. It requires an Elasticsearch connection object
and the model_id of the model deployed in the cluster.
In Elasticsearch you need to have an embedding model loaded and deployed.
- https://www.elastic.co/guide/en/elasticsearch/reference/current/infer-trained-model.html
- https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html
""" # noqa: E501

def __init__(
self,
client: MlClient,
model_id: str,
*,
input_field: str = "text_field",
):
"""
Initialize the ElasticsearchEmbeddings instance.
Args:
client (MlClient): An Elasticsearch ML client object.
model_id (str): The model_id of the model deployed in the Elasticsearch
cluster.
input_field (str): The name of the key for the input text field in the
document. Defaults to 'text_field'.
"""
self.client = client
self.model_id = model_id
self.input_field = input_field

@classmethod
def from_credentials(
cls,
model_id: str,
*,
es_cloud_id: Optional[str] = None,
es_user: Optional[str] = None,
es_password: Optional[str] = None,
input_field: str = "text_field",
) -> ElasticsearchEmbeddings:
"""Instantiate embeddings from Elasticsearch credentials.
Args:
model_id (str): The model_id of the model deployed in the Elasticsearch
cluster.
input_field (str): The name of the key for the input text field in the
document. Defaults to 'text_field'.
es_cloud_id: (str, optional): The Elasticsearch cloud ID to connect to.
es_user: (str, optional): Elasticsearch username.
es_password: (str, optional): Elasticsearch password.
Example Usage:
from langchain.embeddings import ElasticsearchEmbeddings
# Define the model ID and input field name (if different from default)
model_id = "your_model_id"
# Optional, only if different from 'text_field'
input_field = "your_input_field"
# Credentials can be passed in two ways. Either set the env vars
# ES_CLOUD_ID, ES_USER, ES_PASSWORD and they will be automatically pulled
# in, or pass them in directly as kwargs.
embeddings = ElasticsearchEmbeddings.from_credentials(
model_id,
input_field=input_field,
# es_cloud_id="foo",
# es_user="bar",
# es_password="baz",
)
documents = [
"This is an example document.",
"Another example document to generate embeddings for.",
]
embeddings_generator.embed_documents(documents)
"""
try:
from elasticsearch import Elasticsearch
from elasticsearch.client import MlClient
except ImportError:
raise ImportError(
"elasticsearch package not found, please install with 'pip install "
"elasticsearch'"
)

es_cloud_id = es_cloud_id or get_from_env("es_cloud_id", "ES_CLOUD_ID")
es_user = es_user or get_from_env("es_user", "ES_USER")
es_password = es_password or get_from_env("es_password", "ES_PASSWORD")

# Connect to Elasticsearch
es_connection = Elasticsearch(
cloud_id=es_cloud_id, basic_auth=(es_user, es_password)
)
client = MlClient(es_connection)
return cls(client, model_id, input_field=input_field)

def _embedding_func(self, texts: List[str]) -> List[List[float]]:
"""
Generate embeddings for the given texts using the Elasticsearch model.
Args:
texts (List[str]): A list of text strings to generate embeddings for.
Returns:
List[List[float]]: A list of embeddings, one for each text in the input
list.
"""
response = self.client.infer_trained_model(
model_id=self.model_id, docs=[{self.input_field: text} for text in texts]
)

embeddings = [doc["predicted_value"] for doc in response["inference_results"]]
return embeddings

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
Generate embeddings for a list of documents.
Args:
texts (List[str]): A list of document text strings to generate embeddings
for.
Returns:
List[List[float]]: A list of embeddings, one for each document in the input
list.
"""
return self._embedding_func(texts)

def embed_query(self, text: str) -> List[float]:
"""
Generate an embedding for a single query text.
Args:
text (str): The query text to generate an embedding for.
Returns:
List[float]: The embedding for the input query text.
"""
return self._embedding_func([text])[0]
30 changes: 30 additions & 0 deletions tests/integration_tests/embeddings/test_elasticsearch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Test elasticsearch_embeddings embeddings."""

import pytest

from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings


@pytest.fixture
def model_id() -> str:
# Replace with your actual model_id
return "your_model_id"


def test_elasticsearch_embedding_documents(model_id: str) -> None:
"""Test Elasticsearch embedding documents."""
documents = ["foo bar", "bar foo", "foo"]
embedding = ElasticsearchEmbeddings.from_credentials(model_id)
output = embedding.embed_documents(documents)
assert len(output) == 3
assert len(output[0]) == 768 # Change 768 to the expected embedding size
assert len(output[1]) == 768 # Change 768 to the expected embedding size
assert len(output[2]) == 768 # Change 768 to the expected embedding size


def test_elasticsearch_embedding_query(model_id: str) -> None:
"""Test Elasticsearch embedding query."""
document = "foo bar"
embedding = ElasticsearchEmbeddings.from_credentials(model_id)
output = embedding.embed_query(document)
assert len(output) == 768 # Change 768 to the expected embedding size

0 comments on commit 1fdd086

Please sign in to comment.