Skip to content

Commit

Permalink
feat: add choices for all options
Browse files Browse the repository at this point in the history
  • Loading branch information
pnsuau committed Mar 25, 2022
1 parent d73c081 commit ed43b82
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 22 deletions.
4 changes: 2 additions & 2 deletions models/modules/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def forward(self, x):
return out

# torchvision models
model_classes = {
TORCH_MODEL_CLASSES = {
"alexnet": models.alexnet,
"vgg11": models.vgg11,
"vgg11_bn": models.vgg11_bn,
Expand Down Expand Up @@ -235,7 +235,7 @@ def forward(self, x):
class torch_model(nn.Module):
def __init__(self, input_nc, ndf, nclasses, img_size, template, pretrained):
super().__init__()
self.model = model_classes[template](pretrained=pretrained)
self.model = TORCH_MODEL_CLASSES[template](pretrained=pretrained)
self.input_nc = input_nc
self.model.fc = nn.Linear(512, nclasses)

Expand Down
4 changes: 2 additions & 2 deletions models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .modules.resnet_architecture.resnet_generator import ResnetGenerator_attn
from .modules.discriminators import NLayerDiscriminator
from .modules.discriminators import PixelDiscriminator
from .modules.classifiers import Classifier, VGG16_FCN8s, torch_model,model_classes
from .modules.classifiers import Classifier, VGG16_FCN8s, torch_model,TORCH_MODEL_CLASSES
from .modules.UNet_classification import UNet
from .modules.classifiers import Classifier_w
from .modules.fid.pytorch_fid.inception import InceptionV3
Expand Down Expand Up @@ -142,7 +142,7 @@ def define_D(netD, model_input_nc, D_ndf, D_n_layers, D_norm, D_dropout, D_spect
net = PixelDiscriminator(model_input_nc, D_ndf, norm_layer=norm_layer)
elif 'stylegan2' in netD: # global D from sty2 repo
net = StyleGAN2Discriminator(model_input_nc, D_ndf, D_n_layers, no_antialias=D_no_antialias, img_size=data_crop_size,netD=netD)
elif netD in model_classes : # load torchvision model
elif netD in TORCH_MODEL_CLASSES: # load torchvision model
nclasses=1
template=netD
net = torch_model(model_input_nc, D_ndf, nclasses,opt.data_crop_size, template, pretrained=False)
Expand Down
27 changes: 13 additions & 14 deletions options/base_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from argparse import _HelpAction, _SubParsersAction, _StoreConstAction
from util.util import MAX_INT
import json
from models.modules.classifiers import TORCH_MODEL_CLASSES

TRAIN_JSON_FILENAME = "train_config.json"

Expand Down Expand Up @@ -95,18 +96,18 @@ def initialize(self, parser):
parser.add_argument('--ddp_port', type=str, default='12355')

# model parameters
parser.add_argument('--model_type', type=str, default='cycle_gan', help='chooses which model to use. [' + " | ".join(models.get_models_names()) + ']')
parser.add_argument('--model_input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
parser.add_argument('--model_output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
parser.add_argument('--model_init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')
parser.add_argument('--model_type', type=str, default='cut',choices=['cycle_gan','cut','cycle_gan_semantic','cut_semantic','cycle_gan_semantic_mask','cut_semantic_mask'], help='chooses which model to use.')
parser.add_argument('--model_input_nc', type=int, default=3,choices=[1,3], help='# of input image channels: 3 for RGB and 1 for grayscale')
parser.add_argument('--model_output_nc', type=int, default=3,choices=[1,3], help='# of output image channels: 3 for RGB and 1 for grayscale')
parser.add_argument('--model_init_type', type=str, default='normal',choices=['normal','xavier','kaiming','orthogonal'], help='network initialization')
parser.add_argument('--model_init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')

# generator
parser.add_argument('--G_ngf', type=int, default=64, help='# of gen filters in the last conv layer')
parser.add_argument('--G_netG', type=str, default='resnet_attn', help='specify generator architecture [resnet_9blocks | resnet_6blocks | resnet_3blocks | resnet_12blocks | mobile_resnet_9blocks | mobile_resnet_3blocks | resnet_attn | mobile_resnet_attn | unet_256 | unet_128 | stylegan2 | smallstylegan2 | segformer_attn_conv | segformer_conv]')
parser.add_argument('--G_netG', type=str, default='mobile_resnet_attn',choices=['resnet_9blocks', 'resnet_6blocks', 'resnet_3blocks','resnet_12blocks' 'mobile_resnet_9blocks', 'mobile_resnet_3blocks''resnet_attn','mobile_resnet_attn', 'unet_256', 'unet_128','stylegan2','smallstylegan2','segformer_attn_conv' ,'segformer_conv'], help='specify generator architecture')
parser.add_argument('--G_dropout', action='store_true', help='dropout for the generator')
parser.add_argument('--G_spectral', action='store_true', help='whether to use spectral norm in the generator')
parser.add_argument('--G_padding_type', type=str, help='whether to use padding in the generator, zeros or reflect', default='reflect')
parser.add_argument('--G_padding_type', type=str,choices=['reflect','replicate','zero'], help='whether to use padding in the generator', default='reflect')
parser.add_argument('--G_norm', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G')
parser.add_argument('--G_stylegan2_num_downsampling',
default=1, type=int,
Expand All @@ -116,9 +117,9 @@ def initialize(self, parser):
parser.add_argument('--G_attn_nb_mask_input',default=1,type=int)

# discriminator
parser.add_argument('--D_ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
parser.add_argument('--D_netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel] or any torchvision model [resnet18...]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
parser.add_argument('--D_netD_global', type=str, default='none', help='specify discriminator architecture, any torchvision model can be used [resnet18...]. By default no global discriminator will be used.')
parser.add_argument('--D_ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
parser.add_argument('--D_netD', type=str, default='basic',choices=['basic','n_layers','pixel','stylegan2','patchstylegan2','smallpatchstylegan2','projected_d']+ list(TORCH_MODEL_CLASSES.keys()), help='specify discriminator architecture, D_n_layers allows you to specify the layers in the discriminator')
parser.add_argument('--D_netD_global', type=str, default='none',choices=['none','basic','n_layers','pixel','stylegan2','patchstylegan2','smallpatchstylegan2','projected_d']+ list(TORCH_MODEL_CLASSES.keys()), help='specify discriminator architecture, any torchvision model can be used. By default no global discriminator will be used.')
parser.add_argument('--D_n_layers', type=int, default=3, help='only used if netD==n_layers')
parser.add_argument('--D_norm', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for D')
parser.add_argument('--D_dropout', action='store_true', help='whether to use dropout in the discriminator')
Expand All @@ -143,17 +144,15 @@ def initialize(self, parser):


# dataset parameters
parser.add_argument('--data_dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
parser.add_argument('--data_direction', type=str, default='AtoB', help='AtoB or BtoA')
parser.add_argument('--data_dataset_mode', type=str, default='unaligned',choices=['unaligned','unaligned_labeled','unaligned_labeled_mask','unaligned_labeled_mask_online'], help='chooses how datasets are loaded.')
parser.add_argument('--data_direction', type=str, default='AtoB',choices=['AtoB','BtoA'], help='AtoB or BtoA')
parser.add_argument('--data_serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
parser.add_argument('--data_num_threads', default=4, type=int, help='# threads for loading data')

parser.add_argument('--data_load_size', type=int, default=286, help='scale images to this size')
parser.add_argument('--data_crop_size', type=int, default=256, help='then crop to this size')
parser.add_argument('--data_max_dataset_size', type=int, default=MAX_INT, help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
parser.add_argument('--data_preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')


parser.add_argument('--data_preprocess', type=str, default='resize_and_crop',choices=['resize_and_crop','crop','scale_width','scale_width_and_crop','none'], help='scaling and cropping of images at load time')


# Online dataset creation options
Expand Down
8 changes: 4 additions & 4 deletions options/train_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def initialize(self, parser):
parser.add_argument('--train_beta1', type=float, default=0.5, help='momentum term of adam')
parser.add_argument('--train_G_lr', type=float, default=0.0002, help='initial learning rate for generator')
parser.add_argument('--train_D_lr', type=float, default=0.0002, help='discriminator separate learning rate')
parser.add_argument('--train_gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
parser.add_argument('--train_gan_mode', type=str, default='lsgan',choices=['vanilla', 'lsgan', 'wgangp','projected'], help='the type of GAN objective. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
parser.add_argument('--train_pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
parser.add_argument('--train_lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')
parser.add_argument('--train_lr_policy', type=str, default='linear',choices=['linear', 'step','plateau','cosine'], help='learning rate policy.')
parser.add_argument('--train_lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
parser.add_argument('--train_nb_img_max_fid', type=int, default=MAX_INT, help='Maximum number of samples allowed per dataset to compute fid. If the dataset directory contains more than nb_img_max_fid, only a subset is used.')
parser.add_argument('--train_iter_size', type=int, default=1, help='backward will be apllied each iter_size iterations, it simulate a greater batch size : its value is batch_size*iter_size')
Expand All @@ -75,15 +75,15 @@ def initialize(self, parser):
parser.add_argument('--train_mask_no_train_f_s_A', action='store_true', help='if true f_s wont be trained on domain A')
parser.add_argument('--train_mask_out_mask', action='store_true', help='use loss out mask')
parser.add_argument('--train_mask_lambda_out_mask', type=float, default=10.0, help='weight for loss out mask')
parser.add_argument('--train_mask_loss_out_mask', type=str, default='L1', help='loss for mask, L1, MSE or Charbonnier')
parser.add_argument('--train_mask_loss_out_mask', type=str, default='L1',choices=['L1','MSE','Charbonnier'], help='loss for out mask content (which should not change).')
parser.add_argument('--train_mask_charbonnier_eps', type=float, default=1e-6, help='Charbonnier loss epsilon value')
parser.add_argument('--train_mask_disjoint_f_s',action='store_true', help='whether to use a disjoint f_s with the same exact structure')
parser.add_argument('--train_mask_for_removal',action='store_true',help='if true, object removal mode, domain B images with label 0, cut models only')

# train with re-(cycle/cut)
parser.add_argument('--alg_re_adversarial_loss_p',action='store_true',help='if True, also train the prediction model with an adversarial loss')
parser.add_argument('--alg_re_nuplet_size', type=int, default=3,help='Number of frames loaded')
parser.add_argument('--alg_re_netP', type=str, default='unet_128', help='specify P architecture [resnet_9blocks | resnet_6blocks | resnet_attn | unet_256 | unet_128]')
parser.add_argument('--alg_re_netP', type=str, default='unet_128', choices=['resnet_9blocks','resnet_6blocks', 'resnet_attn', 'unet_256' ,'unet_128'], help='specify P architecture')
parser.add_argument('--alg_re_no_train_P_fake_images',action='store_true',help='if True, P wont be trained over fake images projections')
parser.add_argument('--alg_re_projection_threshold',default=1.0,type=float,help='threshold of the real images projection loss below with fake projection and fake reconstruction losses are applied')
parser.add_argument('--alg_re_P_lr', type=float, default=0.0002, help='initial learning rate for P networks')
Expand Down

0 comments on commit ed43b82

Please sign in to comment.