-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add evaluation codes and their readme
- Loading branch information
Showing
93 changed files
with
156,564 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Mugs: A Multi-Granular Self-Supervised Learning Framework | ||
Here we provide the evaluation code to evaluate the pretrained model by **Mugs** on several downstream tasks. | ||
|
||
### Environment | ||
For reproducing, please install [PyTorch](https://pytorch.org/) and download the [ImageNet](https://imagenet.stanford.edu/) dataset. | ||
This codebase has been developed with python version 3.8, PyTorch version 1.7.1, CUDA 11.0 and torchvision 0.8.2. For the full | ||
environment, please refer to our `Dockerfile` file. | ||
|
||
|
||
## Evaluation | ||
For all downstream tasks tested in the manuscript, you can directly use the corresponding `.sh` evaluation file | ||
in the `eval` fold. For these `.sh` evaluation file, all hyper-parameters are assigned. In this way, what all you | ||
is to assign the paths of the dataset and the pretrained model. | ||
|
||
|
||
### Fine-tuning classification on ImageNet-1K | ||
To evaluate fine-tuning on a pre-trained model, you can first enter the `eval` fold. Then you can run `eval_finetuning.sh` | ||
or run one command in the `eval_finetuning.sh` which contains fine-tuning evaluations on all models. | ||
#### Step 1. extract the backbone weight | ||
``` | ||
python ./eval_finetuning/extract_backbone_weights_for_finetuning.py --checkpoint $CHECKPOINT --output $OUTPUT --checkpoint_key teacher | ||
``` | ||
#### Step 2. load and fine-tune the backbone weight | ||
``` | ||
NPROC_PER_NODE=4 | ||
BATCH_SIZE_PER_GPU=256 | ||
python -m torch.distributed.launch --nproc_per_node=$NPROC_PER_NODE ./eval_finetuning/eval_finetuning.py --data_path $DATASET_ROOT --finetune $OUTPUT --model vit_small --epochs 200 --batch_size $BATCH_SIZE_PER_GPU --warmup_epochs 20 --drop_path 0.1 --lr 0.0012 --layer_decay 0.55 --mixup 0.8 --cutmix 1.0 --layer_scale_init_value 0.0 --disable_rel_pos_bias --abs_pos_emb --use_cls --imagenet_default_mean_and_std | ||
``` | ||
For all fine-tuning logs, please find them in `Table 2`. | ||
|
||
**<p align="center">Table 2. Hyper-parameters, logs and model weights for linear probing, fine-tuning.</p>** | ||
<table> | ||
<tr> | ||
<th>arch</th> | ||
<th>params</th> | ||
<th>pretraining epochs</th> | ||
<th>k-nn</th> | ||
<th>linear</th> | ||
<th>fine-tune</th> | ||
<th colspan="2">linear evaluation</th> | ||
<th colspan="2">fine-tuning evaluation</th> | ||
</tr> | ||
<tr> | ||
<td>ViT-S/16</td> | ||
<td>21M</td> | ||
<td>800</td> | ||
<td>75.6%</td> | ||
<td>78.9%</td> | ||
<td>82.6%</td> | ||
<td><a href="https://drive.google.com/file/d/14LF-T94dCBqLii0qhOfZhZzZm6AuhCi_/view?usp=sharing">linear weights</a></td> | ||
<td><a href="https://drive.google.com/file/d/12tiO4glWZNB044TYiPPCfbnUX_9AbqVc/view?usp=sharing">eval logs</a></td> | ||
<td><a href="https://drive.google.com/file/d/1cEkQW72VZv-4aQVbQHP4CBgyPJP22CPv/view?usp=sharing">fine-tune weights</a></td> | ||
<td><a href="https://drive.google.com/file/d/1LrElU1T4lvHxCuU5LJ-llX-9cCMl8o1L/view?usp=sharing">eval logs</a></td> | ||
</tr> | ||
<tr> | ||
<td>ViT-B/16</td> | ||
<td>85M</td> | ||
<td>400</td> | ||
<td>78.0%</td> | ||
<td>80.6%</td> | ||
<td>84.3%</td> | ||
<td><a href="https://drive.google.com/file/d/1MAz28bBgzPb7MVhfbveL7PTox06xu_Wx/view?usp=sharing">linear weights</a></td> | ||
<td><a href="https://drive.google.com/file/d/1gOR250QFLZfe40pLNPcOqaLPAnKLuE_C/view?usp=sharing">eval logs</a></td> | ||
<td><a href="https://drive.google.com/file/d/1YTC9rj5t8onqJ5oAmVPAcXa1tADQtaYe/view?usp=sharing">fine-tune weights</a></td> | ||
<td><a href="https://drive.google.com/file/d/1L8EixjzZzP62dU3Z6mzykIdjuVHU-bpb/view?usp=sharing">eval logs</a></td> | ||
</tr> | ||
<tr> | ||
<td>ViT-L/16</td> | ||
<td>307M</td> | ||
<td>250</td> | ||
<td>80.3%</td> | ||
<td>82.1%</td> | ||
<td>85.2%</td> | ||
<td><a href="https://drive.google.com/file/d/1j6rQwFTsT3NMLBs4s6qrQbjxn-HK1Mv6/view?usp=sharing">linear weights</a></td> | ||
<td><a href="https://drive.google.com/file/d/1rqWenRFN0czat_55GY9GNOu7gS6fww3g/view?usp=sharing">eval logs</a></td> | ||
<td><a href="https://drive.google.com/file/d/10Tcp-EMkNz1Kj1enjTYoGG90jkH9Gx-7/view?usp=sharing">fine-tune weights</a></td> | ||
<td><a href="https://drive.google.com/file/d/16o19XGdwR9_lsGdJqMTBZACgHOONppx2/view?usp=sharing">eval logs</a></td> | ||
</tr> | ||
</table> | ||
|
||
## License | ||
This repository is released under the Apache 2.0 license as found in the [LICENSE](LICENSE) file. | ||
|
||
## Citation | ||
If you find this repository useful, please consider giving a star :star: and citation :t-rex:: | ||
``` | ||
@inproceedings{mugs2022SSL, | ||
title={Mugs: A Multi-Granular Self-Supervised Learning Framework}, | ||
author={Pan Zhou and Yichen Zhou and Chenyang Si and Weihao Yu and Teck Khim Ng and Shuicheng Yan}, | ||
booktitle={Axriv}, | ||
year={2022} | ||
} | ||
``` |
Binary file added
BIN
+4.62 KB
eval/eval_finetuning/__pycache__/engine_for_finetuning.cpython-38.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
# Copyright 2021 Garena Online Private Limited | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
training code for fine-tuning | ||
Most are copyed from BEiT library: | ||
https://github.com/microsoft/unilm/tree/master/beit | ||
""" | ||
|
||
import math | ||
import sys | ||
import torch | ||
import utils_for_finetuning | ||
|
||
from typing import Iterable, Optional | ||
from timm.data import Mixup | ||
from timm.utils import accuracy, ModelEma | ||
|
||
def train_class_batch(model, samples, target, criterion): | ||
outputs = model(samples) | ||
loss = criterion(outputs, target) | ||
return loss, outputs | ||
|
||
|
||
def get_loss_scale_for_deepspeed(model): | ||
optimizer = model.optimizer | ||
return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale | ||
|
||
|
||
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, | ||
data_loader: Iterable, optimizer: torch.optim.Optimizer, | ||
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, | ||
model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, | ||
start_steps=None, lr_schedule_values=None, wd_schedule_values=None, | ||
num_training_steps_per_epoch=None, update_freq=None): | ||
model.train(True) | ||
metric_logger = utils_for_finetuning.MetricLogger(delimiter=" ") | ||
metric_logger.add_meter('lr', utils_for_finetuning.SmoothedValue(window_size=1, fmt='{value:.6f}')) | ||
metric_logger.add_meter('min_lr', utils_for_finetuning.SmoothedValue(window_size=1, fmt='{value:.6f}')) | ||
header = 'Epoch: [{}]'.format(epoch) | ||
print_freq = 10 | ||
|
||
if loss_scaler is None: | ||
model.zero_grad() | ||
model.micro_steps = 0 | ||
else: | ||
optimizer.zero_grad() | ||
|
||
for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): | ||
step = data_iter_step // update_freq | ||
if step >= num_training_steps_per_epoch: | ||
continue | ||
it = start_steps + step # global training iteration | ||
# Update LR & WD for the first acc | ||
if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0: | ||
for i, param_group in enumerate(optimizer.param_groups): | ||
if lr_schedule_values is not None: | ||
param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] | ||
if wd_schedule_values is not None and param_group["weight_decay"] > 0: | ||
param_group["weight_decay"] = wd_schedule_values[it] | ||
|
||
samples = samples.to(device, non_blocking=True) | ||
targets = targets.to(device, non_blocking=True) | ||
|
||
if mixup_fn is not None: | ||
samples, targets = mixup_fn(samples, targets) | ||
|
||
if loss_scaler is None: | ||
samples = samples.half() | ||
loss, output = train_class_batch( | ||
model, samples, targets, criterion) | ||
else: | ||
with torch.cuda.amp.autocast(): | ||
loss, output = train_class_batch( | ||
model, samples, targets, criterion) | ||
|
||
loss_value = loss.item() | ||
|
||
if not math.isfinite(loss_value): | ||
print("Loss is {}, stopping training".format(loss_value)) | ||
sys.exit(1) | ||
|
||
if loss_scaler is None: | ||
loss /= update_freq | ||
model.backward(loss) | ||
model.step() | ||
|
||
if (data_iter_step + 1) % update_freq == 0: | ||
# model.zero_grad() | ||
# Deepspeed will call step() & model.zero_grad() automatic | ||
if model_ema is not None: | ||
model_ema.update(model) | ||
grad_norm = None | ||
loss_scale_value = get_loss_scale_for_deepspeed(model) | ||
else: | ||
# this attribute is added by timm on one optimizer (adahessian) | ||
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order | ||
loss /= update_freq | ||
grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, | ||
parameters=model.parameters(), create_graph=is_second_order, | ||
update_grad=(data_iter_step + 1) % update_freq == 0) | ||
if (data_iter_step + 1) % update_freq == 0: | ||
optimizer.zero_grad() | ||
if model_ema is not None: | ||
model_ema.update(model) | ||
loss_scale_value = loss_scaler.state_dict()["scale"] | ||
|
||
torch.cuda.synchronize() | ||
|
||
if mixup_fn is None: | ||
class_acc = (output.max(-1)[-1] == targets).float().mean() | ||
else: | ||
class_acc = None | ||
metric_logger.update(loss=loss_value) | ||
metric_logger.update(class_acc=class_acc) | ||
metric_logger.update(loss_scale=loss_scale_value) | ||
min_lr = 10. | ||
max_lr = 0. | ||
for group in optimizer.param_groups: | ||
min_lr = min(min_lr, group["lr"]) | ||
max_lr = max(max_lr, group["lr"]) | ||
|
||
metric_logger.update(lr=max_lr) | ||
metric_logger.update(min_lr=min_lr) | ||
weight_decay_value = None | ||
for group in optimizer.param_groups: | ||
if group["weight_decay"] > 0: | ||
weight_decay_value = group["weight_decay"] | ||
metric_logger.update(weight_decay=weight_decay_value) | ||
metric_logger.update(grad_norm=grad_norm) | ||
|
||
if log_writer is not None: | ||
log_writer.update(loss=loss_value, head="loss") | ||
log_writer.update(class_acc=class_acc, head="loss") | ||
log_writer.update(loss_scale=loss_scale_value, head="opt") | ||
log_writer.update(lr=max_lr, head="opt") | ||
log_writer.update(min_lr=min_lr, head="opt") | ||
log_writer.update(weight_decay=weight_decay_value, head="opt") | ||
log_writer.update(grad_norm=grad_norm, head="opt") | ||
|
||
log_writer.set_step() | ||
|
||
# gather the stats from all processes | ||
metric_logger.synchronize_between_processes() | ||
print("Averaged stats:", metric_logger) | ||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} | ||
|
||
|
||
@torch.no_grad() | ||
def evaluate(data_loader, model, device): | ||
criterion = torch.nn.CrossEntropyLoss() | ||
|
||
metric_logger = utils_for_finetuning.MetricLogger(delimiter=" ") | ||
header = 'Test:' | ||
|
||
# switch to evaluation mode | ||
model.eval() | ||
|
||
for batch in metric_logger.log_every(data_loader, 10, header): | ||
images = batch[0] | ||
target = batch[-1] | ||
images = images.to(device, non_blocking=True) | ||
target = target.to(device, non_blocking=True) | ||
|
||
# compute output | ||
with torch.cuda.amp.autocast(): | ||
output = model(images) | ||
loss = criterion(output, target) | ||
|
||
acc1, acc5 = accuracy(output, target, topk=(1, 5)) | ||
|
||
batch_size = images.shape[0] | ||
metric_logger.update(loss=loss.item()) | ||
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) | ||
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) | ||
# gather the stats from all processes | ||
metric_logger.synchronize_between_processes() | ||
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' | ||
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) | ||
|
||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
Oops, something went wrong.