Skip to content

Commit

Permalink
Merge pull request #26 from zwkkk/master
Browse files Browse the repository at this point in the history
添加FLIP训练功能
  • Loading branch information
yangapku authored Dec 12, 2022
2 parents 1924b1b + 0b6d4c2 commit 1e618b0
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 9 deletions.
32 changes: 25 additions & 7 deletions cn_clip/clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,30 @@ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: i
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable

def forward(self, x: torch.Tensor):
def random_masking(self, x, mask_ratio):
N, L, D = x.shape # batch, length, dim
len_keep = int((L - 1) * (1 - mask_ratio))

noise = torch.rand(N, L - 1, device=x.device)
ids_shuffle = torch.argsort(noise, dim=1) + torch.ones(N, L - 1, device=x.device,
dtype=int)
ids_keep = ids_shuffle[:, :len_keep]

x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

x0 = x[:, 0, :]
x0 = x0.reshape(N, 1, D)
x_masked_add = torch.cat([x0, x_masked], axis=1)
return x_masked_add

def forward(self, x: torch.Tensor, mask_ratio: float = 0.0):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
if mask_ratio != 0:
x = self.random_masking(x, mask_ratio)
x = self.ln_pre(x)

x = x.permute(1, 0, 2) # NLD -> LND
Expand Down Expand Up @@ -282,7 +300,7 @@ def __init__(self,
text_type_vocab_size: int,
tokenizer = _tokenizer,
# vision head width, added this param for ViT-H
vision_head_width: int = 64,
vision_head_width: int = 64,
):
super().__init__()

Expand Down Expand Up @@ -357,23 +375,23 @@ def set_grad_checkpointing(self, enable=True):
def dtype(self):
return self.visual.conv1.weight.dtype

def encode_image(self, image):
return self.visual(image.type(self.dtype))
def encode_image(self, image, mask_ratio=0):
return self.visual(image.type(self.dtype), mask_ratio)

def encode_text(self, text):
pad_index = self.tokenizer.vocab['[PAD]']
attn_mask = text.ne(pad_index).type(self.dtype)
x = self.bert(text, attention_mask=attn_mask)[0].type(self.dtype) # [batch_size, seq_length, hidden_size]
return x[:, 0, :] @ self.text_projection

def forward(self, image, text):
def forward(self, image, text, mask_ratio=0):
assert image is not None or text is not None, "text and image cannot both be None!"

if image is None:
return self.encode_text(text)
elif text is None:
return self.encode_image(image)
image_features = self.encode_image(image)
image_features = self.encode_image(image, mask_ratio)
text_features = self.encode_text(text)

image_features = image_features / image_features.norm(dim=-1, keepdim=True)
Expand Down Expand Up @@ -500,4 +518,4 @@ def parse(x):
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = lambda n, x: _ntuple(n)(x)
to_ntuple = lambda n, x: _ntuple(n)(x)
2 changes: 1 addition & 1 deletion cn_clip/training/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def main():
model_info['vision_layers'] = eval(model_info['vision_layers'])
for k, v in json.load(ft).items():
model_info[k] = v

model = CLIP(**model_info)
if args.clip_weight_path is not None:
assert os.path.exists(args.clip_weight_path), "Pretrained CLIP weight not exists!"
Expand Down
6 changes: 6 additions & 0 deletions cn_clip/training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ def parse_args():
default="ViT-B-16",
help="Name of the vision backbone to use.",
)
parser.add_argument(
"--mask_ratio",
default=0,
type=float,
help="mask ratio of patches.",
)
parser.add_argument(
"--clip-weight-path",
default=None,
Expand Down
2 changes: 1 addition & 1 deletion cn_clip/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def is_master(args):
return args.rank == 0

def get_loss(model, images, texts, loss_img, loss_txt, args):
image_features, text_features, logit_scale = model(images, texts)
image_features, text_features, logit_scale = model(images, texts, args.mask_ratio)
logit_scale = logit_scale.mean()
if args.aggregate:
world_size = dist.get_world_size()
Expand Down
85 changes: 85 additions & 0 deletions run_scripts/flickr30k_finetune_vit-b-16_rbt-base_flip.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#!/usr/bin/env

# Guide:
# This script supports distributed training on multi-gpu workers (as well as single-worker training).
# Please set the options below according to the comments.
# For multi-gpu workers training, these options should be manually set for each worker.
# After setting the options, please run the script on each worker.
# Command: bash run_scripts/muge_finetune_vit-b-16_rbt-base.sh ${DATAPATH}

# Number of GPUs per GPU worker
GPUS_PER_NODE=8
# Number of GPU workers, for single-worker training, please set to 1
WORKER_CNT=1
# The ip address of the rank-0 worker, for single-worker training, please set to localhost
export MASTER_ADDR=XX.XX.XX.XX
# The port for communication
export MASTER_PORT=8514
# The rank of this worker, should be in {0, ..., WORKER_CNT-1}, for single-worker training, please set to 0
export RANK=0

export PYTHONPATH=${PYTHONPATH}:`pwd`/cn_clip/

DATAPATH=${1}

# data options
train_data=${DATAPATH}/datasets/Flickr30k-CN/lmdb/train
val_data=${DATAPATH}/datasets/Flickr30k-CN/lmdb/valid # if val_data is not specified, the validation will be automatically disabled

# restore options
resume=${DATAPATH}/pretrained_weights/clip_cn_vit-b-16.pt # or specify your customed ckpt path to resume
reset_data_offset="--reset-data-offset"
reset_optimizer="--reset-optimizer"
# reset_optimizer=""

# output options
output_base_dir=${DATAPATH}/experiments/
name=flickr30k_finetune_vit-b-16_roberta-base_bs128_8gpu
save_step_frequency=999999 # disable it
save_epoch_frequency=1
log_interval=1
report_training_batch_acc="--report-training-batch-acc"
# report_training_batch_acc=""

# training hyper-params
context_length=52
warmup=100
batch_size=128
valid_batch_size=128
lr=5e-5
wd=0.001
max_epochs=3 # or specify your customed ckpt path to resume
valid_step_interval=150
valid_epoch_interval=1
vision_model=ViT-B-16
text_model=RoBERTa-wwm-ext-base-chinese
mask_ratio=0.5 # use flip: set mask ratio
use_augment="--use-augment"
# use_augment=""

python3 -m torch.distributed.launch --nproc_per_node=${GPUS_PER_NODE} --nnodes=${WORKER_CNT} --node_rank=${RANK} \
--master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} cn_clip/training/main.py \
--train-data=${train_data} \
--val-data=${val_data} \
--resume=${resume} \
${reset_data_offset} \
${reset_optimizer} \
--logs=${output_base_dir} \
--name=${name} \
--save-step-frequency=${save_step_frequency} \
--save-epoch-frequency=${save_epoch_frequency} \
--log-interval=${log_interval} \
${report_training_batch_acc} \
--context-length=${context_length} \
--warmup=${warmup} \
--batch-size=${batch_size} \
--valid-batch-size=${valid_batch_size} \
--valid-step-interval=${valid_step_interval} \
--valid-epoch-interval=${valid_epoch_interval} \
--lr=${lr} \
--wd=${wd} \
--max-epochs=${max_epochs} \
--vision-model=${vision_model} \
--mask_ratio=${mask_ratio} \
${use_augment} \
--text-model=${text_model}
85 changes: 85 additions & 0 deletions run_scripts/muge_finetune_vit-b-16_rbt-base_flip.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#!/usr/bin/env

# Guide:
# This script supports distributed training on multi-gpu workers (as well as single-worker training).
# Please set the options below according to the comments.
# For multi-gpu workers training, these options should be manually set for each worker.
# After setting the options, please run the script on each worker.
# Command: bash run_scripts/muge_finetune_vit-b-16_rbt-base.sh ${DATAPATH}

# Number of GPUs per GPU worker
GPUS_PER_NODE=8
# Number of GPU workers, for single-worker training, please set to 1
WORKER_CNT=1
# The ip address of the rank-0 worker, for single-worker training, please set to localhost
export MASTER_ADDR=XX.XX.XX.XX
# The port for communication
export MASTER_PORT=8514
# The rank of this worker, should be in {0, ..., WORKER_CNT-1}, for single-worker training, please set to 0
export RANK=0

export PYTHONPATH=${PYTHONPATH}:`pwd`/cn_clip/

DATAPATH=${1}

# data options
train_data=${DATAPATH}/datasets/MUGE/lmdb/train
val_data=${DATAPATH}/datasets/MUGE/lmdb/valid # if val_data is not specified, the validation will be automatically disabled

# restore options
resume=${DATAPATH}/pretrained_weights/clip_cn_vit-b-16.pt # or specify your customed ckpt path to resume
reset_data_offset="--reset-data-offset"
reset_optimizer="--reset-optimizer"
# reset_optimizer=""

# output options
output_base_dir=${DATAPATH}/experiments/
name=muge_finetune_vit-b-16_roberta-base_bs128_8gpu
save_step_frequency=999999 # disable it
save_epoch_frequency=1
log_interval=1
report_training_batch_acc="--report-training-batch-acc"
# report_training_batch_acc=""

# training hyper-params
context_length=52
warmup=100
batch_size=128
valid_batch_size=128
lr=5e-5
wd=0.001
max_epochs=3 # or specify your customed ckpt path to resume
valid_step_interval=150
valid_epoch_interval=1
vision_model=ViT-B-16
text_model=RoBERTa-wwm-ext-base-chinese
mask_ratio=0.5 # use flip: set mask ratio
use_augment="--use-augment"
# use_augment=""

python3 -m torch.distributed.launch --nproc_per_node=${GPUS_PER_NODE} --nnodes=${WORKER_CNT} --node_rank=${RANK} \
--master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} cn_clip/training/main.py \
--train-data=${train_data} \
--val-data=${val_data} \
--resume=${resume} \
${reset_data_offset} \
${reset_optimizer} \
--logs=${output_base_dir} \
--name=${name} \
--save-step-frequency=${save_step_frequency} \
--save-epoch-frequency=${save_epoch_frequency} \
--log-interval=${log_interval} \
${report_training_batch_acc} \
--context-length=${context_length} \
--warmup=${warmup} \
--batch-size=${batch_size} \
--valid-batch-size=${valid_batch_size} \
--valid-step-interval=${valid_step_interval} \
--valid-epoch-interval=${valid_epoch_interval} \
--lr=${lr} \
--wd=${wd} \
--max-epochs=${max_epochs} \
--vision-model=${vision_model} \
--mask_ratio=${mask_ratio} \
${use_augment} \
--text-model=${text_model}

0 comments on commit 1e618b0

Please sign in to comment.