The OTHER pytorch boilerplate.
-
LightningModule [flashlight/runner/pl.py]
-
Trainer [flashlight/runner/main_pl.py]
-
Accelerators
-
Callback
-
Logging [flashlight/runner/pl.py]
-
Metrics
-
Plugins
- python 3.5 >
- pytorch 1.5.0, torchvision 0.6.0 for your OS/CUDA match version
- ... and install requirements.txt packages
pip install -r requirements.txt
master
branch runs MNIST classification (torchvision dataset) with squeezenet (torchvision model)
for detail, check config/config.py
- Prepare enviroment : gpu docker, local python env... whatever
- if docker :
docker pull davinnovation/pytorch-boilerplate:alpha
python run.py
orpython -W ignore run.py
- after experiment...
tensorboard --logdir Logs
-
Prepare environment
-
nnictl create --config nni_config.yml
-
localhost:8080
- Adding Network
flashlight.network.__init__.py
"""Network Define"""
# Add {"Network Name" : and nn.Module without initalize}
def _get_squeezenet(num_classes, version:str="1_0", pretrained=False, progress=True):
VERSION = {
"1_0" : torchvision.models.squeezenet1_0,
"1_1" : torchvision.models.squeezenet1_1
}
return VERSION[version](pretrained=pretrained, progress=progress, num_classes=num_classes)
NETWORK_DICT = {
"squeezenet": _get_squeezenet
}
- Adding Dataset
flashlight.dataloader.__init__.py
""" Dataset """
# Add {Dataset Name : torch.utils.data.Dataset}
DATA_DICT = {"MNIST": torchvision.datasets.MNIST}
""" Dataset Transform """
transform = torchvision.transforms.Compose(
[torchvision.transforms.Grayscale(num_output_channels=3), torchvision.transforms.ToTensor()]
)
def get_datalaoder(data, root="../datasets/", split="train"):
if data in ["MNIST"]: # if torchvision
if split == "val":
print(f"{data} dataset dosen't support validation set. val replaced by train")
if split in ["train", "val"]:
return DATA_DICT[data](root=root, train=True, download=True, transform=transform)
else:
return DATA_DICT[data](root=root, train=False, download=True, transform=transform)
- Change Loss, forward/backward... [Research Code]
flashlight.runner.pl.py
- Change Logger, hw options... [Engineering Code]
flashlight.runner.main_pl.py