Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adding DatastoreOnlineStore 'database' argument. #4180

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions protos/feast/core/DatastoreTable.proto
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,7 @@ message DatastoreTable {

// Datastore namespace
google.protobuf.StringValue namespace = 4;

// Firestore database
google.protobuf.StringValue database = 5;
}
30 changes: 24 additions & 6 deletions sdk/python/feast/infra/online_stores/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ class DatastoreOnlineStoreConfig(FeastConfigBaseModel):
namespace: Optional[StrictStr] = None
""" (optional) Datastore namespace """

database: Optional[StrictStr] = None
""" (optional) Firestore database """

write_concurrency: Optional[PositiveInt] = 40
""" (optional) Amount of threads to use when writing batches of feature rows into Datastore"""

Expand Down Expand Up @@ -155,7 +158,9 @@ def teardown(
def _get_client(self, online_config: DatastoreOnlineStoreConfig):
if not self._client:
self._client = _initialize_client(
online_config.project_id, online_config.namespace
online_config.project_id,
online_config.namespace,
online_config.database,
)
return self._client

Expand Down Expand Up @@ -344,11 +349,14 @@ def worker(shared_counter):


def _initialize_client(
project_id: Optional[str], namespace: Optional[str]
project_id: Optional[str], namespace: Optional[str], database: Optional[str]
) -> datastore.Client:
try:
client = datastore.Client(
project=project_id, namespace=namespace, client_info=get_http_client_info()
project=project_id,
namespace=namespace,
database=database,
client_info=get_http_client_info(),
)
return client
except DefaultCredentialsError as e:
Expand All @@ -368,23 +376,27 @@ class DatastoreTable(InfraObject):
name: The name of the table.
project_id (optional): The GCP project id.
namespace (optional): Datastore namespace.
database (optional): Firestore database.
"""

project: str
project_id: Optional[str]
namespace: Optional[str]
database: Optional[str]

def __init__(
self,
project: str,
name: str,
project_id: Optional[str] = None,
namespace: Optional[str] = None,
database: Optional[str] = None,
):
super().__init__(name)
self.project = project
self.project_id = project_id
self.namespace = namespace
self.database = database

def to_infra_object_proto(self) -> InfraObjectProto:
datastore_table_proto = self.to_proto()
Expand All @@ -401,6 +413,8 @@ def to_proto(self) -> Any:
datastore_table_proto.project_id.value = self.project_id
if self.namespace:
datastore_table_proto.namespace.value = self.namespace
if self.database:
datastore_table_proto.database.value = self.database
return datastore_table_proto

@staticmethod
Expand All @@ -410,7 +424,7 @@ def from_infra_object_proto(infra_object_proto: InfraObjectProto) -> Any:
name=infra_object_proto.datastore_table.name,
)

# Distinguish between null and empty string, since project_id and namespace are StringValues.
# Distinguish between null and empty string, since project_id, namespace and database are StringValues.
if infra_object_proto.datastore_table.HasField("project_id"):
datastore_table.project_id = (
infra_object_proto.datastore_table.project_id.value
Expand All @@ -419,6 +433,8 @@ def from_infra_object_proto(infra_object_proto: InfraObjectProto) -> Any:
datastore_table.namespace = (
infra_object_proto.datastore_table.namespace.value
)
if infra_object_proto.datastore_table.HasField("database"):
datastore_table.database = infra_object_proto.datastore_table.database.value

return datastore_table

Expand All @@ -434,11 +450,13 @@ def from_proto(datastore_table_proto: DatastoreTableProto) -> Any:
datastore_table.project_id = datastore_table_proto.project_id.value
if datastore_table_proto.HasField("namespace"):
datastore_table.namespace = datastore_table_proto.namespace.value
if datastore_table_proto.HasField("database"):
datastore_table.database = datastore_table_proto.database.value

return datastore_table

def update(self):
client = _initialize_client(self.project_id, self.namespace)
client = _initialize_client(self.project_id, self.namespace, self.database)
key = client.key("Project", self.project, "Table", self.name)
entity = datastore.Entity(
key=key, exclude_from_indexes=("created_ts", "event_ts", "values")
Expand All @@ -447,7 +465,7 @@ def update(self):
client.put(entity)

def teardown(self):
client = _initialize_client(self.project_id, self.namespace)
client = _initialize_client(self.project_id, self.namespace, self.database)
key = client.key("Project", self.project, "Table", self.name)
_delete_all_values(client, key)

Expand Down
17 changes: 14 additions & 3 deletions sdk/python/tests/unit/diff/test_infra_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,14 @@ def test_tag_infra_proto_objects_for_keep_delete_add():

def test_diff_between_datastore_tables():
pre_changed = DatastoreTable(
project="test", name="table", project_id="pre", namespace="pre"
project="test", name="table", project_id="pre", namespace="pre", database="pre"
).to_proto()
post_changed = DatastoreTable(
project="test", name="table", project_id="post", namespace="post"
project="test",
name="table",
project_id="post",
namespace="post",
database="post",
).to_proto()

infra_object_diff = diff_between(pre_changed, pre_changed, "datastore table")
Expand All @@ -51,7 +55,7 @@ def test_diff_between_datastore_tables():

infra_object_diff = diff_between(pre_changed, post_changed, "datastore table")
infra_object_property_diffs = infra_object_diff.infra_object_property_diffs
assert len(infra_object_property_diffs) == 2
assert len(infra_object_property_diffs) == 3

assert infra_object_property_diffs[0].property_name == "project_id"
assert infra_object_property_diffs[0].val_existing == wrappers.StringValue(
Expand All @@ -67,6 +71,13 @@ def test_diff_between_datastore_tables():
assert infra_object_property_diffs[1].val_declared == wrappers.StringValue(
value="post"
)
assert infra_object_property_diffs[2].property_name == "database"
assert infra_object_property_diffs[2].val_existing == wrappers.StringValue(
value="pre"
)
assert infra_object_property_diffs[2].val_declared == wrappers.StringValue(
value="post"
)


def test_diff_infra_protos():
Expand Down
Loading