diff --git a/ivy/data_classes/container/base.py b/ivy/data_classes/container/base.py index 17fb7ac4f9635..f893659cc93c3 100644 --- a/ivy/data_classes/container/base.py +++ b/ivy/data_classes/container/base.py @@ -1093,6 +1093,17 @@ def cont_identical_array_shapes(containers, exclusive=False): return False return True + @staticmethod + def cont_load(filepath, format="h5py"): + if format == "json": + return ivy.Container.cont_from_disk_as_json(filepath) + elif format == "pickle": + return ivy.Container.cont_from_disk_as_pickled(filepath) + elif format == "h5py": + return ivy.Container.cont_from_disk_as_hdf5(filepath) + else: + raise ivy.utils.exceptions.IvyException("Unsupported format") + @staticmethod def cont_from_disk_as_hdf5( h5_obj_or_filepath, slice_obj=slice(None), alphabetical_keys=True, ivyh=None @@ -1948,6 +1959,16 @@ def cont_size_ordered_arrays(self, exclusive=False): alphabetical_keys=False, ) + def cont_save(self, filepath, format="h5py"): + if format == "json": + self.cont_to_disk_as_json(filepath) + elif format == "pickle": + self.cont_to_disk_as_pickled(filepath) + elif format == "h5py": + self.cont_to_disk_as_hdf5(filepath) + else: + raise ValueError("Unsupported format") + def cont_to_disk_as_hdf5( self, h5_obj_or_filepath, starting_index=0, mode="a", max_batch_size=None ): @@ -2002,9 +2023,9 @@ def cont_to_disk_as_hdf5( ) space_left = max_batch_size - starting_index amount_to_write = min(this_batch_size, space_left) - h5_obj[key][starting_index : starting_index + amount_to_write] = ( - value_as_np[0:amount_to_write] - ) + h5_obj[key][ + starting_index : starting_index + amount_to_write + ] = value_as_np[0:amount_to_write] def cont_to_disk_as_pickled(self, pickle_filepath): """ diff --git a/ivy/functional/ivy/utility.py b/ivy/functional/ivy/utility.py index 8240428e66a44..7d91c1e1b0809 100644 --- a/ivy/functional/ivy/utility.py +++ b/ivy/functional/ivy/utility.py @@ -232,3 +232,32 @@ def any( } """ return ivy.current_backend(x).any(x, axis=axis, keepdims=keepdims, out=out) + + +# Extra # +# ----- # + + +def save(item, filepath, format=None): + if isinstance(item, ivy.Container): + if format is not None: + item.cont_save(filepath, format=format) + else: + item.cont_save(filepath) + elif isinstance(item, ivy.Module): + item.save(filepath) + else: + raise ivy.utils.exceptions.IvyException("Unsupported item type for saving.") + + +@staticmethod +def load(filepath, format=None, type="module"): + if type == "module": + return ivy.Module.load(filepath) + elif type == "container": + if format is not None: + return ivy.Container.cont_load(filepath, format=format) + else: + return ivy.Container.cont_load(filepath) + else: + raise ivy.utils.exceptions.IvyException("Unsupported item type for loading.")