You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The preferred way to structure the Transformation classes is to put the initialization of random weights/params in a static get_params() method. The method should receive any hyper parameter necessary for the sampling and it should return all the necessary random variables. This method should be called by forward() during the transformation process. This is an example of how this would look:
There might be potentially others. We should refactor the codebase so that all of the above calls happen within a static get_params() method. See #3065RandomInvert for an example on how to structure it.
datumbox
changed the title
Refactor Transformations to avoid calling torch.rand on forward()
Avoid calling torch.rand in Transformation.forward()
Nov 30, 2020
The preferred way to structure the Transformation classes is to put the initialization of random weights/params in a static
get_params()
method. The method should receive any hyper parameter necessary for the sampling and it should return all the necessary random variables. This method should be called byforward()
during the transformation process. This is an example of how this would look:vision/torchvision/transforms/transforms.py
Lines 530 to 531 in 9e71fda
vision/torchvision/transforms/transforms.py
Line 589 in 9e71fda
Unfortunately many
forward()
methods call directlytorch.rand
. Here are a few examples:vision/torchvision/transforms/transforms.py
Lines 452 to 453 in 9e71fda
vision/torchvision/transforms/transforms.py
Line 619 in 9e71fda
vision/torchvision/transforms/transforms.py
Line 649 in 9e71fda
vision/torchvision/transforms/transforms.py
Line 700 in 9e71fda
vision/torchvision/transforms/transforms.py
Line 1454 in 9e71fda
vision/torchvision/transforms/transforms.py
Line 1560 in 9e71fda
There might be potentially others. We should refactor the codebase so that all of the above calls happen within a static
get_params()
method. See #3065RandomInvert
for an example on how to structure it.cc @vfdev-5
The text was updated successfully, but these errors were encountered: