"多语言文本-视频跨模态检索的新基线模型"论文源代码
- CUDA 10.1
- Python 3.8
- PyTorch 1.5.1
我们使用Anaconda设置了一个支持PyTorch的深度学习工作区,请运行以下脚本以安装所需的程序包。
conda create --name mlcmr python=3.8
conda activate mlcmr
git clone https://github.com/HuiGuanLab/MLCMR.git
cd MLCMR
pip install -r requirements.txt
conda deactivate
我们使用两种公开数据集: VATEX, MSR-VTT. 预训练提取的特征请放置在 $HOME/VisualSearch/
.
我们已经在项目文件的 VisualSearch
里准备好了所需的文本文件。
其中,带有_google_aa2bb的文本文件表示由原始人工标注的aa语言通过谷歌翻译至bb语言得到,其余均为原始人工标注文本
训练集:
平行多语言场景: 原始标注语言a + 原始标注语言b
伪平行多语言场景: 原始标注语言a + 由a翻译得到的标注语言b
不平行多语言场景: 原始标注语言a + 原始标注语言b,其中,a和b描述的是不同的视频
验证集:
原始标注语言a + 原始标注语言b
测试集:
原始标注的目标语言a + 由a翻译得到的标注语言b
对应的视频特征可通过下方获取
Dataset | feature |
---|---|
VATEX | vatex-i3d.tar.gz, pwd:p3p0 |
MSR-VTT | msrvtt10k-resnext101_resnet152.tar.gz, pwd:p3p0 |
ROOTPATH=$HOME/VisualSearch
mkdir -p $ROOTPATH && cd $ROOTPATH
请组织这些文件成下面的形式:
# 下载VATEX数据[英语, 中文]
VisualSearch/vatex/
FeatureData/
i3d_kinetics/
feature.bin
id.txt
shape.txt
video2frames.txt
TextData/
xx.txt
# 下载MSR-VTT数据[英语, 中文]
VisualSearch/msrvtt10kyu/
FeatureData/
resnext101-resnet152/
feature.bin
id.txt
shape.txt
video2frames.txt
TextData/
xx.txt
运行以下脚本来训练和评估“MLCMR”网络。具体而言,它将训练“MLCMR”网络,并选择在验证集上表现最好的检查点作为最终模型。请注意,我们只在验证集上保存性能最好的检查点,以节省磁盘空间。
ROOTPATH=$HOME/VisualSearch
conda activate mlcmr
# 例子:
# 使用 VATEX 训练 平行多语言 MLCMR 以验证中文性能
./do_all.sh vatex i3d_kinetics parallel human_label zh $ROOTPATH
# 可以通过修改训练完毕后产生的do_testxxx.py的target_language参数为en,直接验证对应的英语性能
# 使用 VATEX 训练 平行多语言 MLCMR 以验证英文性能
./do_all.sh vatex i3d_kinetics parallel human_label en $ROOTPATH
所有的检查点,从百度云盘(url,密码:4qvt) 下载VATEX上经过训练的检查点,并运行以下脚本对其进行评估。
ROOTPATH=$HOME/VisualSearch/
#将mlcmr_human_label_vatex/model_best.pth.tar移动至ROOTPATH/vatex/mlcmr_human_label_vatex/下,没有则创建
#在本项目下创建do_test_mlcmr_vatex.sh文件,内容如下:
#------
rootpath=<yourROOTPATH>
testCollection=vatex
logger_name=<yourROOTPATH>/vatex/mlcmr_human_label_vatex
overwrite=0
train_mode=parallel
label_situation=human_label
target_language=zh #测试英文请改成en
gpu=0
CUDA_VISIBLE_DEVICES=$gpu python tester.py --testCollection $testCollection --train_mode $train_mode --label_situation $label_situation --target_language $target_language --rootpath $rootpath --overwrite $overwrite --logger_name $logger_name
#------
#保存后运行do_test_mlcmr_vatex.sh文件
./do_test_mlcmr_vatex.sh
由于SGD的随机性,本次检查点模型预期性能表现与论文描述稍有不同,预计如下:
Dateset | Text-to-Video Retrieval | Video-to-Text Retrieval | SumR | ||||||||
---|---|---|---|---|---|---|---|---|---|---|---|
R@1 | R@5 | R@10 | MedR | mAP | R@1 | R@5 | R@10 | MedR | mAP | ||
Parllel_VATEX_Chinese | 36.9 | 71.3 | 80.6 | 2.0 | 52.13 | 51.1 | 79.1 | 86.7 | 1.0 | 39.05 | 405.9 |
Parllel_VATEX_English | 39.6 | 74.4 | 83.3 | / | / | 50.1 | 78.1 | 86.8 | / | / | 412.3 |
运行以下脚本来训练和评估伪平行多语言场景下“MLCMR”网络。
ROOTPATH=$HOME/VisualSearch
conda activate mlcmr
# 使用 VATEX 训练 伪平行多语言 MLCMR 以验证中文性能
./do_all.sh vatex i3d_kinetics parallel translate zh $ROOTPATH
# 可以通过修改训练完毕后产生的do_testxxx.py的target_language参数为en,直接验证对应的英语性能
# 请注意,下面这个方式将以中文翻译的英语文本进行训练,与论文中的实验无关,为了便于比较,论文中的英语性能是使用原始英语,即上面的训练方式直接验证英语性能得出的
# 使用 VATEX 训练 伪平行多语言 MLCMR 以验证英文性能
./do_all.sh vatex i3d_kinetics parallel translate en $ROOTPATH
所有的检查点,从百度云盘(url,密码:4qvt) 下载VATEX上经过训练的检查点,并运行以下脚本对其进行评估。
ROOTPATH=$HOME/VisualSearch/
#将mlcmr_translate_vatex/model_best.pth.tar移动至ROOTPATH/vatex/mlcmr_translate_vatex/下,没有则创建
#在本项目下创建do_test_mlcmr_vatex.sh文件,内容如下:
#------
rootpath=<yourROOTPATH>
testCollection=vatex
logger_name=<yourROOTPATH>/vatex/mlcmr_translate_vatex
overwrite=0
train_mode=parallel
label_situation=translate
target_language=zh #测试英文请改成en
gpu=0
CUDA_VISIBLE_DEVICES=$gpu python tester.py --testCollection $testCollection --train_mode $train_mode --label_situation $label_situation --target_language $target_language --rootpath $rootpath --overwrite $overwrite --logger_name $logger_name
#------
#保存后运行do_test_mlcmr_vatex.sh文件
./do_test_mlcmr_vatex.sh
由于SGD的随机性,本次检查点模型预期性能表现与论文描述稍有不同,预计如下:
Dataset | Text-to-Video Retrieval | Video-to-Text Retrieval | SumR | ||||||||
---|---|---|---|---|---|---|---|---|---|---|---|
R@1 | R@5 | R@10 | MedR | mAP | R@1 | R@5 | R@10 | MedR | mAP | ||
Parllel_VATEX_Translate_Chinese | 32.5 | 66.2 | 76.7 | 3.0 | 47.58 | 48.1 | 75.5 | 84.5 | 2.0 | 34.94 | 383.5 |
Parllel_VATEX_Translate_English | 37.5 | 73.7 | 82.5 | / | / | 49.4 | 78.3 | 86.5 | / | / | 407.9 |
运行以下脚本来训练和评估不平行多语言场景下“MLCMR”网络。
ROOTPATH=$HOME/VisualSearch
conda activate mlcmr
# 使用 VATEX 训练 不平行多语言 MLCMR 以验证中文性能
./do_all.sh vatex i3d_kinetics unparallel human_label zh $ROOTPATH
# 可以通过修改训练完毕后产生的do_testxxx.py的target_language参数为en,直接验证对应的英语性能
# 使用 VATEX 训练 不平行多语言 MLCMR 以验证英文性能
./do_all.sh vatex i3d_kinetics unparallel human_label en $ROOTPATH
参考论文中做出的实验,VATEX上不平行多语言场景下的MLCMR预期性能如下:
Dataset | Text-to-Video Retrieval | Video-to-Text Retrieval | SumR | ||||
---|---|---|---|---|---|---|---|
R@1 | R@5 | R@10 | R@1 | R@5 | R@10 | ||
Unparllel_VATEX_Chinese | 31.6 | 64.7 | 76.0 | 44.9 | 75.5 | 84.5 | 377.1 |
Unparllel_VATEX_English | 32.8 | 66.7 | 77.4 | 44.5 | 74.4 | 84.3 | 380.1 |
由于MSRVTT不具备多语言特性,因此仅验证伪平行多语言场景下的性能
运行以下脚本来训练和评估“MLCMR”网络。
ROOTPATH=$HOME/VisualSearch
conda activate mlcmr
# 例子:
# 使用 MSRVTT 训练 伪平行多语言 MLCMR 以验证中文性能
./do_all.sh msrvtt10k resnext101-resnet152 parallel translate zh $ROOTPATH
# 可以通过修改训练完毕后产生的do_testxxx.py的target_language参数为en,直接验证对应的英语性能
所有的检查点,从百度云盘(url,密码:4qvt) 下载MSRVTT上经过训练的检查点,并运行以下脚本对其进行评估。
ROOTPATH=$HOME/VisualSearch/
#将mlcmr_translate_msrvtt10kyu/model_best.pth.tar移动至ROOTPATH/msrvtt10kyu/mlcmr_translate_msrvtt10kyu/下,没有则创建
#在本项目下创建do_test_mlcmr_msrvtt10kyu.sh文件,内容如下:
#------
rootpath=<yourROOTPATH>
testCollection=msrvtt10kyu
logger_name=<yourROOTPATH>/msrvtt10kyu/mlcmr_translate_msrvtt10kyu/
overwrite=0
train_mode=parallel
label_situation=translate
target_language=zh #测试英文请改成en
gpu=0
CUDA_VISIBLE_DEVICES=$gpu python tester.py --testCollection $testCollection --train_mode $train_mode --label_situation $label_situation --target_language $target_language --rootpath $rootpath --overwrite $overwrite --logger_name $logger_name
#------
#保存后运行do_test_mlcmr_msrvtt10kyu.sh文件
./do_test_mlcmr_msrvtt10kyu.sh
检查点性能预计表现如下:
Dataset | Text-to-Video Retrieval | Video-to-Text Retrieval | SumR | ||||||||
---|---|---|---|---|---|---|---|---|---|---|---|
R@1 | R@5 | R@10 | MedR | mAP | R@1 | R@5 | R@10 | MedR | mAP | ||
Parllel_MSRVTT_Translate_Chinese | 28.9 | 55.5 | 69.2 | 4.0 | 41.54 | 30.4 | 57.8 | 69.7 | 4.0 | 43.07 | 311.5 |
Parllel_MSRVTT_Translate_English | 26.9 | 55.3 | 67.7 | 4.0 | 40.20 | 29.8 | 55.7 | 67.7 | 5.0 | 42.16 | 303.1 |