Skip to content

Commit

Permalink
Fix typo distillation
Browse files Browse the repository at this point in the history
  • Loading branch information
successwang committed Sep 25, 2023
1 parent 2fce425 commit a3ea2ef
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 26 deletions.
6 changes: 3 additions & 3 deletions cn_clip/training/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ def main():
# only do so if it is the 0th worker.
args.should_save = (args.logs is not None and args.logs != '' and args.logs.lower() != 'none') and is_master(args)

# load teacher model to distllation
if args.distllation:
# load teacher model to distillation
if args.distillation:
try:
from modelscope.models import Model
except:
Expand Down Expand Up @@ -292,7 +292,7 @@ def main():
for epoch in range(start_epoch, args.max_epochs):
if is_master(args) == 0:
logging.info(f'Start epoch {epoch + 1}')
if args.distllation:
if args.distillation:
num_steps_this_epoch = train(model, data, epoch, optimizer, scaler, scheduler, args, steps, teacher_model)
else:
num_steps_this_epoch = train(model, data, epoch, optimizer, scaler, scheduler, args, steps)
Expand Down
4 changes: 2 additions & 2 deletions cn_clip/training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,9 @@ def parse_args():
default=123,
help="Random seed."
)
# arguments for distllation
# arguments for distillation
parser.add_argument(
"--distllation",
"--distillation",
default=False,
action="store_true",
help="If true, more information is logged."
Expand Down
26 changes: 13 additions & 13 deletions cn_clip/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_feature
if args.accum_freq == 1:
image_features, text_features, logit_scale = model(images, texts, args.mask_ratio)

if args.distllation:
if args.distillation:
with torch.no_grad():
# different teacher model has different output
output = teacher_model.module.get_feature(images)
Expand All @@ -34,7 +34,7 @@ def get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_feature
assert accum_image_features and accum_text_features and accum_idx != -1
chunk_image_features, chunk_text_features, logit_scale = model(images, texts, args.mask_ratio)

if args.distllation:
if args.distillation:
with torch.no_grad():
# different teacher model has different output
output = teacher_model.module.get_feature(images)
Expand All @@ -59,7 +59,7 @@ def get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_feature
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)

if args.distllation:
if args.distillation:
all_teacher_image_features = torch.cat(torch.distributed.nn.all_gather(teacher_image_features), dim=0)
else:
gathered_image_features = [
Expand Down Expand Up @@ -87,7 +87,7 @@ def get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_feature
logits_per_image = logit_scale * all_image_features @ all_text_features.t()
logits_per_text = logits_per_image.t()

if args.distllation:
if args.distillation:
gathered_teacher_image_features = [
torch.zeros_like(teacher_image_features) for _ in range(world_size)
]
Expand All @@ -103,7 +103,7 @@ def get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_feature
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logit_scale * text_features @ image_features.t()

if args.distllation:
if args.distillation:
kd_loss = cosineSimilarityLoss(teacher_image_features, image_features)

ground_truth = torch.arange(len(logits_per_image)).long()
Expand All @@ -120,7 +120,7 @@ def get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_feature
t2i_acc = (logits_per_text.argmax(-1) == ground_truth).sum() / len(logits_per_text)
acc = {"i2t": i2t_acc, "t2i": t2i_acc}

if args.distllation:
if args.distillation:
total_loss += kd_loss * args.kd_loss_weight

return total_loss, acc
Expand Down Expand Up @@ -156,7 +156,7 @@ def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained

if args.accum_freq > 1:
accum_images, accum_texts, accum_image_features, accum_text_features = [], [], [], []
if args.distllation:
if args.distillation:
teacher_accum_image_features = []

end = time.time()
Expand Down Expand Up @@ -188,7 +188,7 @@ def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained
# with automatic mixed precision.
if args.precision == "amp":
with autocast():
if args.distllation:
if args.distillation:
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args, teacher_model=teacher_model)
else:
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args)
Expand All @@ -197,7 +197,7 @@ def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained
scaler.update()

else:
if args.distllation:
if args.distillation:
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args, teacher_model=teacher_model)
else:
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args)
Expand All @@ -208,15 +208,15 @@ def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained
with torch.no_grad():
with autocast(enabled=(args.precision == "amp")):
chunk_image_features, chunk_text_features, _ = model(images, texts)
if args.distllation:
if args.distillation:
output = teacher_model.module.get_feature(images)
if(len(output) == 2):
teacher_chunk_image_features = output[0]
else:
teacher_chunk_image_features = output
accum_image_features.append(chunk_image_features)
accum_text_features.append(chunk_text_features)
if args.distllation:
if args.distillation:
teacher_accum_image_features.append(teacher_chunk_image_features)

accum_images.append(images)
Expand All @@ -237,7 +237,7 @@ def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained
with autocast(enabled=(args.precision == "amp")):
# `total_loss` and `acc` are coarsely sampled, taking only the last result in the loop.
# Although each result should be the same in theory, it will be slightly different in practice
if args.distllation:
if args.distillation:
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_features, accum_text_features, j, teacher_model, teacher_accum_image_features)
else:
total_loss, acc = get_loss(model, images, texts, loss_img, loss_txt, args, accum_image_features, accum_text_features, j)
Expand All @@ -255,7 +255,7 @@ def train(model, data, epoch, optimizer, scaler, scheduler, args, global_trained
# reset gradient accum, if enabled
if args.accum_freq > 1:
accum_images, accum_texts, accum_image_features, accum_text_features = [], [], [], []
if args.distllation:
if args.distillation:
teacher_accum_image_features = []

# Note: we clamp to 4.6052 = ln(100), as in the original paper.
Expand Down
6 changes: 3 additions & 3 deletions distillation.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

## 在Chinese-CLIP中用起来!

在Chinese-CLIP finetune中对于图像端应用知识蒸馏并不复杂。只需要在finetune的sh脚本中加入`--distllation`配置项。
在Chinese-CLIP finetune中对于图像端应用知识蒸馏并不复杂。只需要在finetune的sh脚本中加入`--distillation`配置项。
然后在配置项`--teacher-model-name`填入所要使用的Teacher model名称。现在支持的Teacher model包括以下四种。
<table border="1" width="120%">
<tr align="center">
Expand All @@ -39,11 +39,11 @@


其中各配置项定义如下:
+ `distllation`: 是否启用知识蒸馏微调模型图像端。
+ `distillation`: 是否启用知识蒸馏微调模型图像端。
+ `teacher-model-name`: 指定使用的Teacher model。目前支持以上四个Teacher model,如填入`damo/multi-modal_team-vit-large-patch14_multi-modal-similarity`
+ `kd_loss_weight`(可选): 蒸馏损失的权值,默认值是0.5。

我们提供了样例脚本`run_scripts/muge_finetune_vit-b-16_rbt-base_distllation.sh`,使用的是`TEAM图文检索模型-中文-large`作为Teacher model。
我们提供了样例脚本`run_scripts/muge_finetune_vit-b-16_rbt-base_distillation.sh`,使用的是`TEAM图文检索模型-中文-large`作为Teacher model。

## 效果验证
这里是我们模型(finetune+distillation) vs 预训练模型 vs finetune模型的图像检索Top10结果。左上角图像作为query,右边按顺序Top1到Top10检索结果。本次实验的support数据集有10万电商数据量(包括鞋子、衣服、裤子等物品)。
Expand Down
6 changes: 3 additions & 3 deletions distillation_En.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Here we provide an example of knowledge distillation for Chinese-CLIP fine-tunin
+ Other dependencies as required in [requirements.txt](requirements.txt).

## Use it in Chinese-CLIP!
It is not complicated to apply knowledge distillation to the image side in Chinese-CLIP finetune. Just add the `--distllation` configuration item to the sh script of finetune.
It is not complicated to apply knowledge distillation to the image side in Chinese-CLIP finetune. Just add the `--distillation` configuration item to the sh script of finetune.
Then fill in the name of the Teacher model to be used in the configuration item `--teacher-model-name`. The currently supported Teacher models include the following four ModelScope-supported models.
<table border="1" width="120%">
<tr align="center">
Expand All @@ -37,11 +37,11 @@ Then fill in the name of the Teacher model to be used in the configuration item
Finally, fill in the weight of the distillation loss in the configuration item `--kd_loss_weight`, the default value is 0.5.

The configuration items are defined as follows:
+ `distllation`: Whether to enable knowledge distillation to fine-tune the image side of the model.
+ `distillation`: Whether to enable knowledge distillation to fine-tune the image side of the model.
+ `teacher-model-name`: Specify the Teacher model to use. Currently supports the above four Teacher models, such as filling in `damo/multi-modal_team-vit-large-patch14_multi-modal-similarity`.
+ `kd_loss_weight` (optional): Distillation loss weight, default value is 0.5.

We provide a sample script `run_scripts/muge_finetune_vit-b-16_rbt-base_distllation.sh`, we take the `TEAM image-text retrieval model-Chinese-large` as Teacher model.
We provide a sample script `run_scripts/muge_finetune_vit-b-16_rbt-base_distillation.sh`, we take the `TEAM image-text retrieval model-Chinese-large` as Teacher model.

## Effect verification
Image retrieval Top10 results of our model (finetune+distillation) v.s. pre-trained model v.s. finetune model. The image in the upper left corner is used as a query, and the search results are in order from Top1 to Top10 on the right. The support data set in this experiment has 100,000 e-commerce data (including shoes, clothes, pants, etc.).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ valid_epoch_interval=1
vision_model=ViT-B-16
text_model=RoBERTa-wwm-ext-base-chinese
use_augment="--use-augment"
distllation="--distllation"
distillation="--distillation"
teacher_model_name="damo/multi-modal_team-vit-large-patch14_multi-modal-similarity"
# use_augment=""

Expand Down Expand Up @@ -85,5 +85,5 @@ python3 -m torch.distributed.launch --use_env --nproc_per_node=${GPUS_PER_NODE}
--vision-model=${vision_model} \
${use_augment} \
--text-model=${text_model} \
${distllation} \
${distillation} \
--teacher-model-name=${teacher_model_name} \

0 comments on commit a3ea2ef

Please sign in to comment.