diff --git a/RELEASE.md b/RELEASE.md index f5d7f3cf61..daef0e614c 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -3,7 +3,9 @@ ## Major features and improvements * Enhanced `OmegaConfigLoader` configuration validation to detect duplicate keys at all parameter levels, ensuring comprehensive nested key checking. ## Bug fixes and other changes +* Fixed bug where using dataset factories breaks with `ThreadRunner`. * Fixed a bug where `SharedMemoryDataset.exists` would not call the underlying `MemoryDataset`. + ## Breaking changes to the API ## Documentation changes diff --git a/kedro/config/omegaconf_config.py b/kedro/config/omegaconf_config.py index 691ae30385..c4850159a1 100644 --- a/kedro/config/omegaconf_config.py +++ b/kedro/config/omegaconf_config.py @@ -9,6 +9,7 @@ import mimetypes import typing from collections.abc import KeysView +from enum import Enum, auto from pathlib import Path from typing import Any, Callable, Iterable @@ -26,6 +27,17 @@ _NO_VALUE = object() +class MergeStrategies(Enum): + SOFT = auto() + DESTRUCTIVE = auto() + + +MERGING_IMPLEMENTATIONS = { + MergeStrategies.SOFT: "_soft_merge", + MergeStrategies.DESTRUCTIVE: "_destructive_merge", +} + + class OmegaConfigLoader(AbstractConfigLoader): """Recursively scan directories (config paths) contained in ``conf_source`` for configuration files with a ``yaml``, ``yml`` or ``json`` extension, load and merge @@ -131,18 +143,9 @@ def __init__( # noqa: PLR0913 self._register_new_resolvers(custom_resolvers) # Register globals resolver self._register_globals_resolver() - file_mimetype, _ = mimetypes.guess_type(conf_source) - if file_mimetype == "application/x-tar": - self._protocol = "tar" - elif file_mimetype in ( - "application/zip", - "application/x-zip-compressed", - "application/zip-compressed", - ): - self._protocol = "zip" - else: - self._protocol = "file" - self._fs = fsspec.filesystem(protocol=self._protocol, fo=conf_source) + + # Setup file system and protocol + self._fs, self._protocol = self._initialise_filesystem_and_protocol(conf_source) super().__init__( conf_source=conf_source, @@ -220,6 +223,11 @@ def __getitem__(self, key: str) -> dict[str, Any]: # noqa: PLR0912 # Load chosen env config run_env = self.env or self.default_run_env + + # Return if chosen env config is the same as base config to avoid loading the same config twice + if run_env == self.base_env: + return config # type: ignore[no-any-return] + if self._protocol == "file": env_path = str(Path(self.conf_source) / run_env) else: @@ -236,16 +244,7 @@ def __getitem__(self, key: str) -> dict[str, Any]: # noqa: PLR0912 else: raise exc - merging_strategy = self.merge_strategy.get(key) - if merging_strategy == "soft": - resulting_config = self._soft_merge(config, env_config) - elif merging_strategy == "destructive" or not merging_strategy: - resulting_config = self._destructive_merge(config, env_config, env_path) - else: - raise ValueError( - f"Merging strategy {merging_strategy} not supported. The accepted merging " - f"strategies are `soft` and `destructive`." - ) + resulting_config = self._merge_configs(config, env_config, key, env_path) if not processed_files and key != "globals": raise MissingConfigException( @@ -355,6 +354,47 @@ def load_and_merge_dir_config( if not k.startswith("_") } + @staticmethod + def _initialise_filesystem_and_protocol( + conf_source: str, + ) -> tuple[fsspec.AbstractFileSystem, str]: + """Set up the file system based on the file type detected in conf_source.""" + file_mimetype, _ = mimetypes.guess_type(conf_source) + if file_mimetype == "application/x-tar": + protocol = "tar" + elif file_mimetype in ( + "application/zip", + "application/x-zip-compressed", + "application/zip-compressed", + ): + protocol = "zip" + else: + protocol = "file" + fs = fsspec.filesystem(protocol=protocol, fo=conf_source) + return fs, protocol + + def _merge_configs( + self, + config: dict[str, Any], + env_config: dict[str, Any], + key: str, + env_path: str, + ) -> Any: + merging_strategy = self.merge_strategy.get(key, "destructive") + try: + strategy = MergeStrategies[merging_strategy.upper()] + + # Get the corresponding merge function and call it + merge_function_name = MERGING_IMPLEMENTATIONS[strategy] + merge_function = getattr(self, merge_function_name) + return merge_function(config, env_config, env_path) + except KeyError: + allowed_strategies = [strategy.name.lower() for strategy in MergeStrategies] + raise ValueError( + f"Merging strategy {merging_strategy} not supported. The accepted merging " + f"strategies are {allowed_strategies}." + ) + def _get_all_keys(self, cfg: Any, parent_key: str = "") -> set[str]: keys: set[str] = set() @@ -499,7 +539,9 @@ def _destructive_merge( return config @staticmethod - def _soft_merge(config: dict[str, Any], env_config: dict[str, Any]) -> Any: + def _soft_merge( + config: dict[str, Any], env_config: dict[str, Any], env_path: str | None = None + ) -> Any: # Soft merge the two env dirs. The chosen env will override base if keys clash. return OmegaConf.to_container(OmegaConf.merge(config, env_config)) diff --git a/kedro/framework/session/session.py b/kedro/framework/session/session.py index 23ac653d20..91928f7c4b 100644 --- a/kedro/framework/session/session.py +++ b/kedro/framework/session/session.py @@ -24,7 +24,7 @@ validate_settings, ) from kedro.io.core import generate_timestamp -from kedro.runner import AbstractRunner, SequentialRunner +from kedro.runner import AbstractRunner, SequentialRunner, ThreadRunner from kedro.utils import _find_kedro_project if TYPE_CHECKING: @@ -395,6 +395,12 @@ def run( # noqa: PLR0913 ) try: + if isinstance(runner, ThreadRunner): + for ds in filtered_pipeline.datasets(): + if catalog._match_pattern( + catalog._dataset_patterns, ds + ) or catalog._match_pattern(catalog._default_pattern, ds): + _ = catalog._get_dataset(ds) run_result = runner.run( filtered_pipeline, catalog, hook_manager, session_id ) diff --git a/tests/framework/session/test_session.py b/tests/framework/session/test_session.py index 749f730e69..bc25db37c7 100644 --- a/tests/framework/session/test_session.py +++ b/tests/framework/session/test_session.py @@ -87,6 +87,16 @@ def mock_runner(mocker): return mock_runner +@pytest.fixture +def mock_thread_runner(mocker): + mock_runner = mocker.patch( + "kedro.runner.thread_runner.ThreadRunner", + autospec=True, + ) + mock_runner.__name__ = "MockThreadRunner`" + return mock_runner + + @pytest.fixture def mock_context_class(mocker): mock_cls = create_attrs_autospec(KedroContext) @@ -693,6 +703,74 @@ def test_run( catalog=mock_catalog, ) + @pytest.mark.usefixtures("mock_settings_context_class") + @pytest.mark.parametrize("fake_pipeline_name", [None, _FAKE_PIPELINE_NAME]) + @pytest.mark.parametrize("match_pattern", [True, False]) + def test_run_thread_runner( + self, + fake_project, + fake_session_id, + fake_pipeline_name, + mock_context_class, + mock_thread_runner, + mocker, + match_pattern, + ): + """Test running the project via the session""" + + mock_hook = mocker.patch( + "kedro.framework.session.session._create_hook_manager" + ).return_value.hook + + ds_mock = mocker.Mock(**{"datasets.return_value": ["ds_1", "ds_2"]}) + filter_mock = mocker.Mock(**{"filter.return_value": ds_mock}) + pipelines_ret = { + _FAKE_PIPELINE_NAME: filter_mock, + "__default__": filter_mock, + } + mocker.patch("kedro.framework.session.session.pipelines", pipelines_ret) + mocker.patch( + "kedro.io.data_catalog.DataCatalog._match_pattern", + return_value=match_pattern, + ) + + with KedroSession.create(fake_project) as session: + session.run(runner=mock_thread_runner, pipeline_name=fake_pipeline_name) + + mock_context = mock_context_class.return_value + record_data = { + "session_id": fake_session_id, + "project_path": fake_project.as_posix(), + "env": mock_context.env, + "kedro_version": kedro_version, + "tags": None, + "from_nodes": None, + "to_nodes": None, + "node_names": None, + "from_inputs": None, + "to_outputs": None, + "load_versions": None, + "extra_params": {}, + "pipeline_name": fake_pipeline_name, + "namespace": None, + "runner": mock_thread_runner.__name__, + } + mock_catalog = mock_context._get_catalog.return_value + mock_pipeline = filter_mock.filter() + + mock_hook.before_pipeline_run.assert_called_once_with( + run_params=record_data, pipeline=mock_pipeline, catalog=mock_catalog + ) + mock_thread_runner.run.assert_called_once_with( + mock_pipeline, mock_catalog, session._hook_manager, fake_session_id + ) + mock_hook.after_pipeline_run.assert_called_once_with( + run_params=record_data, + run_result=mock_thread_runner.run.return_value, + pipeline=mock_pipeline, + catalog=mock_catalog, + ) + @pytest.mark.usefixtures("mock_settings_context_class") @pytest.mark.parametrize("fake_pipeline_name", [None, _FAKE_PIPELINE_NAME]) def test_run_multiple_times(