Skip to content

Commit

Permalink
Merge 'main' into fix/exists-method-for-shared-memory-dataset
Browse files Browse the repository at this point in the history
Signed-off-by: Felix Scherz <felixwscherz@gmail.com>
  • Loading branch information
felixscherz committed Aug 29, 2024
2 parents e34d849 + f6319dd commit a97a74e
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 24 deletions.
2 changes: 2 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
## Major features and improvements
* Enhanced `OmegaConfigLoader` configuration validation to detect duplicate keys at all parameter levels, ensuring comprehensive nested key checking.
## Bug fixes and other changes
* Fixed bug where using dataset factories breaks with `ThreadRunner`.
* Fixed a bug where `SharedMemoryDataset.exists` would not call the underlying `MemoryDataset`.

## Breaking changes to the API

## Documentation changes
Expand Down
88 changes: 65 additions & 23 deletions kedro/config/omegaconf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import mimetypes
import typing
from collections.abc import KeysView
from enum import Enum, auto
from pathlib import Path
from typing import Any, Callable, Iterable

Expand All @@ -26,6 +27,17 @@
_NO_VALUE = object()


class MergeStrategies(Enum):
SOFT = auto()
DESTRUCTIVE = auto()


MERGING_IMPLEMENTATIONS = {
MergeStrategies.SOFT: "_soft_merge",
MergeStrategies.DESTRUCTIVE: "_destructive_merge",
}


class OmegaConfigLoader(AbstractConfigLoader):
"""Recursively scan directories (config paths) contained in ``conf_source`` for
configuration files with a ``yaml``, ``yml`` or ``json`` extension, load and merge
Expand Down Expand Up @@ -131,18 +143,9 @@ def __init__( # noqa: PLR0913
self._register_new_resolvers(custom_resolvers)
# Register globals resolver
self._register_globals_resolver()
file_mimetype, _ = mimetypes.guess_type(conf_source)
if file_mimetype == "application/x-tar":
self._protocol = "tar"
elif file_mimetype in (
"application/zip",
"application/x-zip-compressed",
"application/zip-compressed",
):
self._protocol = "zip"
else:
self._protocol = "file"
self._fs = fsspec.filesystem(protocol=self._protocol, fo=conf_source)

# Setup file system and protocol
self._fs, self._protocol = self._initialise_filesystem_and_protocol(conf_source)

super().__init__(
conf_source=conf_source,
Expand Down Expand Up @@ -220,6 +223,11 @@ def __getitem__(self, key: str) -> dict[str, Any]: # noqa: PLR0912

# Load chosen env config
run_env = self.env or self.default_run_env

# Return if chosen env config is the same as base config to avoid loading the same config twice
if run_env == self.base_env:
return config # type: ignore[no-any-return]

if self._protocol == "file":
env_path = str(Path(self.conf_source) / run_env)
else:
Expand All @@ -236,16 +244,7 @@ def __getitem__(self, key: str) -> dict[str, Any]: # noqa: PLR0912
else:
raise exc

merging_strategy = self.merge_strategy.get(key)
if merging_strategy == "soft":
resulting_config = self._soft_merge(config, env_config)
elif merging_strategy == "destructive" or not merging_strategy:
resulting_config = self._destructive_merge(config, env_config, env_path)
else:
raise ValueError(
f"Merging strategy {merging_strategy} not supported. The accepted merging "
f"strategies are `soft` and `destructive`."
)
resulting_config = self._merge_configs(config, env_config, key, env_path)

if not processed_files and key != "globals":
raise MissingConfigException(
Expand Down Expand Up @@ -355,6 +354,47 @@ def load_and_merge_dir_config(
if not k.startswith("_")
}

@staticmethod
def _initialise_filesystem_and_protocol(
conf_source: str,
) -> tuple[fsspec.AbstractFileSystem, str]:
"""Set up the file system based on the file type detected in conf_source."""
file_mimetype, _ = mimetypes.guess_type(conf_source)
if file_mimetype == "application/x-tar":
protocol = "tar"
elif file_mimetype in (
"application/zip",
"application/x-zip-compressed",
"application/zip-compressed",
):
protocol = "zip"
else:
protocol = "file"
fs = fsspec.filesystem(protocol=protocol, fo=conf_source)
return fs, protocol

def _merge_configs(
self,
config: dict[str, Any],
env_config: dict[str, Any],
key: str,
env_path: str,
) -> Any:
merging_strategy = self.merge_strategy.get(key, "destructive")
try:
strategy = MergeStrategies[merging_strategy.upper()]

# Get the corresponding merge function and call it
merge_function_name = MERGING_IMPLEMENTATIONS[strategy]
merge_function = getattr(self, merge_function_name)
return merge_function(config, env_config, env_path)
except KeyError:
allowed_strategies = [strategy.name.lower() for strategy in MergeStrategies]
raise ValueError(
f"Merging strategy {merging_strategy} not supported. The accepted merging "
f"strategies are {allowed_strategies}."
)

def _get_all_keys(self, cfg: Any, parent_key: str = "") -> set[str]:
keys: set[str] = set()

Expand Down Expand Up @@ -499,7 +539,9 @@ def _destructive_merge(
return config

@staticmethod
def _soft_merge(config: dict[str, Any], env_config: dict[str, Any]) -> Any:
def _soft_merge(
config: dict[str, Any], env_config: dict[str, Any], env_path: str | None = None
) -> Any:
# Soft merge the two env dirs. The chosen env will override base if keys clash.
return OmegaConf.to_container(OmegaConf.merge(config, env_config))

Expand Down
8 changes: 7 additions & 1 deletion kedro/framework/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
validate_settings,
)
from kedro.io.core import generate_timestamp
from kedro.runner import AbstractRunner, SequentialRunner
from kedro.runner import AbstractRunner, SequentialRunner, ThreadRunner
from kedro.utils import _find_kedro_project

if TYPE_CHECKING:
Expand Down Expand Up @@ -395,6 +395,12 @@ def run( # noqa: PLR0913
)

try:
if isinstance(runner, ThreadRunner):
for ds in filtered_pipeline.datasets():
if catalog._match_pattern(
catalog._dataset_patterns, ds
) or catalog._match_pattern(catalog._default_pattern, ds):
_ = catalog._get_dataset(ds)
run_result = runner.run(
filtered_pipeline, catalog, hook_manager, session_id
)
Expand Down
78 changes: 78 additions & 0 deletions tests/framework/session/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@ def mock_runner(mocker):
return mock_runner


@pytest.fixture
def mock_thread_runner(mocker):
mock_runner = mocker.patch(
"kedro.runner.thread_runner.ThreadRunner",
autospec=True,
)
mock_runner.__name__ = "MockThreadRunner`"
return mock_runner


@pytest.fixture
def mock_context_class(mocker):
mock_cls = create_attrs_autospec(KedroContext)
Expand Down Expand Up @@ -693,6 +703,74 @@ def test_run(
catalog=mock_catalog,
)

@pytest.mark.usefixtures("mock_settings_context_class")
@pytest.mark.parametrize("fake_pipeline_name", [None, _FAKE_PIPELINE_NAME])
@pytest.mark.parametrize("match_pattern", [True, False])
def test_run_thread_runner(
self,
fake_project,
fake_session_id,
fake_pipeline_name,
mock_context_class,
mock_thread_runner,
mocker,
match_pattern,
):
"""Test running the project via the session"""

mock_hook = mocker.patch(
"kedro.framework.session.session._create_hook_manager"
).return_value.hook

ds_mock = mocker.Mock(**{"datasets.return_value": ["ds_1", "ds_2"]})
filter_mock = mocker.Mock(**{"filter.return_value": ds_mock})
pipelines_ret = {
_FAKE_PIPELINE_NAME: filter_mock,
"__default__": filter_mock,
}
mocker.patch("kedro.framework.session.session.pipelines", pipelines_ret)
mocker.patch(
"kedro.io.data_catalog.DataCatalog._match_pattern",
return_value=match_pattern,
)

with KedroSession.create(fake_project) as session:
session.run(runner=mock_thread_runner, pipeline_name=fake_pipeline_name)

mock_context = mock_context_class.return_value
record_data = {
"session_id": fake_session_id,
"project_path": fake_project.as_posix(),
"env": mock_context.env,
"kedro_version": kedro_version,
"tags": None,
"from_nodes": None,
"to_nodes": None,
"node_names": None,
"from_inputs": None,
"to_outputs": None,
"load_versions": None,
"extra_params": {},
"pipeline_name": fake_pipeline_name,
"namespace": None,
"runner": mock_thread_runner.__name__,
}
mock_catalog = mock_context._get_catalog.return_value
mock_pipeline = filter_mock.filter()

mock_hook.before_pipeline_run.assert_called_once_with(
run_params=record_data, pipeline=mock_pipeline, catalog=mock_catalog
)
mock_thread_runner.run.assert_called_once_with(
mock_pipeline, mock_catalog, session._hook_manager, fake_session_id
)
mock_hook.after_pipeline_run.assert_called_once_with(
run_params=record_data,
run_result=mock_thread_runner.run.return_value,
pipeline=mock_pipeline,
catalog=mock_catalog,
)

@pytest.mark.usefixtures("mock_settings_context_class")
@pytest.mark.parametrize("fake_pipeline_name", [None, _FAKE_PIPELINE_NAME])
def test_run_multiple_times(
Expand Down

0 comments on commit a97a74e

Please sign in to comment.