From 518d91542234c44617b8c44fd2109f9f478809d0 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 17 Dec 2020 11:13:48 +0100 Subject: [PATCH] add doctests for example 1/n (#5079) * define tests * fix basic * fix gans * unet * test * drop * format * fix * revert Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- pl_examples/basic_examples/autoencoder.py | 7 + .../backbone_image_classifier.py | 13 ++ .../basic_examples/conv_sequential_example.py | 6 + .../basic_examples/mnist_datamodule.py | 3 + .../basic_examples/simple_image_classifier.py | 7 + pl_examples/bug_report_model.py | 15 +- .../computer_vision_fine_tuning.py | 36 +++-- .../generative_adversarial_net.py | 63 ++++++-- pl_examples/domain_templates/imagenet.py | 20 ++- .../domain_templates/reinforce_learn_Qnet.py | 136 +++++++++++------- .../domain_templates/semantic_segmentation.py | 57 ++++---- pl_examples/domain_templates/unet.py | 52 +++++-- 12 files changed, 288 insertions(+), 127 deletions(-) diff --git a/pl_examples/basic_examples/autoencoder.py b/pl_examples/basic_examples/autoencoder.py index 72bfcb17c0872..91f7ac0a1569d 100644 --- a/pl_examples/basic_examples/autoencoder.py +++ b/pl_examples/basic_examples/autoencoder.py @@ -31,6 +31,13 @@ class LitAutoEncoder(pl.LightningModule): + """ + >>> LitAutoEncoder() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + LitAutoEncoder( + (encoder): ... + (decoder): ... + ) + """ def __init__(self): super().__init__() diff --git a/pl_examples/basic_examples/backbone_image_classifier.py b/pl_examples/basic_examples/backbone_image_classifier.py index b0ca2efd5d76b..bb1daad301d08 100644 --- a/pl_examples/basic_examples/backbone_image_classifier.py +++ b/pl_examples/basic_examples/backbone_image_classifier.py @@ -29,6 +29,13 @@ class Backbone(torch.nn.Module): + """ + >>> Backbone() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Backbone( + (l1): Linear(...) + (l2): Linear(...) + ) + """ def __init__(self, hidden_dim=128): super().__init__() self.l1 = torch.nn.Linear(28 * 28, hidden_dim) @@ -42,6 +49,12 @@ def forward(self, x): class LitClassifier(pl.LightningModule): + """ + >>> LitClassifier(Backbone()) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + LitClassifier( + (backbone): ... + ) + """ def __init__(self, backbone, learning_rate=1e-3): super().__init__() self.save_hyperparameters() diff --git a/pl_examples/basic_examples/conv_sequential_example.py b/pl_examples/basic_examples/conv_sequential_example.py index 1d178c32a3ce3..84efb4bea7670 100644 --- a/pl_examples/basic_examples/conv_sequential_example.py +++ b/pl_examples/basic_examples/conv_sequential_example.py @@ -55,6 +55,12 @@ def forward(self, x): class LitResnet(pl.LightningModule): + """ + >>> LitResnet() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + LitResnet( + (sequential_module): Sequential(...) + ) + """ def __init__(self, lr=0.05, batch_size=32, manual_optimization=False): super().__init__() diff --git a/pl_examples/basic_examples/mnist_datamodule.py b/pl_examples/basic_examples/mnist_datamodule.py index eb1415cf8b981..95e20d22e1fdd 100644 --- a/pl_examples/basic_examples/mnist_datamodule.py +++ b/pl_examples/basic_examples/mnist_datamodule.py @@ -29,6 +29,9 @@ class MNISTDataModule(LightningDataModule): """ Standard MNIST, train, val, test splits and transforms + + >>> MNISTDataModule() # doctest: +ELLIPSIS + <...mnist_datamodule.MNISTDataModule object at ...> """ name = "mnist" diff --git a/pl_examples/basic_examples/simple_image_classifier.py b/pl_examples/basic_examples/simple_image_classifier.py index 6b8457e0e4897..894eeea619ba9 100644 --- a/pl_examples/basic_examples/simple_image_classifier.py +++ b/pl_examples/basic_examples/simple_image_classifier.py @@ -24,6 +24,13 @@ class LitClassifier(pl.LightningModule): + """ + >>> LitClassifier() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + LitClassifier( + (l1): Linear(...) + (l2): Linear(...) + ) + """ def __init__(self, hidden_dim=128, learning_rate=1e-3): super().__init__() self.save_hyperparameters() diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py index e2201db12f894..30345122e251f 100644 --- a/pl_examples/bug_report_model.py +++ b/pl_examples/bug_report_model.py @@ -28,6 +28,10 @@ class RandomDataset(Dataset): + """ + >>> RandomDataset(size=10, length=20) # doctest: +ELLIPSIS + <...bug_report_model.RandomDataset object at ...> + """ def __init__(self, size, length): self.len = length self.data = torch.randn(length, size) @@ -40,6 +44,12 @@ def __len__(self): class BoringModel(LightningModule): + """ + >>> BoringModel() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + BoringModel( + (layer): Linear(...) + ) + """ def __init__(self): """ @@ -113,10 +123,9 @@ def configure_optimizers(self): # parser = ArgumentParser() # args = parser.parse_args(opt) -def run_test(): +def test_run(): class TestModel(BoringModel): - def on_train_epoch_start(self) -> None: print('override any method to prove your bug') @@ -140,4 +149,4 @@ def on_train_epoch_start(self) -> None: if __name__ == '__main__': cli_lightning_logo() - run_test() + test_run() diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index 1c60e3aa6d23f..4392ac47e837f 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -159,20 +159,30 @@ def _unfreeze_and_add_param_group(module: Module, class TransferLearningModel(pl.LightningModule): """Transfer Learning with pre-trained ResNet50. - Args: - hparams: Model hyperparameters - dl_path: Path where the data will be downloaded + >>> with TemporaryDirectory(dir='.') as tmp_dir: + ... TransferLearningModel(tmp_dir) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + TransferLearningModel( + (feature_extractor): Sequential(...) + (fc): Sequential(...) + ) """ - def __init__(self, - dl_path: Union[str, Path], - backbone: str = 'resnet50', - train_bn: bool = True, - milestones: tuple = (5, 10), - batch_size: int = 8, - lr: float = 1e-2, - lr_scheduler_gamma: float = 1e-1, - num_workers: int = 6, **kwargs) -> None: - super().__init__() + def __init__( + self, + dl_path: Union[str, Path], + backbone: str = 'resnet50', + train_bn: bool = True, + milestones: tuple = (5, 10), + batch_size: int = 8, + lr: float = 1e-2, + lr_scheduler_gamma: float = 1e-1, + num_workers: int = 6, + **kwargs, + ) -> None: + """ + Args: + dl_path: Path where the data will be downloaded + """ + super().__init__(**kwargs) self.dl_path = dl_path self.backbone = backbone self.train_bn = train_bn diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index 210a80721d9a9..b0c324c193574 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -37,7 +37,13 @@ class Generator(nn.Module): - def __init__(self, latent_dim, img_shape): + """ + >>> Generator(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Generator( + (model): Sequential(...) + ) + """ + def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)): super().__init__() self.img_shape = img_shape @@ -64,6 +70,12 @@ def forward(self, z): class Discriminator(nn.Module): + """ + >>> Discriminator(img_shape=(1, 28, 28)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Discriminator( + (model): Sequential(...) + ) + """ def __init__(self, img_shape): super().__init__() @@ -83,6 +95,37 @@ def forward(self, img): class GAN(LightningModule): + """ + >>> GAN(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + GAN( + (generator): Generator( + (model): Sequential(...) + ) + (discriminator): Discriminator( + (model): Sequential(...) + ) + ) + """ + def __init__( + self, + img_shape: tuple = (1, 28, 28), + lr: float = 0.0002, + b1: float = 0.5, + b2: float = 0.999, + latent_dim: int = 100, + ): + super().__init__() + + self.save_hyperparameters() + + # networks + self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=img_shape) + self.discriminator = Discriminator(img_shape=img_shape) + + self.validation_z = torch.randn(8, self.hparams.latent_dim) + + self.example_input_array = torch.zeros(2, self.hparams.latent_dim) + @staticmethod def add_argparse_args(parent_parser: ArgumentParser): parser = ArgumentParser(parents=[parent_parser], add_help=False) @@ -96,20 +139,6 @@ def add_argparse_args(parent_parser: ArgumentParser): return parser - def __init__(self, hparams: Namespace): - super().__init__() - - self.hparams = hparams - - # networks - mnist_shape = (1, 28, 28) - self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=mnist_shape) - self.discriminator = Discriminator(img_shape=mnist_shape) - - self.validation_z = torch.randn(8, self.hparams.latent_dim) - - self.example_input_array = torch.zeros(2, self.hparams.latent_dim) - def forward(self, z): return self.generator(z) @@ -180,6 +209,10 @@ def on_epoch_end(self): class MNISTDataModule(LightningDataModule): + """ + >>> MNISTDataModule() # doctest: +ELLIPSIS + <...generative_adversarial_net.MNISTDataModule object at ...> + """ def __init__(self, batch_size: int = 64, data_path: str = os.getcwd(), num_workers: int = 4): super().__init__() self.batch_size = batch_size diff --git a/pl_examples/domain_templates/imagenet.py b/pl_examples/domain_templates/imagenet.py index b1eea307478f9..cc36f3542a1c8 100644 --- a/pl_examples/domain_templates/imagenet.py +++ b/pl_examples/domain_templates/imagenet.py @@ -50,6 +50,12 @@ class ImageNetLightningModel(LightningModule): + """ + >>> ImageNetLightningModel(data_path='missing') # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + ImageNetLightningModel( + (model): ResNet(...) + ) + """ # pull out resnet names from torchvision models MODEL_NAMES = sorted( name for name in models.__dict__ @@ -58,14 +64,14 @@ class ImageNetLightningModel(LightningModule): def __init__( self, - arch: str, - pretrained: bool, - lr: float, - momentum: float, - weight_decay: int, data_path: str, - batch_size: int, - workers: int, + arch: str = 'resnet18', + pretrained: bool = False, + lr: float = 0.1, + momentum: float = 0.9, + weight_decay: float = 1e-4, + batch_size: int = 4, + workers: int = 2, **kwargs, ): super().__init__() diff --git a/pl_examples/domain_templates/reinforce_learn_Qnet.py b/pl_examples/domain_templates/reinforce_learn_Qnet.py index a8b9db095f377..6aee8bb6038c1 100644 --- a/pl_examples/domain_templates/reinforce_learn_Qnet.py +++ b/pl_examples/domain_templates/reinforce_learn_Qnet.py @@ -53,13 +53,19 @@ class DQN(nn.Module): """ Simple MLP network - Args: - obs_size: observation/state size of the environment - n_actions: number of discrete actions available in the environment - hidden_size: size of hidden layers + >>> DQN(10, 5) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + DQN( + (net): Sequential(...) + ) """ def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128): + """ + Args: + obs_size: observation/state size of the environment + n_actions: number of discrete actions available in the environment + hidden_size: size of hidden layers + """ super(DQN, self).__init__() self.net = nn.Sequential( nn.Linear(obs_size, hidden_size), @@ -81,11 +87,15 @@ class ReplayBuffer: """ Replay Buffer for storing past experiences allowing the agent to learn from them - Args: - capacity: size of the buffer + >>> ReplayBuffer(5) # doctest: +ELLIPSIS + <...reinforce_learn_Qnet.ReplayBuffer object at ...> """ def __init__(self, capacity: int) -> None: + """ + Args: + capacity: size of the buffer + """ self.buffer = deque(maxlen=capacity) def __len__(self) -> int: @@ -113,12 +123,16 @@ class RLDataset(IterableDataset): Iterable Dataset containing the ExperienceBuffer which will be updated with new experiences during training - Args: - buffer: replay buffer - sample_size: number of experiences to sample at a time + >>> RLDataset(ReplayBuffer(5)) # doctest: +ELLIPSIS + <...reinforce_learn_Qnet.RLDataset object at ...> """ def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None: + """ + Args: + buffer: replay buffer + sample_size: number of experiences to sample at a time + """ self.buffer = buffer self.sample_size = sample_size @@ -132,12 +146,18 @@ class Agent: """ Base Agent class handling the interaction with the environment - Args: - env: training environment - replay_buffer: replay buffer storing experiences + >>> env = gym.make("CartPole-v0") + >>> buffer = ReplayBuffer(10) + >>> Agent(env, buffer) # doctest: +ELLIPSIS + <...reinforce_learn_Qnet.Agent object at ...> """ def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None: + """ + Args: + env: training environment + replay_buffer: replay buffer storing experiences + """ self.env = env self.replay_buffer = replay_buffer self.reset() @@ -204,20 +224,34 @@ def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = 'cpu') - class DQNLightning(pl.LightningModule): - """ Basic DQN Model """ - - def __init__(self, - replay_size, - warm_start_steps: int, - gamma: float, - eps_start: int, - eps_end: int, - eps_last_frame: int, - sync_rate, - lr: float, - episode_length, - batch_size, **kwargs) -> None: - super().__init__() + """ Basic DQN Model + + >>> DQNLightning(env="CartPole-v0") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + DQNLightning( + (net): DQN( + (net): Sequential(...) + ) + (target_net): DQN( + (net): Sequential(...) + ) + ) + """ + def __init__( + self, + env: str, + replay_size: int = 200, + warm_start_steps: int = 200, + gamma: float = 0.99, + eps_start: float = 1.0, + eps_end: float = 0.01, + eps_last_frame: int = 200, + sync_rate: int = 10, + lr: float = 1e-2, + episode_length: int = 50, + batch_size: int = 4, + **kwargs, + ) -> None: + super().__init__(**kwargs) self.replay_size = replay_size self.warm_start_steps = warm_start_steps self.gamma = gamma @@ -229,7 +263,7 @@ def __init__(self, self.episode_length = episode_length self.batch_size = batch_size - self.env = gym.make(self.env) + self.env = gym.make(env) obs_size = self.env.observation_space.shape[0] n_actions = self.env.action_space.n @@ -302,8 +336,7 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> O Training loss and log metrics """ device = self.get_device(batch) - epsilon = max(self.eps_end, self.eps_start - - self.global_step + 1 / self.eps_last_frame) + epsilon = max(self.eps_end, self.eps_start - self.global_step + 1 / self.eps_last_frame) # step through environment with agent reward, done = self.agent.play_step(self.net, epsilon, device) @@ -349,6 +382,30 @@ def get_device(self, batch) -> str: """Retrieve device currently being used by minibatch""" return batch[0].device.index if self.on_gpu else 'cpu' + @staticmethod + def add_model_specific_args(parent_parser): # pragma: no-cover + parser = argparse.ArgumentParser(parents=[parent_parser]) + parser.add_argument("--batch_size", type=int, default=16, help="size of the batches") + parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") + parser.add_argument("--env", type=str, default="CartPole-v0", help="gym environment tag") + parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") + parser.add_argument("--sync_rate", type=int, default=10, + help="how many frames do we update the target network") + parser.add_argument("--replay_size", type=int, default=1000, + help="capacity of the replay buffer") + parser.add_argument("--warm_start_size", type=int, default=1000, + help="how many samples do we use to fill our buffer at the start of training") + parser.add_argument("--eps_last_frame", type=int, default=1000, + help="what frame should epsilon stop decaying") + parser.add_argument("--eps_start", type=float, default=1.0, help="starting value of epsilon") + parser.add_argument("--eps_end", type=float, default=0.01, help="final value of epsilon") + parser.add_argument("--episode_length", type=int, default=200, help="max length of an episode") + parser.add_argument("--max_episode_reward", type=int, default=200, + help="max episode reward in the environment") + parser.add_argument("--warm_start_steps", type=int, default=1000, + help="max episode reward in the environment") + return parser + def main(args) -> None: model = DQNLightning(**vars(args)) @@ -368,26 +425,7 @@ def main(args) -> None: np.random.seed(0) parser = argparse.ArgumentParser() - parser.add_argument("--batch_size", type=int, default=16, help="size of the batches") - parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") - parser.add_argument("--env", type=str, default="CartPole-v0", help="gym environment tag") - parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") - parser.add_argument("--sync_rate", type=int, default=10, - help="how many frames do we update the target network") - parser.add_argument("--replay_size", type=int, default=1000, - help="capacity of the replay buffer") - parser.add_argument("--warm_start_size", type=int, default=1000, - help="how many samples do we use to fill our buffer at the start of training") - parser.add_argument("--eps_last_frame", type=int, default=1000, - help="what frame should epsilon stop decaying") - parser.add_argument("--eps_start", type=float, default=1.0, help="starting value of epsilon") - parser.add_argument("--eps_end", type=float, default=0.01, help="final value of epsilon") - parser.add_argument("--episode_length", type=int, default=200, help="max length of an episode") - parser.add_argument("--max_episode_reward", type=int, default=200, - help="max episode reward in the environment") - parser.add_argument("--warm_start_steps", type=int, default=1000, - help="max episode reward in the environment") - + parser = DQNLightning.add_model_specific_args(parser) args = parser.parse_args() main(args) diff --git a/pl_examples/domain_templates/semantic_segmentation.py b/pl_examples/domain_templates/semantic_segmentation.py index 08bdc1140916a..7bcad597a9a68 100644 --- a/pl_examples/domain_templates/semantic_segmentation.py +++ b/pl_examples/domain_templates/semantic_segmentation.py @@ -142,15 +142,17 @@ class SegModel(pl.LightningModule): Adam optimizer is used along with Cosine Annealing learning rate scheduler. """ - - def __init__(self, - data_path: str, - batch_size: int, - lr: float, - num_layers: int, - features_start: int, - bilinear: bool, **kwargs): - super().__init__() + def __init__( + self, + data_path: str, + batch_size: int = 4, + lr: float = 1e-3, + num_layers: int = 3, + features_start: int = 64, + bilinear: bool = False, + **kwargs, + ): + super().__init__(**kwargs) self.data_path = data_path self.batch_size = batch_size self.lr = lr @@ -204,6 +206,18 @@ def train_dataloader(self): def val_dataloader(self): return DataLoader(self.validset, batch_size=self.batch_size, shuffle=False) + @staticmethod + def add_model_specific_args(parent_parser): # pragma: no-cover + parser = ArgumentParser(parents=[parent_parser]) + parser.add_argument("--data_path", type=str, help="path where dataset is stored") + parser.add_argument("--batch_size", type=int, default=16, help="size of the batches") + parser.add_argument("--lr", type=float, default=0.001, help="adam: learning rate") + parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net") + parser.add_argument("--features_start", type=float, default=64, help="number of features in first layer") + parser.add_argument("--bilinear", action='store_true', default=False, + help="whether to use bilinear interpolation or transposed") + return parser + def main(hparams: Namespace): # ------------------------ @@ -224,14 +238,7 @@ def main(hparams: Namespace): # ------------------------ # 3 INIT TRAINER # ------------------------ - trainer = pl.Trainer( - gpus=hparams.gpus, - logger=logger, - max_epochs=hparams.epochs, - accumulate_grad_batches=hparams.grad_batches, - accelerator=hparams.accelerator, - precision=16 if hparams.use_amp else 32, - ) + trainer = pl.Trainer.from_argparse_args(hparams) # ------------------------ # 5 START TRAINING @@ -242,21 +249,7 @@ def main(hparams: Namespace): if __name__ == '__main__': cli_lightning_logo() parser = ArgumentParser() - parser.add_argument("--data_path", type=str, help="path where dataset is stored") - parser.add_argument("--gpus", type=int, default=-1, help="number of available GPUs") - parser.add_argument('--distributed-backend', type=str, default='dp', choices=('dp', 'ddp', 'ddp2'), - help='supports three options dp, ddp, ddp2') - parser.add_argument('--use_amp', action='store_true', help='if true uses 16 bit precision') - parser.add_argument("--batch_size", type=int, default=4, help="size of the batches") - parser.add_argument("--lr", type=float, default=0.001, help="adam: learning rate") - parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net") - parser.add_argument("--features_start", type=float, default=64, help="number of features in first layer") - parser.add_argument("--bilinear", action='store_true', default=False, - help="whether to use bilinear interpolation or transposed") - parser.add_argument("--grad_batches", type=int, default=1, help="number of batches to accumulate") - parser.add_argument("--epochs", type=int, default=20, help="number of epochs to train") - parser.add_argument("--log_wandb", action='store_true', help="log training on Weights & Biases") - + parser = SegModel.add_model_specific_args(parser) hparams = parser.parse_args() main(hparams) diff --git a/pl_examples/domain_templates/unet.py b/pl_examples/domain_templates/unet.py index 20b4bdb2a4bf9..2314e19ddbfc9 100644 --- a/pl_examples/domain_templates/unet.py +++ b/pl_examples/domain_templates/unet.py @@ -22,20 +22,33 @@ class UNet(nn.Module): Architecture based on U-Net: Convolutional Networks for Biomedical Image Segmentation Link - https://arxiv.org/abs/1505.04597 - Parameters: - num_classes: Number of output classes required (default 19 for KITTI dataset) - num_layers: Number of layers in each side of U-net - features_start: Number of features in first layer - bilinear: Whether to use bilinear interpolation or transposed - convolutions for upsampling. + >>> UNet(num_classes=2, num_layers=3) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + UNet( + (layers): ModuleList( + (0): DoubleConv(...) + (1): Down(...) + (2): Down(...) + (3): Up(...) + (4): Up(...) + (5): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1)) + ) + ) """ def __init__( - self, num_classes: int = 19, + self, + num_classes: int = 19, num_layers: int = 5, features_start: int = 64, - bilinear: bool = False + bilinear: bool = False, ): + """ + Args: + num_classes: Number of output classes required (default 19 for KITTI dataset) + num_layers: Number of layers in each side of U-net + features_start: Number of features in first layer + bilinear: Whether to use bilinear interpolation or transposed convolutions for upsampling. + """ super().__init__() self.num_layers = num_layers @@ -69,6 +82,11 @@ class DoubleConv(nn.Module): """ Double Convolution and BN and ReLU (3x3 conv -> BN -> ReLU) ** 2 + + >>> DoubleConv(4, 4) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + DoubleConv( + (net): Sequential(...) + ) """ def __init__(self, in_ch: int, out_ch: int): @@ -89,6 +107,16 @@ def forward(self, x): class Down(nn.Module): """ Combination of MaxPool2d and DoubleConv in series + + >>> Down(4, 8) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Down( + (net): Sequential( + (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) + (1): DoubleConv( + (net): Sequential(...) + ) + ) + ) """ def __init__(self, in_ch: int, out_ch: int): @@ -107,6 +135,14 @@ class Up(nn.Module): Upsampling (by either bilinear interpolation or transpose convolutions) followed by concatenation of feature map from contracting path, followed by double 3x3 convolution. + + >>> Up(8, 4) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Up( + (upsample): ConvTranspose2d(8, 4, kernel_size=(2, 2), stride=(2, 2)) + (conv): DoubleConv( + (net): Sequential(...) + ) + ) """ def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False):