diff --git a/sdk/python/feast/infra/registry/caching_registry.py b/sdk/python/feast/infra/registry/caching_registry.py index f7eab7d70a..298639028d 100644 --- a/sdk/python/feast/infra/registry/caching_registry.py +++ b/sdk/python/feast/infra/registry/caching_registry.py @@ -1,4 +1,6 @@ +import atexit import logging +import threading from abc import abstractmethod from datetime import timedelta from threading import Lock @@ -21,11 +23,7 @@ class CachingRegistry(BaseRegistry): - def __init__( - self, - project: str, - cache_ttl_seconds: int, - ): + def __init__(self, project: str, cache_ttl_seconds: int, cache_mode: str): self.cached_registry_proto = self.proto() proto_registry_utils.init_project_metadata(self.cached_registry_proto, project) self.cached_registry_proto_created = _utc_now() @@ -33,6 +31,10 @@ def __init__( self.cached_registry_proto_ttl = timedelta( seconds=cache_ttl_seconds if cache_ttl_seconds is not None else 0 ) + self.cache_mode = cache_mode + if cache_mode == "thread": + self._start_thread_async_refresh(cache_ttl_seconds) + atexit.register(self._exit_handler) @abstractmethod def _get_data_source(self, name: str, project: str) -> DataSource: @@ -322,22 +324,35 @@ def refresh(self, project: Optional[str] = None): self.cached_registry_proto_created = _utc_now() def _refresh_cached_registry_if_necessary(self): - with self._refresh_lock: - expired = ( - self.cached_registry_proto is None - or self.cached_registry_proto_created is None - ) or ( - self.cached_registry_proto_ttl.total_seconds() - > 0 # 0 ttl means infinity - and ( - _utc_now() - > ( - self.cached_registry_proto_created - + self.cached_registry_proto_ttl + if self.cache_mode == "sync": + with self._refresh_lock: + expired = ( + self.cached_registry_proto is None + or self.cached_registry_proto_created is None + ) or ( + self.cached_registry_proto_ttl.total_seconds() + > 0 # 0 ttl means infinity + and ( + _utc_now() + > ( + self.cached_registry_proto_created + + self.cached_registry_proto_ttl + ) ) ) - ) + if expired: + logger.info("Registry cache expired, so refreshing") + self.refresh() + + def _start_thread_async_refresh(self, cache_ttl_seconds): + self.refresh() + if cache_ttl_seconds <= 0: + return + self.registry_refresh_thread = threading.Timer( + cache_ttl_seconds, self._start_thread_async_refresh, [cache_ttl_seconds] + ) + self.registry_refresh_thread.setDaemon(True) + self.registry_refresh_thread.start() - if expired: - logger.info("Registry cache expired, so refreshing") - self.refresh() + def _exit_handler(self): + self.registry_refresh_thread.cancel() diff --git a/sdk/python/feast/infra/registry/sql.py b/sdk/python/feast/infra/registry/sql.py index 6ef08989b7..a2b16a3a09 100644 --- a/sdk/python/feast/infra/registry/sql.py +++ b/sdk/python/feast/infra/registry/sql.py @@ -193,7 +193,9 @@ def __init__( ) metadata.create_all(self.engine) super().__init__( - project=project, cache_ttl_seconds=registry_config.cache_ttl_seconds + project=project, + cache_ttl_seconds=registry_config.cache_ttl_seconds, + cache_mode=registry_config.cache_mode, ) def teardown(self): diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 8d6bff2818..137023ef22 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -124,6 +124,9 @@ class RegistryConfig(FeastBaseModel): sqlalchemy_config_kwargs: Dict[str, Any] = {} """ Dict[str, Any]: Extra arguments to pass to SQLAlchemy.create_engine. """ + cache_mode: StrictStr = "sync" + """ str: Cache mode type, Possible options are sync and thread(asynchronous caching using threading library)""" + @field_validator("path") def validate_path(cls, path: str, values: ValidationInfo) -> str: if values.data.get("registry_type") == "sql": diff --git a/sdk/python/tests/integration/registration/test_universal_registry.py b/sdk/python/tests/integration/registration/test_universal_registry.py index c06ccf2d4d..b0738c8419 100644 --- a/sdk/python/tests/integration/registration/test_universal_registry.py +++ b/sdk/python/tests/integration/registration/test_universal_registry.py @@ -125,7 +125,7 @@ def minio_registry() -> Registry: logger = logging.getLogger(__name__) -@pytest.fixture(scope="session") +@pytest.fixture(scope="function") def pg_registry(): container = ( DockerContainer("postgres:latest") @@ -137,6 +137,35 @@ def pg_registry(): container.start() + registry_config = _given_registry_config_for_pg_sql(container) + + yield SqlRegistry(registry_config, "project", None) + + container.stop() + + +@pytest.fixture(scope="function") +def pg_registry_async(): + container = ( + DockerContainer("postgres:latest") + .with_exposed_ports(5432) + .with_env("POSTGRES_USER", POSTGRES_USER) + .with_env("POSTGRES_PASSWORD", POSTGRES_PASSWORD) + .with_env("POSTGRES_DB", POSTGRES_DB) + ) + + container.start() + + registry_config = _given_registry_config_for_pg_sql(container, 2, "thread") + + yield SqlRegistry(registry_config, "project", None) + + container.stop() + + +def _given_registry_config_for_pg_sql( + container, cache_ttl_seconds=2, cache_mode="sync" +): log_string_to_wait_for = "database system is ready to accept connections" waited = wait_for_logs( container=container, @@ -148,25 +177,42 @@ def pg_registry(): container_port = container.get_exposed_port(5432) container_host = container.get_container_host_ip() - registry_config = RegistryConfig( + return RegistryConfig( registry_type="sql", + cache_ttl_seconds=cache_ttl_seconds, + cache_mode=cache_mode, # The `path` must include `+psycopg` in order for `sqlalchemy.create_engine()` # to understand that we are using psycopg3. path=f"postgresql+psycopg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{container_host}:{container_port}/{POSTGRES_DB}", sqlalchemy_config_kwargs={"echo": False, "pool_pre_ping": True}, ) + +@pytest.fixture(scope="function") +def mysql_registry(): + container = MySqlContainer("mysql:latest") + container.start() + + registry_config = _given_registry_config_for_mysql(container) + yield SqlRegistry(registry_config, "project", None) container.stop() -@pytest.fixture(scope="session") -def mysql_registry(): +@pytest.fixture(scope="function") +def mysql_registry_async(): container = MySqlContainer("mysql:latest") container.start() - # testing for the database to exist and ready to connect and start testing. + registry_config = _given_registry_config_for_mysql(container, 2, "thread") + + yield SqlRegistry(registry_config, "project", None) + + container.stop() + + +def _given_registry_config_for_mysql(container, cache_ttl_seconds=2, cache_mode="sync"): import sqlalchemy engine = sqlalchemy.create_engine( @@ -174,16 +220,14 @@ def mysql_registry(): ) engine.connect() - registry_config = RegistryConfig( + return RegistryConfig( registry_type="sql", path=container.get_connection_url(), + cache_ttl_seconds=cache_ttl_seconds, + cache_mode=cache_mode, sqlalchemy_config_kwargs={"echo": False, "pool_pre_ping": True}, ) - yield SqlRegistry(registry_config, "project", None) - - container.stop() - @pytest.fixture(scope="session") def sqlite_registry(): @@ -269,6 +313,17 @@ def mock_remote_registry(): lazy_fixture("sqlite_registry"), ] +async_sql_fixtures = [ + pytest.param( + lazy_fixture("pg_registry_async"), + marks=pytest.mark.xdist_group(name="pg_registry_async"), + ), + pytest.param( + lazy_fixture("mysql_registry_async"), + marks=pytest.mark.xdist_group(name="mysql_registry_async"), + ), +] + @pytest.mark.integration @pytest.mark.parametrize("test_registry", all_fixtures) @@ -999,6 +1054,44 @@ def test_registry_cache(test_registry): test_registry.teardown() +@pytest.mark.integration +@pytest.mark.parametrize( + "test_registry", + async_sql_fixtures, +) +def test_registry_cache_thread_async(test_registry): + # Create Feature View + batch_source = FileSource( + name="test_source", + file_format=ParquetFormat(), + path="file://feast/*", + timestamp_field="ts_col", + created_timestamp_column="timestamp", + ) + + project = "project" + + # Register data source + test_registry.apply_data_source(batch_source, project) + registry_data_sources_cached = test_registry.list_data_sources( + project, allow_cache=True + ) + # async ttl yet to expire, so there will be a cache miss + assert len(registry_data_sources_cached) == 0 + + # Wait for cache to be refreshed + time.sleep(4) + # Now objects exist + registry_data_sources_cached = test_registry.list_data_sources( + project, allow_cache=True + ) + assert len(registry_data_sources_cached) == 1 + registry_data_source = registry_data_sources_cached[0] + assert registry_data_source == batch_source + + test_registry.teardown() + + @pytest.mark.integration @pytest.mark.parametrize( "test_registry",