-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ColorJitter Enhancement #548
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
I've left a few comments that I think could improve the functionality and make it less error-prone if in the future we want to extend its functionality.
torchvision/transforms/transforms.py
Outdated
@@ -754,20 +758,53 @@ def get_params(brightness, contrast, saturation, hue): | |||
saturation in a random order. | |||
""" | |||
transforms = [] | |||
if brightness > 0: | |||
brightness_factor = None | |||
if isinstance(brightness, float) and brightness >= 0: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torchvision/transforms/transforms.py
Outdated
brightness_factor = random.uniform(max(0, 1 - brightness), 1 + brightness) | ||
elif isinstance(brightness, tuple) and len(brightness) == 2: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
I've improved it according to your advice. Please review it. @fmassa |
@yaox12 just wonder in which case the following conditions are met like if brightness_factor is not None: in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made another round of comments. Let me know what you think
torchvision/transforms/transforms.py
Outdated
@@ -753,21 +757,39 @@ def get_params(brightness, contrast, saturation, hue): | |||
Transform which randomly adjusts brightness, contrast and | |||
saturation in a random order. | |||
""" | |||
def _sample_from(value, name): | |||
factor = None |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torchvision/transforms/transforms.py
Outdated
factor = None | ||
if isinstance(value, numbers.Number) and value >= 0: | ||
if name == 'hue': | ||
factor = random.uniform(-value, value) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torchvision/transforms/transforms.py
Outdated
brightness_factor = random.uniform(max(0, 1 - brightness), 1 + brightness) | ||
|
||
brightness_factor = _sample_from(brightness, 'brightness') | ||
if brightness_factor is not None: | ||
transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
In the previous code, if |
@yaox12 Maybe something like in the previous code could help. Previously, if brightness/saturation/... value is 0 than ignore it, we can replace it with None and check as previously : if saturation is not None:
saturation_factor = _sample_from(saturation, 'saturation')
transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) or if we want to keep it as before if saturation > 0:
saturation_factor = _sample_from(saturation, 'saturation')
transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) HTH |
Check input in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Set custom min and max bounds by inputing a tuple for each parameter in ColorJitter.
See issue #545