Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset factories #2635

Merged
merged 29 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
ec66a12
Cleaned up and up to date version of dataset factories code
merelcht May 22, 2023
5e6c15d
Add some simple tests
merelcht May 23, 2023
0fca72c
Add parsing rules
ankatiyar Jun 2, 2023
06ed1a4
Refactor
ankatiyar Jun 8, 2023
b0e3fb9
Add some tests
ankatiyar Jun 8, 2023
0833af2
Add unit tests
ankatiyar Jun 12, 2023
8fc80f9
Fix test + refactor runner
ankatiyar Jun 12, 2023
091f794
Merge branch 'main' into feature/dataset-factories-for-catalog
ankatiyar Jun 12, 2023
8c192ee
Add comments + update specificity fn
ankatiyar Jun 13, 2023
3e2642c
Update function names
ankatiyar Jun 15, 2023
c2635d0
Merge branch 'main' into feature/dataset-factories-for-catalog
ankatiyar Jun 15, 2023
d310486
Update test
ankatiyar Jun 15, 2023
573c67f
Merge branch 'main' into feature/dataset-factories-for-catalog
ankatiyar Jun 19, 2023
9d80de4
Release notes + update resume scenario fix
ankatiyar Jun 19, 2023
549823f
revert change to suggest resume scenario
ankatiyar Jun 19, 2023
e052ae6
Update tests DataSet->Dataset
ankatiyar Jun 19, 2023
96c219f
Small refactor + move parsing rules to a new fn
ankatiyar Jun 20, 2023
c2634e5
Fix problem with load_version + refactor
ankatiyar Jul 3, 2023
eee606a
linting + small fix _get_datasets
ankatiyar Jul 3, 2023
635510a
Remove check for existence
ankatiyar Jul 4, 2023
394f37b
Merge branch 'main' into feature/dataset-factories-for-catalog
ankatiyar Jul 4, 2023
a1c602d
Add updated tests + Release notes
ankatiyar Jul 5, 2023
978d0a5
change classmethod to staticmethod for _match_patterns
ankatiyar Jul 5, 2023
b4fe7a7
Add test for layer
ankatiyar Jul 5, 2023
2782dca
Merge branch 'main' into feature/dataset-factories-for-catalog
ankatiyar Jul 5, 2023
85d3df1
Minor change from code review
ankatiyar Jul 5, 2023
8904ce3
Remove type conversion
ankatiyar Jul 6, 2023
bdc953d
Add warning for catch-all patterns [dataset factories] (#2774)
ankatiyar Jul 6, 2023
fa6c256
Merge branch 'main' into feature/dataset-factories-for-catalog
ankatiyar Jul 6, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dependency/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ importlib_resources>=1.3 # The `files()` API was introduced in `importlib_resou
jmespath>=0.9.5, <1.0
more_itertools~=9.0
omegaconf~=2.3
parse~=1.19.0
pip-tools~=6.5
pluggy~=1.0.0
PyYAML>=4.2, <7.0
Expand Down
175 changes: 154 additions & 21 deletions kedro/io/data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import logging
import re
from collections import defaultdict
from typing import Any
from typing import Any, Iterable

from parse import parse

from kedro.io.core import (
AbstractDataSet,
Expand Down Expand Up @@ -136,11 +138,14 @@ class DataCatalog:
to the underlying data sets.
"""

def __init__(
def __init__( # pylint: disable=too-many-arguments
self,
data_sets: dict[str, AbstractDataSet] = None,
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
feed_dict: dict[str, Any] = None,
layers: dict[str, set[str]] = None,
dataset_patterns: dict[str, dict[str, Any]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it is good to introduce keyword-only arguments here, similar to the proposal to get us more freedom to re-arrange the arguments later without breaking changes. The current argument lists does not make too much sense

load_versions: dict[str, str] = None,
save_version: str = None,
merelcht marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""``DataCatalog`` stores instances of ``AbstractDataSet``
implementations to provide ``load`` and ``save`` capabilities from
Expand Down Expand Up @@ -170,8 +175,12 @@ def __init__(
self._data_sets = dict(data_sets or {})
self.datasets = _FrozenDatasets(self._data_sets)
self.layers = layers
# Keep a record of all patterns in the catalog.
# {dataset pattern name : dataset pattern body}
self._dataset_patterns = dict(dataset_patterns or {})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still confuse me a lot, is there any difference? Isn't dataset_patterns a dict already?

Suggested change
self._dataset_patterns = dict(dataset_patterns or {})
self._dataset_patterns =dataset_patterns or {}

self._load_versions = dict(load_versions or {})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as self._dataset_patterns

self._save_version = save_version

# import the feed dict
if feed_dict:
self.add_feed_dict(feed_dict)

Expand All @@ -181,7 +190,7 @@ def _logger(self):

@classmethod
def from_config(
cls: type,
cls,
catalog: dict[str, dict[str, Any]] | None,
credentials: dict[str, dict[str, Any]] = None,
load_versions: dict[str, str] = None,
Expand Down Expand Up @@ -257,36 +266,105 @@ class to be loaded is specified with the key ``type`` and their
>>> catalog.save("boats", df)
"""
data_sets = {}
dataset_patterns = {}
catalog = copy.deepcopy(catalog) or {}
credentials = copy.deepcopy(credentials) or {}
save_version = save_version or generate_timestamp()
load_versions = copy.deepcopy(load_versions) or {}
layers: dict[str, set[str]] = defaultdict(set)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same argument, good to have keywords-only argument here even we are going to remove layers soon


missing_keys = load_versions.keys() - catalog.keys()
for ds_name, ds_config in catalog.items():
ds_config = _resolve_credentials(ds_config, credentials)
if cls._is_pattern(ds_name):
# Add each factory to the dataset_patterns dict.
dataset_patterns[ds_name] = ds_config
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved

else:
ds_layer = ds_config.pop("layer", None)
if ds_layer is not None:
layers[ds_layer].add(ds_name)
data_sets[ds_name] = AbstractDataSet.from_config(
ds_name, ds_config, load_versions.get(ds_name), save_version
)
dataset_layers = layers or None
sorted_patterns = cls._sort_patterns(dataset_patterns)
missing_keys = [
key
for key in load_versions.keys()
if not (cls._match_pattern(sorted_patterns, key) or key in catalog)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test should be reordered, since key in catalog is easier to test and probably will match more often. No need to go through _match_pattern for all the non-patterned datasets.

]
if missing_keys:
raise DatasetNotFoundError(
f"'load_versions' keys [{', '.join(sorted(missing_keys))}] "
f"are not found in the catalog."
)

layers: dict[str, set[str]] = defaultdict(set)
for ds_name, ds_config in catalog.items():
ds_layer = ds_config.pop("layer", None)
if ds_layer is not None:
layers[ds_layer].add(ds_name)
return cls(
data_sets=data_sets,
layers=dataset_layers,
dataset_patterns=sorted_patterns,
load_versions=load_versions,
save_version=save_version,
)

ds_config = _resolve_credentials(ds_config, credentials)
data_sets[ds_name] = AbstractDataSet.from_config(
ds_name, ds_config, load_versions.get(ds_name), save_version
)
@staticmethod
merelcht marked this conversation as resolved.
Show resolved Hide resolved
def _is_pattern(pattern: str):
"""Check if a given string is a pattern. Assume that any name with '{' is a pattern."""
if "{" in pattern:
return True
return False
Copy link
Contributor

@noklam noklam Jul 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if "{" in pattern:
return True
return False
return "{" in pattern


dataset_layers = layers or None
return cls(data_sets=data_sets, layers=dataset_layers)
@classmethod
def _match_pattern(
cls, data_set_patterns: dict[str, dict[str, Any]], data_set_name: str
) -> str | None:
"""Match a dataset name against patterns in a dictionary containing patterns"""
for pattern, _ in data_set_patterns.items():
result = parse(pattern, data_set_name)
if result:
return pattern
return None
Comment on lines +329 to +333
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be nicely rewritten to:

matches = (parse(pattern, data_set_name) for pattern in data_set_patterns.keys())
return next(filter(None, matches), None)


@classmethod
def _sort_patterns(
cls, data_set_patterns: dict[str, dict[str, Any]]
) -> dict[str, dict[str, Any]]:
"""Sort a dictionary of dataset patterns according to parsing rules -
1. Decreasing specificity (no of characters outside the curly brackets
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
2. Decreasing number of placeholders (no of curly bracket pairs)
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
3. Alphabetically
"""
sorted_keys = sorted(
data_set_patterns,
key=lambda pattern: (
-(cls._specificity(pattern)),
-pattern.count("{"),
pattern,
),
)
sorted_patterns = {}
for key in sorted_keys:
sorted_patterns[key] = data_set_patterns[key]
return sorted_patterns
Comment on lines +352 to +355
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shorter and neater.

return {key: data_set_patterns[key] for key in sorted_keys}


@staticmethod
def _specificity(pattern: str) -> int:
"""Helper function to check the length of exactly matched characters not inside brackets
Example -
specificity("{namespace}.companies") = 10
specificity("{namespace}.{dataset}") = 1
specificity("france.companies") = 16
Args:
pattern: The factory pattern
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For completeness this doc string should also include the return type/value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a private method, the Args: part is leftover from when it was a helper fn outside DataCatalog I've removed it

"""
# Remove all the placeholders from the pattern and count the number of remaining chars
result = re.sub(r"\{.*?\}", "", pattern)
return len(result)

def _get_dataset(
self, data_set_name: str, version: Version = None, suggest: bool = True
) -> AbstractDataSet:
if data_set_name not in self._data_sets:
merelcht marked this conversation as resolved.
Show resolved Hide resolved
if data_set_name not in self:
error_msg = f"Dataset '{data_set_name}' not found in the catalog"

# Flag to turn on/off fuzzy-matching which can be time consuming and
Expand All @@ -298,9 +376,7 @@ def _get_dataset(
if matches:
suggestions = ", ".join(matches)
error_msg += f" - did you mean one of these instead: {suggestions}"

raise DatasetNotFoundError(error_msg)

data_set = self._data_sets[data_set_name]
if version and isinstance(data_set, AbstractVersionedDataSet):
# we only want to return a similar-looking dataset,
Expand All @@ -311,6 +387,53 @@ def _get_dataset(

return data_set

def __contains__(self, data_set_name):
merelcht marked this conversation as resolved.
Show resolved Hide resolved
"""Check if an item is in the catalog as a materialised dataset or pattern,
add to catalog if it is a pattern"""
if data_set_name in self._data_sets:
return True
matched_pattern = self._match_pattern(self._dataset_patterns, data_set_name)
if matched_pattern:
# If the dataset is a patterned dataset, materialise it and add it to
# the catalog
data_set_config = self._resolve_config(data_set_name, matched_pattern)
ds_layer = data_set_config.pop("layer", None)
if ds_layer:
if not self.layers:
self.layers = {}
self.layers.setdefault(ds_layer, set()).add(data_set_name)
data_set = AbstractDataSet.from_config(
data_set_name,
data_set_config,
self._load_versions.get(data_set_name),
self._save_version,
)
self.add(data_set_name, data_set)
return True
return False

def _resolve_config(
self,
data_set_name: str,
matched_pattern: str,
) -> dict[str, Any]:
"""Get resolved AbstractDataSet from a factory config"""
result = parse(matched_pattern, data_set_name)
config_copy = copy.deepcopy(self._dataset_patterns[matched_pattern])
# Resolve the factory config for the dataset
for key, value in config_copy.items():
if isinstance(value, Iterable) and "}" in value:
# result.named: gives access to all dict items in the match result.
# format_map fills in dict values into a string with {...} placeholders
# of the same key name.
try:
config_copy[key] = str(value).format_map(result.named)
except KeyError as exc:
raise DatasetError(
f"Unable to resolve '{key}' for the pattern '{matched_pattern}'"
) from exc
return config_copy

def load(self, name: str, version: str = None) -> Any:
"""Loads a registered data set.

Expand Down Expand Up @@ -573,10 +696,20 @@ def shallow_copy(self) -> DataCatalog:
Returns:
Copy of the current object.
"""
return DataCatalog(data_sets=self._data_sets, layers=self.layers)
return DataCatalog(
data_sets=self._data_sets,
layers=self.layers,
dataset_patterns=self._dataset_patterns,
load_versions=self._load_versions,
save_version=self._save_version,
)

def __eq__(self, other):
ankatiyar marked this conversation as resolved.
Show resolved Hide resolved
return (self._data_sets, self.layers) == (other._data_sets, other.layers)
return (self._data_sets, self.layers, self._dataset_patterns) == (
other._data_sets,
other.layers,
other._dataset_patterns,
)

def confirm(self, name: str) -> None:
"""Confirm a dataset by its name.
Expand Down
21 changes: 16 additions & 5 deletions kedro/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class AbstractRunner(ABC):
"""

def __init__(self, is_async: bool = False):
"""Instantiates the runner classs.
"""Instantiates the runner class.

Args:
is_async: If True, the node inputs and outputs are loaded and saved
Expand Down Expand Up @@ -74,14 +74,25 @@ def run(
hook_manager = hook_manager or _NullPluginManager()
catalog = catalog.shallow_copy()

unsatisfied = pipeline.inputs() - set(catalog.list())
# Check which datasets used in the pipeline are in the catalog or match
# a pattern in the catalog
registered_ds = [ds for ds in pipeline.data_sets() if ds in catalog]

# Check if there are any input datasets that aren't in the catalog and
# don't match a pattern in the catalog.
unsatisfied = pipeline.inputs() - set(registered_ds)

if unsatisfied:
raise ValueError(
f"Pipeline input(s) {unsatisfied} not found in the DataCatalog"
)

free_outputs = pipeline.outputs() - set(catalog.list())
unregistered_ds = pipeline.data_sets() - set(catalog.list())
# Check if there's any output datasets that aren't in the catalog and don't match a pattern
# in the catalog.
free_outputs = pipeline.outputs() - set(registered_ds)
unregistered_ds = pipeline.data_sets() - set(registered_ds)

# Create a default dataset for unregistered datasets
for ds_name in unregistered_ds:
catalog.add(ds_name, self.create_default_data_set(ds_name))

Expand Down Expand Up @@ -420,7 +431,7 @@ def _run_node_sequential(
items: Iterable = outputs.items()
# if all outputs are iterators, then the node is a generator node
if all(isinstance(d, Iterator) for d in outputs.values()):
# Python dictionaries are ordered so we are sure
# Python dictionaries are ordered, so we are sure
# the keys and the chunk streams are in the same order
# [a, b, c]
keys = list(outputs.keys())
Expand Down
Loading