diff --git a/CHANGELOG.md b/CHANGELOG.md index ddb06676e4..d3e43ec59d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ - [#876](https://github.com/helmholtz-analytics/heat/pull/876) Make examples work (Lasso and kNN) - [#894](https://github.com/helmholtz-analytics/heat/pull/894) Change inclusion of license file - [#884](https://github.com/helmholtz-analytics/heat/pull/884) Added capabilities for PyTorch 1.10.0, this is now the recommended version to use. +- [#941](https://github.com/helmholtz-analytics/heat/pull/941) Add function to save data as CSV. ## Bug Fixes - [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension diff --git a/heat/core/io.py b/heat/core/io.py index 27515e5987..c58e1d5fd4 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -2,6 +2,7 @@ from __future__ import annotations import os.path +from math import log10 import numpy as np import torch import warnings @@ -14,6 +15,8 @@ from .communication import Communication, MPI, MPI_WORLD, sanitize_comm from .dndarray import DNDarray +from .manipulations import hsplit, vsplit +from .statistics import max as smax, min as smin from .stride_tricks import sanitize_axis from .types import datatype @@ -23,7 +26,7 @@ __NETCDF_EXTENSIONS = frozenset([".nc", ".nc4", "netcdf"]) __NETCDF_DIM_TEMPLATE = "{}_dim_{}" -__all__ = ["load", "load_csv", "save", "supports_hdf5", "supports_netcdf"] +__all__ = ["load", "load_csv", "save_csv", "save", "supports_hdf5", "supports_netcdf"] try: import h5py @@ -143,7 +146,9 @@ def load_hdf5( return DNDarray(data, gshape, dtype, split, device, comm, balanced) - def save_hdf5(data: str, path: str, dataset: str, mode: str = "w", **kwargs: Dict[str, object]): + def save_hdf5( + data: DNDarray, path: str, dataset: str, mode: str = "w", **kwargs: Dict[str, object] + ): """ Saves ``data`` to an HDF5 file. Attempts to utilize parallel I/O if possible. @@ -344,7 +349,7 @@ def load_netcdf( return DNDarray(data, gshape, dtype, split, device, comm, balanced) def save_netcdf( - data: str, + data: DNDarray, path: str, variable: str, mode: str = "w", @@ -918,6 +923,140 @@ def load_csv( return resulting_tensor +def save_csv( + data: DNDarray, + path: str, + header_lines: Iterable[str] = None, + sep: str = ",", + decimals: int = -1, + encoding: str = "utf-8", + comm: Optional[Communication] = None, + truncate: bool = True, +): + """ + Saves data to CSV files. Only 2D data, all split axes. + + Parameters + ---------- + data : DNDarray + The DNDarray to be saved to CSV. + path : str + The path as a string. + header_lines : Iterable[str] + Optional iterable of str to prepend at the beginning of the file. No + pound sign or any other comment marker will be inserted. + sep : str + The separator character used in this CSV. + decimals: int + Number of digits after decimal point. + encoding : str + The encoding to be used in this CSV. + comm : Optional[Communication] + An optional object of type Communication to be used. + truncate : bool + Whether to truncate an existing file before writing, i.e. fully overwrite it. + The sane default is True. Setting it to False will not shorten files if + needed and thus may leave garbage at the end of existing files. + """ + if not isinstance(path, str): + raise TypeError("path must be str, not {}".format(type(path))) + if not isinstance(sep, str): + raise TypeError("separator must be str, not {}".format(type(sep))) + # check this to allow None + if not isinstance(header_lines, Iterable) and header_lines is not None: + raise TypeError("header_lines must Iterable[str], not {}".format(type(header_lines))) + if data.split not in [None, 0, 1]: + raise ValueError("split must be in [None, 0, 1], but is {}".format(data.split)) + + if os.path.exists(path) and truncate: + if data.comm.rank == 0: + os.truncate(path, 0) + # avoid truncating and writing at the same time + data.comm.handle.Barrier() + + amode = MPI.MODE_WRONLY | MPI.MODE_CREATE + csv_out = MPI.File.Open(data.comm.handle, path, amode) + + # will be needed as an additional offset later + hl_displacement = 0 + if header_lines is not None: + hl_displacement = sum(len(hl) for hl in header_lines) + # count additions everywhere, but write only on rank 0, avoiding reduce op to share final hl_displacement + for hl in header_lines: + if not hl.endswith("\n"): + hl = hl + "\n" + hl_displacement = hl_displacement + 1 + if data.comm.rank == 0 and header_lines: + csv_out.Write(hl.encode(encoding)) + + # formatting and element width + data_min = smin(data).item() # at least min is used twice, so cache it here + data_max = smax(data).item() + sign = 1 if data_min < 0 else 0 + if abs(data_max) > 0 or abs(data_min) > 0: + pre_point_digits = int(log10(max(abs(data_max), abs(data_min)))) + 1 + else: + pre_point_digits = 1 + + dec_sep = 1 + fmt = "" + if types.issubdtype(data.dtype, types.integer): + decimals = 0 + dec_sep = 0 + if sign == 1: + fmt = "{: %dd}" % (pre_point_digits + 1) + else: + fmt = "{:%dd}" % (pre_point_digits) + elif types.issubdtype(data.dtype, types.floating): + if decimals == -1: + decimals = 7 if data.dtype is types.float32 else 15 + if sign == 1: + fmt = "{: %d.%df}" % (pre_point_digits + decimals + 2, decimals) + else: + fmt = "{:%d.%df}" % (pre_point_digits + decimals + 1, decimals) + + # sign + decimal separator + pre separator digits + decimals (post separator) + item_size = decimals + dec_sep + sign + pre_point_digits + # each item is one position larger than its representation, either b/c of separator or line break + row_width = item_size + 1 + if len(data.shape) > 1: + row_width = data.shape[1] * (item_size + 1) + + offset = hl_displacement # all splits + if data.split == 0: + _, displs = data.counts_displs() + offset = offset + displs[data.comm.rank] * row_width + elif data.split == 1: + _, displs = data.counts_displs() + offset = offset + displs[data.comm.rank] * (item_size + 1) + + for i in range(data.lshape[0]): + # if lshape is of the form (x,), then there will only be a single element per row + if len(data.lshape) == 1: + row = fmt.format(data.larray[i]) + else: + if data.lshape[1] == 0: + break + row = sep.join(fmt.format(item) for item in data.larray[i]) + + if ( + data.split is None + or data.split == 0 + or displs[data.comm.rank] + data.lshape[1] == data.shape[1] + ): + row = row + "\n" + else: + row = row + sep + + if data.split is not None or data.comm.rank == 0: + csv_out.Write_at(offset, row.encode("utf-8")) + + offset = offset + row_width + + csv_out.Close() + data.comm.handle.Barrier() + + def save( data: DNDarray, path: str, *args: Optional[List[object]], **kwargs: Optional[Dict[str, object]] ): @@ -962,6 +1101,8 @@ def save( save_netcdf(data, path, *args, **kwargs) else: raise RuntimeError("netcdf is required for file extension {}".format(extension)) + elif extension in __CSV_EXTENSION: + save_csv(data, path, *args, **kwargs) else: raise ValueError("Unsupported file extension {}".format(extension)) diff --git a/heat/core/tests/test_io.py b/heat/core/tests/test_io.py index 24afb25a86..47ad519cad 100644 --- a/heat/core/tests/test_io.py +++ b/heat/core/tests/test_io.py @@ -1,6 +1,7 @@ import numpy as np import os import torch +import tempfile import heat as ht from .test_suites.basic_test import TestCase @@ -22,6 +23,7 @@ def setUpClass(cls): # load comparison data from csv cls.CSV_PATH = os.path.join(os.getcwd(), "heat/datasets/iris.csv") + cls.CSV_OUT_PATH = pwd + "/test.csv" cls.IRIS = ( torch.from_numpy(np.loadtxt(cls.CSV_PATH, delimiter=";")) .float() @@ -141,6 +143,58 @@ def test_load_csv(self): with self.assertRaises(TypeError): ht.load_csv(self.CSV_PATH, header_lines="3", sep=";", split=0) + def test_save_csv(self): + for rnd_type in [ + (ht.random.randint, ht.types.int32), + (ht.random.randint, ht.types.int64), + (ht.random.rand, ht.types.float32), + (ht.random.rand, ht.types.float64), + ]: + for separator in [",", ";", "|"]: + for split in [None, 0, 1]: + for headers in [None, ["# This", "# is a", "# test."], ["an,ordinary,header"]]: + for shape in [(1, 1), (10, 10), (20, 1), (1, 20), (25, 4), (4, 25)]: + if rnd_type[0] == ht.random.randint: + data = rnd_type[0]( + -100, 1000, size=shape, dtype=rnd_type[1], split=split + ) + else: + data = rnd_type[0]( + shape[0], + shape[1], + split=split, + dtype=rnd_type[1], + ) + + if data.comm.rank == 0: + tmpfile = tempfile.NamedTemporaryFile( + prefix="test_io_", suffix=".csv", delete=False + ) + tmpfile.close() + filename = tmpfile.name + else: + filename = None + filename = data.comm.handle.bcast(filename, root=0) + + data.save( + filename, + header_lines=headers, + sep=separator, + ) + comparison = ht.load_csv( + filename, + # split=split, + header_lines=0 if headers is None else len(headers), + sep=separator, + ) + resid = data - comparison + self.assertTrue( + ht.max(resid).item() < 0.00001 and ht.min(resid).item() > -0.00001 + ) + data.comm.handle.Barrier() + if data.comm.rank == 0: + os.unlink(filename) + def test_load_exception(self): # correct extension, file does not exist if ht.io.supports_hdf5():