diff --git a/data/base_dataset.py b/data/base_dataset.py index 9db1a33f1..58f20f22a 100644 --- a/data/base_dataset.py +++ b/data/base_dataset.py @@ -35,6 +35,9 @@ def __init__(self, opt): opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ self.opt = opt + + self.use_domain_B = not "self_supervised" in self.opt.data_dataset_mode + self.root = opt.dataroot self.sv_dir = os.path.join(opt.checkpoints_dir, opt.name) self.warning_mode = self.opt.warning_mode @@ -146,16 +149,21 @@ def set_dataset_dirs_and_dims(self): self.dir_A = os.path.join( self.opt.dataroot, self.opt.phase + "A" ) # create a path '/path/to/data/trainA' - self.dir_B = os.path.join( - self.opt.dataroot, self.opt.phase + "B" - ) # create a path '/path/to/data/trainB' + + if self.use_domain_B: + + self.dir_B = os.path.join( + self.opt.dataroot, self.opt.phase + "B" + ) # create a path '/path/to/data/trainB' else: self.dir_A = os.path.join( self.opt.dataroot, self.opt.phase + "B" ) # create a path '/path/to/data/trainB' - self.dir_B = os.path.join( - self.opt.dataroot, self.opt.phase + "A" - ) # create a path '/path/to/data/trainA' + + if self.use_domain_B: + self.dir_B = os.path.join( + self.opt.dataroot, self.opt.phase + "A" + ) # create a path '/path/to/data/trainA' def get_validation_set(self, size): return_A_list = [] diff --git a/data/temporal_dataset.py b/data/temporal_dataset.py index a525c32b6..a7002b03d 100644 --- a/data/temporal_dataset.py +++ b/data/temporal_dataset.py @@ -35,39 +35,47 @@ def __init__(self, opt): self.dir_A, "/paths.txt" ) # load images from '/path/to/data/trainA/paths.txt' as well as labels - self.B_img_paths, self.B_label_mask_paths = make_labeled_path_dataset( - self.dir_B, "/paths.txt" - ) # load images from '/path/to/data/trainB' + if self.use_domain_B: + self.B_img_paths, self.B_label_mask_paths = make_labeled_path_dataset( + self.dir_B, "/paths.txt" + ) # load images from '/path/to/data/trainB' # sort self.A_img_paths.sort(key=natural_keys) self.A_label_mask_paths.sort(key=natural_keys) - self.B_img_paths.sort(key=natural_keys) - self.B_label_mask_paths.sort(key=natural_keys) + + if self.use_domain_B: + self.B_img_paths.sort(key=natural_keys) + self.B_label_mask_paths.sort(key=natural_keys) self.A_img_paths, self.A_label_mask_paths = ( self.A_img_paths[: opt.data_max_dataset_size], self.A_label_mask_paths[: opt.data_max_dataset_size], ) - self.B_img_paths, self.B_label_mask_paths = ( - self.B_img_paths[: opt.data_max_dataset_size], - self.B_label_mask_paths[: opt.data_max_dataset_size], - ) + + if self.use_domain_B: + self.B_img_paths, self.B_label_mask_paths = ( + self.B_img_paths[: opt.data_max_dataset_size], + self.B_label_mask_paths[: opt.data_max_dataset_size], + ) self.transform = get_transform_list(self.opt, grayscale=(self.input_nc == 1)) - self.num_A = len(self.A_img_paths) - self.num_B = len(self.B_img_paths) self.num_frames = opt.D_temporal_number_frames self.frame_step = opt.D_temporal_frame_step + + self.num_A = len(self.A_img_paths) self.range_A = self.num_A - self.num_frames * self.frame_step - self.range_B = self.num_B - self.num_frames * self.frame_step + + if self.use_domain_B: + self.num_B = len(self.B_img_paths) + self.range_B = self.num_B - self.num_frames * self.frame_step self.num_common_char = self.opt.D_temporal_num_common_char self.opt = opt self.A_size = len(self.A_img_paths) # get the size of dataset A - if os.path.exists(self.dir_B): + if self.use_domain_B and os.path.exists(self.dir_B): self.B_size = len(self.B_img_paths) # get the size of dataset B def get_img( @@ -149,36 +157,55 @@ def get_img( labels_A = torch.stack(labels_A) - index_B = random.randint(0, self.range_B - 1) + if self.use_domain_B: - images_B = [] - labels_B = [] + index_B = random.randint(0, self.range_B - 1) - ref_name_B = self.B_img_paths[index_B].split("/")[-1][: self.num_common_char] + images_B = [] + labels_B = [] - for i in range(self.num_frames): - cur_index_B = index_B + i * self.frame_step + ref_name_B = self.B_img_paths[index_B].split("/")[-1][ + : self.num_common_char + ] - if ( - self.num_common_char != -1 - and self.B_img_paths[cur_index_B].split("/")[-1][: self.num_common_char] - not in ref_name_B - ): - return None + for i in range(self.num_frames): + cur_index_B = index_B + i * self.frame_step - cur_B_img_path, cur_B_label_path = ( - self.B_img_paths[cur_index_B], - self.B_label_mask_paths[cur_index_B], - ) + if ( + self.num_common_char != -1 + and self.B_img_paths[cur_index_B].split("/")[-1][ + : self.num_common_char + ] + not in ref_name_B + ): + return None - if self.opt.data_relative_paths: - cur_B_img_path = os.path.join(self.root, cur_B_img_path) - if cur_B_label_path is not None: - cur_B_label_path = os.path.join(self.root, cur_B_label_path) + cur_B_img_path, cur_B_label_path = ( + self.B_img_paths[cur_index_B], + self.B_label_mask_paths[cur_index_B], + ) - try: - if i == 0: - crop_coordinates = crop_image( + if self.opt.data_relative_paths: + cur_B_img_path = os.path.join(self.root, cur_B_img_path) + if cur_B_label_path is not None: + cur_B_label_path = os.path.join(self.root, cur_B_label_path) + + try: + if i == 0: + crop_coordinates = crop_image( + cur_B_img_path, + cur_B_label_path, + mask_delta=self.opt.data_online_creation_mask_delta_B, + crop_delta=self.opt.data_online_creation_crop_delta_B, + mask_square=self.opt.data_online_creation_mask_square_B, + crop_dim=self.opt.data_online_creation_crop_size_B, + output_dim=self.opt.data_load_size, + context_pixels=self.opt.data_online_context_pixels, + load_size=self.opt.data_online_creation_load_size_B, + get_crop_coordinates=True, + ) + + cur_B_img, cur_B_label = crop_image( cur_B_img_path, cur_B_label_path, mask_delta=self.opt.data_online_creation_mask_delta_B, @@ -188,34 +215,25 @@ def get_img( output_dim=self.opt.data_load_size, context_pixels=self.opt.data_online_context_pixels, load_size=self.opt.data_online_creation_load_size_B, - get_crop_coordinates=True, + crop_coordinates=crop_coordinates, ) - cur_B_img, cur_B_label = crop_image( - cur_B_img_path, - cur_B_label_path, - mask_delta=self.opt.data_online_creation_mask_delta_B, - crop_delta=self.opt.data_online_creation_crop_delta_B, - mask_square=self.opt.data_online_creation_mask_square_B, - crop_dim=self.opt.data_online_creation_crop_size_B, - output_dim=self.opt.data_load_size, - context_pixels=self.opt.data_online_context_pixels, - load_size=self.opt.data_online_creation_load_size_B, - crop_coordinates=crop_coordinates, - ) + except Exception as e: + print(e, f"{i+1}th frame of domain B in temporal dataloading") + return None - except Exception as e: - print(e, f"{i+1}th frame of domain B in temporal dataloading") - return None + images_B.append(cur_B_img) + labels_B.append(cur_B_label) - images_B.append(cur_B_img) - labels_B.append(cur_B_label) + images_B, labels_B = self.transform(images_B, labels_B) - images_B, labels_B = self.transform(images_B, labels_B) + images_B = torch.stack(images_B) - images_B = torch.stack(images_B) + labels_B = torch.stack(labels_B) - labels_B = torch.stack(labels_B) + else: + images_B = None + labels_B = None result = {"A": images_A, "B": images_B} diff --git a/data/unaligned_labeled_mask_dataset.py b/data/unaligned_labeled_mask_dataset.py index e426e4ae2..76f27177a 100644 --- a/data/unaligned_labeled_mask_dataset.py +++ b/data/unaligned_labeled_mask_dataset.py @@ -42,7 +42,7 @@ def __init__(self, opt): ) # load images from '/path/to/data/trainA/paths.txt' as well as labels self.A_size = len(self.A_img_paths) # get the size of dataset A - if os.path.exists(self.dir_B): + if self.use_domain_B and os.path.exists(self.dir_B): self.B_img_paths, self.B_label = make_labeled_path_dataset( self.dir_B, "/paths.txt", opt.data_max_dataset_size ) # load images from '/path/to/data/trainB' @@ -56,7 +56,7 @@ def __init__(self, opt): for label in self.A_label: self.A_label_mask_paths.append(label.split(" ")[-1]) - if hasattr(self, "B_label"): + if self.use_domain_B and hasattr(self, "B_label"): for label in self.B_label: self.B_label_mask_paths.append(label.split(" ")[-1]) diff --git a/data/unaligned_labeled_mask_online_dataset.py b/data/unaligned_labeled_mask_online_dataset.py index 8de68f809..606125aa8 100644 --- a/data/unaligned_labeled_mask_online_dataset.py +++ b/data/unaligned_labeled_mask_online_dataset.py @@ -57,7 +57,7 @@ def __init__(self, opt): opt.dataroot, "/paths.txt" ) # load images from '/path/to/data/trainA/paths.txt' as well as labels - if os.path.exists(self.dir_B): + if self.use_domain_B and os.path.exists(self.dir_B): self.B_img_paths, self.B_label_mask_paths = make_labeled_path_dataset( self.dir_B, "/paths.txt" ) # load images from '/path/to/data/trainB' @@ -81,14 +81,14 @@ def __init__(self, opt): self.A_img_paths[: opt.data_max_dataset_size], self.A_label_mask_paths[: opt.data_max_dataset_size], ) - if os.path.exists(self.dir_B): + if self.use_domain_B and os.path.exists(self.dir_B): self.B_img_paths, self.B_label_mask_paths = ( self.B_img_paths[: opt.data_max_dataset_size], self.B_label_mask_paths[: opt.data_max_dataset_size], ) self.A_size = len(self.A_img_paths) # get the size of dataset A - if os.path.exists(self.dir_B): + if self.use_domain_B and os.path.exists(self.dir_B): self.B_size = len(self.B_img_paths) # get the size of dataset B self.transform = get_transform_seg(self.opt, grayscale=(self.input_nc == 1))