Skip to content

Commit

Permalink
fix: selfsupervised dataloaders don't need domain B anymore
Browse files Browse the repository at this point in the history
  • Loading branch information
pnsuau authored and beniz committed Feb 17, 2023
1 parent 69a02db commit 257a056
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 68 deletions.
20 changes: 14 additions & 6 deletions data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def __init__(self, opt):
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
self.opt = opt

self.use_domain_B = not "self_supervised" in self.opt.data_dataset_mode

self.root = opt.dataroot
self.sv_dir = os.path.join(opt.checkpoints_dir, opt.name)
self.warning_mode = self.opt.warning_mode
Expand Down Expand Up @@ -146,16 +149,21 @@ def set_dataset_dirs_and_dims(self):
self.dir_A = os.path.join(
self.opt.dataroot, self.opt.phase + "A"
) # create a path '/path/to/data/trainA'
self.dir_B = os.path.join(
self.opt.dataroot, self.opt.phase + "B"
) # create a path '/path/to/data/trainB'

if self.use_domain_B:

self.dir_B = os.path.join(
self.opt.dataroot, self.opt.phase + "B"
) # create a path '/path/to/data/trainB'
else:
self.dir_A = os.path.join(
self.opt.dataroot, self.opt.phase + "B"
) # create a path '/path/to/data/trainB'
self.dir_B = os.path.join(
self.opt.dataroot, self.opt.phase + "A"
) # create a path '/path/to/data/trainA'

if self.use_domain_B:
self.dir_B = os.path.join(
self.opt.dataroot, self.opt.phase + "A"
) # create a path '/path/to/data/trainA'

def get_validation_set(self, size):
return_A_list = []
Expand Down
132 changes: 75 additions & 57 deletions data/temporal_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,39 +35,47 @@ def __init__(self, opt):
self.dir_A, "/paths.txt"
) # load images from '/path/to/data/trainA/paths.txt' as well as labels

self.B_img_paths, self.B_label_mask_paths = make_labeled_path_dataset(
self.dir_B, "/paths.txt"
) # load images from '/path/to/data/trainB'
if self.use_domain_B:
self.B_img_paths, self.B_label_mask_paths = make_labeled_path_dataset(
self.dir_B, "/paths.txt"
) # load images from '/path/to/data/trainB'

# sort
self.A_img_paths.sort(key=natural_keys)
self.A_label_mask_paths.sort(key=natural_keys)
self.B_img_paths.sort(key=natural_keys)
self.B_label_mask_paths.sort(key=natural_keys)

if self.use_domain_B:
self.B_img_paths.sort(key=natural_keys)
self.B_label_mask_paths.sort(key=natural_keys)

self.A_img_paths, self.A_label_mask_paths = (
self.A_img_paths[: opt.data_max_dataset_size],
self.A_label_mask_paths[: opt.data_max_dataset_size],
)
self.B_img_paths, self.B_label_mask_paths = (
self.B_img_paths[: opt.data_max_dataset_size],
self.B_label_mask_paths[: opt.data_max_dataset_size],
)

if self.use_domain_B:
self.B_img_paths, self.B_label_mask_paths = (
self.B_img_paths[: opt.data_max_dataset_size],
self.B_label_mask_paths[: opt.data_max_dataset_size],
)

self.transform = get_transform_list(self.opt, grayscale=(self.input_nc == 1))

self.num_A = len(self.A_img_paths)
self.num_B = len(self.B_img_paths)
self.num_frames = opt.D_temporal_number_frames
self.frame_step = opt.D_temporal_frame_step

self.num_A = len(self.A_img_paths)
self.range_A = self.num_A - self.num_frames * self.frame_step
self.range_B = self.num_B - self.num_frames * self.frame_step

if self.use_domain_B:
self.num_B = len(self.B_img_paths)
self.range_B = self.num_B - self.num_frames * self.frame_step
self.num_common_char = self.opt.D_temporal_num_common_char

self.opt = opt

self.A_size = len(self.A_img_paths) # get the size of dataset A
if os.path.exists(self.dir_B):
if self.use_domain_B and os.path.exists(self.dir_B):
self.B_size = len(self.B_img_paths) # get the size of dataset B

def get_img(
Expand Down Expand Up @@ -149,36 +157,55 @@ def get_img(

labels_A = torch.stack(labels_A)

index_B = random.randint(0, self.range_B - 1)
if self.use_domain_B:

images_B = []
labels_B = []
index_B = random.randint(0, self.range_B - 1)

ref_name_B = self.B_img_paths[index_B].split("/")[-1][: self.num_common_char]
images_B = []
labels_B = []

for i in range(self.num_frames):
cur_index_B = index_B + i * self.frame_step
ref_name_B = self.B_img_paths[index_B].split("/")[-1][
: self.num_common_char
]

if (
self.num_common_char != -1
and self.B_img_paths[cur_index_B].split("/")[-1][: self.num_common_char]
not in ref_name_B
):
return None
for i in range(self.num_frames):
cur_index_B = index_B + i * self.frame_step

cur_B_img_path, cur_B_label_path = (
self.B_img_paths[cur_index_B],
self.B_label_mask_paths[cur_index_B],
)
if (
self.num_common_char != -1
and self.B_img_paths[cur_index_B].split("/")[-1][
: self.num_common_char
]
not in ref_name_B
):
return None

if self.opt.data_relative_paths:
cur_B_img_path = os.path.join(self.root, cur_B_img_path)
if cur_B_label_path is not None:
cur_B_label_path = os.path.join(self.root, cur_B_label_path)
cur_B_img_path, cur_B_label_path = (
self.B_img_paths[cur_index_B],
self.B_label_mask_paths[cur_index_B],
)

try:
if i == 0:
crop_coordinates = crop_image(
if self.opt.data_relative_paths:
cur_B_img_path = os.path.join(self.root, cur_B_img_path)
if cur_B_label_path is not None:
cur_B_label_path = os.path.join(self.root, cur_B_label_path)

try:
if i == 0:
crop_coordinates = crop_image(
cur_B_img_path,
cur_B_label_path,
mask_delta=self.opt.data_online_creation_mask_delta_B,
crop_delta=self.opt.data_online_creation_crop_delta_B,
mask_square=self.opt.data_online_creation_mask_square_B,
crop_dim=self.opt.data_online_creation_crop_size_B,
output_dim=self.opt.data_load_size,
context_pixels=self.opt.data_online_context_pixels,
load_size=self.opt.data_online_creation_load_size_B,
get_crop_coordinates=True,
)

cur_B_img, cur_B_label = crop_image(
cur_B_img_path,
cur_B_label_path,
mask_delta=self.opt.data_online_creation_mask_delta_B,
Expand All @@ -188,34 +215,25 @@ def get_img(
output_dim=self.opt.data_load_size,
context_pixels=self.opt.data_online_context_pixels,
load_size=self.opt.data_online_creation_load_size_B,
get_crop_coordinates=True,
crop_coordinates=crop_coordinates,
)

cur_B_img, cur_B_label = crop_image(
cur_B_img_path,
cur_B_label_path,
mask_delta=self.opt.data_online_creation_mask_delta_B,
crop_delta=self.opt.data_online_creation_crop_delta_B,
mask_square=self.opt.data_online_creation_mask_square_B,
crop_dim=self.opt.data_online_creation_crop_size_B,
output_dim=self.opt.data_load_size,
context_pixels=self.opt.data_online_context_pixels,
load_size=self.opt.data_online_creation_load_size_B,
crop_coordinates=crop_coordinates,
)
except Exception as e:
print(e, f"{i+1}th frame of domain B in temporal dataloading")
return None

except Exception as e:
print(e, f"{i+1}th frame of domain B in temporal dataloading")
return None
images_B.append(cur_B_img)
labels_B.append(cur_B_label)

images_B.append(cur_B_img)
labels_B.append(cur_B_label)
images_B, labels_B = self.transform(images_B, labels_B)

images_B, labels_B = self.transform(images_B, labels_B)
images_B = torch.stack(images_B)

images_B = torch.stack(images_B)
labels_B = torch.stack(labels_B)

labels_B = torch.stack(labels_B)
else:
images_B = None
labels_B = None

result = {"A": images_A, "B": images_B}

Expand Down
4 changes: 2 additions & 2 deletions data/unaligned_labeled_mask_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, opt):
) # load images from '/path/to/data/trainA/paths.txt' as well as labels
self.A_size = len(self.A_img_paths) # get the size of dataset A

if os.path.exists(self.dir_B):
if self.use_domain_B and os.path.exists(self.dir_B):
self.B_img_paths, self.B_label = make_labeled_path_dataset(
self.dir_B, "/paths.txt", opt.data_max_dataset_size
) # load images from '/path/to/data/trainB'
Expand All @@ -56,7 +56,7 @@ def __init__(self, opt):
for label in self.A_label:
self.A_label_mask_paths.append(label.split(" ")[-1])

if hasattr(self, "B_label"):
if self.use_domain_B and hasattr(self, "B_label"):
for label in self.B_label:
self.B_label_mask_paths.append(label.split(" ")[-1])

Expand Down
6 changes: 3 additions & 3 deletions data/unaligned_labeled_mask_online_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, opt):
opt.dataroot, "/paths.txt"
) # load images from '/path/to/data/trainA/paths.txt' as well as labels

if os.path.exists(self.dir_B):
if self.use_domain_B and os.path.exists(self.dir_B):
self.B_img_paths, self.B_label_mask_paths = make_labeled_path_dataset(
self.dir_B, "/paths.txt"
) # load images from '/path/to/data/trainB'
Expand All @@ -81,14 +81,14 @@ def __init__(self, opt):
self.A_img_paths[: opt.data_max_dataset_size],
self.A_label_mask_paths[: opt.data_max_dataset_size],
)
if os.path.exists(self.dir_B):
if self.use_domain_B and os.path.exists(self.dir_B):
self.B_img_paths, self.B_label_mask_paths = (
self.B_img_paths[: opt.data_max_dataset_size],
self.B_label_mask_paths[: opt.data_max_dataset_size],
)

self.A_size = len(self.A_img_paths) # get the size of dataset A
if os.path.exists(self.dir_B):
if self.use_domain_B and os.path.exists(self.dir_B):
self.B_size = len(self.B_img_paths) # get the size of dataset B

self.transform = get_transform_seg(self.opt, grayscale=(self.input_nc == 1))
Expand Down

0 comments on commit 257a056

Please sign in to comment.