From feefcf256e78a5f8d60c3a940f2be3b5c3ca335d Mon Sep 17 00:00:00 2001 From: Cauldrath Date: Thu, 18 Apr 2024 23:15:36 -0400 Subject: [PATCH] Display name of error latent file When trying to load stored latents, if an error occurs, this change will tell you what file failed to load Currently it will just tell you that something failed without telling you which file --- library/train_util.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3cc..58527fa00 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2123,18 +2123,21 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): if not os.path.exists(npz_path): return False - npz = np.load(npz_path) - if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver? - return False - if npz["latents"].shape[1:3] != expected_latents_size: - return False - - if flip_aug: - if "latents_flipped" not in npz: + try: + npz = np.load(npz_path) + if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver? return False - if npz["latents_flipped"].shape[1:3] != expected_latents_size: + if npz["latents"].shape[1:3] != expected_latents_size: return False + if flip_aug: + if "latents_flipped" not in npz: + return False + if npz["latents_flipped"].shape[1:3] != expected_latents_size: + return False + except: + raise RuntimeError(f"Error loading file: {npz_path}") + return True