Skip to content

Unofficial PyTorch Reimplementation of UniformAugment.

License

Notifications You must be signed in to change notification settings

tgilewicz/uniformaugment

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

15 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

UniformAugment

Unofficial PyTorch Reimplementation of UniformAugment. Most of codes are from Fast AutoAugment and PyTorch RandAugment.

Introduction

UniformAugment is an automated data augmentation approach that completely avoids a search phase. UniformAugment’s effectiveness is comparable to the known methods, while still being highly efficient by virtue of not requiring any search.

Install

pip install git+https://github.com/tgilewicz/uniformaugment/

Usage

from torchvision.transforms import transforms
from UniformAugment import UniformAugment

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
])
# Add UniformAugment with num_ops hyperparameter (num_ops=2 is optimal)
transform_train.transforms.insert(0, UniformAugment())

Experiment

The details of the experiment were consulted with the authors of the UniformAugment paper.

You can run an example experiment with,

$ python UniformAugment/train.py -c confs/wresnet28x10_cifar.yaml --dataset cifar10 \
    --save cifar10_wres28x10.pth --dataroot ~/data --tag v1

CIFAR-10 Classification, TOP1 Accuracy

Model Paper's Result Run1 Run2 Run3 Run4 Avg (Ours)
Wide-ResNet 28x10 97.33 97.26 97.31 97.33 97.42 97.33
Wide-ResNet 40x2 96.25 96.27 96.36 96.5 96.54 96.41

CIFAR-100 Classification, TOP1 Accuracy

Model Paper's Result Run1 Run2 Run3 Run4 Avg (Ours)
Wide-ResNet 28x10 82.82 83.55 82.56 82.66 82.72 82.87
Wide-ResNet 40x2 79.01 79.06 79.08 79.09 78.77 79.00

ImageNet Classification

Model Paper's Result Ours
ResNet-50 77.63 77.80
ResNet-200 80.4 Stay tuned

Core class

class UniformAugment:
    def __init__(self, ops_num=2):
        self._augment_list = augment_list(for_autoaug=False)
        self._ops_num = ops_num

    def __call__(self, img):
        # Selecting unique num_ops transforms for each image would help the
        #   training procedure.
        ops = random.choices(self._augment_list, k=self._ops_num)

        for op in ops:
            augment_fn, low, high = op
            probability = random.random()
            if random.random() < probability:
                img = augment_fn(img.copy(), random.uniform(low, high))

        return img

References