Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Retiarii] Rewrite trainer with PyTorch Lightning #3359

Merged
merged 32 commits into from
Feb 14, 2021

Conversation

ultmaster
Copy link
Contributor

@ultmaster ultmaster commented Feb 2, 2021

Work items

  • Infrastructure: refactor serialization and supports serializing any class.
  • Introduce lightning module and adopt lightning trainer in execution engine.
  • Refine and finalize trainer format in IR for mutation.
  • Support 2-3 commonly used lightning modules (classification/regression/self-supervision).

Future items

  • Design the interface of LightningModule for multi-graph optimization.
  • Rename the APIs for better user experience.

@ultmaster ultmaster marked this pull request as draft February 2, 2021 11:33
@ultmaster ultmaster marked this pull request as ready for review February 4, 2021 08:38
@@ -1,6 +1,7 @@
tensorflow
torch >= 1.6+cpu, != 1.7+cpu -f https://download.pytorch.org/whl/torch_stable.html
torchvision >= 0.8+cpu -f https://download.pytorch.org/whl/torch_stable.html
pytorch-lightning
onnx
peewee
thop
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@colorjam do we still need thop?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will remove it in another PR.

from torchvision import transforms

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = blackbox(MNIST, root='data/mnist', train=True, download=True, transform=transform)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blackbox is strange here, because i don't understand why it creates train_dataset

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

o... MNIST is a class name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. This is limited by serialization.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible we also wrap the dataset class by default? when users want to define their own dataset class, they decorate this class with for example @register_dataset

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another option, for this pr, is renaming blackbox to make_serializable

def __init__(self, model, optimizer_class_name='SGD', learning_rate=0.1):
@blackbox_module
class AutoEncoder(LightningModule):
def __init__(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is model configured?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_model

@register_trainer
class MnistTrainer(BaseTrainer):
def __init__(self, model, optimizer_class_name='SGD', learning_rate=0.1):
@blackbox_module
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blackbox_module used on trainer is not that clear, let's discuss the name then

applied_mutators: Mutator = None, strategy: BaseStrategy = None):
def __init__(self, base_model: nn.Module, trainer: Union[TrainingConfig, BaseOneShotTrainer],
applied_mutators: List[Mutator] = None, strategy: BaseStrategy = None):
# TODO: The current design of init interface of Retiarii experiment needs to be reviewed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trainer could be both configuration and instantiated trainer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have to admit that this is very confusing. Maybe we can find a better name than "TrainingConfig".

@@ -1,3 +1,5 @@
# This file is deprecated.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not directly remove it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I don't want to remove the cgo_engine. If I remove this file and do not remove cgo_engine, linter will complain.

Class for optimizer (not an instance). default: ``Adam``
"""

def __init__(self, criterion: nn.Module = nn.MSELoss,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the difference between Regression and Classification?

@QuanluZhang
Copy link
Contributor

the code looks great, one thing is that it is not easy to understand how to write/use trainer

class MnistTrainer(BaseTrainer):
def __init__(self, model, optimizer_class_name='SGD', learning_rate=0.1):
@blackbox_module
class AutoEncoder(LightningModule):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we do not define model in LightningModule, if we still use the name LightningModule, it may be misleading. we can use more understandable name, for example, BaseTrainer, TrainingModule, etc.

If the ``lightning_module`` has a predefined val_dataloaders method this will be skipped.
"""

def __init__(self, lightning_module: LightningModule, trainer: Trainer,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i am confused, what is the difference between lightning_module and trainer?

Copy link
Contributor

@QuanluZhang QuanluZhang Feb 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's think about how to support standalone mode, which would help us to think about how extensible is our current design

@J-shang J-shang mentioned this pull request Feb 5, 2021
94 tasks
With FunctionalTrainer
^^^^^^^^^^^^^^^^^^^^^^

There is another way to customize a new trainer with functional APIs, which provides more flexibility. Users only need to write a fit function that wraps everything. This function takes one positional arguments (model) and possible keyword arguments. In this way, users get everything under their control, but exposes less information to the framework and thus fewer opportunities for possible optimization. An example is as belows:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should consistently use the word "trainer"

@ultmaster ultmaster merged commit 445e7e0 into microsoft:master Feb 14, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants