Skip to content

gregorbachmann/scaling_mlps

Repository files navigation

Scaling MLPs 🔥

Overview

This repository contains the code accompanying our paper Scaling MLPs: A Tale of Inductive Bias. In this work we explore the limits of the multi-layer perceptron, or short MLP, when subjected to higher amounts of compute. More precisely, we study architectures with the following block form: Why? We argue that such an MLP has minimal inductive bias (compared to convolutional networks, vision transformers, MLPMixers etc.) and thus offers an interesting test bed to explore whether simply scaling compute can make even the simplest models work (to some degree). The importance of inductive bias has recently been questioned due to vision transformers and MLPMixers eclipsing the more structured convolutional models on standard benchmarks. ​ Moreover, MLPs still remain to be the main protagonists in ML theory works but surprisingly, very little is known about their empirical performance at scale! We aim to close this gap here and provide the community with very performant MLPs to analyse! ​

Explore

You can easily explore our pre-trained and fine-tuned models by specifying the checkpooint flag. For instance, to load a BottleneckMLP with 12 blocks of width 1024, pre-trained on Imagenet21k, simply run

model = get_model(architecture='B_12-Wi_1024', resolution=64, num_classes=11230, checkpoint='in21k')

If you need an already fine-tuned model, you can specify

model = get_model(architecture='B_12-Wi_1024', resolution=64, num_classes=10, checkpoint='in21k_cifar10')

Check-out the Juypter notebook explore.ipynb to play around with the models.

Pretrained Models

We further publish our models pre-trained on ImageNet21k for various number of epochs at an image resolution of $64\times 64$ here. Fine-tuning the $800$ epochs models for $100$ epochs should give you roughly the following down-stream performances (check Fine-tuning section for hyper-parameter details)

#Params CIFAR10 CIFAR100 STL10 TinyImageNet ImageNet ImageNetReal
B_6-Wi_512 24M 88.5% 71.2% 79.9% 53.2% 33.3% 38.2
B_12-Wi_512 37M 91.4% 75.1% 84.4% 60.0% 38.0% 42.8
B_6-Wi_1024 74M 92.5% 77.1% 86.5% 64.3% 40.0% 47.0%
B_12-Wi_1024 124M 94.2% 80.0% 89.9% 69.9% 43.2% 48.6%
B_12-Wi_1024 + TTA 124M 95.5% 82.6% 92.2% 73.1% 51.4% 57.9%

Make sure that you also download the config.txt file and place in together in the same folder as the corresponding checkpoint.

Environment Setup

​ For installing the FFCV dataloading framework, we refer to the original repository. To install the remaining packages, activate the FFCV environment and run

pip install -r requirements.txt

Creating .beton Files

In order to use the efficiency of MLPs to the fullest, we need a more optimised data loading framework than the standard one provided by torch. This is because the data transfer from CPU to GPU otherwise becomes the bottleneck of training, not the gradient computation!! To ensure a faster data transfer, we use the FFCV framework, which requires converting your dataset first to the beton format. This can be achieved by specifying your dataset as a torchvision.dataset object.

If your dataset is implemented in the torchvision.datasets library, simply add the corresponding lines of code to the get_dataset function in dataset_to_beton.py. We provide implementations for CIFAR10 and CIFAR100.

If you have your dataset in the standard hierarchical subfolder structure, i.e. your dataset consists of subfolders each corresponding to a separate class, you can simply specify the dataset_path argument in create_beton in order to obtain the .beton file. ​

Conversion to .beton accepts a resolution parameter res, specifying the resolution of the images. We recommend using -- res 64 for very large datasets such as ImageNet21k in order to keep the computational requirements manageable for users with less resources. ​

Downloading and converting the trainset of CIFAR10 to the .beton format can for instance be achieved by running

python3 data_utils/dataset_to_beton.py --dataset_name cifar10 --mode train --res 32

Converting a subfolder-structured dataset can be converted to the .beton format at resolution 64 by running

python3 data_utils/dataset_to_beton.py --data_path path/to/folders --mode train --res 64

Pre-training

ImageNet21k. Due to legal reasons, we cannot provide the ImageNet21k in the .beton format directly. We recommend applying here to download it but in case you cannot get access, you can use the torrent here. Similarly for ImageNet1k. Once you have downloaded the dataset, we recommend pre-processing it as detailed in this repository to remove faulty images and classes with only very little examples. Then produce the .beton as outlined above. ​ ​

Pre-training. For pre-training the B_12-Wi_1024 BottleneckMLP on ImageNet21k at resolution $64 \times 64$, you can use the following command:

python3 train.py --dataset imagenet21 --model BottleneckMLP --architecture B_12-Wi_1024 --batch_size 16384 --resolution 64

For more specific configurations, we encourage the user to check out all available flags in train.py. In case you run into memory issues, try to reduce the batch-size. We remark however that smaller batch sizes tend to lead worse results, check-out our paper where we highlight this effect. During training, the parameters will automatically be saved to the checkpointsfolder.

Fine-tuning

​ You can fine-tune our pre-trained checkpoints or your own using the script finetune.py. For instance, the following command fine-tunes a pre-trained B_12-Wi_1024 model on CIFAR10, provided you have converted the CIFAR10 dataset to the .beton format:

python3 finetune.py --architecture B_12-Wi_1024 --checkpoint res_64_in21k --dataset cifar10 --data_resolution 32 --batch_size 2048 --epochs 100 --lr 0.01 --weight_decay 0.0001 --data_path /local/home/stuff/ --crop_scale 0.4 1. --crop_ratio 1. 1. --optimizer sgd --augment --mode finetune --smooth 0.3

You can also train a linear layer on top by specifying the flag --mode linearinstead.

Own Dataset

If you want to add your own dataset, convert it to the FFCV format as detailed above and make sure to fill in the values provided in data_utils/data_stats for the other datasets, such as number of classes, number of samples etc.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published