diff --git a/changelog/6629.bugfix.md b/changelog/6629.bugfix.md new file mode 100644 index 000000000000..43063d142356 --- /dev/null +++ b/changelog/6629.bugfix.md @@ -0,0 +1 @@ +Fixed a bug that occurred when setting multiple Sanic workers in combination with a custom [Lock Store](lock-stores.mdx). Previously, if the number was set higher than 1 and you were using a custom lock store, it would reject because of a strict check to use a [Redis Lock Store](lock-stores.mdx#redislockstore). diff --git a/rasa/core/utils.py b/rasa/core/utils.py index 289de4900934..ad7443925bc3 100644 --- a/rasa/core/utils.py +++ b/rasa/core/utils.py @@ -32,7 +32,7 @@ # backwards compatibility 1.0.x # noinspection PyUnresolvedReferences -from rasa.core.lock_store import LockStore, RedisLockStore +from rasa.core.lock_store import LockStore, RedisLockStore, InMemoryLockStore from rasa.utils.endpoints import EndpointConfig, read_endpoint_config from sanic import Sanic from sanic.views import CompositionView @@ -450,25 +450,29 @@ def replace_decimals_with_floats(obj: Any) -> Any: return json.loads(json.dumps(obj, cls=DecimalEncoder)) -def _lock_store_is_redis_lock_store( +def _lock_store_is_multi_worker_compatible( lock_store: Union[EndpointConfig, LockStore, None] ) -> bool: + if isinstance(lock_store, InMemoryLockStore): + return False + if isinstance(lock_store, RedisLockStore): return True - if isinstance(lock_store, LockStore): - return False - # `lock_store` is `None` or `EndpointConfig` - return lock_store is not None and lock_store.type == "redis" + return ( + lock_store is not None + and isinstance(lock_store, EndpointConfig) + and lock_store.type != "in_memory" + ) def number_of_sanic_workers(lock_store: Union[EndpointConfig, LockStore, None]) -> int: """Get the number of Sanic workers to use in `app.run()`. If the environment variable constants.ENV_SANIC_WORKERS is set and is not equal to - 1, that value will only be permitted if the used lock store supports shared - resources across multiple workers (e.g. ``RedisLockStore``). + 1, that value will only be permitted if the used lock store is not the + `InMemoryLockStore`. """ def _log_and_get_default_number_of_workers(): @@ -496,12 +500,12 @@ def _log_and_get_default_number_of_workers(): ) return _log_and_get_default_number_of_workers() - if _lock_store_is_redis_lock_store(lock_store): + if _lock_store_is_multi_worker_compatible(lock_store): logger.debug(f"Using {env_value} Sanic workers.") return env_value logger.debug( f"Unable to assign desired number of Sanic workers ({env_value}) as " - f"no `RedisLockStore` endpoint configuration has been found." + f"no `RedisLockStore` or custom `LockStore` endpoint configuration has been found." ) return _log_and_get_default_number_of_workers() diff --git a/tests/core/test_utils.py b/tests/core/test_utils.py index 3e9160739003..e21a60e41ea3 100644 --- a/tests/core/test_utils.py +++ b/tests/core/test_utils.py @@ -15,6 +15,10 @@ from tests.conftest import write_endpoint_config_to_yaml +class CustomRedisLockStore(RedisLockStore): + """Test class used to test the behavior of custom lock stores.""" + + def test_is_int(): assert utils.is_int(1) assert utils.is_int(1.0) @@ -132,8 +136,10 @@ def test_replace_decimals_with_floats(_input: Any, expected: Any): (5, "in_memory", 1), (2, None, 1), (0, "in_memory", 1), + (3, "tests/core/test_utils.CustomRedisLockStore", 3), (3, RedisLockStore(), 3), (2, InMemoryLockStore(), 1), + (3, CustomRedisLockStore(), 3), ], ) def test_get_number_of_sanic_workers( @@ -168,16 +174,19 @@ def test_get_number_of_sanic_workers( (EndpointConfig(type="redis"), True), (RedisLockStore(), True), (EndpointConfig(type="in_memory"), False), - (EndpointConfig(type="random_store"), False), + (EndpointConfig(type="custom_lock_store"), True), (None, False), (InMemoryLockStore(), False), + (CustomRedisLockStore(), True), ], ) -def test_lock_store_is_redis_lock_store( +def test_lock_store_is_multi_worker_compatible( lock_store: Union[EndpointConfig, LockStore, None], expected: bool ): # noinspection PyProtectedMember - assert rasa.core.utils._lock_store_is_redis_lock_store(lock_store) == expected + assert ( + rasa.core.utils._lock_store_is_multi_worker_compatible(lock_store) == expected + ) def test_read_endpoints_from_path(tmp_path: Path):