Skip to content

Commit

Permalink
feat: Adding Feature Store Vector DB option for RAG corpuses to SDK
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 673571692
  • Loading branch information
speedstorm1 authored and copybara-github committed Sep 11, 2024
1 parent 73490b2 commit cfc3421
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 7 deletions.
21 changes: 21 additions & 0 deletions tests/unit/vertex_rag/test_rag_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
JiraSource,
JiraQuery,
Weaviate,
VertexFeatureStore,
)
from google.cloud.aiplatform_v1beta1 import (
GoogleDriveSource,
Expand Down Expand Up @@ -68,6 +69,7 @@
collection_name=TEST_WEAVIATE_COLLECTION_NAME,
api_key=TEST_WEAVIATE_API_KEY_SECRET_VERSION,
)
TEST_VERTEX_FEATURE_STORE_RESOURCE_NAME = "test-feature-view-resource-name"
TEST_GAPIC_RAG_CORPUS = GapicRagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
Expand All @@ -94,9 +96,22 @@
),
),
)
TEST_GAPIC_RAG_CORPUS_VERTEX_FEATURE_STORE = GapicRagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
description=TEST_CORPUS_DISCRIPTION,
rag_vector_db_config=RagVectorDbConfig(
vertex_feature_store=RagVectorDbConfig.VertexFeatureStore(
feature_view_resource_name=TEST_VERTEX_FEATURE_STORE_RESOURCE_NAME
),
),
)
TEST_EMBEDDING_MODEL_CONFIG = EmbeddingModelConfig(
publisher_model="publishers/google/models/textembedding-gecko",
)
TEST_VERTEX_FEATURE_STORE_CONFIG = VertexFeatureStore(
resource_name=TEST_VERTEX_FEATURE_STORE_RESOURCE_NAME,
)
TEST_RAG_CORPUS = RagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
Expand All @@ -109,6 +124,12 @@
description=TEST_CORPUS_DISCRIPTION,
vector_db=TEST_WEAVIATE_CONFIG,
)
TEST_RAG_CORPUS_VERTEX_FEATURE_STORE = RagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
description=TEST_CORPUS_DISCRIPTION,
vector_db=TEST_VERTEX_FEATURE_STORE_CONFIG,
)
TEST_PAGE_TOKEN = "test-page-token"

# RagFiles
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/vertex_rag/test_rag_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,23 @@ def create_rag_corpus_mock_weaviate():
yield create_rag_corpus_mock_weaviate


@pytest.fixture
def create_rag_corpus_mock_vertex_feature_store():
with mock.patch.object(
VertexRagDataServiceClient,
"create_rag_corpus",
) as create_rag_corpus_mock_vertex_feature_store:
create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
create_rag_corpus_lro_mock.done.return_value = True
create_rag_corpus_lro_mock.result.return_value = (
tc.TEST_GAPIC_RAG_CORPUS_VERTEX_FEATURE_STORE
)
create_rag_corpus_mock_vertex_feature_store.return_value = (
create_rag_corpus_lro_mock
)
yield create_rag_corpus_mock_vertex_feature_store


@pytest.fixture
def list_rag_corpora_pager_mock():
with mock.patch.object(
Expand Down Expand Up @@ -216,6 +233,15 @@ def test_create_corpus_weaviate_success(self):

rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS_WEAVIATE)

@pytest.mark.usefixtures("create_rag_corpus_mock_vertex_feature_store")
def test_create_corpus_vertex_feature_store_success(self):
rag_corpus = rag.create_corpus(
display_name=tc.TEST_CORPUS_DISPLAY_NAME,
vector_db=tc.TEST_VERTEX_FEATURE_STORE_CONFIG,
)

rag_corpus_eq(rag_corpus, tc.TEST_RAG_CORPUS_VERTEX_FEATURE_STORE)

@pytest.mark.usefixtures("rag_data_client_mock_exception")
def test_create_corpus_failure(self):
with pytest.raises(RuntimeError) as e:
Expand Down
2 changes: 2 additions & 0 deletions vertexai/preview/rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
RagResource,
SlackChannel,
SlackChannelsSource,
VertexFeatureStore,
Weaviate,
)

Expand All @@ -59,6 +60,7 @@
"Retrieval",
"SlackChannel",
"SlackChannelsSource",
"VertexFeatureStore",
"VertexRagStore",
"Weaviate",
"create_corpus",
Expand Down
3 changes: 2 additions & 1 deletion vertexai/preview/rag/rag_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
RagCorpus,
RagFile,
SlackChannelsSource,
VertexFeatureStore,
Weaviate,
)

Expand All @@ -56,7 +57,7 @@ def create_corpus(
display_name: Optional[str] = None,
description: Optional[str] = None,
embedding_model_config: Optional[EmbeddingModelConfig] = None,
vector_db: Optional[Weaviate] = None,
vector_db: Optional[Union[Weaviate, VertexFeatureStore]] = None,
) -> RagCorpus:
"""Creates a new RagCorpus resource.
Expand Down
21 changes: 17 additions & 4 deletions vertexai/preview/rag/utils/_gapic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
RagFile,
SlackChannelsSource,
JiraSource,
VertexFeatureStore,
Weaviate,
)

Expand Down Expand Up @@ -97,14 +98,18 @@ def convert_gapic_to_embedding_model_config(

def convert_gapic_to_vector_db(
gapic_vector_db: RagVectorDbConfig,
) -> Weaviate:
"""Convert Gapic RagVectorDbConfig to Weaviate."""
) -> Union[Weaviate, VertexFeatureStore]:
"""Convert Gapic RagVectorDbConfig to Weaviate or VertexFeatureStore."""
if gapic_vector_db.__contains__("weaviate"):
return Weaviate(
weaviate_http_endpoint=gapic_vector_db.weaviate.http_endpoint,
collection_name=gapic_vector_db.weaviate.collection_name,
api_key=gapic_vector_db.api_auth.api_key_config.api_key_secret_version,
)
elif gapic_vector_db.__contains__("vertex_feature_store"):
return VertexFeatureStore(
resource_name=gapic_vector_db.vertex_feature_store.feature_view_resource_name,
)
else:
return None

Expand Down Expand Up @@ -390,7 +395,7 @@ def set_embedding_model_config(


def set_vector_db(
vector_db: Weaviate,
vector_db: Union[Weaviate, VertexFeatureStore],
rag_corpus: GapicRagCorpus,
) -> None:
"""Sets the vector db configuration for the rag corpus."""
Expand All @@ -410,5 +415,13 @@ def set_vector_db(
),
),
)
elif isinstance(vector_db, VertexFeatureStore):
resource_name = vector_db.resource_name

rag_corpus.rag_vector_db_config = RagVectorDbConfig(
vertex_feature_store=RagVectorDbConfig.VertexFeatureStore(
feature_view_resource_name=resource_name,
),
)
else:
raise TypeError("vector_db must be a Weaviate.")
raise TypeError("vector_db must be a Weaviate or VertexFeatureStore.")
17 changes: 15 additions & 2 deletions vertexai/preview/rag/utils/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#

import dataclasses
from typing import List, Optional, Sequence
from typing import List, Optional, Sequence, Union

from google.protobuf import timestamp_pb2

Expand Down Expand Up @@ -85,6 +85,19 @@ class Weaviate:
api_key: str


@dataclasses.dataclass
class VertexFeatureStore:
"""VertexFeatureStore.
Attributes:
resource_name: The resource name of the FeatureView. Format:
``projects/{project}/locations/{location}/featureOnlineStores/
{feature_online_store}/featureViews/{feature_view}``
"""

resource_name: str


@dataclasses.dataclass
class RagCorpus:
"""RAG corpus(output only).
Expand All @@ -102,7 +115,7 @@ class RagCorpus:
display_name: Optional[str] = None
description: Optional[str] = None
embedding_model_config: Optional[EmbeddingModelConfig] = None
vector_db: Optional[Weaviate] = None
vector_db: Optional[Union[Weaviate, VertexFeatureStore]] = None


@dataclasses.dataclass
Expand Down

0 comments on commit cfc3421

Please sign in to comment.