Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

750 save csv v2 #941

Merged
merged 23 commits into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
- [#876](https://github.com/helmholtz-analytics/heat/pull/876) Make examples work (Lasso and kNN)
- [#894](https://github.com/helmholtz-analytics/heat/pull/894) Change inclusion of license file
- [#884](https://github.com/helmholtz-analytics/heat/pull/884) Added capabilities for PyTorch 1.10.0, this is now the recommended version to use.
- [#941](https://github.com/helmholtz-analytics/heat/pull/941) Add function to save data as CSV.

## Bug Fixes
- [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension
Expand Down
136 changes: 133 additions & 3 deletions heat/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import os.path
from math import log10
import numpy as np
import torch
import warnings
Expand All @@ -14,6 +15,8 @@

from .communication import Communication, MPI, MPI_WORLD, sanitize_comm
from .dndarray import DNDarray
from .manipulations import hsplit, vsplit
from .statistics import max as smax, min as smin
from .stride_tricks import sanitize_axis
from .types import datatype

Expand All @@ -23,7 +26,7 @@
__NETCDF_EXTENSIONS = frozenset([".nc", ".nc4", "netcdf"])
__NETCDF_DIM_TEMPLATE = "{}_dim_{}"

__all__ = ["load", "load_csv", "save", "supports_hdf5", "supports_netcdf"]
__all__ = ["load", "load_csv", "save_csv", "save", "supports_hdf5", "supports_netcdf"]

try:
import h5py
Expand Down Expand Up @@ -143,7 +146,9 @@ def load_hdf5(

return DNDarray(data, gshape, dtype, split, device, comm, balanced)

def save_hdf5(data: str, path: str, dataset: str, mode: str = "w", **kwargs: Dict[str, object]):
def save_hdf5(
data: DNDarray, path: str, dataset: str, mode: str = "w", **kwargs: Dict[str, object]
):
"""
Saves ``data`` to an HDF5 file. Attempts to utilize parallel I/O if possible.

Expand Down Expand Up @@ -344,7 +349,7 @@ def load_netcdf(
return DNDarray(data, gshape, dtype, split, device, comm, balanced)

def save_netcdf(
data: str,
data: DNDarray,
path: str,
variable: str,
mode: str = "w",
Expand Down Expand Up @@ -918,6 +923,129 @@ def load_csv(
return resulting_tensor


def save_csv(
data: DNDarray,
path: str,
header_lines: Iterable[str] = None,
sep: str = ",",
decimals: int = -1,
encoding: str = "utf-8",
comm: Optional[Communication] = None,
truncate: bool = True,
):
"""
Saves data to CSV files. Only 2D data, all split axes.

Parameters
----------
data : DNDarray
The DNDarray to be saved to CSV.
path : str
The path as a string.
header_lines : Iterable[str]
Optional iterable of str to prepend at the beginning of the file. No
pound sign or any other comment marker will be inserted.
sep : str
The separator character used in this CSV.
decimals: int
Number of digits after decimal point.
encoding : str
The encoding to be used in this CSV.
comm : Optional[Communication]
An optional object of type Communication to be used.
truncate : bool
Whether to truncate an existing file before writing, i.e. fully overwrite it.
The sane default is True. Setting it to False will not shorten files if
needed and thus may leave garbage at the end of existing files.
"""
if not isinstance(path, str):
raise TypeError("path must be str, not {}".format(type(path)))
if not isinstance(sep, str):
raise TypeError("separator must be str, not {}".format(type(sep)))
# check this to allow None
if not isinstance(header_lines, Iterable) and header_lines is not None:
raise TypeError("header_lines must Iterable[str], not {}".format(type(header_lines)))
if data.split not in [None, 0, 1]:
raise ValueError("split must be in [None, 0, 1], but is {}".format(data.split))

if os.path.exists(path) and truncate:
if data.comm.rank == 0:
os.truncate(path, 0)
# avoid truncating and writing at the same time
data.comm.handle.Barrier()

amode = MPI.MODE_WRONLY | MPI.MODE_CREATE
csv_out = MPI.File.Open(data.comm.handle, path, amode)

# will be needed as an additional offset later
hl_displacement = 0
if header_lines is not None:
hl_displacement = sum(len(hl) for hl in header_lines)
# count additions everywhere, but write only on rank 0, avoiding reduce op to share final hl_displacement
for hl in header_lines:
if not hl.endswith("\n"):
hl = hl + "\n"
hl_displacement = hl_displacement + 1
if data.comm.rank == 0 and header_lines:
csv_out.Write(hl.encode(encoding))

# formatting and element width
data_min = smin(data).item() # at least min is used twice, so cache it here
data_max = smax(data).item()
sign = 1 if data_min < 0 else 0
pre_point_digits = int(log10(max(data_max, abs(data_min)))) + 1

dec_sep = 1
fmt = ""
if types.issubdtype(data.dtype, types.integer):
decimals = 0
dec_sep = 0
if sign == 1:
fmt = "{: %dd}" % (pre_point_digits + 1)
else:
fmt = "{:%dd}" % (pre_point_digits)
elif types.issubdtype(data.dtype, types.floating):
if decimals == -1:
decimals = 7 if data.dtype is types.float32 else 15
if sign == 1:
fmt = "{: %d.%df}" % (pre_point_digits + decimals + 2, decimals)
else:
fmt = "{:%d.%df}" % (pre_point_digits + decimals + 1, decimals)

# sign + decimal separator + pre separator digits + decimals (post separator)
item_size = decimals + dec_sep + sign + pre_point_digits
# each item is one position larger than its representation, either b/c of separator or line break
row_width = item_size + 1
if len(data.shape) > 1:
row_width = data.shape[1] * (item_size + 1)

offset = hl_displacement # all splits
if data.split == 0:
_, displs = data.counts_displs()
offset = offset + displs[data.comm.rank] * row_width
elif data.split == 1:
_, displs = data.counts_displs()
offset = offset + displs[data.comm.rank] * (item_size + 1)

for i in range(data.lshape[0]):
# if lshape is of the form (x,), then there will only be a single element per row
if len(data.lshape) == 1:
row = fmt.format(data.larray[i])
else:
row = sep.join(fmt.format(item) for item in data.larray[i])

if data.split is None or data.split == 0 or data.comm.rank == (data.comm.size - 1):
row = row + "\n"
else:
row = row + sep

csv_out.Write_at(offset, row.encode("utf-8"))
offset = offset + row_width

csv_out.Close()
data.comm.handle.Barrier()


def save(
data: DNDarray, path: str, *args: Optional[List[object]], **kwargs: Optional[Dict[str, object]]
):
Expand Down Expand Up @@ -962,6 +1090,8 @@ def save(
save_netcdf(data, path, *args, **kwargs)
else:
raise RuntimeError("netcdf is required for file extension {}".format(extension))
elif extension in __CSV_EXTENSION:
save_csv(data, path, *args, **kwargs)
else:
raise ValueError("Unsupported file extension {}".format(extension))

Expand Down
75 changes: 75 additions & 0 deletions heat/core/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def setUpClass(cls):

# load comparison data from csv
cls.CSV_PATH = os.path.join(os.getcwd(), "heat/datasets/iris.csv")
cls.CSV_OUT_PATH = pwd + "/test.csv"
cls.IRIS = (
torch.from_numpy(np.loadtxt(cls.CSV_PATH, delimiter=";"))
.float()
Expand Down Expand Up @@ -141,6 +142,80 @@ def test_load_csv(self):
with self.assertRaises(TypeError):
ht.load_csv(self.CSV_PATH, header_lines="3", sep=";", split=0)

def test_save_csv(self):
if os.path.exists(self.CSV_OUT_PATH):
os.truncate(self.CSV_OUT_PATH, 0)
# {signed, unsigned} x {float, integer} x {',', '|', ';'} x {None, 0, 1} x {header_lines, no_header_lines}
data = ht.arange(100).reshape(20, 5)
ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None, sep=";")
comparison = ht.load_csv(self.CSV_OUT_PATH, sep=";", dtype=data.dtype)
if data.comm.rank == 0:
self.assertTrue((data.larray == comparison.larray).all().item())

if os.path.exists(self.CSV_OUT_PATH):
os.truncate(self.CSV_OUT_PATH, 0)
header_lines = ["col1,col2,col3,col4,col5", "second line\n"]
ht.save_csv(data, self.CSV_OUT_PATH, header_lines=header_lines, sep=",")
comparison = ht.load_csv(self.CSV_OUT_PATH, sep=",", dtype=data.dtype, header_lines=2)
if data.comm.rank == 0:
self.assertTrue((data.larray == comparison.larray).all().item())

if os.path.exists(self.CSV_OUT_PATH):
os.truncate(self.CSV_OUT_PATH, 0)
data = data - 50
ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None)
comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype)
if data.comm.rank == 0:
self.assertTrue((data.larray == comparison.larray).all().item())

if os.path.exists(self.CSV_OUT_PATH):
os.truncate(self.CSV_OUT_PATH, 0)
data = ht.random.rand(100.0, split=0).reshape(20, 5)
ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None)
comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype)
self.assertTrue(
ht.max(data - comparison).item() < 0.0001
) # need a parameter to control precision

if os.path.exists(self.CSV_OUT_PATH):
os.truncate(self.CSV_OUT_PATH, 0)
data = ht.random.rand(100.0, split=0).reshape(50, 2)
ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None)
comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype)
self.assertTrue(ht.max(data - comparison).item() < 0.0001)

if os.path.exists(self.CSV_OUT_PATH):
os.truncate(self.CSV_OUT_PATH, 0)
data = ht.arange(100.0, split=0, dtype=ht.float64).reshape(50, 2)
ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None)
comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype)
self.assertTrue(ht.max(data - comparison).item() < 0.0000001)

# 1D vector (25,)
# + do not truncate this time, but rely on save_csv
data = ht.arange(25, split=0, dtype=ht.int)
ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None)
comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype)
# read_csv always reads in (n, m), so need to reshape to (n,)
self.assertTrue(ht.max(data - comparison.reshape(data.shape)).item() == 0)

# 1D vector (1, 25)
data = ht.arange(25, split=0, dtype=ht.int).reshape(1, 25)
ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None)
comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype)
# read_csv always reads in (n, m), so need to reshape to (n,)
self.assertTrue(ht.max(data - comparison.reshape(data.shape)).item() == 0)

data = ht.random.randint(0, 100, (25, 4), split=1)
data.save(self.CSV_OUT_PATH)
comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype)
self.assertTrue(ht.max(data - comparison.reshape(data.shape)).item() < 0.0000001)

data = ht.random.rand(250, 10, dtype=ht.float64, split=0)
data.save(self.CSV_OUT_PATH)
comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype)
self.assertTrue(ht.max(data - comparison.reshape(data.shape)).item() < 0.0000001)

def test_load_exception(self):
# correct extension, file does not exist
if ht.io.supports_hdf5():
Expand Down