Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add doctests for example 1/n #5079

Merged
merged 9 commits into from
Dec 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pl_examples/basic_examples/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@


class LitAutoEncoder(pl.LightningModule):
"""
>>> LitAutoEncoder() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
LitAutoEncoder(
(encoder): ...
(decoder): ...
)
"""

def __init__(self):
super().__init__()
Expand Down
13 changes: 13 additions & 0 deletions pl_examples/basic_examples/backbone_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@


class Backbone(torch.nn.Module):
"""
>>> Backbone() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Backbone(
(l1): Linear(...)
(l2): Linear(...)
)
"""
def __init__(self, hidden_dim=128):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, hidden_dim)
Expand All @@ -42,6 +49,12 @@ def forward(self, x):


class LitClassifier(pl.LightningModule):
"""
>>> LitClassifier(Backbone()) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
LitClassifier(
(backbone): ...
)
"""
def __init__(self, backbone, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
Expand Down
6 changes: 6 additions & 0 deletions pl_examples/basic_examples/conv_sequential_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def forward(self, x):


class LitResnet(pl.LightningModule):
"""
>>> LitResnet() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
LitResnet(
(sequential_module): Sequential(...)
)
"""
def __init__(self, lr=0.05, batch_size=32, manual_optimization=False):
super().__init__()

Expand Down
3 changes: 3 additions & 0 deletions pl_examples/basic_examples/mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
class MNISTDataModule(LightningDataModule):
"""
Standard MNIST, train, val, test splits and transforms

>>> MNISTDataModule() # doctest: +ELLIPSIS
<...mnist_datamodule.MNISTDataModule object at ...>
"""

name = "mnist"
Expand Down
7 changes: 7 additions & 0 deletions pl_examples/basic_examples/simple_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@


class LitClassifier(pl.LightningModule):
"""
>>> LitClassifier() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
LitClassifier(
(l1): Linear(...)
(l2): Linear(...)
)
"""
def __init__(self, hidden_dim=128, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
Expand Down
15 changes: 12 additions & 3 deletions pl_examples/bug_report_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@


class RandomDataset(Dataset):
"""
>>> RandomDataset(size=10, length=20) # doctest: +ELLIPSIS
<...bug_report_model.RandomDataset object at ...>
"""
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
Expand All @@ -40,6 +44,12 @@ def __len__(self):


class BoringModel(LightningModule):
"""
>>> BoringModel() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
BoringModel(
(layer): Linear(...)
)
"""

def __init__(self):
"""
Expand Down Expand Up @@ -113,10 +123,9 @@ def configure_optimizers(self):
# parser = ArgumentParser()
# args = parser.parse_args(opt)

def run_test():
def test_run():

class TestModel(BoringModel):

def on_train_epoch_start(self) -> None:
print('override any method to prove your bug')

Expand All @@ -140,4 +149,4 @@ def on_train_epoch_start(self) -> None:

if __name__ == '__main__':
cli_lightning_logo()
run_test()
test_run()
36 changes: 23 additions & 13 deletions pl_examples/domain_templates/computer_vision_fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,20 +159,30 @@ def _unfreeze_and_add_param_group(module: Module,
class TransferLearningModel(pl.LightningModule):
"""Transfer Learning with pre-trained ResNet50.

Args:
hparams: Model hyperparameters
dl_path: Path where the data will be downloaded
>>> with TemporaryDirectory(dir='.') as tmp_dir:
... TransferLearningModel(tmp_dir) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
TransferLearningModel(
(feature_extractor): Sequential(...)
(fc): Sequential(...)
)
"""
def __init__(self,
dl_path: Union[str, Path],
backbone: str = 'resnet50',
train_bn: bool = True,
milestones: tuple = (5, 10),
batch_size: int = 8,
lr: float = 1e-2,
lr_scheduler_gamma: float = 1e-1,
num_workers: int = 6, **kwargs) -> None:
super().__init__()
def __init__(
self,
dl_path: Union[str, Path],
backbone: str = 'resnet50',
train_bn: bool = True,
milestones: tuple = (5, 10),
batch_size: int = 8,
lr: float = 1e-2,
lr_scheduler_gamma: float = 1e-1,
num_workers: int = 6,
**kwargs,
) -> None:
"""
Args:
dl_path: Path where the data will be downloaded
"""
super().__init__(**kwargs)
self.dl_path = dl_path
self.backbone = backbone
self.train_bn = train_bn
Expand Down
63 changes: 48 additions & 15 deletions pl_examples/domain_templates/generative_adversarial_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@


class Generator(nn.Module):
def __init__(self, latent_dim, img_shape):
"""
>>> Generator(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Generator(
(model): Sequential(...)
)
"""
def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)):
super().__init__()
self.img_shape = img_shape

Expand All @@ -64,6 +70,12 @@ def forward(self, z):


class Discriminator(nn.Module):
"""
>>> Discriminator(img_shape=(1, 28, 28)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Discriminator(
(model): Sequential(...)
)
"""
def __init__(self, img_shape):
super().__init__()

Expand All @@ -83,6 +95,37 @@ def forward(self, img):


class GAN(LightningModule):
"""
>>> GAN(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
GAN(
(generator): Generator(
(model): Sequential(...)
)
(discriminator): Discriminator(
(model): Sequential(...)
)
)
"""
def __init__(
self,
img_shape: tuple = (1, 28, 28),
lr: float = 0.0002,
b1: float = 0.5,
b2: float = 0.999,
latent_dim: int = 100,
):
super().__init__()

self.save_hyperparameters()

# networks
self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=img_shape)
self.discriminator = Discriminator(img_shape=img_shape)

self.validation_z = torch.randn(8, self.hparams.latent_dim)

self.example_input_array = torch.zeros(2, self.hparams.latent_dim)

@staticmethod
def add_argparse_args(parent_parser: ArgumentParser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
Expand All @@ -96,20 +139,6 @@ def add_argparse_args(parent_parser: ArgumentParser):

return parser

def __init__(self, hparams: Namespace):
super().__init__()

self.hparams = hparams

# networks
mnist_shape = (1, 28, 28)
self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=mnist_shape)
self.discriminator = Discriminator(img_shape=mnist_shape)

self.validation_z = torch.randn(8, self.hparams.latent_dim)

self.example_input_array = torch.zeros(2, self.hparams.latent_dim)

def forward(self, z):
return self.generator(z)

Expand Down Expand Up @@ -180,6 +209,10 @@ def on_epoch_end(self):


class MNISTDataModule(LightningDataModule):
"""
>>> MNISTDataModule() # doctest: +ELLIPSIS
<...generative_adversarial_net.MNISTDataModule object at ...>
"""
def __init__(self, batch_size: int = 64, data_path: str = os.getcwd(), num_workers: int = 4):
super().__init__()
self.batch_size = batch_size
Expand Down
20 changes: 13 additions & 7 deletions pl_examples/domain_templates/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@


class ImageNetLightningModel(LightningModule):
"""
>>> ImageNetLightningModel(data_path='missing') # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
ImageNetLightningModel(
(model): ResNet(...)
)
"""
# pull out resnet names from torchvision models
MODEL_NAMES = sorted(
name for name in models.__dict__
Expand All @@ -58,14 +64,14 @@ class ImageNetLightningModel(LightningModule):

def __init__(
self,
arch: str,
pretrained: bool,
lr: float,
momentum: float,
weight_decay: int,
data_path: str,
batch_size: int,
workers: int,
arch: str = 'resnet18',
pretrained: bool = False,
lr: float = 0.1,
momentum: float = 0.9,
weight_decay: float = 1e-4,
batch_size: int = 4,
workers: int = 2,
**kwargs,
):
super().__init__()
Expand Down
Loading