Skip to content

Commit

Permalink
remove workaround for transfomers bs>1 close #869
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Oct 10, 2023
1 parent 3e81bd6 commit 17813ff
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion finetune/make_captions_by_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def collate_fn_remove_corrupted(batch):


def main(args):
r"""
transformers 4.30.2で、バッチサイズ>1でも動くようになったので、以下コメントアウト
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
Expand All @@ -65,6 +68,7 @@ def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs)
return input_ids
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
"""

print(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir)
Expand All @@ -81,7 +85,7 @@ def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs)
def run_batch(path_imgs):
imgs = [im for _, im in path_imgs]

curr_batch_size[0] = len(path_imgs)
# curr_batch_size[0] = len(path_imgs)
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
Expand Down

0 comments on commit 17813ff

Please sign in to comment.