diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index ce6e66955..b3c5cc423 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -52,6 +52,9 @@ def collate_fn_remove_corrupted(batch): def main(args): + r""" + transformers 4.30.2で、バッチサイズ>1でも動くようになったので、以下コメントアウト + # GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用 org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように @@ -65,6 +68,7 @@ def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs) return input_ids GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch + """ print(f"load images from {args.train_data_dir}") train_data_dir_path = Path(args.train_data_dir) @@ -81,7 +85,7 @@ def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs) def run_batch(path_imgs): imgs = [im for _, im in path_imgs] - curr_batch_size[0] = len(path_imgs) + # curr_batch_size[0] = len(path_imgs) inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式 generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length) captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)