-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
using np.random.RandomState(seed)
instead of np.random.seed(seed)
#4250
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test/test_datasets.py
Outdated
@@ -21,6 +21,8 @@ | |||
import torch.nn.functional as F | |||
from torchvision import datasets | |||
|
|||
random_state_numpy = np.random.RandomState(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should create a local RandomState object in each test function instead of having a global one in each file. Otherwise, tests are still dependent on each other withing a single module.
Also, no strong opinion on that but np_rng
would be shorter and still be a descriptive name, so I'd suggest to use that instead
torchvision/models/inception.py
Outdated
@@ -123,7 +125,9 @@ def __init__( | |||
import scipy.stats as stats | |||
stddev = m.stddev if hasattr(m, 'stddev') else 0.1 | |||
X = stats.truncnorm(-2, 2, scale=stddev) | |||
values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype) | |||
values = torch.as_tensor( | |||
X.rvs(m.weight.numel(), random_state=random_state_numpy), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, during our convo I had missed that the rvs
was within the model's code, I thought it was just in the test. That's a bigger problem than I thought
@fmassa, googlenet and inception call scipy's rvs methods, which draw from numpy's RNG.
That's fairly unexpected I think. Shouldn't we just be relying on pytorch's RNG? I'm thinking of the following workarounds:
- add a new np_random_state parameter to the constructor to control that RNG
- rely on torch instead of numpy to draw samples from a truncated normal.
I think the second would make much more sense, although I don't know how easy this will be. WDYT?
As a temporary workaround we could use a pytest fixture that sets numpy's RNG and restores it as e.g. in https://gist.github.com/VictorDarvariu/6cede9c79900c6215b5f848993d283c6, but ugh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's fairly unexpected I think. Shouldn't we just be relying on pytorch's RNG? I'm thinking of the following workarounds:
We should, but PyTorch didn't implement trunc_normal
back when we first implemented this model.
It seems now that it has since been implemented in pytorch/pytorch#32397 , so we should replace it to use PyTorch's implementation. It should be fairly straightforward, but it would be good to check if the PyTorch sampling is much slower than scipy's or not, and to make this change in a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vmoens Would you like a submit a PR to change googlenet and inception to rely on torch.nn.init.trunc_normal_
instead?
I think we can keep this one on hold until then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure let me do this:
- keep
np.random.seed(0)
inset_seed
for now in this PR - do a new PR where
np.random.seed(0)
is taken away and we usetrunc_normal_
for inception
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To keep matters separate, we should aim at removing all numpy seedings in one go, so the first step might unnecessary. We'll need to get rid of the rvs
calls first before merging this PR IMO
torchvision/models/inception.py
Outdated
@@ -69,7 +70,8 @@ def __init__( | |||
aux_logits: bool = True, | |||
transform_input: bool = False, | |||
inception_blocks: Optional[List[Callable[..., nn.Module]]] = None, | |||
init_weights: Optional[bool] = None | |||
init_weights: Optional[bool] = None, | |||
random_state_numpy: np.random.RandomState = np.random.RandomState(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we ever add a new parameter (I'm not sure we should for now, but we might have to), the default should be None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hum, we don't pass RNG in any of torchvision functions as of now, so I would rather not do it here, as it would involve a larger discussion.
…o np_randomstate
linter
…o np_randomstate
np.random.RandomState(seed)
instead of np.random.seed(seed)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @vmoens !! I have some minor comments below, but I'll approve now so you can address and merge once ready
test/test_image.py
Outdated
@@ -273,8 +273,9 @@ def test_write_file_non_ascii(): | |||
]) | |||
def test_read_1_bit_png(shape): | |||
with get_tmp_dir() as root: | |||
np_rng = np.random.RandomState(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as a very minor nit (nitpick = feel free not to address): the declaration of np_rng
doesn't need to be within the context manager. Same below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so is the rest of the test :p strictly speaking, only image_path = os.path.join(root, f'test_{shape}.png')
needs to be there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well we do need image_path
to be valid, so technically we could only remove the last 2 lines. But yeah, it's a nit anyway :)
test/test_models.py
Outdated
@@ -193,7 +192,7 @@ def _check_fx_compatible(model, inputs): | |||
# the _test_*_model methods. | |||
_model_params = { | |||
'inception_v3': { | |||
'input_shape': (1, 3, 299, 299) | |||
'input_shape': (1, 3, 299, 299), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realize this is due to previous changes that were reverted, but in general we try to avoid unrelated changes, as it obfuscates git blame. Same for the removal of import sys
above, which on his own is actually relevant (and doesn't hurt git blame), but having lots of those in a single PR makes review more difficult. Would you mind reverting those changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree i'll revert these!
@@ -200,18 +200,20 @@ class TestToTensor: | |||
def test_to_tensor(self, channels): | |||
height, width = 4, 4 | |||
trans = transforms.ToTensor() | |||
np_rng = np.random.RandomState(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove this one, as you're already declaring another RandomState below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my bad yeah i missed that one
@@ -225,22 +227,25 @@ def test_to_tensor(self, channels): | |||
def test_to_tensor_errors(self): | |||
height, width = 4, 4 | |||
trans = transforms.ToTensor() | |||
np_rng = np.random.RandomState(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove this one too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think they come from a rather chaotic git revert
:) should have checked before pushing though
Hey @vmoens! You merged this PR, but no labels were added. |
…ed(seed)` (#4250) Summary: Co-authored-by: Vincent Moens <vmoens@fb.com> Reviewed By: NicolasHug Differential Revision: D30417196 fbshipit-source-id: f53bc950aea4935c164939cab0e14b266e3dd1cb
closing #4247
This PR makes use of
np.RandomState()
locally in the tests instead of the previousnp.random.seed()
, so as to contain numpy's RNG locally to test functions instead of leaking the RNG seed globally.