Skip to content

TTitcombe/LotteryTicketHypothesis

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Lottery Ticket Hypothesis

PyTorch implementation of the Lottery Ticket Hypothesis. This implementation uses PyTorch's prune module.

Pre-requisities

Developed in Python 3.7, but other versions should work.

If using conda, run

conda env create -f environment.yml

to create an environment, lottery-ticket, with all required packages.

To Run

main.py trains a LeNet-esque classifier on MNIST with several rounds of pruning.

usage: Finding Lottery Tickets on an MNIST classifier [-h] [--lr LR] [--bs BS]
                                                      [--epochs EPOCHS]
                                                      [--prune_pc PRUNE_PC]
                                                      [--prune_rounds PRUNE_ROUNDS]

optional arguments:
  -h, --help            show this help message and exit
  --lr LR               Learning rate (default 1e-3)
  --bs BS               Batch size (default 128)
  --epochs EPOCHS       Number of epochs (default 8)
  --prune_pc PRUNE_PC   Percentage of parameters to prune over the course of
                        the training process (default 0.2)
  --prune_rounds PRUNE_ROUNDS
                        Number of rounds of pruning to perform (default 5)

Contributions

Contributions are welcome. If opening a PR, ensure the code conforms to black formatting and isort import configurations.

Feel free to open an issue to ask a question, raise a bug, or request new features.

Project Organization

├── README.md          <- The top-level README for developers using this project.
│
├── environment.yml    <- The conda environment file for creating the analysis environment, e.g.
│                         `conda env create -f environment.yml`.
│
├── main.py           <- The training script.
│
├── .gitignore         <- git-ignore configuration file.
│
├── data               <- Directory in which downloaded data will be stored. No data is provided in the repo.
│
├── src                <- source code. Things imported into `main.py`.

TODO

  • Implement basic pruning
  • Recreate experiments from the paper
  • Use tensorboard to visualise model and pruning progress

License

See the full license

Releases

No releases published

Packages

No packages published

Languages