diff --git a/library/train_util.py b/library/train_util.py index 15c23f3cc..58527fa00 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2123,18 +2123,21 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): if not os.path.exists(npz_path): return False - npz = np.load(npz_path) - if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver? - return False - if npz["latents"].shape[1:3] != expected_latents_size: - return False - - if flip_aug: - if "latents_flipped" not in npz: + try: + npz = np.load(npz_path) + if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver? return False - if npz["latents_flipped"].shape[1:3] != expected_latents_size: + if npz["latents"].shape[1:3] != expected_latents_size: return False + if flip_aug: + if "latents_flipped" not in npz: + return False + if npz["latents_flipped"].shape[1:3] != expected_latents_size: + return False + except: + raise RuntimeError(f"Error loading file: {npz_path}") + return True