Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(bug-fix) : fix VertexAI missing configurations #1926

Merged
merged 10 commits into from
Oct 3, 2024
3 changes: 2 additions & 1 deletion docs/mint.json
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
"components/embedders/models/azure_openai",
"components/embedders/models/ollama",
"components/embedders/models/huggingface",
"components/embedders/models/vertexai",
"components/embedders/models/gemini"
]
}
Expand Down Expand Up @@ -231,4 +232,4 @@
"apiHost": "https://us.i.posthog.com"
}
}
}
}
5 changes: 5 additions & 0 deletions mem0/configs/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def __init__(
# AzureOpenAI specific
azure_kwargs: Optional[AzureConfig] = {},
http_client_proxies: Optional[Union[Dict, str]] = None,
# VertexAI specific
vertex_credentials_json: Optional[str] = None,
):
"""
Initializes a configuration class instance for the Embeddings.
Expand Down Expand Up @@ -63,3 +65,6 @@ def __init__(

# AzureOpenAI specific
self.azure_kwargs = AzureConfig(**azure_kwargs) or {}

# VertexAI specific
self.vertex_credentials_json = vertex_credentials_json
2 changes: 1 addition & 1 deletion mem0/embeddings/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mem0.embeddings.base import EmbeddingBase


class VertexAI(EmbeddingBase):
class VertexAIEmbedding(EmbeddingBase):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)

Expand Down
1 change: 1 addition & 0 deletions mem0/utils/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class EmbedderFactory:
"huggingface": "mem0.embeddings.huggingface.HuggingFaceEmbedding",
"azure_openai": "mem0.embeddings.azure_openai.AzureOpenAIEmbedding",
"gemini": "mem0.embeddings.gemini.GoogleGenAIEmbedding",
"vertexai": "mem0.embeddings.vertexai.VertexAIEmbedding",
}

@classmethod
Expand Down
12 changes: 6 additions & 6 deletions tests/embeddings/test_vertexai_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from unittest.mock import Mock, patch
from mem0.embeddings.vertexai import VertexAI
from mem0.embeddings.vertexai import VertexAIEmbedding
from mem0.configs.embeddings.base import BaseEmbedderConfig


Expand Down Expand Up @@ -32,7 +32,7 @@ def test_embed_default_model(mock_text_embedding_model, mock_os_environ, mock_co
mock_config.return_value.embedding_dims = 256

config = mock_config()
embedder = VertexAI(config)
embedder = VertexAIEmbedding(config)

mock_embedding = Mock(values=[0.1, 0.2, 0.3])
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [
Expand All @@ -57,7 +57,7 @@ def test_embed_custom_model(mock_text_embedding_model, mock_os_environ, mock_con

config = mock_config()

embedder = VertexAI(config)
embedder = VertexAIEmbedding(config)

mock_embedding = Mock(values=[0.4, 0.5, 0.6])
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [
Expand All @@ -81,7 +81,7 @@ def test_credentials_from_environment(mock_os, mock_text_embedding_model, mock_c
mock_os.getenv.return_value = "/path/to/env/credentials.json"
mock_config.vertex_credentials_json = None
config = mock_config()
VertexAI(config)
VertexAIEmbedding(config)

mock_os.environ.setitem.assert_not_called()

Expand All @@ -96,7 +96,7 @@ def test_missing_credentials(mock_os, mock_text_embedding_model, mock_config):
with pytest.raises(
ValueError, match="Google application credentials JSON is not provided"
):
VertexAI(config)
VertexAIEmbedding(config)


@patch("mem0.embeddings.vertexai.TextEmbeddingModel")
Expand All @@ -107,7 +107,7 @@ def test_embed_with_different_dimensions(
mock_config.return_value.embedding_dims = 1024

config = mock_config()
embedder = VertexAI(config)
embedder = VertexAIEmbedding(config)

mock_embedding = Mock(values=[0.1] * 1024)
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [
Expand Down
Loading