Skip to content

Commit

Permalink
feat(ml): modif for horse2zebra prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
wr0124 authored and beniz committed Jun 20, 2024
1 parent 023dd54 commit b66a954
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions data/unaligned_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os.path
from data.base_dataset import BaseDataset, get_transform
from data.utils import load_image
from data.image_folder import make_dataset
from data.image_folder import make_dataset, make_ref_path_list
from PIL import Image
import random

Expand Down Expand Up @@ -40,6 +40,11 @@ def __init__(self, opt, phase, name=""):

self.header = ["img"]

if os.path.isfile(self.dir_B + "/prompts.txt"):
self.B_img_prompt = make_ref_path_list(self.dir_B, "/prompts.txt")
else:
self.B_img_prompt = None

# A_label_path and B_label_path are unused
def get_img(
self,
Expand All @@ -57,7 +62,17 @@ def get_img(
A = self.transform_A(A_img)
B = self.transform_B(B_img)

return {"A": A, "B": B, "A_img_paths": A_img_path, "B_img_paths": B_img_path}
real_B_prompt = self.B_img_prompt[B_img_path]
if len(real_B_prompt) == 1 and isinstance(real_B_prompt[0], str):
real_B_prompt = real_B_prompt[0]

return {
"A": A,
"B": B,
"A_img_paths": A_img_path,
"B_img_paths": B_img_path,
"real_B_prompt": real_B_prompt,
}

def __len__(self):
"""Return the total number of images in the dataset.
Expand Down

0 comments on commit b66a954

Please sign in to comment.