From 7fc2e0e490e7f1b68a93f2c8b05ed2ed488ad904 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Hagemeier?= Date: Wed, 23 Mar 2022 11:48:09 +0100 Subject: [PATCH 01/21] An initial implementation for save_csv it has problems when running with mpirun, so will try a different approach on a different branch. --- heat/core/io.py | 85 ++++++++++++++++++++++++++++++++++++-- heat/core/tests/test_io.py | 8 ++++ 2 files changed, 90 insertions(+), 3 deletions(-) diff --git a/heat/core/io.py b/heat/core/io.py index b390747878..f4f23cae65 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -5,6 +5,7 @@ import numpy as np import torch import warnings +import sys from typing import Dict, Iterable, List, Optional, Tuple, Union @@ -14,6 +15,7 @@ from .communication import Communication, MPI, MPI_WORLD, sanitize_comm from .dndarray import DNDarray +from .manipulations import hsplit, vsplit from .stride_tricks import sanitize_axis from .types import datatype @@ -23,7 +25,7 @@ __NETCDF_EXTENSIONS = frozenset([".nc", ".nc4", "netcdf"]) __NETCDF_DIM_TEMPLATE = "{}_dim_{}" -__all__ = ["load", "load_csv", "save", "supports_hdf5", "supports_netcdf"] +__all__ = ["load", "load_csv", "save_csv", "save", "supports_hdf5", "supports_netcdf"] try: import h5py @@ -144,7 +146,9 @@ def load_hdf5( return DNDarray(data, gshape, dtype, split, device, comm, balanced) - def save_hdf5(data: str, path: str, dataset: str, mode: str = "w", **kwargs: Dict[str, object]): + def save_hdf5( + data: DNDarray, path: str, dataset: str, mode: str = "w", **kwargs: Dict[str, object] + ): """ Saves ``data`` to an HDF5 file. Attempts to utilize parallel I/O if possible. @@ -346,7 +350,7 @@ def load_netcdf( return DNDarray(data, gshape, dtype, split, device, comm, balanced) def save_netcdf( - data: str, + data: DNDarray, path: str, variable: str, mode: str = "w", @@ -920,6 +924,81 @@ def load_csv( return resulting_tensor +def save_csv( + data: DNDarray, + path: str, + header_lines: Iterable[str] = None, + sep: str = ",", + dtype: datatype = types.float32, + encoding: str = "utf-8", + comm: Optional[Communication] = None, +): + """ + Saves data to CSV files + + + """ + print("Saving CSV on rank %d of %d" % (data.comm.rank, data.comm.size)) + if data.comm.rank == 0 and header_lines: + with open(path, "w") as csv_out: + for hl in header_lines: + print(hl, file=csv_out) + csv_out.flush() + + print("Waiting for header writing on rank %d of %d" % (data.comm.rank, data.comm.size)) + data.comm.Barrier() + print("Waited for header writing on rank %d of %d" % (data.comm.rank, data.comm.size)) + print("Data split == %s on rank %d" % (data.split, data.comm.rank)) + + # split None can be written out by rank 0 + if data.split is None: + print("Split is None", file=sys.stderr) + if data.comm.rank == 0: + with open(path, "w") as csv_out: + for row in vsplit(data, data.lshape[0]): + print( + sep.join(map(lambda dnda: str(dnda.item()), hsplit(row, data.lshape[1]))), + file=csv_out, + ) + csv_out.flush() + print("Done on rank %d" % (data.comm.rank), file=sys.stderr) + elif data.split == 0: + print("Split is 0", file=sys.stderr) + if data.comm.rank == 0: + print("On rank %d" % (data.comm.rank), file=sys.stderr) + with open(path, "w") as csv_out: + for row in vsplit(data, data.lshape[0]): + print( + sep.join(map(lambda dnda: str(dnda.item()), hsplit(row, data.lshape[1]))), + file=csv_out, + ) + csv_out.flush() + if data.comm.rank < data.comm.size - 1: + print( + "Rank %d: sending %d to rank %d" + % (data.comm.rank, data.comm.rank + 1, data.comm.rank + 1), + file=sys.stderr, + ) + data.comm.Send([None, 0, MPI.INT], dest=data.comm.rank + 1) + else: + rank = -1 + print("Rank %d: waiting for my turn" % (data.comm.rank), file=sys.stderr) + data.comm.Recv([None, 0, MPI.INT], source=data.comm.rank - 1) + if rank != data.comm.rank: + raise IndexError("Wrong write sequence for CSV file.") + with open(path, "w") as csv_out: + for row in vsplit(data, data.lshape[0]): + print( + sep.join(map(lambda dnda: str(dnda.item()), hsplit(row, data.lshape[1]))), + file=csv_out, + ) + csv_out.flush() + if data.comm.rank < data.comm.size - 1: + data.comm.Isend([None, 0, MPI.INT], dest=data.comm.rank + 1) + else: + raise NotImplementedError() + + def save( data: DNDarray, path: str, *args: Optional[List[object]], **kwargs: Optional[Dict[str, object]] ): diff --git a/heat/core/tests/test_io.py b/heat/core/tests/test_io.py index 24afb25a86..e0c8b2e09d 100644 --- a/heat/core/tests/test_io.py +++ b/heat/core/tests/test_io.py @@ -22,6 +22,7 @@ def setUpClass(cls): # load comparison data from csv cls.CSV_PATH = os.path.join(os.getcwd(), "heat/datasets/iris.csv") + cls.CSV_OUT_PATH = pwd + "/test.csv" cls.IRIS = ( torch.from_numpy(np.loadtxt(cls.CSV_PATH, delimiter=";")) .float() @@ -141,6 +142,13 @@ def test_load_csv(self): with self.assertRaises(TypeError): ht.load_csv(self.CSV_PATH, header_lines="3", sep=";", split=0) + def test_save_csv(self): + data = ht.arange(100).reshape(20, 5) + ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None, sep=";", dtype=data.dtype.char()) + comparison = ht.load_csv(self.CSV_OUT_PATH, sep=";", dtype=data.dtype.char()) + if data.comm.rank == 0: + self.assertTrue((data.larray == comparison.larray).all().item()) + def test_load_exception(self): # correct extension, file does not exist if ht.io.supports_hdf5(): From 507d5259e3e811294908761459c62ec0f456793d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Hagemeier?= Date: Wed, 23 Mar 2022 12:05:14 +0100 Subject: [PATCH 02/21] save_csv docstring --- heat/core/io.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/heat/core/io.py b/heat/core/io.py index f4f23cae65..478f12b2f5 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -936,7 +936,23 @@ def save_csv( """ Saves data to CSV files - + Parameters + ---------- + data : DNDarray + The DNDarray to be saved to CSV. + path : str + The path as a string. + header_lines : Iterable[str] + Optional iterable of str to prepend at the beginning of the file. No + pound sign or any other comment marker will be inserted. + sep : str + The separator character used in this CSV. + dtype : datatype + The datatype used in this CSV. Currently, the value is ignored. + encoding : str + The encoding to be used in this CSV. + comm : Optional[Communication] + An optional object of type Communication to be used. """ print("Saving CSV on rank %d of %d" % (data.comm.rank, data.comm.size)) if data.comm.rank == 0 and header_lines: From a5d82c27f0ba80955f7d901dcc20a99cd3af9b23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Hagemeier?= Date: Thu, 24 Mar 2022 14:43:58 +0100 Subject: [PATCH 03/21] save_csv function This is an implementation of save_csv that covers split None and 0. resolves #750 --- heat/core/io.py | 137 +++++++++++++++++++++---------------- heat/core/tests/test_io.py | 28 +++++++- 2 files changed, 105 insertions(+), 60 deletions(-) diff --git a/heat/core/io.py b/heat/core/io.py index 478f12b2f5..81ce0a0340 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -2,6 +2,7 @@ from __future__ import annotations import os.path +from math import log10 import numpy as np import torch import warnings @@ -16,6 +17,7 @@ from .communication import Communication, MPI, MPI_WORLD, sanitize_comm from .dndarray import DNDarray from .manipulations import hsplit, vsplit +from .statistics import max as smax, min as smin from .stride_tricks import sanitize_axis from .types import datatype @@ -929,7 +931,7 @@ def save_csv( path: str, header_lines: Iterable[str] = None, sep: str = ",", - dtype: datatype = types.float32, + decimals: int = -1, encoding: str = "utf-8", comm: Optional[Communication] = None, ): @@ -947,73 +949,92 @@ def save_csv( pound sign or any other comment marker will be inserted. sep : str The separator character used in this CSV. - dtype : datatype - The datatype used in this CSV. Currently, the value is ignored. + decimals: int + Number of digits after decimal point. encoding : str The encoding to be used in this CSV. comm : Optional[Communication] An optional object of type Communication to be used. """ - print("Saving CSV on rank %d of %d" % (data.comm.rank, data.comm.size)) - if data.comm.rank == 0 and header_lines: - with open(path, "w") as csv_out: - for hl in header_lines: - print(hl, file=csv_out) - csv_out.flush() - - print("Waiting for header writing on rank %d of %d" % (data.comm.rank, data.comm.size)) - data.comm.Barrier() - print("Waited for header writing on rank %d of %d" % (data.comm.rank, data.comm.size)) - print("Data split == %s on rank %d" % (data.split, data.comm.rank)) - - # split None can be written out by rank 0 + if not isinstance(path, str): + raise TypeError("path must be str, not {}".format(type(path))) + if not isinstance(sep, str): + raise TypeError("separator must be str, not {}".format(type(sep))) + # check this to allow None + if not isinstance(header_lines, Iterable) and header_lines is not None: + raise TypeError("header_lines must Iterable[str], not {}".format(type(header_lines))) + if data.split not in [None, 0]: + raise ValueError("split must be in [None, 0], but is {}".format(data.split)) + + amode = MPI.MODE_WRONLY | MPI.MODE_CREATE + csv_out = MPI.File.Open(data.comm.handle, path, amode) + + # will be needed as an additional offset later + hl_displacement = 0 + if header_lines is not None: + hl_displacement = sum(len(hl) for hl in header_lines) + # count additions everywhere, but write only on rank 0, avoiding reduce op to share final hl_displacement + for hl in header_lines: + if not hl.endswith("\n"): + hl = hl + "\n" + hl_displacement = hl_displacement + 1 + if data.comm.rank == 0 and header_lines: + csv_out.Write(hl.encode(encoding)) + + # formatting and element width + data_min = smin(data).item() # at least min is used twice, so cache it here + data_max = smax(data).item() + sign = 1 if data_min < 0 else 0 + pre_point_digits = int(log10(max(data_max, abs(data_min)))) + 1 + + dec_sep = 1 + fmt = "" + if types.issubdtype(data.dtype, types.integer): + decimals = 0 + dec_sep = 0 + if sign == 1: + fmt = "{: %dd}" % (pre_point_digits + 1) + else: + fmt = "{:%dd}" % (pre_point_digits) + elif types.issubdtype(data.dtype, types.float): + if decimals == -1: + decimals = 7 if data.dtype is types.float32 else 15 + print("Decimals: %d" % (decimals)) + dec_sep = 1 + if sign == 1: + fmt = "{: %d.%df}" % (pre_point_digits + decimals, decimals) + else: + fmt = "{:%d.%df}" % (pre_point_digits + decimals, decimals) + + # sign + decimal separator + pre separator digits + decimals (post separator) + item_size = decimals + dec_sep + sign + pre_point_digits + # number of items times their size + (items - 1) commas + final nl, the last two cancelling each other out + row_width = data.lshape[1] * (item_size + 1) + if data.split is None: - print("Split is None", file=sys.stderr) - if data.comm.rank == 0: - with open(path, "w") as csv_out: - for row in vsplit(data, data.lshape[0]): - print( - sep.join(map(lambda dnda: str(dnda.item()), hsplit(row, data.lshape[1]))), - file=csv_out, - ) - csv_out.flush() - print("Done on rank %d" % (data.comm.rank), file=sys.stderr) + offset = hl_displacement elif data.split == 0: - print("Split is 0", file=sys.stderr) - if data.comm.rank == 0: - print("On rank %d" % (data.comm.rank), file=sys.stderr) - with open(path, "w") as csv_out: - for row in vsplit(data, data.lshape[0]): - print( - sep.join(map(lambda dnda: str(dnda.item()), hsplit(row, data.lshape[1]))), - file=csv_out, - ) - csv_out.flush() - if data.comm.rank < data.comm.size - 1: - print( - "Rank %d: sending %d to rank %d" - % (data.comm.rank, data.comm.rank + 1, data.comm.rank + 1), - file=sys.stderr, - ) - data.comm.Send([None, 0, MPI.INT], dest=data.comm.rank + 1) - else: - rank = -1 - print("Rank %d: waiting for my turn" % (data.comm.rank), file=sys.stderr) - data.comm.Recv([None, 0, MPI.INT], source=data.comm.rank - 1) - if rank != data.comm.rank: - raise IndexError("Wrong write sequence for CSV file.") - with open(path, "w") as csv_out: - for row in vsplit(data, data.lshape[0]): - print( - sep.join(map(lambda dnda: str(dnda.item()), hsplit(row, data.lshape[1]))), - file=csv_out, - ) - csv_out.flush() - if data.comm.rank < data.comm.size - 1: - data.comm.Isend([None, 0, MPI.INT], dest=data.comm.rank + 1) + # v1: via counts_displs + _, displs = data.counts_displs() + offset = displs[data.comm.rank] + # v2: via lshape_map, did not work for me + # offset = a.lshape_map[:, a.split][a.comm.rank] + offset = offset * row_width + hl_displacement else: raise NotImplementedError() + for i in range(data.lshape[0]): + row = sep.join(fmt.format(item) for item in data.larray[i]) + row = row + "\n" + # buf = BytesIO(row) + # buf = StringIO() + # buf.write(row) + csv_out.Write_at_all(offset, row.encode("utf-8")) + # buf.close() + offset = offset + row_width + + csv_out.Close() + def save( data: DNDarray, path: str, *args: Optional[List[object]], **kwargs: Optional[Dict[str, object]] diff --git a/heat/core/tests/test_io.py b/heat/core/tests/test_io.py index e0c8b2e09d..ecec8dff3a 100644 --- a/heat/core/tests/test_io.py +++ b/heat/core/tests/test_io.py @@ -143,12 +143,36 @@ def test_load_csv(self): ht.load_csv(self.CSV_PATH, header_lines="3", sep=";", split=0) def test_save_csv(self): + os.truncate(self.CSV_OUT_PATH, 0) + # {signed, unsigned} x {float, integer} x {',', '|', ';'} x {None, 0, 1} x {header_lines, no_header_lines} data = ht.arange(100).reshape(20, 5) - ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None, sep=";", dtype=data.dtype.char()) - comparison = ht.load_csv(self.CSV_OUT_PATH, sep=";", dtype=data.dtype.char()) + ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None, sep=";") + comparison = ht.load_csv(self.CSV_OUT_PATH, sep=";", dtype=data.dtype) if data.comm.rank == 0: self.assertTrue((data.larray == comparison.larray).all().item()) + os.truncate(self.CSV_OUT_PATH, 0) + header_lines = ["col1,col2,col3,col4,col5", "second line\n"] + ht.save_csv(data, self.CSV_OUT_PATH, header_lines=header_lines, sep=",") + comparison = ht.load_csv(self.CSV_OUT_PATH, sep=",", dtype=data.dtype, header_lines=2) + if data.comm.rank == 0: + self.assertTrue((data.larray == comparison.larray).all().item()) + + os.truncate(self.CSV_OUT_PATH, 0) + data = data - 50 + ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None) + comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) + if data.comm.rank == 0: + self.assertTrue((data.larray == comparison.larray).all().item()) + + os.truncate(self.CSV_OUT_PATH, 0) + data = ht.random.rand(100.0, split=0).reshape(20, 5) + ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None) + comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) + self.assertTrue( + ht.max(data - comparison).item() < 0.0001 + ) # need a parameter to control precision + def test_load_exception(self): # correct extension, file does not exist if ht.io.supports_hdf5(): From e4e8fa4510c600b371885b19f089e8f86b7d6f9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Hagemeier?= Date: Thu, 24 Mar 2022 15:18:47 +0100 Subject: [PATCH 04/21] Test whether test.csv exists before truncation --- heat/core/tests/test_io.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/heat/core/tests/test_io.py b/heat/core/tests/test_io.py index ecec8dff3a..433e71c7e0 100644 --- a/heat/core/tests/test_io.py +++ b/heat/core/tests/test_io.py @@ -143,7 +143,8 @@ def test_load_csv(self): ht.load_csv(self.CSV_PATH, header_lines="3", sep=";", split=0) def test_save_csv(self): - os.truncate(self.CSV_OUT_PATH, 0) + if os.path.exists(self.CSV_OUT_PATH): + os.truncate(self.CSV_OUT_PATH, 0) # {signed, unsigned} x {float, integer} x {',', '|', ';'} x {None, 0, 1} x {header_lines, no_header_lines} data = ht.arange(100).reshape(20, 5) ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None, sep=";") @@ -151,21 +152,24 @@ def test_save_csv(self): if data.comm.rank == 0: self.assertTrue((data.larray == comparison.larray).all().item()) - os.truncate(self.CSV_OUT_PATH, 0) + if os.path.exists(self.CSV_OUT_PATH): + os.truncate(self.CSV_OUT_PATH, 0) header_lines = ["col1,col2,col3,col4,col5", "second line\n"] ht.save_csv(data, self.CSV_OUT_PATH, header_lines=header_lines, sep=",") comparison = ht.load_csv(self.CSV_OUT_PATH, sep=",", dtype=data.dtype, header_lines=2) if data.comm.rank == 0: self.assertTrue((data.larray == comparison.larray).all().item()) - os.truncate(self.CSV_OUT_PATH, 0) + if os.path.exists(self.CSV_OUT_PATH): + os.truncate(self.CSV_OUT_PATH, 0) data = data - 50 ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None) comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) if data.comm.rank == 0: self.assertTrue((data.larray == comparison.larray).all().item()) - os.truncate(self.CSV_OUT_PATH, 0) + if os.path.exists(self.CSV_OUT_PATH): + os.truncate(self.CSV_OUT_PATH, 0) data = ht.random.rand(100.0, split=0).reshape(20, 5) ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None) comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) From 0fcd7976c8e432b1c16bd6058423148d7a0b7b50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Hagemeier?= Date: Thu, 24 Mar 2022 21:20:24 +0100 Subject: [PATCH 05/21] Use more appropriate MPI.File.Write_at Using the collective MPI.File.Write_at_all led to problems with not perfectly balanced chunks. The ordinary Write_at is much better for this purpose. On the way, I also removed print statements and comments. --- heat/core/io.py | 14 +++----------- heat/core/tests/test_io.py | 10 ++++++++++ 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/heat/core/io.py b/heat/core/io.py index 81ce0a0340..01ba78e655 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -999,8 +999,6 @@ def save_csv( elif types.issubdtype(data.dtype, types.float): if decimals == -1: decimals = 7 if data.dtype is types.float32 else 15 - print("Decimals: %d" % (decimals)) - dec_sep = 1 if sign == 1: fmt = "{: %d.%df}" % (pre_point_digits + decimals, decimals) else: @@ -1008,17 +1006,15 @@ def save_csv( # sign + decimal separator + pre separator digits + decimals (post separator) item_size = decimals + dec_sep + sign + pre_point_digits - # number of items times their size + (items - 1) commas + final nl, the last two cancelling each other out + # each item is one position larger than its representation, either b/c of separator or line break row_width = data.lshape[1] * (item_size + 1) if data.split is None: - offset = hl_displacement + offset = hl_displacement # split None elif data.split == 0: # v1: via counts_displs _, displs = data.counts_displs() offset = displs[data.comm.rank] - # v2: via lshape_map, did not work for me - # offset = a.lshape_map[:, a.split][a.comm.rank] offset = offset * row_width + hl_displacement else: raise NotImplementedError() @@ -1026,11 +1022,7 @@ def save_csv( for i in range(data.lshape[0]): row = sep.join(fmt.format(item) for item in data.larray[i]) row = row + "\n" - # buf = BytesIO(row) - # buf = StringIO() - # buf.write(row) - csv_out.Write_at_all(offset, row.encode("utf-8")) - # buf.close() + csv_out.Write_at(offset, row.encode("utf-8")) offset = offset + row_width csv_out.Close() diff --git a/heat/core/tests/test_io.py b/heat/core/tests/test_io.py index 433e71c7e0..cbf9625b88 100644 --- a/heat/core/tests/test_io.py +++ b/heat/core/tests/test_io.py @@ -177,6 +177,16 @@ def test_save_csv(self): ht.max(data - comparison).item() < 0.0001 ) # need a parameter to control precision + if os.path.exists(self.CSV_OUT_PATH): + os.truncate(self.CSV_OUT_PATH, 0) + data = ht.random.rand(100.0, split=0).reshape(50, 2) + ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None) + comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) + self.assertTrue( + ht.max(data - comparison).item() < 0.0001 + ) # need a parameter to control precision + # print("Done testing save_csv") + def test_load_exception(self): # correct extension, file does not exist if ht.io.supports_hdf5(): From 38ac575b550ba81d7584a4a891287018ff5a4a42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Hagemeier?= Date: Fri, 25 Mar 2022 11:40:54 +0100 Subject: [PATCH 06/21] save_csv: fix float64 problem floating is the supertype of float32 and float64, not float, which is just an alias. Added a corresponding test. --- heat/core/io.py | 2 +- heat/core/tests/test_io.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/heat/core/io.py b/heat/core/io.py index 01ba78e655..34047700ef 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -996,7 +996,7 @@ def save_csv( fmt = "{: %dd}" % (pre_point_digits + 1) else: fmt = "{:%dd}" % (pre_point_digits) - elif types.issubdtype(data.dtype, types.float): + elif types.issubdtype(data.dtype, types.floating): if decimals == -1: decimals = 7 if data.dtype is types.float32 else 15 if sign == 1: diff --git a/heat/core/tests/test_io.py b/heat/core/tests/test_io.py index cbf9625b88..e8790c9df5 100644 --- a/heat/core/tests/test_io.py +++ b/heat/core/tests/test_io.py @@ -182,10 +182,14 @@ def test_save_csv(self): data = ht.random.rand(100.0, split=0).reshape(50, 2) ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None) comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) - self.assertTrue( - ht.max(data - comparison).item() < 0.0001 - ) # need a parameter to control precision - # print("Done testing save_csv") + self.assertTrue(ht.max(data - comparison).item() < 0.0001) + + if os.path.exists(self.CSV_OUT_PATH): + os.truncate(self.CSV_OUT_PATH, 0) + data = ht.random.rand(100.0, split=0, dtype=ht.float64).reshape(50, 2) + ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None) + comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) + self.assertTrue(ht.max(data - comparison).item() < 0.0000001) def test_load_exception(self): # correct extension, file does not exist From c6ffcee08aaff57915ec5148c46c81769e0e80fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Hagemeier?= Date: Fri, 25 Mar 2022 12:43:14 +0100 Subject: [PATCH 07/21] Adjust format to match field width --- heat/core/io.py | 7 ++++--- heat/core/tests/test_io.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/heat/core/io.py b/heat/core/io.py index 34047700ef..341a327145 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -1000,10 +1000,11 @@ def save_csv( if decimals == -1: decimals = 7 if data.dtype is types.float32 else 15 if sign == 1: - fmt = "{: %d.%df}" % (pre_point_digits + decimals, decimals) + fmt = "{: %d.%df}" % (pre_point_digits + decimals + 2, decimals) else: - fmt = "{:%d.%df}" % (pre_point_digits + decimals, decimals) - + fmt = "{:%d.%df}" % (pre_point_digits + decimals + 1, decimals) + print("Format string: %s" % (fmt)) + print("pre_point_digits: %d" % (pre_point_digits)) # sign + decimal separator + pre separator digits + decimals (post separator) item_size = decimals + dec_sep + sign + pre_point_digits # each item is one position larger than its representation, either b/c of separator or line break diff --git a/heat/core/tests/test_io.py b/heat/core/tests/test_io.py index e8790c9df5..26c0bb33e5 100644 --- a/heat/core/tests/test_io.py +++ b/heat/core/tests/test_io.py @@ -186,7 +186,7 @@ def test_save_csv(self): if os.path.exists(self.CSV_OUT_PATH): os.truncate(self.CSV_OUT_PATH, 0) - data = ht.random.rand(100.0, split=0, dtype=ht.float64).reshape(50, 2) + data = ht.arange(100.0, split=0, dtype=ht.float64).reshape(50, 2) ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None) comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) self.assertTrue(ht.max(data - comparison).item() < 0.0000001) From d6d1d4991822a0bd912f5ac069f8c50deea719bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Hagemeier?= Date: Fri, 25 Mar 2022 12:45:18 +0100 Subject: [PATCH 08/21] Truncate file before writing The way we use MPI-IO does not reset existing contents of files and therefore may leave garbage at the end if the data to be written has a shorter representation than the existing file. Therefore, we reset by default, but allow to omit this step. --- heat/core/io.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/heat/core/io.py b/heat/core/io.py index 341a327145..87e6de0598 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -934,6 +934,7 @@ def save_csv( decimals: int = -1, encoding: str = "utf-8", comm: Optional[Communication] = None, + truncate: bool = True, ): """ Saves data to CSV files @@ -955,6 +956,10 @@ def save_csv( The encoding to be used in this CSV. comm : Optional[Communication] An optional object of type Communication to be used. + truncate : bool + Whether to truncate an existing file before writing, i.e. fully overwrite it. + The sane default is True. Setting it to False will not shorten files if + needed and thus may leave garbage at the end of existing files. """ if not isinstance(path, str): raise TypeError("path must be str, not {}".format(type(path))) @@ -966,6 +971,12 @@ def save_csv( if data.split not in [None, 0]: raise ValueError("split must be in [None, 0], but is {}".format(data.split)) + if os.path.exists(path) and truncate: + if data.comm.rank == 0: + os.truncate(path, 0) + # avoid truncating and writing at the same time + data.comm.handle.Barrier() + amode = MPI.MODE_WRONLY | MPI.MODE_CREATE csv_out = MPI.File.Open(data.comm.handle, path, amode) From eda98250ee9d6ccc93cf328b449aee2a0943912c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Hagemeier?= Date: Fri, 25 Mar 2022 12:49:26 +0100 Subject: [PATCH 09/21] Remove print statements --- heat/core/io.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/heat/core/io.py b/heat/core/io.py index 87e6de0598..863307fb05 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -1014,8 +1014,7 @@ def save_csv( fmt = "{: %d.%df}" % (pre_point_digits + decimals + 2, decimals) else: fmt = "{:%d.%df}" % (pre_point_digits + decimals + 1, decimals) - print("Format string: %s" % (fmt)) - print("pre_point_digits: %d" % (pre_point_digits)) + # sign + decimal separator + pre separator digits + decimals (post separator) item_size = decimals + dec_sep + sign + pre_point_digits # each item is one position larger than its representation, either b/c of separator or line break From ed54d313f82f0ae2b4e395c507243b6846f1dc76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Hagemeier?= Date: Fri, 25 Mar 2022 14:24:32 +0100 Subject: [PATCH 10/21] Add changelog entry for save_csv --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ddb06676e4..d3e43ec59d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ - [#876](https://github.com/helmholtz-analytics/heat/pull/876) Make examples work (Lasso and kNN) - [#894](https://github.com/helmholtz-analytics/heat/pull/894) Change inclusion of license file - [#884](https://github.com/helmholtz-analytics/heat/pull/884) Added capabilities for PyTorch 1.10.0, this is now the recommended version to use. +- [#941](https://github.com/helmholtz-analytics/heat/pull/941) Add function to save data as CSV. ## Bug Fixes - [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension From c5727ab258e61c151c58db4a0d13d0e2b35a5b1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Hagemeier?= Date: Fri, 25 Mar 2022 15:04:35 +0100 Subject: [PATCH 11/21] Allow for 1D vectors --- heat/core/io.py | 10 ++++++++-- heat/core/tests/test_io.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/heat/core/io.py b/heat/core/io.py index 863307fb05..29dbb574a6 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -1018,7 +1018,9 @@ def save_csv( # sign + decimal separator + pre separator digits + decimals (post separator) item_size = decimals + dec_sep + sign + pre_point_digits # each item is one position larger than its representation, either b/c of separator or line break - row_width = data.lshape[1] * (item_size + 1) + row_width = item_size + 1 + if len(data.lshape) > 1: + row_width = data.lshape[1] * (item_size + 1) if data.split is None: offset = hl_displacement # split None @@ -1031,7 +1033,11 @@ def save_csv( raise NotImplementedError() for i in range(data.lshape[0]): - row = sep.join(fmt.format(item) for item in data.larray[i]) + # if lshape is of the form (x,), then there will only be a single element per row + if len(data.lshape) == 1: + row = fmt.format(data.larray[i]) + else: + row = sep.join(fmt.format(item) for item in data.larray[i]) row = row + "\n" csv_out.Write_at(offset, row.encode("utf-8")) offset = offset + row_width diff --git a/heat/core/tests/test_io.py b/heat/core/tests/test_io.py index 26c0bb33e5..e49a3b781f 100644 --- a/heat/core/tests/test_io.py +++ b/heat/core/tests/test_io.py @@ -191,6 +191,21 @@ def test_save_csv(self): comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) self.assertTrue(ht.max(data - comparison).item() < 0.0000001) + # 1D vector (25,) + # + do not truncate this time, but rely on save_csv + data = ht.arange(25, split=0, dtype=ht.int) + ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None) + comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) + # read_csv always reads in (n, m), so need to reshape to (n,) + self.assertTrue(ht.max(data - comparison.reshape(data.shape)).item() == 0) + + # 1D vector (1, 25) + data = ht.arange(25, split=0, dtype=ht.int).reshape(1, 25) + ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None) + comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) + # read_csv always reads in (n, m), so need to reshape to (n,) + self.assertTrue(ht.max(data - comparison.reshape(data.shape)).item() == 0) + def test_load_exception(self): # correct extension, file does not exist if ht.io.supports_hdf5(): From d42f99c2a3d3d07664ebbc14502c45e6219cd129 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Hagemeier?= Date: Fri, 25 Mar 2022 15:11:18 +0100 Subject: [PATCH 12/21] Add CSV to generic save function --- heat/core/io.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/heat/core/io.py b/heat/core/io.py index 29dbb574a6..255fa4d689 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -1089,6 +1089,8 @@ def save( save_netcdf(data, path, *args, **kwargs) else: raise RuntimeError("netcdf is required for file extension {}".format(extension)) + elif extension in __CSV_EXTENSION: + save_csv(data, path, *args, **kwargs) else: raise ValueError("Unsupported file extension {}".format(extension)) From b0407a31b2eabbda0888cf7c1ee75b1ff9724893 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Hagemeier?= Date: Fri, 25 Mar 2022 22:12:13 +0100 Subject: [PATCH 13/21] save_csv for split 1 the difference from split 0 is not so big after all --- heat/core/io.py | 22 ++++++++++++++-------- heat/core/tests/test_io.py | 10 ++++++++++ 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/heat/core/io.py b/heat/core/io.py index 255fa4d689..9728ccccf3 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -968,8 +968,8 @@ def save_csv( # check this to allow None if not isinstance(header_lines, Iterable) and header_lines is not None: raise TypeError("header_lines must Iterable[str], not {}".format(type(header_lines))) - if data.split not in [None, 0]: - raise ValueError("split must be in [None, 0], but is {}".format(data.split)) + if data.split not in [None, 0, 1]: + raise ValueError("split must be in [None, 0, 1], but is {}".format(data.split)) if os.path.exists(path) and truncate: if data.comm.rank == 0: @@ -1019,16 +1019,17 @@ def save_csv( item_size = decimals + dec_sep + sign + pre_point_digits # each item is one position larger than its representation, either b/c of separator or line break row_width = item_size + 1 - if len(data.lshape) > 1: - row_width = data.lshape[1] * (item_size + 1) + if len(data.shape) > 1: + row_width = data.shape[1] * (item_size + 1) if data.split is None: offset = hl_displacement # split None elif data.split == 0: - # v1: via counts_displs _, displs = data.counts_displs() - offset = displs[data.comm.rank] - offset = offset * row_width + hl_displacement + offset = displs[data.comm.rank] * row_width + hl_displacement + elif data.split == 1: + _, displs = data.counts_displs() + offset = displs[data.comm.rank] * (item_size + 1) + hl_displacement else: raise NotImplementedError() @@ -1038,7 +1039,12 @@ def save_csv( row = fmt.format(data.larray[i]) else: row = sep.join(fmt.format(item) for item in data.larray[i]) - row = row + "\n" + + if data.split is None or data.split == 0 or data.comm.rank == (data.comm.size - 1): + row = row + "\n" + else: + row = row + sep + csv_out.Write_at(offset, row.encode("utf-8")) offset = offset + row_width diff --git a/heat/core/tests/test_io.py b/heat/core/tests/test_io.py index e49a3b781f..b3747e1458 100644 --- a/heat/core/tests/test_io.py +++ b/heat/core/tests/test_io.py @@ -206,6 +206,16 @@ def test_save_csv(self): # read_csv always reads in (n, m), so need to reshape to (n,) self.assertTrue(ht.max(data - comparison.reshape(data.shape)).item() == 0) + data = ht.random.randint(0, 100, (25, 4), split=1) + data.save(self.CSV_OUT_PATH) + comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) + self.assertTrue(ht.max(data - comparison.reshape(data.shape)).item() < 0.0000001) + + data = ht.random.rand(250, 10, dtype=ht.float64, split=0) + data.save(self.CSV_OUT_PATH) + comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) + self.assertTrue(ht.max(data - comparison.reshape(data.shape)).item() < 0.0000001) + def test_load_exception(self): # correct extension, file does not exist if ht.io.supports_hdf5(): From 3bf1319944ea74261838884841394273efa81463 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Hagemeier?= Date: Fri, 25 Mar 2022 22:12:36 +0100 Subject: [PATCH 14/21] Remove unused import sys was only there for debugging purposes --- heat/core/io.py | 1 - 1 file changed, 1 deletion(-) diff --git a/heat/core/io.py b/heat/core/io.py index 9728ccccf3..84ee0ee2cd 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -6,7 +6,6 @@ import numpy as np import torch import warnings -import sys from typing import Dict, Iterable, List, Optional, Tuple, Union From d1dbc3871ec8b5bf835c7a66c56018c7f9249e9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Hagemeier?= Date: Fri, 25 Mar 2022 22:19:13 +0100 Subject: [PATCH 15/21] Code simplification remove unreachable else branch and start from common offset for all splits --- heat/core/io.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/heat/core/io.py b/heat/core/io.py index 84ee0ee2cd..b2bf829018 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -1021,16 +1021,13 @@ def save_csv( if len(data.shape) > 1: row_width = data.shape[1] * (item_size + 1) - if data.split is None: - offset = hl_displacement # split None - elif data.split == 0: + offset = hl_displacement # all splits + if data.split == 0: _, displs = data.counts_displs() - offset = displs[data.comm.rank] * row_width + hl_displacement + offset = offset + displs[data.comm.rank] * row_width elif data.split == 1: _, displs = data.counts_displs() - offset = displs[data.comm.rank] * (item_size + 1) + hl_displacement - else: - raise NotImplementedError() + offset = offset + displs[data.comm.rank] * (item_size + 1) for i in range(data.lshape[0]): # if lshape is of the form (x,), then there will only be a single element per row From 6fe203c836216ccd8d65508e73d45407009da2bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Hagemeier?= Date: Mon, 28 Mar 2022 08:15:34 +0200 Subject: [PATCH 16/21] Make save_csv synchronous Not synchronizing at the end of writing the file may lead to strange effects for imbalanced tensors. --- heat/core/io.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/heat/core/io.py b/heat/core/io.py index b2bf829018..2ee01bf25c 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -936,7 +936,7 @@ def save_csv( truncate: bool = True, ): """ - Saves data to CSV files + Saves data to CSV files. Only 2D data, all split axes. Parameters ---------- @@ -1045,6 +1045,7 @@ def save_csv( offset = offset + row_width csv_out.Close() + data.comm.handle.Barrier() def save( From a8740cba1a0957aa9b85d9bc2791e610cc8592de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Hagemeier?= Date: Mon, 28 Mar 2022 15:55:37 +0200 Subject: [PATCH 17/21] Bugfix: nprocs > shape[1] Having more processes than chunks in split 1 did not work. Rather than checking whether we are the last (overall) rank, we check whether we have the last chunk of data and don't write anything if we have no data. Last chunk is relevant to distinguish newline or separator addition at the end of our buffer. --- heat/core/io.py | 8 +++++++- heat/core/tests/test_io.py | 5 +++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/heat/core/io.py b/heat/core/io.py index 975c368f8f..b029cd2bc6 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -1032,9 +1032,15 @@ def save_csv( if len(data.lshape) == 1: row = fmt.format(data.larray[i]) else: + if data.lshape[1] == 0: + break row = sep.join(fmt.format(item) for item in data.larray[i]) - if data.split is None or data.split == 0 or data.comm.rank == (data.comm.size - 1): + if ( + data.split is None + or data.split == 0 + or displs[data.comm.rank] + data.lshape[1] == data.shape[1] + ): row = row + "\n" else: row = row + sep diff --git a/heat/core/tests/test_io.py b/heat/core/tests/test_io.py index b3747e1458..43e6212019 100644 --- a/heat/core/tests/test_io.py +++ b/heat/core/tests/test_io.py @@ -216,6 +216,11 @@ def test_save_csv(self): comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) self.assertTrue(ht.max(data - comparison.reshape(data.shape)).item() < 0.0000001) + data = ht.random.randint(10, 100, (50, 2), split=1) + data.save(self.CSV_OUT_PATH) + comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) + self.assertTrue((data == comparison).all().item()) + def test_load_exception(self): # correct extension, file does not exist if ht.io.supports_hdf5(): From 7f76ef5b63c48480d9d667347ab54a3ff8fb0e0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Hagemeier?= Date: Tue, 29 Mar 2022 23:05:11 +0200 Subject: [PATCH 18/21] Write split 0 only on rank 0 --- heat/core/io.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/heat/core/io.py b/heat/core/io.py index b029cd2bc6..6f05d17ea0 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -1045,7 +1045,9 @@ def save_csv( else: row = row + sep - csv_out.Write_at(offset, row.encode("utf-8")) + if data.split is not None or data.comm.rank == 0: + csv_out.Write_at(offset, row.encode("utf-8")) + offset = offset + row_width csv_out.Close() From 6fb198113581c35846e78a4309a80f8291f601ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Hagemeier?= Date: Tue, 29 Mar 2022 23:11:00 +0200 Subject: [PATCH 19/21] Enumerate test_csv cases and use individual temporary files --- heat/core/tests/test_io.py | 128 +++++++++++++++---------------------- 1 file changed, 51 insertions(+), 77 deletions(-) diff --git a/heat/core/tests/test_io.py b/heat/core/tests/test_io.py index 43e6212019..04fa908106 100644 --- a/heat/core/tests/test_io.py +++ b/heat/core/tests/test_io.py @@ -1,6 +1,7 @@ import numpy as np import os import torch +import tempfile import heat as ht from .test_suites.basic_test import TestCase @@ -143,83 +144,56 @@ def test_load_csv(self): ht.load_csv(self.CSV_PATH, header_lines="3", sep=";", split=0) def test_save_csv(self): - if os.path.exists(self.CSV_OUT_PATH): - os.truncate(self.CSV_OUT_PATH, 0) - # {signed, unsigned} x {float, integer} x {',', '|', ';'} x {None, 0, 1} x {header_lines, no_header_lines} - data = ht.arange(100).reshape(20, 5) - ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None, sep=";") - comparison = ht.load_csv(self.CSV_OUT_PATH, sep=";", dtype=data.dtype) - if data.comm.rank == 0: - self.assertTrue((data.larray == comparison.larray).all().item()) - - if os.path.exists(self.CSV_OUT_PATH): - os.truncate(self.CSV_OUT_PATH, 0) - header_lines = ["col1,col2,col3,col4,col5", "second line\n"] - ht.save_csv(data, self.CSV_OUT_PATH, header_lines=header_lines, sep=",") - comparison = ht.load_csv(self.CSV_OUT_PATH, sep=",", dtype=data.dtype, header_lines=2) - if data.comm.rank == 0: - self.assertTrue((data.larray == comparison.larray).all().item()) - - if os.path.exists(self.CSV_OUT_PATH): - os.truncate(self.CSV_OUT_PATH, 0) - data = data - 50 - ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None) - comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) - if data.comm.rank == 0: - self.assertTrue((data.larray == comparison.larray).all().item()) - - if os.path.exists(self.CSV_OUT_PATH): - os.truncate(self.CSV_OUT_PATH, 0) - data = ht.random.rand(100.0, split=0).reshape(20, 5) - ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None) - comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) - self.assertTrue( - ht.max(data - comparison).item() < 0.0001 - ) # need a parameter to control precision - - if os.path.exists(self.CSV_OUT_PATH): - os.truncate(self.CSV_OUT_PATH, 0) - data = ht.random.rand(100.0, split=0).reshape(50, 2) - ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None) - comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) - self.assertTrue(ht.max(data - comparison).item() < 0.0001) - - if os.path.exists(self.CSV_OUT_PATH): - os.truncate(self.CSV_OUT_PATH, 0) - data = ht.arange(100.0, split=0, dtype=ht.float64).reshape(50, 2) - ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None) - comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) - self.assertTrue(ht.max(data - comparison).item() < 0.0000001) - - # 1D vector (25,) - # + do not truncate this time, but rely on save_csv - data = ht.arange(25, split=0, dtype=ht.int) - ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None) - comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) - # read_csv always reads in (n, m), so need to reshape to (n,) - self.assertTrue(ht.max(data - comparison.reshape(data.shape)).item() == 0) - - # 1D vector (1, 25) - data = ht.arange(25, split=0, dtype=ht.int).reshape(1, 25) - ht.save_csv(data, self.CSV_OUT_PATH, header_lines=None) - comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) - # read_csv always reads in (n, m), so need to reshape to (n,) - self.assertTrue(ht.max(data - comparison.reshape(data.shape)).item() == 0) - - data = ht.random.randint(0, 100, (25, 4), split=1) - data.save(self.CSV_OUT_PATH) - comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) - self.assertTrue(ht.max(data - comparison.reshape(data.shape)).item() < 0.0000001) - - data = ht.random.rand(250, 10, dtype=ht.float64, split=0) - data.save(self.CSV_OUT_PATH) - comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) - self.assertTrue(ht.max(data - comparison.reshape(data.shape)).item() < 0.0000001) - - data = ht.random.randint(10, 100, (50, 2), split=1) - data.save(self.CSV_OUT_PATH) - comparison = ht.load_csv(self.CSV_OUT_PATH, dtype=data.dtype) - self.assertTrue((data == comparison).all().item()) + for rnd_type in [ + (ht.random.randint, ht.types.int32), + (ht.random.randint, ht.types.int64), + (ht.random.rand, ht.types.float32), + (ht.random.rand, ht.types.float64), + ]: + for separator in [",", ";", "|"]: + for split in [None, 0, 1]: + for headers in [None, ["# This", "# is a", "# test."], ["an,ordinary,header"]]: + for shape in [(1, 1), (10, 10), (20, 1), (1, 20), (25, 4), (4, 25)]: + if rnd_type[0] == ht.random.randint: + data = rnd_type[0]( + -100, 1000, size=shape, dtype=rnd_type[1], split=split + ) + else: + data = rnd_type[0]( + shape[0], + shape[1], + split=split, + dtype=rnd_type[1], + ) + + if data.comm.rank == 0: + tmpfile = tempfile.NamedTemporaryFile( + prefix="test_io_", suffix=".csv", delete=False + ) + tmpfile.close() + filename = tmpfile.name + else: + filename = None + filename = data.comm.handle.bcast(filename, root=0) + + data.save( + filename, + header_lines=headers, + sep=separator, + ) + comparison = ht.load_csv( + filename, + # split=split, + header_lines=0 if headers is None else len(headers), + sep=separator, + ) + resid = data - comparison + self.assertTrue( + ht.max(resid).item() < 0.00001 and ht.min(resid).item() > -0.00001 + ) + if data.comm.rank == 0: + os.unlink(filename) + data.comm.handle.Barrier() def test_load_exception(self): # correct extension, file does not exist From 7c2024ffb74a9854a0b8eb049237d2ed4839f398 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Hagemeier?= Date: Wed, 30 Mar 2022 09:05:04 +0200 Subject: [PATCH 20/21] Barrier before unlink Sometimes, load_csv complained about files not being available anymore. We need to sync through a barrier before unlinking files. --- heat/core/tests/test_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/tests/test_io.py b/heat/core/tests/test_io.py index 04fa908106..47ad519cad 100644 --- a/heat/core/tests/test_io.py +++ b/heat/core/tests/test_io.py @@ -191,9 +191,9 @@ def test_save_csv(self): self.assertTrue( ht.max(resid).item() < 0.00001 and ht.min(resid).item() > -0.00001 ) + data.comm.handle.Barrier() if data.comm.rank == 0: os.unlink(filename) - data.comm.handle.Barrier() def test_load_exception(self): # correct extension, file does not exist From 7a673cdbe73a2989f186e3856b4df04c37397941 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Hagemeier?= Date: Wed, 30 Mar 2022 09:20:18 +0200 Subject: [PATCH 21/21] Use abs also for max value The maximum of a tensor can be less than 0, so need abs around the max, too, before passing it to log10. Also, a 0 value must be excluded. --- heat/core/io.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/heat/core/io.py b/heat/core/io.py index 6f05d17ea0..c58e1d5fd4 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -993,7 +993,10 @@ def save_csv( data_min = smin(data).item() # at least min is used twice, so cache it here data_max = smax(data).item() sign = 1 if data_min < 0 else 0 - pre_point_digits = int(log10(max(data_max, abs(data_min)))) + 1 + if abs(data_max) > 0 or abs(data_min) > 0: + pre_point_digits = int(log10(max(abs(data_max), abs(data_min)))) + 1 + else: + pre_point_digits = 1 dec_sep = 1 fmt = ""