Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(datasets)!: expose load and save publicly #829

Merged
merged 3 commits into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
* Fixed incorrect `pandas` optional dependency

## Breaking Changes
* Exposed `load` and `save` publicly for each dataset. This requires Kedro version 0.19.7 or higher.

## Community contributions
Many thanks to the following Kedroids for contributing PRs to this release:
* [Brandon Meek](https://github.com/bpmeek)
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/api/api_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _execute_request(self, session: Session) -> requests.Response:

return response

def _load(self) -> requests.Response:
def load(self) -> requests.Response:
if self._request_args["method"] == "GET":
with sessions.Session() as session:
return self._execute_request(session)
Expand Down Expand Up @@ -219,7 +219,7 @@ def _execute_save_request(self, json_data: Any) -> requests.Response:
raise DatasetError("Failed to connect to the remote server") from exc
return response

def _save(self, data: Any) -> requests.Response: # type: ignore[override]
def save(self, data: Any) -> requests.Response: # type: ignore[override]
if self._request_args["method"] in ["PUT", "POST"]:
if isinstance(data, list):
return self._execute_save_with_chunks(json_data=data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,12 @@ def _describe(self) -> dict[str, Any]:
"save_args": self._save_args,
}

def _load(self) -> list:
def load(self) -> list:
load_path = get_filepath_str(self._filepath, self._protocol)
with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
return list(SeqIO.parse(handle=fs_file, **self._load_args))

def _save(self, data: list) -> None:
def save(self, data: list) -> None:
save_path = get_filepath_str(self._filepath, self._protocol)

with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/dask/csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,12 @@ def _describe(self) -> dict[str, Any]:
"save_args": self._save_args,
}

def _load(self) -> dd.DataFrame:
def load(self) -> dd.DataFrame:
return dd.read_csv(
self._filepath, storage_options=self.fs_args, **self._load_args
)

def _save(self, data: dd.DataFrame) -> None:
def save(self, data: dd.DataFrame) -> None:
data.to_csv(self._filepath, storage_options=self.fs_args, **self._save_args)

def _exists(self) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/dask/parquet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,12 @@ def _describe(self) -> dict[str, Any]:
"save_args": self._save_args,
}

def _load(self) -> dd.DataFrame:
def load(self) -> dd.DataFrame:
return dd.read_parquet(
self._filepath, storage_options=self.fs_args, **self._load_args
)

def _save(self, data: dd.DataFrame) -> None:
def save(self, data: dd.DataFrame) -> None:
self._process_schema()
data.to_parquet(
path=self._filepath, storage_options=self.fs_args, **self._save_args
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def __init__( # noqa: PLR0913
exists_function=self._exists, # type: ignore[arg-type]
)

def _load(self) -> DataFrame | pd.DataFrame:
def load(self) -> DataFrame | pd.DataFrame:
"""Loads the version of data in the format defined in the init
(spark|pandas dataframe)

Expand Down Expand Up @@ -380,7 +380,7 @@ def _save_upsert(self, update_data: DataFrame) -> None:
else:
self._save_append(update_data)

def _save(self, data: DataFrame | pd.DataFrame) -> None:
def save(self, data: DataFrame | pd.DataFrame) -> None:
"""Saves the data based on the write_mode and dataframe_type in the init.
If write_mode is pandas, Spark dataframe is created first.
If schema is provided, data is matched to schema before saving
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/email/message_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,13 @@ def _describe(self) -> dict[str, Any]:
"version": self._version,
}

def _load(self) -> Message:
def load(self) -> Message:
load_path = get_filepath_str(self._get_load_path(), self._protocol)

with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
return Parser(**self._parser_args).parse(fs_file, **self._load_args)

def _save(self, data: Message) -> None:
def save(self, data: Message) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)

with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/geopandas/geojson_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,12 @@ def __init__( # noqa: PLR0913
self._fs_open_args_load = _fs_open_args_load
self._fs_open_args_save = _fs_open_args_save

def _load(self) -> gpd.GeoDataFrame | dict[str, gpd.GeoDataFrame]:
def load(self) -> gpd.GeoDataFrame | dict[str, gpd.GeoDataFrame]:
load_path = get_filepath_str(self._get_load_path(), self._protocol)
with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
return gpd.read_file(fs_file, **self._load_args)

def _save(self, data: gpd.GeoDataFrame) -> None:
def save(self, data: gpd.GeoDataFrame) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)
with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
data.to_file(fs_file, **self._save_args)
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/holoviews/holoviews_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ def _describe(self) -> dict[str, Any]:
"version": self._version,
}

def _load(self) -> NoReturn:
def load(self) -> NoReturn:
raise DatasetError(f"Loading not supported for '{self.__class__.__name__}'")

def _save(self, data: HoloViews) -> None:
def save(self, data: HoloViews) -> None:
bytes_buffer = io.BytesIO()
hv.save(data, bytes_buffer, **self._save_args)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def __init__(
self._dataset_kwargs = dataset_kwargs or {}
self.metadata = metadata

def _load(self):
def load(self):
return load_dataset(self.dataset_name, **self._dataset_kwargs)

def _save(self):
def save(self):
raise NotImplementedError("Not yet implemented")

def _describe(self) -> dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ def __init__(
self._pipeline_kwargs.pop("task", None)
self._pipeline_kwargs.pop("model", None)

def _load(self) -> Pipeline:
def load(self) -> Pipeline:
return pipeline(self._task, model=self._model_name, **self._pipeline_kwargs)

def _save(self, pipeline: Pipeline) -> None:
def save(self, pipeline: Pipeline) -> None:
raise NotImplementedError("Not yet implemented")

def _describe(self) -> dict[str, t.Any]:
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/ibis/table_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def hashable(value):

return cls._connections[key]

def _load(self) -> ir.Table:
def load(self) -> ir.Table:
if self._filepath is not None:
if self._file_format is None:
raise NotImplementedError
Expand All @@ -173,7 +173,7 @@ def _load(self) -> ir.Table:
else:
return self.connection.table(self._table_name)

def _save(self, data: ir.Table) -> None:
def save(self, data: ir.Table) -> None:
if self._table_name is None:
raise DatasetError("Must provide `table_name` for materialization.")

Expand Down
6 changes: 3 additions & 3 deletions kedro-datasets/kedro_datasets/json/json_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,13 @@ def _describe(self) -> dict[str, Any]:
"version": self._version,
}

def _load(self) -> Any:
def load(self) -> Any:
load_path = get_filepath_str(self._get_load_path(), self._protocol)

with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
return json.load(fs_file)

def _save(self, data: Any) -> None:
def save(self, data: Any) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)

with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
Expand Down Expand Up @@ -172,6 +172,6 @@ def preview(self) -> JSONPreview:
Returns:
A string representing the JSON data for previewing.
"""
data = self._load()
data = self.load.__wrapped__(self) # type: ignore[attr-defined]

return JSONPreview(json.dumps(data))
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/matlab/matlab_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _describe(self) -> dict[str, Any]:
"version": self._version,
}

def _load(self) -> np.ndarray:
def load(self) -> np.ndarray:
"""
Access the specific variable in the .mat file, e.g, data['variable_name']
"""
Expand All @@ -134,7 +134,7 @@ def _load(self) -> np.ndarray:
data = io.loadmat(f)
return data

def _save(self, data: np.ndarray) -> None:
def save(self, data: np.ndarray) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)
with self._fs.open(save_path, **self._fs_open_args_save) as f:
io.savemat(f, {"data": data})
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/matplotlib/matplotlib_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,10 @@ def _describe(self) -> dict[str, Any]:
"version": self._version,
}

def _load(self) -> NoReturn:
def load(self) -> NoReturn:
raise DatasetError(f"Loading not supported for '{self.__class__.__name__}'")

def _save(self, data: Figure | (list[Figure] | dict[str, Figure])) -> None:
def save(self, data: Figure | (list[Figure] | dict[str, Figure])) -> None:
save_path = self._get_save_path()

if isinstance(data, (list, dict)) and self._overwrite and self._exists():
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/networkx/gml_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ def __init__( # noqa: PLR0913
**(_fs_open_args_save or {}),
}

def _load(self) -> networkx.Graph:
def load(self) -> networkx.Graph:
load_path = get_filepath_str(self._get_load_path(), self._protocol)
with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
data = networkx.read_gml(fs_file, **self._load_args)
return data

def _save(self, data: networkx.Graph) -> None:
def save(self, data: networkx.Graph) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)
with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
networkx.write_gml(data, fs_file, **self._save_args)
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/networkx/graphml_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,12 @@ def __init__( # noqa: PLR0913
**(_fs_open_args_save or {}),
}

def _load(self) -> networkx.Graph:
def load(self) -> networkx.Graph:
load_path = get_filepath_str(self._get_load_path(), self._protocol)
with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
return networkx.read_graphml(fs_file, **self._load_args)

def _save(self, data: networkx.Graph) -> None:
def save(self, data: networkx.Graph) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)
with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
networkx.write_graphml(data, fs_file, **self._save_args)
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/networkx/json_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,14 @@ def __init__( # noqa: PLR0913
**(_fs_open_args_save or {}),
}

def _load(self) -> networkx.Graph:
def load(self) -> networkx.Graph:
load_path = get_filepath_str(self._get_load_path(), self._protocol)
with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
json_payload = json.load(fs_file)

return networkx.node_link_graph(json_payload, **self._load_args)

def _save(self, data: networkx.Graph) -> None:
def save(self, data: networkx.Graph) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)

json_graph = networkx.node_link_data(data, **self._save_args)
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/pandas/csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _describe(self) -> dict[str, Any]:
"version": self._version,
}

def _load(self) -> pd.DataFrame:
def load(self) -> pd.DataFrame:
load_path = str(self._get_load_path())
if self._protocol == "file":
# file:// protocol seems to misbehave on Windows
Expand All @@ -176,7 +176,7 @@ def _load(self) -> pd.DataFrame:
load_path, storage_options=self._storage_options, **self._load_args
)

def _save(self, data: pd.DataFrame) -> None:
def save(self, data: pd.DataFrame) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)

with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/pandas/deltatable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,10 @@ def get_loaded_version(self) -> int | None:
"""Returns the version of the DeltaTableDataset that is currently loaded."""
return self._delta_table.version() if self._delta_table else None

def _load(self) -> pd.DataFrame:
def load(self) -> pd.DataFrame:
return self._delta_table.to_pandas() if self._delta_table else None

def _save(self, data: pd.DataFrame) -> None:
def save(self, data: pd.DataFrame) -> None:
if self.is_empty_dir:
# first time creation of delta table
write_deltalake(
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/pandas/excel_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def _describe(self) -> dict[str, Any]:
"version": self._version,
}

def _load(self) -> pd.DataFrame | dict[str, pd.DataFrame]:
def load(self) -> pd.DataFrame | dict[str, pd.DataFrame]:
load_path = str(self._get_load_path())
if self._protocol == "file":
# file:// protocol seems to misbehave on Windows
Expand All @@ -236,7 +236,7 @@ def _load(self) -> pd.DataFrame | dict[str, pd.DataFrame]:
load_path, storage_options=self._storage_options, **self._load_args
)

def _save(self, data: pd.DataFrame | dict[str, pd.DataFrame]) -> None:
def save(self, data: pd.DataFrame | dict[str, pd.DataFrame]) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)

with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/pandas/feather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _describe(self) -> dict[str, Any]:
"version": self._version,
}

def _load(self) -> pd.DataFrame:
def load(self) -> pd.DataFrame:
load_path = str(self._get_load_path())
if self._protocol == "file":
# file:// protocol seems to misbehave on Windows
Expand All @@ -175,7 +175,7 @@ def _load(self) -> pd.DataFrame:
load_path, storage_options=self._storage_options, **self._load_args
)

def _save(self, data: pd.DataFrame) -> None:
def save(self, data: pd.DataFrame) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)

with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
Expand Down
8 changes: 4 additions & 4 deletions kedro-datasets/kedro_datasets/pandas/gbq_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _describe(self) -> dict[str, Any]:
"save_args": self._save_args,
}

def _load(self) -> pd.DataFrame:
def load(self) -> pd.DataFrame:
sql = f"select * from {self._dataset}.{self._table_name}" # nosec
self._load_args.setdefault("query_or_table", sql)
return pd_gbq.read_gbq(
Expand All @@ -146,7 +146,7 @@ def _load(self) -> pd.DataFrame:
**self._load_args,
)

def _save(self, data: pd.DataFrame) -> None:
def save(self, data: pd.DataFrame) -> None:
pd_gbq.to_gbq(
dataframe=data,
destination_table=f"{self._dataset}.{self._table_name}",
Expand Down Expand Up @@ -299,7 +299,7 @@ def _describe(self) -> dict[str, Any]:

return desc

def _load(self) -> pd.DataFrame:
def load(self) -> pd.DataFrame:
load_args = copy.deepcopy(self._load_args)

if self._filepath:
Expand All @@ -313,5 +313,5 @@ def _load(self) -> pd.DataFrame:
**load_args,
)

def _save(self, data: None) -> NoReturn:
def save(self, data: None) -> NoReturn:
raise DatasetError("'save' is not supported on GBQQueryDataset")
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/pandas/generic_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _ensure_file_system_target(self) -> None:
f"does not support a filepath target/source."
)

def _load(self) -> pd.DataFrame:
def load(self) -> pd.DataFrame:
self._ensure_file_system_target()

load_path = get_filepath_str(self._get_load_path(), self._protocol)
Expand All @@ -204,7 +204,7 @@ def _load(self) -> pd.DataFrame:
"https://pandas.pydata.org/docs/reference/io.html"
)

def _save(self, data: pd.DataFrame) -> None:
def save(self, data: pd.DataFrame) -> None:
self._ensure_file_system_target()

save_path = get_filepath_str(self._get_save_path(), self._protocol)
Expand Down
Loading