Learning to learn by gradient descent by gradient descent [PDF]
This is a Pytorch version of the LSTM-based meta optimizer.
- For Quadratic functions
- For Mnist
- Meta Modules for Pytorch (
resnet_meta.py
is provided, with loading pretrained weights supported.)
- Ubuntu
- Python 3
- NVIDIA GPU
This repository has been tested on GTX 1080Ti.
- Clone this repo:
git clone https://github.com/chenwydj/learning-to-learn-by-gradient-descent-by-gradient-descent.git
cd learning-to-learn-by-gradient-descent-by-gradient-descent
- Install dependencies:
pip install requirements.txt
- To reproduce the paper: simply go through the notebook
Grad_buffer.ipynb
. Note that some images not properly loaded in browser will show-up in downloaded local version. - To implement your own Learning-to-Optimize works: please feel free to use
meta_module.py
from meta_module import *
- Replace all
torch.nn.XXX
toMetaXXX
, where"XXX" is in [Module, Linear, Conv2d, ConvTranspose2d, BatchNorm2d, Sequential, ModuleList, ModuleDict]
. resnet_meta.py
is provided. Pretrained weights can be loaded. Use the meta resnet the same way you did before (e.g.model = resnet101(pretrained=True)
).
The core part to reproduce the LSTM meta optimzer is to update the nn.Parameters
of the optimizee in place while retaining the grad_fn
. In Pytorch, nn.Parameters
are designed to be leaf nodes. The only way to modify the value of an patameter is something like p.data.add_
(take the last line in sgd.py
in Pytorch for an example). However, modifying .data
of a tensor does not produce a grad_fn
, which is vital for our meta optimizer to be upadted from. More discussions can be found in here and here.
One way to bypass this problem is to leverage the Buffer
in Pytorch. Buffer
is also a "parameter" in our model and can be saved in state_dict
, but will not be returned by model.parameters()
. Once typical example of Buffer
is the running_mean
and running_var
in BatchNorm
layers. The Buffer
tensors can be treated as weights, while also have the flexibility to retain grad_fn
when being updated in-place. We thus add parameters via nn.Module.register_buffer()
.
This comes the reason why the meta_module.py
is provided. The core class is MetaModule
, which inherits nn.Module
but we manually return the Buffers
as parameters. Further on, we build MetaLinear
, MetaConv2d
, MetaConvTranspose2d
, MetaBatchNorm2d
, and MetaSequential
on top of MetaModule
with registered buffers.
- Original L2O code from AdrienLE/learning_by_grad_by_grad_repro.
- Meta modules from danieltan07/learning-to-reweight-examples.