-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use torch instead of scipy for random initialization of inception and googlenet weights #4256
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @vmoens ,
I realize this is still WIP but I'm seeing lots of changed pkl files: according to https://app.circleci.com/pipelines/github/pytorch/vision/9719/workflows/9ad29d07-4475-418c-88e9-54e8f09b8c5c/jobs/721968 it seems that only one test was failing, so we probably don't need to update all of those.
Also it seems like we still have a typing error for float(m.stddev)
ebdde0a
to
80f7c10
Compare
4d3bf4e
to
e3a84ef
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @vmoens !
I tagged as BC breaking because the results might change slightly, but it's fine as we don't have guarantees w.r.t. randomness
@vmoens Thanks for the PR. The change looks useful. Though it's correctly marked as BC-breaking, hopefully the two methods don't produce significantly different results and the average user should be OK. @NicolasHug It might be worth confirming that a model initialized under the new scheme does not diverge. A rudimental check in this case would be to run the model for 1-2 epochs and confirm that the loss decreases on the new branch. Is this something we ran already or plan to run on the future? |
…ption and googlenet weights (#4256) Summary: using nn.init.trunc_normal_ instead of scipy.stats.truncnorm Reviewed By: NicolasHug Differential Revision: D30417203 fbshipit-source-id: 6b04f6bf7f6d30dfbc65980a4036a9dc539e4651 Co-authored-by: Vincent Moens <vmoens@fb.com>
Follow-up to #4250 (comment)
weights generated with
truncnorm
from scipy are conditioned by np global seed (or a RandomState). We want to avoid settingnp.random.seed
in our tests and also we don't want to pass a RandomState object in the model args. As such, we move to generating weights withtorch.nn.init.trunc_normal_
instead, which is conditioned bytorch.manual_seed
.