From e5bab69e3a8f3dc4afb1badba65b6c50ca2f36d8 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 2 Jun 2024 21:11:40 +0900 Subject: [PATCH] fix alpha mask without disk cache closes #1351, ref #1339 --- library/train_util.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 1f9f3c5df..566f59279 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1265,7 +1265,8 @@ def __getitem__(self, index): if subset.alpha_mask: if img.shape[2] == 4: alpha_mask = img[:, :, 3] # [H,W] - alpha_mask = transforms.ToTensor()(alpha_mask) # 0-255 -> 0-1 + alpha_mask = alpha_mask.astype(np.float32) / 255.0 # 0.0~1.0 + alpha_mask = torch.FloatTensor(alpha_mask) else: alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32) else: @@ -2211,7 +2212,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) def load_latents_from_disk( npz_path, -) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: +) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: npz = np.load(npz_path) if "latents" not in npz: raise ValueError(f"error: npz is old format. please re-generate {npz_path}") @@ -2229,7 +2230,7 @@ def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, fli if flipped_latents_tensor is not None: kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() if alpha_mask is not None: - kwargs["alpha_mask"] = alpha_mask # ndarray + kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() np.savez( npz_path, latents=latents_tensor.float().cpu().numpy(), @@ -2496,8 +2497,9 @@ def cache_batch_latents( if image.shape[2] == 4: alpha_mask = image[:, :, 3] # [H,W] alpha_mask = alpha_mask.astype(np.float32) / 255.0 + alpha_mask = torch.FloatTensor(alpha_mask) # [H,W] else: - alpha_mask = np.ones_like(image[:, :, 0], dtype=np.float32) + alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W] else: alpha_mask = None alpha_masks.append(alpha_mask)