Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retain alpha in pil_resize for --alpha_mask #1619

Merged
merged 2 commits into from
Sep 20, 2024
Merged

Conversation

emcmanus
Copy link

Currently pil_resize() drops the alpha channel when --alpha_mask is supplied, but only if the image width does not exceed the bucket size.

This codepath is entered on the last line, here:

def trim_and_resize_if_required(
    random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int]
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]:
    image_height, image_width = image.shape[0:2]
    original_size = (image_width, image_height)  # size before resize

    if image_width != resized_size[0] or image_height != resized_size[1]:
        # リサイズする
        if image_width > resized_size[0] and image_height > resized_size[1]:
            image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)  # INTER_AREAでやりたいのでcv2でリサイズ
        else:
            image = pil_resize(image, resized_size)

Currently the alpha channel is dropped by `pil_resize()` when `--alpha_mask` is supplied and the image width does not exceed the bucket.

This codepath is entered on the last line, here:
```
def trim_and_resize_if_required(
    random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int]
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]:
    image_height, image_width = image.shape[0:2]
    original_size = (image_width, image_height)  # size before resize

    if image_width != resized_size[0] or image_height != resized_size[1]:
        # リサイズする
        if image_width > resized_size[0] and image_height > resized_size[1]:
            image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)  # INTER_AREAでやりたいのでcv2でリサイズ
        else:
            image = pil_resize(image, resized_size)
```
@kohya-ss
Copy link
Owner

Thank you for this!

@kohya-ss kohya-ss merged commit 95ff9db into kohya-ss:sd3 Sep 20, 2024
1 check passed
@Maru-mee
Copy link
Contributor

Maru-mee commented Sep 22, 2024

私の認識が間違っていなければ、
この変更は、sd3のみで、dev版には反映されていないようです。
しかし、dev版でも同じ事象(※1)が発生する問題のようなので、もし可能ならマージをお願いしたいです。
PR#1632と関係する要素であり、先に解決しておきたい課題です。

※1 下記のような事象です。
pilによるアルファチャンネル喪失、3チャンネル化
→ alpha_mask作成時に
if image.shape[2] == 4:にならず、
else:
alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W]
に分岐し強制停止。

@kohya-ss
Copy link
Owner

devブランチにも同様の変更を行いました。

@Maru-mee
Copy link
Contributor

ありがとうございます!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants