diff --git a/dataset.py b/dataset.py index a5a8e32718..e98f874042 100755 --- a/dataset.py +++ b/dataset.py @@ -274,8 +274,9 @@ def __call__(self, batch): if self.keep_ratio_with_pad: # same concept with 'Rosetta' paper resized_max_w = self.imgW - transform = NormalizePAD((1, self.imgH, resized_max_w)) - + input_channel = 3 if images[0].mode == 'RGB' else 1 + transform = NormalizePAD((input_channel, self.imgH, resized_max_w)) + resized_images = [] for image in images: w, h = image.size