Skip to content

Commit

Permalink
Revert "Change all torch::nn::init::Nonlinearity::{name} and torch::n…
Browse files Browse the repository at this point in the history
…n::init::FanMode::{name} to torch::k{name} (#1394)" (#1428)

This reverts commit 8c3cea7.
  • Loading branch information
fmassa authored Oct 8, 2019
1 parent 2060576 commit ef0ffb8
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 6 deletions.
5 changes: 4 additions & 1 deletion torchvision/csrc/models/mnasnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,10 @@ void MNASNetImpl::_initialize_weights() {
for (auto& module : modules(/*include_self=*/false)) {
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get()))
torch::nn::init::kaiming_normal_(
M->weight, 0, torch::kFanOut, torch::kReLU);
M->weight,
0,
torch::nn::init::FanMode::FanOut,
torch::nn::init::Nonlinearity::ReLU);
else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
torch::nn::init::ones_(M->weight);
torch::nn::init::zeros_(M->bias);
Expand Down
3 changes: 2 additions & 1 deletion torchvision/csrc/models/mobilenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ MobileNetV2Impl::MobileNetV2Impl(

for (auto& module : modules(/*include_self=*/false)) {
if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) {
torch::nn::init::kaiming_normal_(M->weight, 0, torch::kFanOut);
torch::nn::init::kaiming_normal_(
M->weight, 0, torch::nn::init::FanMode::FanOut);
if (M->options.with_bias())
torch::nn::init::zeros_(M->bias);
} else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
Expand Down
4 changes: 2 additions & 2 deletions torchvision/csrc/models/resnet.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ ResNetImpl<Block>::ResNetImpl(
torch::nn::init::kaiming_normal_(
M->weight,
/*a=*/0,
torch::kFanOut,
torch::kReLU);
torch::nn::init::FanMode::FanOut,
torch::nn::init::Nonlinearity::ReLU);
else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
torch::nn::init::constant_(M->weight, 1);
torch::nn::init::constant_(M->bias, 0);
Expand Down
4 changes: 2 additions & 2 deletions torchvision/csrc/models/vgg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ void VGGImpl::_initialize_weights() {
torch::nn::init::kaiming_normal_(
M->weight,
/*a=*/0,
torch::kFanOut,
torch::kReLU);
torch::nn::init::FanMode::FanOut,
torch::nn::init::Nonlinearity::ReLU);
torch::nn::init::constant_(M->bias, 0);
} else if (auto M = dynamic_cast<torch::nn::BatchNormImpl*>(module.get())) {
torch::nn::init::constant_(M->weight, 1);
Expand Down

0 comments on commit ef0ffb8

Please sign in to comment.