Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Merged with main
  • Loading branch information
SauravP97 committed Feb 3, 2025
2 parents dd70b31 + 3d62a21 commit 870ef21
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 17 deletions.
46 changes: 31 additions & 15 deletions libs/vertexai/langchain_google_vertexai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@
_MIN_BATCH_SIZE = 5


EmbeddingTaskTypes = Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
"CODE_RETRIEVAL_QUERY",
]


class GoogleEmbeddingModelType(str, Enum):
TEXT = auto()
MULTIMODAL = auto()
Expand All @@ -63,6 +75,7 @@ class GoogleEmbeddingModelVersion(str, Enum):
EMBEDDINGS_NOV_2023 = auto()
EMBEDDINGS_DEC_2023 = auto()
EMBEDDINGS_MAY_2024 = auto()
EMBEDDINGS_NOV_2024 = auto()

@classmethod
def _missing_(cls, value: Any) -> "GoogleEmbeddingModelVersion":
Expand All @@ -82,6 +95,8 @@ def _missing_(cls, value: Any) -> "GoogleEmbeddingModelVersion":
or "text-multilingual-embedding-preview-0409" in value.lower()
):
return GoogleEmbeddingModelVersion.EMBEDDINGS_MAY_2024
if "text-embedding-005" in value.lower():
return GoogleEmbeddingModelVersion.EMBEDDINGS_NOV_2024

return GoogleEmbeddingModelVersion.EMBEDDINGS_JUNE_2023

Expand Down Expand Up @@ -376,17 +391,7 @@ def embed(
self,
texts: List[str],
batch_size: int = 0,
embeddings_task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
]
] = None,
embeddings_task_type: Optional[EmbeddingTaskTypes] = None,
dimensions: Optional[int] = None,
) -> List[List[float]]:
"""Embed a list of strings.
Expand All @@ -406,6 +411,8 @@ def embed(
for Semantic Textual Similarity (STS).
CLASSIFICATION - Embeddings will be used for classification.
CLUSTERING - Embeddings will be used for clustering.
CODE_RETRIEVAL_QUERY - Embeddings will be used for
code retrieval for Java and Python.
The following are only supported on preview models:
QUESTION_ANSWERING
FACT_VERIFICATION
Expand Down Expand Up @@ -447,7 +454,11 @@ def embed(
return embeddings

def embed_documents(
self, texts: List[str], batch_size: int = 0
self,
texts: List[str],
batch_size: int = 0,
*,
embeddings_task_type: EmbeddingTaskTypes = "RETRIEVAL_DOCUMENT",
) -> List[List[float]]:
"""Embed a list of documents.
Expand All @@ -460,9 +471,14 @@ def embed_documents(
Returns:
List of embeddings, one for each text.
"""
return self.embed(texts, batch_size, "RETRIEVAL_DOCUMENT")
return self.embed(texts, batch_size, embeddings_task_type)

def embed_query(self, text: str) -> List[float]:
def embed_query(
self,
text: str,
*,
embeddings_task_type: EmbeddingTaskTypes = "RETRIEVAL_QUERY",
) -> List[float]:
"""Embed a text.
Args:
Expand All @@ -471,7 +487,7 @@ def embed_query(self, text: str) -> List[float]:
Returns:
Embedding for the text.
"""
return self.embed([text], 1, "RETRIEVAL_QUERY")[0]
return self.embed([text], 1, embeddings_task_type)[0]

@deprecated(
since="2.0.1", removal="3.0.0", alternative="VertexAIEmbeddings.embed_images()"
Expand Down
34 changes: 34 additions & 0 deletions libs/vertexai/tests/integration_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,25 @@ def test_langchain_google_vertexai_embedding_documents(
assert model.model_name == model_name


@pytest.mark.release
@pytest.mark.parametrize(
"model_name, embeddings_dim",
_EMBEDDING_MODELS,
)
def test_langchain_google_vertexai_embedding_documents_with_task_type(
model_name: str,
embeddings_dim: int,
) -> None:
documents = ["foo bar"] * 8
model = VertexAIEmbeddings(model_name)
output = model.embed_documents(documents)
assert len(output) == 8
for embedding in output:
assert len(embedding) == embeddings_dim
assert model.model_name == model.client._model_id
assert model.model_name == model_name


@pytest.mark.release
@pytest.mark.parametrize(
"model_name, embeddings_dim",
Expand All @@ -65,6 +84,21 @@ def test_langchain_google_vertexai_embedding_query(model_name, embeddings_dim) -
assert len(output) == embeddings_dim


@pytest.mark.release
@pytest.mark.parametrize(
"model_name, embeddings_dim",
_EMBEDDING_MODELS,
)
def test_langchain_google_vertexai_embedding_query_with_task_type(
model_name: str,
embeddings_dim: int,
) -> None:
document = "foo bar"
model = VertexAIEmbeddings(model_name)
output = model.embed_query(document)
assert len(output) == embeddings_dim


@pytest.mark.release
@pytest.mark.parametrize(
"dim, expected_dim",
Expand Down
50 changes: 48 additions & 2 deletions libs/vertexai/tests/unit_tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from typing import Any, Dict
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pytest
from pydantic import model_validator
from typing_extensions import Self

from langchain_google_vertexai import VertexAIEmbeddings
from langchain_google_vertexai.embeddings import GoogleEmbeddingModelType
from langchain_google_vertexai.embeddings import (
EmbeddingTaskTypes,
GoogleEmbeddingModelType,
)


def test_langchain_google_vertexai_embed_image_multimodal_only() -> None:
Expand All @@ -29,6 +32,49 @@ def test_langchain_google_vertexai_no_dups_dynamic_batch_size() -> None:
assert len(batches) == 2


@patch.object(VertexAIEmbeddings, "embed")
def test_embed_documents_with_question_answering_task(mock_embed) -> None:
mock_embeddings = MockVertexAIEmbeddings("text-embedding-005")
texts = [f"text {i}" for i in range(5)]

embedding_dimension = 768
embeddings_task_type: EmbeddingTaskTypes = "QUESTION_ANSWERING"

mock_embed.return_value = [[0.001] * embedding_dimension for _ in texts]

embeddings = mock_embeddings.embed_documents(
texts=texts, embeddings_task_type=embeddings_task_type
)

assert isinstance(embeddings, list)
assert len(embeddings) == len(texts)
assert len(embeddings[0]) == embedding_dimension

# Verify embed() was called correctly
mock_embed.assert_called_once_with(texts, 0, embeddings_task_type)


@patch.object(VertexAIEmbeddings, "embed")
def test_embed_query_with_question_answering_task(mock_embed) -> None:
mock_embeddings = MockVertexAIEmbeddings("text-embedding-005")
text = "text 0"

embedding_dimension = 768
embeddings_task_type: EmbeddingTaskTypes = "QUESTION_ANSWERING"

mock_embed.return_value = [[0.001] * embedding_dimension]

embedding = mock_embeddings.embed_query(
text=text, embeddings_task_type=embeddings_task_type
)

assert isinstance(embedding, list)
assert len(embedding) == embedding_dimension

# Verify embed() was called correctly
mock_embed.assert_called_once_with([text], 1, embeddings_task_type)


class MockVertexAIEmbeddings(VertexAIEmbeddings):
"""
A mock class for avoiding instantiating VertexAI and the EmbeddingModel client
Expand Down

0 comments on commit 870ef21

Please sign in to comment.