Skip to content

Commit

Permalink
[Python] Check feature store existence at pipeline construction time (#…
Browse files Browse the repository at this point in the history
…30668)

* check feature store existence at construction time

* postcommit
  • Loading branch information
riteshghorse authored Mar 19, 2024
1 parent c3b3fa6 commit 50f33cd
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 24 deletions.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,32 @@ def __init__(
else:
self.kwargs['client_options'] = {"api_endpoint": self.api_endpoint}

# check if the feature store exists
try:
admin_client = aiplatform.gapic.FeatureOnlineStoreAdminServiceClient(
**self.kwargs)
except Exception:
_LOGGER.warning(
'Due to insufficient admin permission, could not verify '
'the existence of feature store. If the `exception_level` '
'is set to WARN then make sure the feature store exists '
'otherwise the data enrichment will not happen without '
'throwing an error.')
else:
location_path = admin_client.common_location_path(
project=self.project, location=self.location)
feature_store_path = admin_client.feature_online_store_path(
project=self.project,
location=self.location,
feature_online_store=self.feature_store_name)
feature_store = admin_client.get_feature_online_store(
name=feature_store_path)

if not feature_store:
raise NotFound(
'Vertex AI Feature Store %s does not exists in %s' %
(self.feature_store_name, location_path))

def __enter__(self):
"""Connect with the Vertex AI Feature Store."""
self.client = aiplatform.gapic.FeatureOnlineStoreServiceClient(
Expand Down Expand Up @@ -228,26 +254,25 @@ def __init__(
else:
self.kwargs['client_options'] = {"api_endpoint": self.api_endpoint}

def __enter__(self):
"""Connect with the Vertex AI Feature Store (Legacy)."""
# checks if feature store exists
try:
# checks if feature store exists
_ = aiplatform.Featurestore(
featurestore_name=self.feature_store_id,
project=self.project,
location=self.location,
credentials=self.kwargs.get('credentials'),
)
self.client = aiplatform.gapic.FeaturestoreOnlineServingServiceClient(
**self.kwargs)
self.entity_type_path = self.client.entity_type_path(
self.project,
self.location,
self.feature_store_id,
self.entity_type_id)
except NotFound:
raise ValueError(
'Vertex AI Feature Store %s does not exist' % self.feature_store_id)
raise NotFound(
'Vertex AI Feature Store (Legacy) %s does not exist' %
self.feature_store_id)

def __enter__(self):
"""Connect with the Vertex AI Feature Store (Legacy)."""
self.client = aiplatform.gapic.FeaturestoreOnlineServingServiceClient(
**self.kwargs)
self.entity_type_path = self.client.entity_type_path(
self.project, self.location, self.feature_store_id, self.entity_type_id)

def __call__(self, request: beam.Row, *args, **kwargs):
"""Fetches feature value for an entity-id from
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

# pylint: disable=ungrouped-imports
try:
from google.api_core.exceptions import NotFound
from testcontainers.redis import RedisContainer
from apache_beam.transforms.enrichment import Enrichment
from apache_beam.transforms.enrichment_handlers.utils import ExceptionLevel
Expand Down Expand Up @@ -71,7 +72,7 @@ def setUp(self) -> None:
self.entity_type_name = "entity_id"
self.api_endpoint = "us-central1-aiplatform.googleapis.com"
self.feature_ids = ['title', 'genres']

self.retries = 3
self._start_container()

def _start_container(self):
Expand Down Expand Up @@ -124,6 +125,26 @@ def test_vertex_ai_feature_store_bigtable_serving_enrichment(self):
| Enrichment(handler)
| beam.ParDo(ValidateResponse(expected_fields)))

def test_vertex_ai_feature_store_wrong_name(self):
requests = [
beam.Row(entity_id="847", name='cardigan jacket'),
beam.Row(entity_id="16050", name='stripe t-shirt'),
]

with self.assertRaises(NotFound):
handler = VertexAIFeatureStoreEnrichmentHandler(
project=self.project,
location=self.location,
api_endpoint=self.api_endpoint,
feature_store_name="incorrect_name",
feature_view_name=self.feature_view_name,
row_key=self.entity_type_name,
)
test_pipeline = beam.Pipeline()
_ = (test_pipeline | beam.Create(requests) | Enrichment(handler))
res = test_pipeline.run()
res.wait_until_finish()

def test_vertex_ai_feature_store_bigtable_serving_enrichment_bad(self):
requests = [
beam.Row(entity_id="ui", name="fred perry men\'s sharp stripe t-shirt")
Expand Down Expand Up @@ -203,18 +224,18 @@ def test_vertex_ai_legacy_feature_store_invalid_featurestore(self):
]
feature_store_id = "invalid_name"
entity_type_id = "movies"
handler = VertexAIFeatureStoreLegacyEnrichmentHandler(
project=self.project,
location=self.location,
api_endpoint=self.api_endpoint,
feature_store_id=feature_store_id,
entity_type_id=entity_type_id,
feature_ids=self.feature_ids,
row_key=self.entity_type_name,
exception_level=ExceptionLevel.RAISE,
)

with self.assertRaises(ValueError):
with self.assertRaises(NotFound):
handler = VertexAIFeatureStoreLegacyEnrichmentHandler(
project=self.project,
location=self.location,
api_endpoint=self.api_endpoint,
feature_store_id=feature_store_id,
entity_type_id=entity_type_id,
feature_ids=self.feature_ids,
row_key=self.entity_type_name,
exception_level=ExceptionLevel.RAISE,
)
test_pipeline = beam.Pipeline()
_ = (
test_pipeline
Expand Down

0 comments on commit 50f33cd

Please sign in to comment.