Skip to content

Commit

Permalink
feat(datasets)!: expose load and save publicly
Browse files Browse the repository at this point in the history
Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu>
  • Loading branch information
deepyaman committed Sep 10, 2024
1 parent 3da39f7 commit 2971fe2
Show file tree
Hide file tree
Showing 60 changed files with 136 additions and 136 deletions.
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/api/api_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _execute_request(self, session: Session) -> requests.Response:

return response

def _load(self) -> requests.Response:
def load(self) -> requests.Response:
if self._request_args["method"] == "GET":
with sessions.Session() as session:
return self._execute_request(session)
Expand Down Expand Up @@ -219,7 +219,7 @@ def _execute_save_request(self, json_data: Any) -> requests.Response:
raise DatasetError("Failed to connect to the remote server") from exc
return response

def _save(self, data: Any) -> requests.Response: # type: ignore[override]
def save(self, data: Any) -> requests.Response: # type: ignore[override]
if self._request_args["method"] in ["PUT", "POST"]:
if isinstance(data, list):
return self._execute_save_with_chunks(json_data=data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,12 @@ def _describe(self) -> dict[str, Any]:
"save_args": self._save_args,
}

def _load(self) -> list:
def load(self) -> list:
load_path = get_filepath_str(self._filepath, self._protocol)
with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
return list(SeqIO.parse(handle=fs_file, **self._load_args))

def _save(self, data: list) -> None:
def save(self, data: list) -> None:
save_path = get_filepath_str(self._filepath, self._protocol)

with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/dask/csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,12 @@ def _describe(self) -> dict[str, Any]:
"save_args": self._save_args,
}

def _load(self) -> dd.DataFrame:
def load(self) -> dd.DataFrame:
return dd.read_csv(
self._filepath, storage_options=self.fs_args, **self._load_args
)

def _save(self, data: dd.DataFrame) -> None:
def save(self, data: dd.DataFrame) -> None:
data.to_csv(self._filepath, storage_options=self.fs_args, **self._save_args)

def _exists(self) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/dask/parquet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,12 @@ def _describe(self) -> dict[str, Any]:
"save_args": self._save_args,
}

def _load(self) -> dd.DataFrame:
def load(self) -> dd.DataFrame:
return dd.read_parquet(
self._filepath, storage_options=self.fs_args, **self._load_args
)

def _save(self, data: dd.DataFrame) -> None:
def save(self, data: dd.DataFrame) -> None:
self._process_schema()
data.to_parquet(
path=self._filepath, storage_options=self.fs_args, **self._save_args
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def __init__( # noqa: PLR0913
exists_function=self._exists, # type: ignore[arg-type]
)

def _load(self) -> DataFrame | pd.DataFrame:
def load(self) -> DataFrame | pd.DataFrame:
"""Loads the version of data in the format defined in the init
(spark|pandas dataframe)
Expand Down Expand Up @@ -380,7 +380,7 @@ def _save_upsert(self, update_data: DataFrame) -> None:
else:
self._save_append(update_data)

def _save(self, data: DataFrame | pd.DataFrame) -> None:
def save(self, data: DataFrame | pd.DataFrame) -> None:
"""Saves the data based on the write_mode and dataframe_type in the init.
If write_mode is pandas, Spark dataframe is created first.
If schema is provided, data is matched to schema before saving
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/email/message_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,13 @@ def _describe(self) -> dict[str, Any]:
"version": self._version,
}

def _load(self) -> Message:
def load(self) -> Message:
load_path = get_filepath_str(self._get_load_path(), self._protocol)

with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
return Parser(**self._parser_args).parse(fs_file, **self._load_args)

def _save(self, data: Message) -> None:
def save(self, data: Message) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)

with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/geopandas/geojson_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,12 @@ def __init__( # noqa: PLR0913
self._fs_open_args_load = _fs_open_args_load
self._fs_open_args_save = _fs_open_args_save

def _load(self) -> gpd.GeoDataFrame | dict[str, gpd.GeoDataFrame]:
def load(self) -> gpd.GeoDataFrame | dict[str, gpd.GeoDataFrame]:
load_path = get_filepath_str(self._get_load_path(), self._protocol)
with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
return gpd.read_file(fs_file, **self._load_args)

def _save(self, data: gpd.GeoDataFrame) -> None:
def save(self, data: gpd.GeoDataFrame) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)
with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
data.to_file(fs_file, **self._save_args)
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/holoviews/holoviews_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ def _describe(self) -> dict[str, Any]:
"version": self._version,
}

def _load(self) -> NoReturn:
def load(self) -> NoReturn:
raise DatasetError(f"Loading not supported for '{self.__class__.__name__}'")

def _save(self, data: HoloViews) -> None:
def save(self, data: HoloViews) -> None:
bytes_buffer = io.BytesIO()
hv.save(data, bytes_buffer, **self._save_args)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def __init__(
self._dataset_kwargs = dataset_kwargs or {}
self.metadata = metadata

def _load(self):
def load(self):
return load_dataset(self.dataset_name, **self._dataset_kwargs)

def _save(self):
def save(self):
raise NotImplementedError("Not yet implemented")

def _describe(self) -> dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ def __init__(
self._pipeline_kwargs.pop("task", None)
self._pipeline_kwargs.pop("model", None)

def _load(self) -> Pipeline:
def load(self) -> Pipeline:
return pipeline(self._task, model=self._model_name, **self._pipeline_kwargs)

def _save(self, pipeline: Pipeline) -> None:
def save(self, pipeline: Pipeline) -> None:
raise NotImplementedError("Not yet implemented")

def _describe(self) -> dict[str, t.Any]:
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/ibis/table_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def hashable(value):

return cls._connections[key]

def _load(self) -> ir.Table:
def load(self) -> ir.Table:
if self._filepath is not None:
if self._file_format is None:
raise NotImplementedError
Expand All @@ -173,7 +173,7 @@ def _load(self) -> ir.Table:
else:
return self.connection.table(self._table_name)

def _save(self, data: ir.Table) -> None:
def save(self, data: ir.Table) -> None:
if self._table_name is None:
raise DatasetError("Must provide `table_name` for materialization.")

Expand Down
6 changes: 3 additions & 3 deletions kedro-datasets/kedro_datasets/json/json_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,13 @@ def _describe(self) -> dict[str, Any]:
"version": self._version,
}

def _load(self) -> Any:
def load(self) -> Any:
load_path = get_filepath_str(self._get_load_path(), self._protocol)

with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
return json.load(fs_file)

def _save(self, data: Any) -> None:
def save(self, data: Any) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)

with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
Expand Down Expand Up @@ -172,6 +172,6 @@ def preview(self) -> JSONPreview:
Returns:
A string representing the JSON data for previewing.
"""
data = self._load()
data = self.load.__wrapped__(self) # type: ignore[attr-defined]

return JSONPreview(json.dumps(data))
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/matlab/matlab_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _describe(self) -> dict[str, Any]:
"version": self._version,
}

def _load(self) -> np.ndarray:
def load(self) -> np.ndarray:
"""
Access the specific variable in the .mat file, e.g, data['variable_name']
"""
Expand All @@ -134,7 +134,7 @@ def _load(self) -> np.ndarray:
data = io.loadmat(f)
return data

def _save(self, data: np.ndarray) -> None:
def save(self, data: np.ndarray) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)
with self._fs.open(save_path, **self._fs_open_args_save) as f:
io.savemat(f, {"data": data})
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/matplotlib/matplotlib_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,10 @@ def _describe(self) -> dict[str, Any]:
"version": self._version,
}

def _load(self) -> NoReturn:
def load(self) -> NoReturn:
raise DatasetError(f"Loading not supported for '{self.__class__.__name__}'")

def _save(self, data: Figure | (list[Figure] | dict[str, Figure])) -> None:
def save(self, data: Figure | (list[Figure] | dict[str, Figure])) -> None:
save_path = self._get_save_path()

if isinstance(data, (list, dict)) and self._overwrite and self._exists():
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/networkx/gml_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ def __init__( # noqa: PLR0913
**(_fs_open_args_save or {}),
}

def _load(self) -> networkx.Graph:
def load(self) -> networkx.Graph:
load_path = get_filepath_str(self._get_load_path(), self._protocol)
with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
data = networkx.read_gml(fs_file, **self._load_args)
return data

def _save(self, data: networkx.Graph) -> None:
def save(self, data: networkx.Graph) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)
with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
networkx.write_gml(data, fs_file, **self._save_args)
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/networkx/graphml_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,12 @@ def __init__( # noqa: PLR0913
**(_fs_open_args_save or {}),
}

def _load(self) -> networkx.Graph:
def load(self) -> networkx.Graph:
load_path = get_filepath_str(self._get_load_path(), self._protocol)
with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
return networkx.read_graphml(fs_file, **self._load_args)

def _save(self, data: networkx.Graph) -> None:
def save(self, data: networkx.Graph) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)
with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
networkx.write_graphml(data, fs_file, **self._save_args)
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/networkx/json_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,14 @@ def __init__( # noqa: PLR0913
**(_fs_open_args_save or {}),
}

def _load(self) -> networkx.Graph:
def load(self) -> networkx.Graph:
load_path = get_filepath_str(self._get_load_path(), self._protocol)
with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
json_payload = json.load(fs_file)

return networkx.node_link_graph(json_payload, **self._load_args)

def _save(self, data: networkx.Graph) -> None:
def save(self, data: networkx.Graph) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)

json_graph = networkx.node_link_data(data, **self._save_args)
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/pandas/csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _describe(self) -> dict[str, Any]:
"version": self._version,
}

def _load(self) -> pd.DataFrame:
def load(self) -> pd.DataFrame:
load_path = str(self._get_load_path())
if self._protocol == "file":
# file:// protocol seems to misbehave on Windows
Expand All @@ -176,7 +176,7 @@ def _load(self) -> pd.DataFrame:
load_path, storage_options=self._storage_options, **self._load_args
)

def _save(self, data: pd.DataFrame) -> None:
def save(self, data: pd.DataFrame) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)

with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/pandas/deltatable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,10 @@ def get_loaded_version(self) -> int | None:
"""Returns the version of the DeltaTableDataset that is currently loaded."""
return self._delta_table.version() if self._delta_table else None

def _load(self) -> pd.DataFrame:
def load(self) -> pd.DataFrame:
return self._delta_table.to_pandas() if self._delta_table else None

def _save(self, data: pd.DataFrame) -> None:
def save(self, data: pd.DataFrame) -> None:
if self.is_empty_dir:
# first time creation of delta table
write_deltalake(
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/pandas/excel_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def _describe(self) -> dict[str, Any]:
"version": self._version,
}

def _load(self) -> pd.DataFrame | dict[str, pd.DataFrame]:
def load(self) -> pd.DataFrame | dict[str, pd.DataFrame]:
load_path = str(self._get_load_path())
if self._protocol == "file":
# file:// protocol seems to misbehave on Windows
Expand All @@ -236,7 +236,7 @@ def _load(self) -> pd.DataFrame | dict[str, pd.DataFrame]:
load_path, storage_options=self._storage_options, **self._load_args
)

def _save(self, data: pd.DataFrame | dict[str, pd.DataFrame]) -> None:
def save(self, data: pd.DataFrame | dict[str, pd.DataFrame]) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)

with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/pandas/feather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _describe(self) -> dict[str, Any]:
"version": self._version,
}

def _load(self) -> pd.DataFrame:
def load(self) -> pd.DataFrame:
load_path = str(self._get_load_path())
if self._protocol == "file":
# file:// protocol seems to misbehave on Windows
Expand All @@ -175,7 +175,7 @@ def _load(self) -> pd.DataFrame:
load_path, storage_options=self._storage_options, **self._load_args
)

def _save(self, data: pd.DataFrame) -> None:
def save(self, data: pd.DataFrame) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)

with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
Expand Down
8 changes: 4 additions & 4 deletions kedro-datasets/kedro_datasets/pandas/gbq_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _describe(self) -> dict[str, Any]:
"save_args": self._save_args,
}

def _load(self) -> pd.DataFrame:
def load(self) -> pd.DataFrame:
sql = f"select * from {self._dataset}.{self._table_name}" # nosec
self._load_args.setdefault("query", sql)
return pd.read_gbq(
Expand All @@ -145,7 +145,7 @@ def _load(self) -> pd.DataFrame:
**self._load_args,
)

def _save(self, data: pd.DataFrame) -> None:
def save(self, data: pd.DataFrame) -> None:
data.to_gbq(
f"{self._dataset}.{self._table_name}",
project_id=self._project_id,
Expand Down Expand Up @@ -297,7 +297,7 @@ def _describe(self) -> dict[str, Any]:

return desc

def _load(self) -> pd.DataFrame:
def load(self) -> pd.DataFrame:
load_args = copy.deepcopy(self._load_args)

if self._filepath:
Expand All @@ -311,5 +311,5 @@ def _load(self) -> pd.DataFrame:
**load_args,
)

def _save(self, data: None) -> NoReturn:
def save(self, data: None) -> NoReturn:
raise DatasetError("'save' is not supported on GBQQueryDataset")
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/pandas/generic_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _ensure_file_system_target(self) -> None:
f"does not support a filepath target/source."
)

def _load(self) -> pd.DataFrame:
def load(self) -> pd.DataFrame:
self._ensure_file_system_target()

load_path = get_filepath_str(self._get_load_path(), self._protocol)
Expand All @@ -204,7 +204,7 @@ def _load(self) -> pd.DataFrame:
"https://pandas.pydata.org/docs/reference/io.html"
)

def _save(self, data: pd.DataFrame) -> None:
def save(self, data: pd.DataFrame) -> None:
self._ensure_file_system_target()

save_path = get_filepath_str(self._get_save_path(), self._protocol)
Expand Down
Loading

0 comments on commit 2971fe2

Please sign in to comment.