Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make dataset factory resolve nested dict properly #2993

Merged
merged 16 commits into from
Sep 7, 2023
Merged
2 changes: 2 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

## Major features and improvements
## Bug fixes and other changes
* Updated dataset factories to resolve nested catalog config properly.

## Documentation changes
## Breaking changes to the API
## Upcoming deprecations for Kedro 0.19.0
Expand Down
18 changes: 16 additions & 2 deletions kedro/framework/cli/catalog.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""A collection of CLI commands for working with Kedro catalog."""
import copy
from collections import defaultdict
from itertools import chain

Expand Down Expand Up @@ -84,7 +85,13 @@ def list_datasets(metadata: ProjectMetadata, pipeline, env):
data_catalog._dataset_patterns, ds_name
)
if matched_pattern:
ds_config = data_catalog._resolve_config(ds_name, matched_pattern)
ds_config_copy = copy.deepcopy(
data_catalog._dataset_patterns[matched_pattern]
)

ds_config = data_catalog._resolve_config(
ds_name, matched_pattern, ds_config_copy
)
factory_ds_by_type[ds_config["type"]].append(ds_name)

default_ds = default_ds - set(chain.from_iterable(factory_ds_by_type.values()))
Expand Down Expand Up @@ -244,7 +251,14 @@ def resolve_patterns(metadata: ProjectMetadata, env):
data_catalog._dataset_patterns, ds_name
)
if matched_pattern:
ds_config = data_catalog._resolve_config(ds_name, matched_pattern)
ds_config_copy = copy.deepcopy(
data_catalog._dataset_patterns[matched_pattern]
)

ds_config = data_catalog._resolve_config(
ds_name, matched_pattern, ds_config_copy
)

ds_config["filepath"] = _trim_filepath(
str(context.project_path) + "/", ds_config["filepath"]
)
Expand Down
14 changes: 11 additions & 3 deletions kedro/io/data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,10 @@ def _get_dataset(
if data_set_name not in self._data_sets and matched_pattern:
# If the dataset is a patterned dataset, materialise it and add it to
# the catalog
data_set_config = self._resolve_config(data_set_name, matched_pattern)
config_copy = copy.deepcopy(self._dataset_patterns[matched_pattern])
data_set_config = self._resolve_config(
data_set_name, matched_pattern, config_copy
)
ds_layer = data_set_config.pop("layer", None)
if ds_layer:
self.layers = self.layers or {}
Expand Down Expand Up @@ -436,16 +439,21 @@ def __contains__(self, data_set_name):
return True
return False

@classmethod
def _resolve_config(
self,
cls,
data_set_name: str,
matched_pattern: str,
config_copy: dict,
) -> dict[str, Any]:
"""Get resolved AbstractDataset from a factory config"""
result = parse(matched_pattern, data_set_name)
config_copy = copy.deepcopy(self._dataset_patterns[matched_pattern])
# Resolve the factory config for the dataset
for key, value in config_copy.items():
if isinstance(value, Dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep dict here? Unsure why we need typing here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed!

config_copy[key] = cls._resolve_config(
data_set_name, matched_pattern, value
)
if isinstance(value, Iterable) and "}" in value:
# result.named: gives access to all dict items in the match result.
# format_map fills in dict values into a string with {...} placeholders
Expand Down
19 changes: 19 additions & 0 deletions tests/io/test_data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,20 @@ def config_with_dataset_factories():
}


@pytest.fixture
def config_with_dataset_factories_nested():
return {
"catalog": {
"{brand}_cars": {
"type": "PartitionedDataset",
"path": "data/01_raw",
"dataset": "pandas.CSVDataSet",
"metadata": {"my-plugin": {"brand": "{brand}"}},
},
},
}


@pytest.fixture
def config_with_dataset_factories_with_default(config_with_dataset_factories):
config_with_dataset_factories["catalog"]["{default_dataset}"] = {
Expand Down Expand Up @@ -896,3 +910,8 @@ def test_factory_config_versioned(
microsecond=current_ts.microsecond // 1000 * 1000, tzinfo=None
)
assert actual_timestamp == expected_timestamp

def test_factory_nested_config(self, config_with_dataset_factories_nested):
catalog = DataCatalog.from_config(**config_with_dataset_factories_nested)
dataset = catalog._get_dataset("tesla_cars")
assert dataset.metadata["my-plugin"]["brand"] == "tesla"