Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Investigate inconsistent casting inside functional_tensor.py #3067

Closed
datumbox opened this issue Nov 30, 2020 · 7 comments · Fixed by #3472
Closed

Investigate inconsistent casting inside functional_tensor.py #3067

datumbox opened this issue Nov 30, 2020 · 7 comments · Fixed by #3472

Comments

@datumbox
Copy link
Contributor

datumbox commented Nov 30, 2020

The operators in functional_tensor.py perform casting in two ways:

  • Using the tensor.to(dtype=dtype) PyTorch method
  • Using the convert_image_dtype() Transformation method

The first method does direct casting from one type to the other. The latter method has more complex logic that handles corner-cases and performs rescaling. Sometimes both are used on the same operator, for example:

result = img
dtype = img.dtype
if not torch.is_floating_point(img):
result = convert_image_dtype(result, torch.float32)
result = (gain * result ** gamma).clamp(0, 1)
result = convert_image_dtype(result, dtype)
result = result.to(dtype)
return result

We should investigate if the use of the two different approaches across operators is justified and fix any potential inconsistencies.

cc @vfdev-5

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Nov 30, 2020

@datumbox it is a bad merge of two PRs I made :) We can remove the line 405 result = result.to(dtype)

https://github.com/pytorch/vision/pull/2485/files#diff-497b983ae7c82237c4a0722bc6f6637525e0cdf8369dbd78ee19cf986bffe258R404

Thanks for catching !

We should investigate if the use of the two different approaches across operators is justified and fix any potential inconsistencies.

I'd say two approaches can exist depending on if we'd like to rescale or not when changing the dtype. For example, for color adjustments seems like it makes sense to rescale between 0-1 before applying the op.
On the other hand, when we change the dtype to float32 in resize/affine etc which use interpolate or grid_sample. I think rescaling will be an overhead to already slow computations...
What do you think ?

@datumbox
Copy link
Contributor Author

datumbox commented Dec 1, 2020

@vfdev-5 Thanks for the reply.

I agree with you that some cases can be handled only by one or the other. As you said, the main focus of this ticket is to investigate the current uses and simplify the code where necessary. :)

@sanketsans
Copy link
Contributor

@datumbox I'd like to contribute to this.
I tried to run the code and check. I first removed the line no. 405 :
result = result.to(dtype)

and later removed line no. 404 and keeping the line 405:
result = convert_image_dtype(result, dtype)

Both the times, the images gets successfully converted to the tensor type. I tried with double tensor, float tensor, uint8 image and float32 image.

Also, I tried by removing both the lines and check. Since the image provided is already in a tensor type, so maybe we don't need to convert it at all.
Besides, I also check by converting the img to a uint8 tensor and pass to the function, it also gets successfully converted by the line no. 400.

result = convert_image_dtype(result, torch.float32)

I'd like to know if further cases need to be tried to validate more on this.
Thanks.

@datumbox
Copy link
Contributor Author

@sanketsans I believe only 405 needs to be removed. Line 404 is necessary to convert the image back to its original type (in case someone passed a uint8). @vfdev-5 Could you confirm?

It might be worth sending a PR with all the proposed clean ups. Any simplification of the current logic is useful I think.

@sanketsans
Copy link
Contributor

@datumbox I tried by sending a uint8 image. But then it is taken care by the line 400.

 if not torch.is_floating_point(img): 
     result = convert_image_dtype(result, torch.float32) 

That makes it back to a tensor type.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Feb 24, 2021

Looking again to the code, I'd do the following:

    result = img
    dtype = img.dtype
    if not torch.is_floating_point(img):
        result = convert_image_dtype(result, torch.float32)

    result = (gain * result ** gamma).clamp(0, 1)

    if result.dtype != dtype:
        result = convert_image_dtype(result, dtype)
    return result

What do you think @datumbox ?

@datumbox
Copy link
Contributor Author

@vfdev-5 Yeah that works fine.

Could any of you send a PR that fixes this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants