diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index fcba9dcd6bd..27616e5515f 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -728,20 +728,44 @@ class ColorJitter(object): """Randomly change the brightness, contrast and saturation of an image. Args: - brightness (float): How much to jitter brightness. brightness_factor - is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. - contrast (float): How much to jitter contrast. contrast_factor - is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. - saturation (float): How much to jitter saturation. saturation_factor - is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. - hue(float): How much to jitter hue. hue_factor is chosen uniformly from - [-hue, hue]. Should be >=0 and <= 0.5. + brightness (float or tuple of float (min, max)): How much to jitter brightness. + brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] + or the given [min, max]. Should be non negative numbers. + contrast (float or tuple of float (min, max)): How much to jitter contrast. + contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] + or the given [min, max]. Should be non negative numbers. + saturation (float or tuple of float (min, max)): How much to jitter saturation. + saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] + or the given [min, max]. Should be non negative numbers. + hue (float or tuple of float (min, max)): How much to jitter hue. + hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. + Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. """ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): - self.brightness = brightness - self.contrast = contrast - self.saturation = saturation - self.hue = hue + self.brightness = self._check_input(brightness, 'brightness') + self.contrast = self._check_input(contrast, 'contrast') + self.saturation = self._check_input(saturation, 'saturation') + self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), + clip_first_on_zero=False) + + def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError("If {} is a single number, it must be non negative.".format(name)) + value = [center - value, center + value] + if clip_first_on_zero: + value[0] = max(value[0], 0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError("{} values should be between {}".format(name, bound)) + else: + raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + value = None + return value @staticmethod def get_params(brightness, contrast, saturation, hue): @@ -754,20 +778,21 @@ def get_params(brightness, contrast, saturation, hue): saturation in a random order. """ transforms = [] - if brightness > 0: - brightness_factor = random.uniform(max(0, 1 - brightness), 1 + brightness) + + if brightness is not None: + brightness_factor = random.uniform(brightness[0], brightness[1]) transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) - if contrast > 0: - contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast) + if contrast is not None: + contrast_factor = random.uniform(contrast[0], contrast[1]) transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) - if saturation > 0: - saturation_factor = random.uniform(max(0, 1 - saturation), 1 + saturation) + if saturation is not None: + saturation_factor = random.uniform(saturation[0], saturation[1]) transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) - if hue > 0: - hue_factor = random.uniform(-hue, hue) + if hue is not None: + hue_factor = random.uniform(hue[0], hue[1]) transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor))) random.shuffle(transforms)