Skip to content

Commit

Permalink
Adding the general save/load methods to ivy to be used as ivy.save an…
Browse files Browse the repository at this point in the history
…d ivy.load for both containers and modules. Additionally, added the intermediate Container.cont_save and Container.cont_load methods. (ivy-llc#18954)
  • Loading branch information
saeedashrraf authored Jul 7, 2023
1 parent 5be1d05 commit dd0c39c
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 3 deletions.
27 changes: 24 additions & 3 deletions ivy/data_classes/container/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,17 @@ def cont_identical_array_shapes(containers, exclusive=False):
return False
return True

@staticmethod
def cont_load(filepath, format="h5py"):
if format == "json":
return ivy.Container.cont_from_disk_as_json(filepath)
elif format == "pickle":
return ivy.Container.cont_from_disk_as_pickled(filepath)
elif format == "h5py":
return ivy.Container.cont_from_disk_as_hdf5(filepath)
else:
raise ivy.utils.exceptions.IvyException("Unsupported format")

@staticmethod
def cont_from_disk_as_hdf5(
h5_obj_or_filepath, slice_obj=slice(None), alphabetical_keys=True, ivyh=None
Expand Down Expand Up @@ -1948,6 +1959,16 @@ def cont_size_ordered_arrays(self, exclusive=False):
alphabetical_keys=False,
)

def cont_save(self, filepath, format="h5py"):
if format == "json":
self.cont_to_disk_as_json(filepath)
elif format == "pickle":
self.cont_to_disk_as_pickled(filepath)
elif format == "h5py":
self.cont_to_disk_as_hdf5(filepath)
else:
raise ValueError("Unsupported format")

def cont_to_disk_as_hdf5(
self, h5_obj_or_filepath, starting_index=0, mode="a", max_batch_size=None
):
Expand Down Expand Up @@ -2002,9 +2023,9 @@ def cont_to_disk_as_hdf5(
)
space_left = max_batch_size - starting_index
amount_to_write = min(this_batch_size, space_left)
h5_obj[key][starting_index : starting_index + amount_to_write] = (
value_as_np[0:amount_to_write]
)
h5_obj[key][
starting_index : starting_index + amount_to_write
] = value_as_np[0:amount_to_write]

def cont_to_disk_as_pickled(self, pickle_filepath):
"""
Expand Down
29 changes: 29 additions & 0 deletions ivy/functional/ivy/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,32 @@ def any(
}
"""
return ivy.current_backend(x).any(x, axis=axis, keepdims=keepdims, out=out)


# Extra #
# ----- #


def save(item, filepath, format=None):
if isinstance(item, ivy.Container):
if format is not None:
item.cont_save(filepath, format=format)
else:
item.cont_save(filepath)
elif isinstance(item, ivy.Module):
item.save(filepath)
else:
raise ivy.utils.exceptions.IvyException("Unsupported item type for saving.")


@staticmethod
def load(filepath, format=None, type="module"):
if type == "module":
return ivy.Module.load(filepath)
elif type == "container":
if format is not None:
return ivy.Container.cont_load(filepath, format=format)
else:
return ivy.Container.cont_load(filepath)
else:
raise ivy.utils.exceptions.IvyException("Unsupported item type for loading.")

0 comments on commit dd0c39c

Please sign in to comment.