Skip to content

Commit

Permalink
Fix image size batch for SDXL
Browse files Browse the repository at this point in the history
  • Loading branch information
rockerBOO committed Jun 19, 2024
1 parent 44fa71c commit 6e124dc
Showing 1 changed file with 24 additions and 10 deletions.
34 changes: 24 additions & 10 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,24 +1342,38 @@ def __getitem__(self, index):
images.append(image)
latents_list.append(None)
alpha_mask_list.append(alpha_mask)

target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)

if not flipped:
crop_left_top = (crop_ltrb[0], crop_ltrb[1])
else:
# crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image
crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1])

original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
crop_top_lefts.append((int(crop_left_top[1]), int(crop_left_top[0])))
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
flippeds.append(flipped)
else:
image, original_size, crop_ltrb, alpha_mask = self.load_and_transform_image(subset, image_info, image_info.absolute_path, flipped)
images.append(image)
latents_list.append(None)
alpha_mask_list.append(alpha_mask)

target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)

if not flipped:
crop_left_top = (crop_ltrb[0], crop_ltrb[1])
else:
# crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image
crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1])
if not flipped:
crop_left_top = (crop_ltrb[0], crop_ltrb[1])
else:
# crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image
crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1])

original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
crop_top_lefts.append((int(crop_left_top[1]), int(crop_left_top[0])))
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
flippeds.append(flipped)

original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
crop_top_lefts.append((int(crop_left_top[1]), int(crop_left_top[0])))
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
flippeds.append(flipped)

# captionとtext encoder outputを処理する
caption = image_info.caption # default
Expand Down

0 comments on commit 6e124dc

Please sign in to comment.