-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
using np.random.RandomState(seed)
instead of np.random.seed(seed)
#4250
Changes from 4 commits
193b3b6
d5dda53
6659e82
31bd6e3
746d612
65dcd02
d715cb4
6e22a1d
47cbd1c
0aeac4d
778dea9
674b4b1
5b0a71b
5511cdd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,10 +14,13 @@ | |
import pytest | ||
import warnings | ||
import traceback | ||
import numpy as np | ||
|
||
|
||
ACCEPT = os.getenv('EXPECTTEST_ACCEPT', '0') == '1' | ||
|
||
random_state_numpy = np.random.RandomState(0) | ||
|
||
|
||
def get_available_classification_models(): | ||
# TODO add a registration mechanism to torchvision.models | ||
|
@@ -84,7 +87,7 @@ def _assert_expected(output, name, prec): | |
else: | ||
expected = torch.load(expected_file) | ||
rtol = atol = prec | ||
torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False) | ||
torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False, msg=f'for query {name}') | ||
|
||
|
||
def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False): | ||
|
@@ -193,7 +196,8 @@ def _check_fx_compatible(model, inputs): | |
# the _test_*_model methods. | ||
_model_params = { | ||
'inception_v3': { | ||
'input_shape': (1, 3, 299, 299) | ||
'input_shape': (1, 3, 299, 299), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I realize this is due to previous changes that were reverted, but in general we try to avoid unrelated changes, as it obfuscates git blame. Same for the removal of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree i'll revert these! |
||
'random_state_numpy': random_state_numpy, | ||
}, | ||
'retinanet_resnet50_fpn': { | ||
'num_classes': 20, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
from collections import namedtuple | ||
import warnings | ||
import torch | ||
import numpy as np | ||
from torch import nn, Tensor | ||
import torch.nn.functional as F | ||
from .._internally_replaced_utils import load_state_dict_from_url | ||
|
@@ -69,7 +70,8 @@ def __init__( | |
aux_logits: bool = True, | ||
transform_input: bool = False, | ||
inception_blocks: Optional[List[Callable[..., nn.Module]]] = None, | ||
init_weights: Optional[bool] = None | ||
init_weights: Optional[bool] = None, | ||
random_state_numpy: np.random.RandomState = np.random.RandomState(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we ever add a new parameter (I'm not sure we should for now, but we might have to), the default should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hum, we don't pass RNG in any of torchvision functions as of now, so I would rather not do it here, as it would involve a larger discussion. |
||
) -> None: | ||
super(Inception3, self).__init__() | ||
if inception_blocks is None: | ||
|
@@ -123,7 +125,9 @@ def __init__( | |
import scipy.stats as stats | ||
stddev = m.stddev if hasattr(m, 'stddev') else 0.1 | ||
X = stats.truncnorm(-2, 2, scale=stddev) | ||
values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype) | ||
values = torch.as_tensor( | ||
X.rvs(m.weight.numel(), random_state=random_state_numpy), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, during our convo I had missed that the @fmassa, googlenet and inception call scipy's rvs methods, which draw from numpy's RNG. That's fairly unexpected I think. Shouldn't we just be relying on pytorch's RNG? I'm thinking of the following workarounds:
I think the second would make much more sense, although I don't know how easy this will be. WDYT? As a temporary workaround we could use a pytest fixture that sets numpy's RNG and restores it as e.g. in https://gist.github.com/VictorDarvariu/6cede9c79900c6215b5f848993d283c6, but ugh There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We should, but PyTorch didn't implement There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @vmoens Would you like a submit a PR to change googlenet and inception to rely on I think we can keep this one on hold until then There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure let me do this:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To keep matters separate, we should aim at removing all numpy seedings in one go, so the first step might unnecessary. We'll need to get rid of the |
||
dtype=m.weight.dtype) | ||
values = values.view(m.weight.size()) | ||
with torch.no_grad(): | ||
m.weight.copy_(values) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should create a local RandomState object in each test function instead of having a global one in each file. Otherwise, tests are still dependent on each other withing a single module.
Also, no strong opinion on that but
np_rng
would be shorter and still be a descriptive name, so I'd suggest to use that instead