Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Simon Brugman <sfbbrugman@gmail.com>
  • Loading branch information
sbrugman committed Jan 29, 2024
1 parent 4a6fc53 commit 9bd31b7
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions kedro-airflow/kedro_airflow/grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_memory_datasets(catalog: DataCatalog, pipeline: Pipeline) -> set[str]:
"""Gather all datasets in the pipeline that are of type MemoryDataset, excluding 'parameters'."""
return {
dataset_name
for dataset_name in pipeline.data_sets()
for dataset_name in pipeline.datasets()
if _is_memory_dataset(catalog, dataset_name)
}

Expand All @@ -29,6 +29,13 @@ def node_sequence_name(node_sequence: list[Node]) -> str:


def group_memory_nodes(catalog: DataCatalog, pipeline: Pipeline):
"""
Nodes that are connected through MemoryDatasets cannot be distributed across
multiple machines, e.g. be in different Kubernetes pods. This function
groups nodes that are connected through MemoryDatasets in the pipeline
together. Essentially, this computes connected components over the graph of
nodes connected by MemoryDatasets.
"""
# get all memory datasets in the pipeline
memory_datasets = get_memory_datasets(catalog, pipeline)

Expand All @@ -38,19 +45,19 @@ def group_memory_nodes(catalog: DataCatalog, pipeline: Pipeline):
# Mapping from dataset name -> node sequence index
sequence_map = {}
for node in pipeline.nodes:
if all(o not in ds for o in node.inputs + node.outputs):
if all(o not in memory_datasets for o in node.inputs + node.outputs):
# standalone node
node_sequences.append([node])
else:
if all(i not in ds for i in node.inputs):
if all(i not in memory_datasets for i in node.inputs):
# start of a sequence; create a new sequence and store the id
node_sequences.append([node])
sequence_id = len(node_sequences) - 1
else:
# continuation of a sequence; retrieve sequence_id
sequence_id = None
for i in node.inputs:
if i in ds:
if i in memory_datasets:
assert sequence_id is None or sequence_id == sequence_map[i]
sequence_id = sequence_map[i]

Expand All @@ -59,7 +66,7 @@ def group_memory_nodes(catalog: DataCatalog, pipeline: Pipeline):

# map outputs to sequence_id
for o in node.outputs:
if o in ds:
if o in memory_datasets:
sequence_map[o] = sequence_id

# Named node sequences
Expand Down

0 comments on commit 9bd31b7

Please sign in to comment.