diff --git a/protos/feast/core/DatastoreTable.proto b/protos/feast/core/DatastoreTable.proto index 4246a6ae6e..acd3ba57b5 100644 --- a/protos/feast/core/DatastoreTable.proto +++ b/protos/feast/core/DatastoreTable.proto @@ -36,4 +36,7 @@ message DatastoreTable { // Datastore namespace google.protobuf.StringValue namespace = 4; + + // Firestore database + google.protobuf.StringValue database = 5; } \ No newline at end of file diff --git a/sdk/python/feast/infra/online_stores/datastore.py b/sdk/python/feast/infra/online_stores/datastore.py index 149354b472..bf44a74966 100644 --- a/sdk/python/feast/infra/online_stores/datastore.py +++ b/sdk/python/feast/infra/online_stores/datastore.py @@ -80,6 +80,9 @@ class DatastoreOnlineStoreConfig(FeastConfigBaseModel): namespace: Optional[StrictStr] = None """ (optional) Datastore namespace """ + database: Optional[StrictStr] = None + """ (optional) Firestore database """ + write_concurrency: Optional[PositiveInt] = 40 """ (optional) Amount of threads to use when writing batches of feature rows into Datastore""" @@ -155,7 +158,9 @@ def teardown( def _get_client(self, online_config: DatastoreOnlineStoreConfig): if not self._client: self._client = _initialize_client( - online_config.project_id, online_config.namespace + online_config.project_id, + online_config.namespace, + online_config.database, ) return self._client @@ -344,11 +349,14 @@ def worker(shared_counter): def _initialize_client( - project_id: Optional[str], namespace: Optional[str] + project_id: Optional[str], namespace: Optional[str], database: Optional[str] ) -> datastore.Client: try: client = datastore.Client( - project=project_id, namespace=namespace, client_info=get_http_client_info() + project=project_id, + namespace=namespace, + database=database, + client_info=get_http_client_info(), ) return client except DefaultCredentialsError as e: @@ -368,11 +376,13 @@ class DatastoreTable(InfraObject): name: The name of the table. project_id (optional): The GCP project id. namespace (optional): Datastore namespace. + database (optional): Firestore database. """ project: str project_id: Optional[str] namespace: Optional[str] + database: Optional[str] def __init__( self, @@ -380,11 +390,13 @@ def __init__( name: str, project_id: Optional[str] = None, namespace: Optional[str] = None, + database: Optional[str] = None, ): super().__init__(name) self.project = project self.project_id = project_id self.namespace = namespace + self.database = database def to_infra_object_proto(self) -> InfraObjectProto: datastore_table_proto = self.to_proto() @@ -401,6 +413,8 @@ def to_proto(self) -> Any: datastore_table_proto.project_id.value = self.project_id if self.namespace: datastore_table_proto.namespace.value = self.namespace + if self.database: + datastore_table_proto.database.value = self.database return datastore_table_proto @staticmethod @@ -410,7 +424,7 @@ def from_infra_object_proto(infra_object_proto: InfraObjectProto) -> Any: name=infra_object_proto.datastore_table.name, ) - # Distinguish between null and empty string, since project_id and namespace are StringValues. + # Distinguish between null and empty string, since project_id, namespace and database are StringValues. if infra_object_proto.datastore_table.HasField("project_id"): datastore_table.project_id = ( infra_object_proto.datastore_table.project_id.value @@ -419,6 +433,8 @@ def from_infra_object_proto(infra_object_proto: InfraObjectProto) -> Any: datastore_table.namespace = ( infra_object_proto.datastore_table.namespace.value ) + if infra_object_proto.datastore_table.HasField("database"): + datastore_table.database = infra_object_proto.datastore_table.database.value return datastore_table @@ -434,11 +450,13 @@ def from_proto(datastore_table_proto: DatastoreTableProto) -> Any: datastore_table.project_id = datastore_table_proto.project_id.value if datastore_table_proto.HasField("namespace"): datastore_table.namespace = datastore_table_proto.namespace.value + if datastore_table_proto.HasField("database"): + datastore_table.database = datastore_table_proto.database.value return datastore_table def update(self): - client = _initialize_client(self.project_id, self.namespace) + client = _initialize_client(self.project_id, self.namespace, self.database) key = client.key("Project", self.project, "Table", self.name) entity = datastore.Entity( key=key, exclude_from_indexes=("created_ts", "event_ts", "values") @@ -447,7 +465,7 @@ def update(self): client.put(entity) def teardown(self): - client = _initialize_client(self.project_id, self.namespace) + client = _initialize_client(self.project_id, self.namespace, self.database) key = client.key("Project", self.project, "Table", self.name) _delete_all_values(client, key) diff --git a/sdk/python/tests/unit/diff/test_infra_diff.py b/sdk/python/tests/unit/diff/test_infra_diff.py index 8e3d5b765f..3a0443e634 100644 --- a/sdk/python/tests/unit/diff/test_infra_diff.py +++ b/sdk/python/tests/unit/diff/test_infra_diff.py @@ -39,10 +39,14 @@ def test_tag_infra_proto_objects_for_keep_delete_add(): def test_diff_between_datastore_tables(): pre_changed = DatastoreTable( - project="test", name="table", project_id="pre", namespace="pre" + project="test", name="table", project_id="pre", namespace="pre", database="pre" ).to_proto() post_changed = DatastoreTable( - project="test", name="table", project_id="post", namespace="post" + project="test", + name="table", + project_id="post", + namespace="post", + database="post", ).to_proto() infra_object_diff = diff_between(pre_changed, pre_changed, "datastore table") @@ -51,7 +55,7 @@ def test_diff_between_datastore_tables(): infra_object_diff = diff_between(pre_changed, post_changed, "datastore table") infra_object_property_diffs = infra_object_diff.infra_object_property_diffs - assert len(infra_object_property_diffs) == 2 + assert len(infra_object_property_diffs) == 3 assert infra_object_property_diffs[0].property_name == "project_id" assert infra_object_property_diffs[0].val_existing == wrappers.StringValue( @@ -67,6 +71,13 @@ def test_diff_between_datastore_tables(): assert infra_object_property_diffs[1].val_declared == wrappers.StringValue( value="post" ) + assert infra_object_property_diffs[2].property_name == "database" + assert infra_object_property_diffs[2].val_existing == wrappers.StringValue( + value="pre" + ) + assert infra_object_property_diffs[2].val_declared == wrappers.StringValue( + value="post" + ) def test_diff_infra_protos():