diff --git a/toolkit/captioning/caption_with_cogvlm.py b/toolkit/captioning/caption_with_cogvlm.py index 52612645..9c1dbf23 100644 --- a/toolkit/captioning/caption_with_cogvlm.py +++ b/toolkit/captioning/caption_with_cogvlm.py @@ -293,6 +293,8 @@ def process_directory( counter += 1 image.save(new_filepath) + else: + new_filepath = full_filepath if args.target_backend_id: upload_to_s3(s3_client, bucket_name, image, new_filename) @@ -342,15 +344,23 @@ def main(): if args.output_dir and not os.path.exists(args.output_dir): os.makedirs(args.output_dir) logger.info("Loading CogVLM model. This should only occur once.") - from transformers import AutoModelForCausalLM, LlamaTokenizer + from transformers import AutoModelForCausalLM, LlamaTokenizer, AutoTokenizer - tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5") logger.info(f"Loading CogVLM in {args.precision} precision.") if "cogvlm2" in args.model_path and torch.backends.mps.is_available(): logger.warning( "Can not run CogVLM 2 on MPS because Triton is unavailable. Falling back to CogVLM 1.1" ) - args.model_path = "THUDM/cogvlm-chat-hf" + elif "cogvlm2" in args.model_path: + import sysconfig + + print(sysconfig.get_paths()["include"]) + tokenizer = AutoTokenizer.from_pretrained( + args.model_path, trust_remote_code=True + ) + else: + tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5") + model = AutoModelForCausalLM.from_pretrained( args.model_path, torch_dtype=torch_dtype,