Skip to content

Commit

Permalink
ci(airflow): Replace type hints with CatalogProtocol (#845)
Browse files Browse the repository at this point in the history
* Replace type checking with CatalogProtocol

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Add try except for import

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Ignore bandit warnings

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Remove any

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

---------

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>
  • Loading branch information
ankatiyar committed Sep 25, 2024
1 parent 4b75db7 commit 0f0c59e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
13 changes: 10 additions & 3 deletions kedro-airflow/kedro_airflow/grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,21 @@
from kedro.pipeline.node import Node
from kedro.pipeline.pipeline import Pipeline

try:
from kedro.io import CatalogProtocol
except ImportError: # pragma: no cover
pass


def _is_memory_dataset(catalog, dataset_name: str) -> bool:
if dataset_name not in catalog:
return True
return False


def get_memory_datasets(catalog: DataCatalog, pipeline: Pipeline) -> set[str]:
def get_memory_datasets(
catalog: CatalogProtocol | DataCatalog, pipeline: Pipeline
) -> set[str]:
"""Gather all datasets in the pipeline that are of type MemoryDataset, excluding 'parameters'."""
return {
dataset_name
Expand All @@ -21,7 +28,7 @@ def get_memory_datasets(catalog: DataCatalog, pipeline: Pipeline) -> set[str]:


def create_adjacency_list(
catalog: DataCatalog, pipeline: Pipeline
catalog: CatalogProtocol | DataCatalog, pipeline: Pipeline
) -> tuple[dict[str, set], dict[str, set]]:
"""
Builds adjacency list (adj_list) to search connected components - undirected graph,
Expand All @@ -48,7 +55,7 @@ def create_adjacency_list(


def group_memory_nodes(
catalog: DataCatalog, pipeline: Pipeline
catalog: CatalogProtocol | DataCatalog, pipeline: Pipeline
) -> tuple[dict[str, list[Node]], dict[str, list[str]]]:
"""
Nodes that are connected through MemoryDatasets cannot be distributed across
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ def _describe(self) -> dict[str, Any]:

def _load(self) -> Any:
load_path = get_filepath_str(self._get_load_path(), self._protocol)
return torch.load(load_path, **self._fs_open_args_load)
return torch.load(load_path, **self._fs_open_args_load) #nosec: B614

def _save(self, data: torch.nn.Module) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)
torch.save(data.state_dict(), save_path, **self._fs_open_args_save)
torch.save(data.state_dict(), save_path, **self._fs_open_args_save) #nosec: B614

self._invalidate_cache()

Expand Down

0 comments on commit 0f0c59e

Please sign in to comment.