diff --git a/sdk/python/feast/infra/registry_stores/sql.py b/sdk/python/feast/infra/registry_stores/sql.py index 2d3ac9d683..a7380e141b 100644 --- a/sdk/python/feast/infra/registry_stores/sql.py +++ b/sdk/python/feast/infra/registry_stores/sql.py @@ -2,7 +2,7 @@ from datetime import datetime from enum import Enum from pathlib import Path -from typing import Any, List, Optional, Set, Union +from typing import Any, Callable, List, Optional, Set, Union from sqlalchemy import ( # type: ignore BigInteger, @@ -555,7 +555,7 @@ def update_infra(self, infra: Infra, project: str, commit: bool = True): ) def get_infra(self, project: str, allow_cache: bool = False) -> Infra: - return self._get_object( + infra_object = self._get_object( managed_infra, "infra_obj", project, @@ -565,6 +565,8 @@ def get_infra(self, project: str, allow_cache: bool = False) -> Infra: "infra_proto", None, ) + infra_object = infra_object or InfraProto() + return Infra.from_proto(infra_object) def apply_user_metadata( self, @@ -676,11 +678,18 @@ def commit(self): pass def _apply_object( - self, table, project: str, id_field_name, obj, proto_field_name, name=None + self, + table: Table, + project: str, + id_field_name, + obj, + proto_field_name, + name=None, ): self._maybe_init_project_metadata(project) - name = name or obj.name + name = name or obj.name if hasattr(obj, "name") else None + assert name, f"name needs to be provided for {obj}" with self.engine.connect() as conn: update_datetime = datetime.utcnow() update_time = int(update_datetime.timestamp()) @@ -738,7 +747,14 @@ def _maybe_init_project_metadata(self, project): conn.execute(insert_stmt) usage.set_current_project_uuid(new_project_uuid) - def _delete_object(self, table, name, project, id_field_name, not_found_exception): + def _delete_object( + self, + table: Table, + name: str, + project: str, + id_field_name: str, + not_found_exception: Optional[Callable], + ): with self.engine.connect() as conn: stmt = delete(table).where( getattr(table.c, id_field_name) == name, table.c.project_id == project @@ -752,14 +768,14 @@ def _delete_object(self, table, name, project, id_field_name, not_found_exceptio def _get_object( self, - table, - name, - project, - proto_class, - python_class, - id_field_name, - proto_field_name, - not_found_exception, + table: Table, + name: str, + project: str, + proto_class: Any, + python_class: Any, + id_field_name: str, + proto_field_name: str, + not_found_exception: Optional[Callable], ): self._maybe_init_project_metadata(project) @@ -771,10 +787,18 @@ def _get_object( if row: _proto = proto_class.FromString(row[proto_field_name]) return python_class.from_proto(_proto) - raise not_found_exception(name, project) + if not_found_exception: + raise not_found_exception(name, project) + else: + return None def _list_objects( - self, table, project, proto_class, python_class, proto_field_name + self, + table: Table, + project: str, + proto_class: Any, + python_class: Any, + proto_field_name: str, ): self._maybe_init_project_metadata(project) with self.engine.connect() as conn: