Skip to content

Commit

Permalink
fix alpha mask without disk cache closes #1351, ref #1339
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Jun 2, 2024
1 parent 0d96e10 commit e5bab69
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,7 +1265,8 @@ def __getitem__(self, index):
if subset.alpha_mask:
if img.shape[2] == 4:
alpha_mask = img[:, :, 3] # [H,W]
alpha_mask = transforms.ToTensor()(alpha_mask) # 0-255 -> 0-1
alpha_mask = alpha_mask.astype(np.float32) / 255.0 # 0.0~1.0
alpha_mask = torch.FloatTensor(alpha_mask)
else:
alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32)
else:
Expand Down Expand Up @@ -2211,7 +2212,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
def load_latents_from_disk(
npz_path,
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
npz = np.load(npz_path)
if "latents" not in npz:
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
Expand All @@ -2229,7 +2230,7 @@ def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, fli
if flipped_latents_tensor is not None:
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
if alpha_mask is not None:
kwargs["alpha_mask"] = alpha_mask # ndarray
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
np.savez(
npz_path,
latents=latents_tensor.float().cpu().numpy(),
Expand Down Expand Up @@ -2496,8 +2497,9 @@ def cache_batch_latents(
if image.shape[2] == 4:
alpha_mask = image[:, :, 3] # [H,W]
alpha_mask = alpha_mask.astype(np.float32) / 255.0
alpha_mask = torch.FloatTensor(alpha_mask) # [H,W]
else:
alpha_mask = np.ones_like(image[:, :, 0], dtype=np.float32)
alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W]
else:
alpha_mask = None
alpha_masks.append(alpha_mask)
Expand Down

0 comments on commit e5bab69

Please sign in to comment.