Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define all C++ model constructors explicit #2944

Merged
merged 2 commits into from
Nov 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchvision/csrc/models/alexnet.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace models {
struct VISION_API AlexNetImpl : torch::nn::Module {
torch::nn::Sequential features{nullptr}, classifier{nullptr};

AlexNetImpl(int64_t num_classes = 1000);
explicit AlexNetImpl(int64_t num_classes = 1000);

torch::Tensor forward(torch::Tensor x);
};
Expand Down
10 changes: 5 additions & 5 deletions torchvision/csrc/models/densenet.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct VISION_API DenseNetImpl : torch::nn::Module {
torch::nn::Sequential features{nullptr};
torch::nn::Linear classifier{nullptr};

DenseNetImpl(
explicit DenseNetImpl(
int64_t num_classes = 1000,
int64_t growth_rate = 32,
const std::vector<int64_t>& block_config = {6, 12, 24, 16},
Expand All @@ -35,7 +35,7 @@ struct VISION_API DenseNetImpl : torch::nn::Module {
};

struct VISION_API DenseNet121Impl : DenseNetImpl {
DenseNet121Impl(
explicit DenseNet121Impl(
int64_t num_classes = 1000,
int64_t growth_rate = 32,
const std::vector<int64_t>& block_config = {6, 12, 24, 16},
Expand All @@ -45,7 +45,7 @@ struct VISION_API DenseNet121Impl : DenseNetImpl {
};

struct VISION_API DenseNet169Impl : DenseNetImpl {
DenseNet169Impl(
explicit DenseNet169Impl(
int64_t num_classes = 1000,
int64_t growth_rate = 32,
const std::vector<int64_t>& block_config = {6, 12, 32, 32},
Expand All @@ -55,7 +55,7 @@ struct VISION_API DenseNet169Impl : DenseNetImpl {
};

struct VISION_API DenseNet201Impl : DenseNetImpl {
DenseNet201Impl(
explicit DenseNet201Impl(
int64_t num_classes = 1000,
int64_t growth_rate = 32,
const std::vector<int64_t>& block_config = {6, 12, 48, 32},
Expand All @@ -65,7 +65,7 @@ struct VISION_API DenseNet201Impl : DenseNetImpl {
};

struct VISION_API DenseNet161Impl : DenseNetImpl {
DenseNet161Impl(
explicit DenseNet161Impl(
int64_t num_classes = 1000,
int64_t growth_rate = 48,
const std::vector<int64_t>& block_config = {6, 12, 36, 24},
Expand Down
4 changes: 2 additions & 2 deletions torchvision/csrc/models/googlenet.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct VISION_API BasicConv2dImpl : torch::nn::Module {
torch::nn::Conv2d conv{nullptr};
torch::nn::BatchNorm2d bn{nullptr};

BasicConv2dImpl(torch::nn::Conv2dOptions options);
explicit BasicConv2dImpl(torch::nn::Conv2dOptions options);

torch::Tensor forward(torch::Tensor x);
};
Expand Down Expand Up @@ -71,7 +71,7 @@ struct VISION_API GoogLeNetImpl : torch::nn::Module {
torch::nn::Dropout dropout{nullptr};
torch::nn::Linear fc{nullptr};

GoogLeNetImpl(
explicit GoogLeNetImpl(
int64_t num_classes = 1000,
bool aux_logits = true,
bool transform_input = false,
Expand Down
12 changes: 7 additions & 5 deletions torchvision/csrc/models/inception.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ struct VISION_API BasicConv2dImpl : torch::nn::Module {
torch::nn::Conv2d conv{nullptr};
torch::nn::BatchNorm2d bn{nullptr};

BasicConv2dImpl(torch::nn::Conv2dOptions options, double std_dev = 0.1);
explicit BasicConv2dImpl(
torch::nn::Conv2dOptions options,
double std_dev = 0.1);

torch::Tensor forward(torch::Tensor x);
};
Expand All @@ -30,7 +32,7 @@ struct VISION_API InceptionAImpl : torch::nn::Module {
struct VISION_API InceptionBImpl : torch::nn::Module {
BasicConv2d branch3x3, branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3;

InceptionBImpl(int64_t in_channels);
explicit InceptionBImpl(int64_t in_channels);

torch::Tensor forward(const torch::Tensor& x);
};
Expand All @@ -50,7 +52,7 @@ struct VISION_API InceptionDImpl : torch::nn::Module {
BasicConv2d branch3x3_1, branch3x3_2, branch7x7x3_1, branch7x7x3_2,
branch7x7x3_3, branch7x7x3_4;

InceptionDImpl(int64_t in_channels);
explicit InceptionDImpl(int64_t in_channels);

torch::Tensor forward(const torch::Tensor& x);
};
Expand All @@ -60,7 +62,7 @@ struct VISION_API InceptionEImpl : torch::nn::Module {
branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3a, branch3x3dbl_3b,
branch_pool;

InceptionEImpl(int64_t in_channels);
explicit InceptionEImpl(int64_t in_channels);

torch::Tensor forward(const torch::Tensor& x);
};
Expand Down Expand Up @@ -110,7 +112,7 @@ struct VISION_API InceptionV3Impl : torch::nn::Module {

_inceptionimpl::InceptionAux AuxLogits{nullptr};

InceptionV3Impl(
explicit InceptionV3Impl(
int64_t num_classes = 1000,
bool aux_logits = true,
bool transform_input = false);
Expand Down
13 changes: 8 additions & 5 deletions torchvision/csrc/models/mnasnet.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,28 @@ struct VISION_API MNASNetImpl : torch::nn::Module {

void _initialize_weights();

MNASNetImpl(double alpha, int64_t num_classes = 1000, double dropout = .2);
explicit MNASNetImpl(
double alpha,
int64_t num_classes = 1000,
double dropout = .2);

torch::Tensor forward(torch::Tensor x);
};

struct MNASNet0_5Impl : MNASNetImpl {
MNASNet0_5Impl(int64_t num_classes = 1000, double dropout = .2);
explicit MNASNet0_5Impl(int64_t num_classes = 1000, double dropout = .2);
};

struct MNASNet0_75Impl : MNASNetImpl {
MNASNet0_75Impl(int64_t num_classes = 1000, double dropout = .2);
explicit MNASNet0_75Impl(int64_t num_classes = 1000, double dropout = .2);
};

struct MNASNet1_0Impl : MNASNetImpl {
MNASNet1_0Impl(int64_t num_classes = 1000, double dropout = .2);
explicit MNASNet1_0Impl(int64_t num_classes = 1000, double dropout = .2);
};

struct MNASNet1_3Impl : MNASNetImpl {
MNASNet1_3Impl(int64_t num_classes = 1000, double dropout = .2);
explicit MNASNet1_3Impl(int64_t num_classes = 1000, double dropout = .2);
};

TORCH_MODULE(MNASNet);
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/models/mobilenet.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ struct VISION_API MobileNetV2Impl : torch::nn::Module {
int64_t last_channel;
torch::nn::Sequential features, classifier;

MobileNetV2Impl(
explicit MobileNetV2Impl(
int64_t num_classes = 1000,
double width_mult = 1.0,
std::vector<std::vector<int64_t>> inverted_residual_settings = {},
Expand Down
30 changes: 20 additions & 10 deletions torchvision/csrc/models/resnet.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ struct ResNetImpl : torch::nn::Module {
int64_t blocks,
int64_t stride = 1);

ResNetImpl(
explicit ResNetImpl(
const std::vector<int>& layers,
int64_t num_classes = 1000,
bool zero_init_residual = false,
Expand Down Expand Up @@ -186,45 +186,55 @@ torch::Tensor ResNetImpl<Block>::forward(torch::Tensor x) {
}

struct VISION_API ResNet18Impl : ResNetImpl<_resnetimpl::BasicBlock> {
ResNet18Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
explicit ResNet18Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};

struct VISION_API ResNet34Impl : ResNetImpl<_resnetimpl::BasicBlock> {
ResNet34Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
explicit ResNet34Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};

struct VISION_API ResNet50Impl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNet50Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
explicit ResNet50Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};

struct VISION_API ResNet101Impl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNet101Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
explicit ResNet101Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};

struct VISION_API ResNet152Impl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNet152Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
explicit ResNet152Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};

struct VISION_API ResNext50_32x4dImpl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNext50_32x4dImpl(
explicit ResNext50_32x4dImpl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};

struct VISION_API ResNext101_32x8dImpl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNext101_32x8dImpl(
explicit ResNext101_32x8dImpl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};

struct VISION_API WideResNet50_2Impl : ResNetImpl<_resnetimpl::Bottleneck> {
WideResNet50_2Impl(
explicit WideResNet50_2Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};

struct VISION_API WideResNet101_2Impl : ResNetImpl<_resnetimpl::Bottleneck> {
WideResNet101_2Impl(
explicit WideResNet101_2Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};
Expand Down
8 changes: 4 additions & 4 deletions torchvision/csrc/models/shufflenetv2.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@ struct VISION_API ShuffleNetV2Impl : torch::nn::Module {
};

struct VISION_API ShuffleNetV2_x0_5Impl : ShuffleNetV2Impl {
ShuffleNetV2_x0_5Impl(int64_t num_classes = 1000);
explicit ShuffleNetV2_x0_5Impl(int64_t num_classes = 1000);
};

struct VISION_API ShuffleNetV2_x1_0Impl : ShuffleNetV2Impl {
ShuffleNetV2_x1_0Impl(int64_t num_classes = 1000);
explicit ShuffleNetV2_x1_0Impl(int64_t num_classes = 1000);
};

struct VISION_API ShuffleNetV2_x1_5Impl : ShuffleNetV2Impl {
ShuffleNetV2_x1_5Impl(int64_t num_classes = 1000);
explicit ShuffleNetV2_x1_5Impl(int64_t num_classes = 1000);
};

struct VISION_API ShuffleNetV2_x2_0Impl : ShuffleNetV2Impl {
ShuffleNetV2_x2_0Impl(int64_t num_classes = 1000);
explicit ShuffleNetV2_x2_0Impl(int64_t num_classes = 1000);
};

TORCH_MODULE(ShuffleNetV2);
Expand Down
6 changes: 3 additions & 3 deletions torchvision/csrc/models/squeezenet.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ struct VISION_API SqueezeNetImpl : torch::nn::Module {
int64_t num_classes;
torch::nn::Sequential features{nullptr}, classifier{nullptr};

SqueezeNetImpl(double version = 1.0, int64_t num_classes = 1000);
explicit SqueezeNetImpl(double version = 1.0, int64_t num_classes = 1000);

torch::Tensor forward(torch::Tensor x);
};
Expand All @@ -19,15 +19,15 @@ struct VISION_API SqueezeNetImpl : torch::nn::Module {
// accuracy with 50x fewer parameters and <0.5MB model size"
// <https://arxiv.org/abs/1602.07360> paper.
struct VISION_API SqueezeNet1_0Impl : SqueezeNetImpl {
SqueezeNet1_0Impl(int64_t num_classes = 1000);
explicit SqueezeNet1_0Impl(int64_t num_classes = 1000);
};

// SqueezeNet 1.1 model from the official SqueezeNet repo
// <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>.
// SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
// than SqueezeNet 1.0, without sacrificing accuracy.
struct VISION_API SqueezeNet1_1Impl : SqueezeNetImpl {
SqueezeNet1_1Impl(int64_t num_classes = 1000);
explicit SqueezeNet1_1Impl(int64_t num_classes = 1000);
};

TORCH_MODULE(SqueezeNet);
Expand Down
34 changes: 25 additions & 9 deletions torchvision/csrc/models/vgg.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ struct VISION_API VGGImpl : torch::nn::Module {

void _initialize_weights();

VGGImpl(
explicit VGGImpl(
const torch::nn::Sequential& features,
int64_t num_classes = 1000,
bool initialize_weights = true);
Expand All @@ -21,42 +21,58 @@ struct VISION_API VGGImpl : torch::nn::Module {

// VGG 11-layer model (configuration "A")
struct VISION_API VGG11Impl : VGGImpl {
VGG11Impl(int64_t num_classes = 1000, bool initialize_weights = true);
explicit VGG11Impl(
int64_t num_classes = 1000,
bool initialize_weights = true);
};

// VGG 13-layer model (configuration "B")
struct VISION_API VGG13Impl : VGGImpl {
VGG13Impl(int64_t num_classes = 1000, bool initialize_weights = true);
explicit VGG13Impl(
int64_t num_classes = 1000,
bool initialize_weights = true);
};

// VGG 16-layer model (configuration "D")
struct VISION_API VGG16Impl : VGGImpl {
VGG16Impl(int64_t num_classes = 1000, bool initialize_weights = true);
explicit VGG16Impl(
int64_t num_classes = 1000,
bool initialize_weights = true);
};

// VGG 19-layer model (configuration "E")
struct VISION_API VGG19Impl : VGGImpl {
VGG19Impl(int64_t num_classes = 1000, bool initialize_weights = true);
explicit VGG19Impl(
int64_t num_classes = 1000,
bool initialize_weights = true);
};

// VGG 11-layer model (configuration "A") with batch normalization
struct VISION_API VGG11BNImpl : VGGImpl {
VGG11BNImpl(int64_t num_classes = 1000, bool initialize_weights = true);
explicit VGG11BNImpl(
int64_t num_classes = 1000,
bool initialize_weights = true);
};

// VGG 13-layer model (configuration "B") with batch normalization
struct VISION_API VGG13BNImpl : VGGImpl {
VGG13BNImpl(int64_t num_classes = 1000, bool initialize_weights = true);
explicit VGG13BNImpl(
int64_t num_classes = 1000,
bool initialize_weights = true);
};

// VGG 16-layer model (configuration "D") with batch normalization
struct VISION_API VGG16BNImpl : VGGImpl {
VGG16BNImpl(int64_t num_classes = 1000, bool initialize_weights = true);
explicit VGG16BNImpl(
int64_t num_classes = 1000,
bool initialize_weights = true);
};

// VGG 19-layer model (configuration 'E') with batch normalization
struct VISION_API VGG19BNImpl : VGGImpl {
VGG19BNImpl(int64_t num_classes = 1000, bool initialize_weights = true);
explicit VGG19BNImpl(
int64_t num_classes = 1000,
bool initialize_weights = true);
};

TORCH_MODULE(VGG);
Expand Down