From 5af5803cbf921e99dcf3598e54fdf75a0259b735 Mon Sep 17 00:00:00 2001 From: Emmanuel Benazera Date: Fri, 13 Sep 2024 12:21:52 +0000 Subject: [PATCH] fix: load_image replacement --- data/aligned_dataset.py | 6 ------ data/base_dataset.py | 6 ++++++ data/unaligned_dataset.py | 31 ++++++++++++++++++++++++------- data/utils.py | 38 +++++++++++++++++++++++++------------- 4 files changed, 55 insertions(+), 26 deletions(-) diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py index 3ffd8990c..dd9d0909c 100644 --- a/data/aligned_dataset.py +++ b/data/aligned_dataset.py @@ -4,7 +4,6 @@ from data.utils import load_image from data.image_folder import make_dataset from PIL import Image -import tifffile class AlignedDataset(BaseDataset): @@ -34,11 +33,6 @@ def __init__(self, opt, phase, name=""): "aligned dataset: domain A and domain B should have the same number of images" ) - if opt.data_image_bits > 8 and opt.model_input_nc > 1: - self.use_tiff = True # multi-channel images > 8bit - else: - self.use_tiff = False - def __getitem__(self, index): """Return a data point and its metadata information. diff --git a/data/base_dataset.py b/data/base_dataset.py index 15869a64a..ce9023fc2 100644 --- a/data/base_dataset.py +++ b/data/base_dataset.py @@ -29,6 +29,7 @@ import imgaug.augmenters as iaa import os import warnings +import tifffile class BaseDataset(data.Dataset, ABC): @@ -63,6 +64,11 @@ def __init__(self, opt, phase, name=""): self.warning_mode = self.opt.warning_mode self.set_dataset_dirs_and_dims() + if opt.data_image_bits > 8 and opt.model_input_nc > 1: + self.use_tiff = True # multi-channel images > 8bit + else: + self.use_tiff = False + @staticmethod def modify_commandline_options(parser, is_train): """Add new dataset-specific options, and rewrite default values for existing options. diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py index e4f562124..081b809d1 100644 --- a/data/unaligned_dataset.py +++ b/data/unaligned_dataset.py @@ -1,5 +1,5 @@ import os.path -from data.base_dataset import BaseDataset, get_transform +from data.base_dataset import BaseDataset, get_transform, get_params from data.utils import load_image from data.image_folder import make_dataset, make_ref_path_list from PIL import Image @@ -35,8 +35,12 @@ def __init__(self, opt, phase, name=""): self.A_size = len(self.A_img_paths) # get the size of dataset A self.B_size = len(self.B_img_paths) # get the size of dataset B - self.transform_A = get_transform(self.opt, grayscale=(self.input_nc == 1)) - self.transform_B = get_transform(self.opt, grayscale=(self.output_nc == 1)) + if self.opt.data_image_bits == 8: + self.grayscale = self.input_nc == 1 + else: # for > 8bit, no explicit conversion + self.grayscale = False + + A = load_image(self.A_img_paths[0]) # temporarily load first image self.header = ["img"] @@ -56,11 +60,24 @@ def get_img( B_label_cls, index, ): - A_img = load_image(A_img_path) - B_img = load_image(B_img_path) + A_img = load_image(A_img_path, self.opt.data_image_bits, self.use_tiff) + B_img = load_image(B_img_path, self.opt.data_image_bits, self.use_tiff) + + if self.use_tiff: + transform_params = get_params(self.opt, A_img[:2]) + else: + transform_params = get_params(self.opt, A_img.size) + + transform_A = get_transform( + self.opt, params=transform_params, grayscale=self.grayscale + ) + transform_B = get_transform( + self.opt, params=transform_params, grayscale=self.grayscale + ) + # apply image transformation - A = self.transform_A(A_img) - B = self.transform_B(B_img) + A = transform_A(A_img) + B = transform_B(B_img) result = { "A": A, diff --git a/data/utils.py b/data/utils.py index 8097b146d..4a1a6320b 100644 --- a/data/utils.py +++ b/data/utils.py @@ -1,22 +1,34 @@ from PIL import Image -def load_image(img_path): - image = Image.open(img_path) +def load_image(img_path, img_bits=8, use_tiff=False): + if use_tiff: + img = tifffile.imread(img_path) + else: + img = Image.open(img_path) - # Define the color for transparency (e.g., transparent black) - transparent_black = (0, 0, 0, 0) + if img_bits == 8: + img = img.convert("RGB") - # Convert the image to RGBA mode if needed - image = image.convert("RGBA") + return img - # Create a new image with the specified color for transparency - transparent_color = Image.new("RGBA", image.size, transparent_black) - # Use alpha_composite to make the specified color transparent - result = Image.alpha_composite(transparent_color, image) +# def load_image(img_path): +# image = Image.open(img_path) - # Convert the result back to RGB mode - result_rgb = result.convert("RGB") +# # Define the color for transparency (e.g., transparent black) +# transparent_black = (0, 0, 0, 0) - return result_rgb +# # Convert the image to RGBA mode if needed +# image = image.convert("RGBA") + +# # Create a new image with the specified color for transparency +# transparent_color = Image.new("RGBA", image.size, transparent_black) + +# # Use alpha_composite to make the specified color transparent +# result = Image.alpha_composite(transparent_color, image) + +# # Convert the result back to RGB mode +# result_rgb = result.convert("RGB") + +# return result_rgb