-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[Retiarii] Rewrite trainer with PyTorch Lightning #3359
Conversation
dependencies/recommended.txt
Outdated
@@ -1,6 +1,7 @@ | |||
tensorflow | |||
torch >= 1.6+cpu, != 1.7+cpu -f https://download.pytorch.org/whl/torch_stable.html | |||
torchvision >= 0.8+cpu -f https://download.pytorch.org/whl/torch_stable.html | |||
pytorch-lightning | |||
onnx | |||
peewee | |||
thop |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@colorjam do we still need thop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will remove it in another PR.
from torchvision import transforms | ||
|
||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) | ||
train_dataset = blackbox(MNIST, root='data/mnist', train=True, download=True, transform=transform) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blackbox
is strange here, because i don't understand why it creates train_dataset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
o... MNIST
is a class name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. This is limited by serialization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible we also wrap the dataset class by default? when users want to define their own dataset class, they decorate this class with for example @register_dataset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another option, for this pr, is renaming blackbox
to make_serializable
def __init__(self, model, optimizer_class_name='SGD', learning_rate=0.1): | ||
@blackbox_module | ||
class AutoEncoder(LightningModule): | ||
def __init__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where is model
configured?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set_model
@register_trainer | ||
class MnistTrainer(BaseTrainer): | ||
def __init__(self, model, optimizer_class_name='SGD', learning_rate=0.1): | ||
@blackbox_module |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blackbox_module
used on trainer is not that clear, let's discuss the name then
applied_mutators: Mutator = None, strategy: BaseStrategy = None): | ||
def __init__(self, base_model: nn.Module, trainer: Union[TrainingConfig, BaseOneShotTrainer], | ||
applied_mutators: List[Mutator] = None, strategy: BaseStrategy = None): | ||
# TODO: The current design of init interface of Retiarii experiment needs to be reviewed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trainer
could be both configuration and instantiated trainer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have to admit that this is very confusing. Maybe we can find a better name than "TrainingConfig".
@@ -1,3 +1,5 @@ | |||
# This file is deprecated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not directly remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because I don't want to remove the cgo_engine. If I remove this file and do not remove cgo_engine, linter will complain.
Class for optimizer (not an instance). default: ``Adam`` | ||
""" | ||
|
||
def __init__(self, criterion: nn.Module = nn.MSELoss, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the difference between Regression
and Classification
?
the code looks great, one thing is that it is not easy to understand how to write/use trainer |
class MnistTrainer(BaseTrainer): | ||
def __init__(self, model, optimizer_class_name='SGD', learning_rate=0.1): | ||
@blackbox_module | ||
class AutoEncoder(LightningModule): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we do not define model in LightningModule
, if we still use the name LightningModule
, it may be misleading. we can use more understandable name, for example, BaseTrainer
, TrainingModule
, etc.
If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped. | ||
""" | ||
|
||
def __init__(self, lightning_module: LightningModule, trainer: Trainer, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i am confused, what is the difference between lightning_module
and trainer
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's think about how to support standalone mode, which would help us to think about how extensible is our current design
With FunctionalTrainer | ||
^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
There is another way to customize a new trainer with functional APIs, which provides more flexibility. Users only need to write a fit function that wraps everything. This function takes one positional arguments (model) and possible keyword arguments. In this way, users get everything under their control, but exposes less information to the framework and thus fewer opportunities for possible optimization. An example is as belows: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should consistently use the word "trainer"
Work items
self-supervision).Future items