Skip to content

Commit

Permalink
add doctests for example 1/n (#5079)
Browse files Browse the repository at this point in the history
* define tests

* fix basic

* fix gans

* unet

* test

* drop

* format

* fix

* revert

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
  • Loading branch information
Borda and justusschock committed Jan 5, 2021
1 parent 3b83666 commit 518d915
Show file tree
Hide file tree
Showing 12 changed files with 288 additions and 127 deletions.
7 changes: 7 additions & 0 deletions pl_examples/basic_examples/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@


class LitAutoEncoder(pl.LightningModule):
"""
>>> LitAutoEncoder() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
LitAutoEncoder(
(encoder): ...
(decoder): ...
)
"""

def __init__(self):
super().__init__()
Expand Down
13 changes: 13 additions & 0 deletions pl_examples/basic_examples/backbone_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@


class Backbone(torch.nn.Module):
"""
>>> Backbone() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Backbone(
(l1): Linear(...)
(l2): Linear(...)
)
"""
def __init__(self, hidden_dim=128):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, hidden_dim)
Expand All @@ -42,6 +49,12 @@ def forward(self, x):


class LitClassifier(pl.LightningModule):
"""
>>> LitClassifier(Backbone()) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
LitClassifier(
(backbone): ...
)
"""
def __init__(self, backbone, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
Expand Down
6 changes: 6 additions & 0 deletions pl_examples/basic_examples/conv_sequential_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def forward(self, x):


class LitResnet(pl.LightningModule):
"""
>>> LitResnet() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
LitResnet(
(sequential_module): Sequential(...)
)
"""
def __init__(self, lr=0.05, batch_size=32, manual_optimization=False):
super().__init__()

Expand Down
3 changes: 3 additions & 0 deletions pl_examples/basic_examples/mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
class MNISTDataModule(LightningDataModule):
"""
Standard MNIST, train, val, test splits and transforms
>>> MNISTDataModule() # doctest: +ELLIPSIS
<...mnist_datamodule.MNISTDataModule object at ...>
"""

name = "mnist"
Expand Down
7 changes: 7 additions & 0 deletions pl_examples/basic_examples/simple_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@


class LitClassifier(pl.LightningModule):
"""
>>> LitClassifier() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
LitClassifier(
(l1): Linear(...)
(l2): Linear(...)
)
"""
def __init__(self, hidden_dim=128, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
Expand Down
15 changes: 12 additions & 3 deletions pl_examples/bug_report_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@


class RandomDataset(Dataset):
"""
>>> RandomDataset(size=10, length=20) # doctest: +ELLIPSIS
<...bug_report_model.RandomDataset object at ...>
"""
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
Expand All @@ -40,6 +44,12 @@ def __len__(self):


class BoringModel(LightningModule):
"""
>>> BoringModel() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
BoringModel(
(layer): Linear(...)
)
"""

def __init__(self):
"""
Expand Down Expand Up @@ -113,10 +123,9 @@ def configure_optimizers(self):
# parser = ArgumentParser()
# args = parser.parse_args(opt)

def run_test():
def test_run():

class TestModel(BoringModel):

def on_train_epoch_start(self) -> None:
print('override any method to prove your bug')

Expand All @@ -140,4 +149,4 @@ def on_train_epoch_start(self) -> None:

if __name__ == '__main__':
cli_lightning_logo()
run_test()
test_run()
36 changes: 23 additions & 13 deletions pl_examples/domain_templates/computer_vision_fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,20 +159,30 @@ def _unfreeze_and_add_param_group(module: Module,
class TransferLearningModel(pl.LightningModule):
"""Transfer Learning with pre-trained ResNet50.
Args:
hparams: Model hyperparameters
dl_path: Path where the data will be downloaded
>>> with TemporaryDirectory(dir='.') as tmp_dir:
... TransferLearningModel(tmp_dir) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
TransferLearningModel(
(feature_extractor): Sequential(...)
(fc): Sequential(...)
)
"""
def __init__(self,
dl_path: Union[str, Path],
backbone: str = 'resnet50',
train_bn: bool = True,
milestones: tuple = (5, 10),
batch_size: int = 8,
lr: float = 1e-2,
lr_scheduler_gamma: float = 1e-1,
num_workers: int = 6, **kwargs) -> None:
super().__init__()
def __init__(
self,
dl_path: Union[str, Path],
backbone: str = 'resnet50',
train_bn: bool = True,
milestones: tuple = (5, 10),
batch_size: int = 8,
lr: float = 1e-2,
lr_scheduler_gamma: float = 1e-1,
num_workers: int = 6,
**kwargs,
) -> None:
"""
Args:
dl_path: Path where the data will be downloaded
"""
super().__init__(**kwargs)
self.dl_path = dl_path
self.backbone = backbone
self.train_bn = train_bn
Expand Down
63 changes: 48 additions & 15 deletions pl_examples/domain_templates/generative_adversarial_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@


class Generator(nn.Module):
def __init__(self, latent_dim, img_shape):
"""
>>> Generator(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Generator(
(model): Sequential(...)
)
"""
def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)):
super().__init__()
self.img_shape = img_shape

Expand All @@ -64,6 +70,12 @@ def forward(self, z):


class Discriminator(nn.Module):
"""
>>> Discriminator(img_shape=(1, 28, 28)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Discriminator(
(model): Sequential(...)
)
"""
def __init__(self, img_shape):
super().__init__()

Expand All @@ -83,6 +95,37 @@ def forward(self, img):


class GAN(LightningModule):
"""
>>> GAN(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
GAN(
(generator): Generator(
(model): Sequential(...)
)
(discriminator): Discriminator(
(model): Sequential(...)
)
)
"""
def __init__(
self,
img_shape: tuple = (1, 28, 28),
lr: float = 0.0002,
b1: float = 0.5,
b2: float = 0.999,
latent_dim: int = 100,
):
super().__init__()

self.save_hyperparameters()

# networks
self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=img_shape)
self.discriminator = Discriminator(img_shape=img_shape)

self.validation_z = torch.randn(8, self.hparams.latent_dim)

self.example_input_array = torch.zeros(2, self.hparams.latent_dim)

@staticmethod
def add_argparse_args(parent_parser: ArgumentParser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
Expand All @@ -96,20 +139,6 @@ def add_argparse_args(parent_parser: ArgumentParser):

return parser

def __init__(self, hparams: Namespace):
super().__init__()

self.hparams = hparams

# networks
mnist_shape = (1, 28, 28)
self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=mnist_shape)
self.discriminator = Discriminator(img_shape=mnist_shape)

self.validation_z = torch.randn(8, self.hparams.latent_dim)

self.example_input_array = torch.zeros(2, self.hparams.latent_dim)

def forward(self, z):
return self.generator(z)

Expand Down Expand Up @@ -180,6 +209,10 @@ def on_epoch_end(self):


class MNISTDataModule(LightningDataModule):
"""
>>> MNISTDataModule() # doctest: +ELLIPSIS
<...generative_adversarial_net.MNISTDataModule object at ...>
"""
def __init__(self, batch_size: int = 64, data_path: str = os.getcwd(), num_workers: int = 4):
super().__init__()
self.batch_size = batch_size
Expand Down
20 changes: 13 additions & 7 deletions pl_examples/domain_templates/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@


class ImageNetLightningModel(LightningModule):
"""
>>> ImageNetLightningModel(data_path='missing') # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
ImageNetLightningModel(
(model): ResNet(...)
)
"""
# pull out resnet names from torchvision models
MODEL_NAMES = sorted(
name for name in models.__dict__
Expand All @@ -58,14 +64,14 @@ class ImageNetLightningModel(LightningModule):

def __init__(
self,
arch: str,
pretrained: bool,
lr: float,
momentum: float,
weight_decay: int,
data_path: str,
batch_size: int,
workers: int,
arch: str = 'resnet18',
pretrained: bool = False,
lr: float = 0.1,
momentum: float = 0.9,
weight_decay: float = 1e-4,
batch_size: int = 4,
workers: int = 2,
**kwargs,
):
super().__init__()
Expand Down
Loading

0 comments on commit 518d915

Please sign in to comment.