Skip to content

Commit

Permalink
Merge pull request #936 from burgalon/fix-cogvlm2
Browse files Browse the repository at this point in the history
Fix caption_with_cogvlm.py  for cogvlm2 + textfile strategy
  • Loading branch information
bghira authored Sep 4, 2024
2 parents 3fcca14 + 8899037 commit 5dcd649
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions toolkit/captioning/caption_with_cogvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ def process_directory(
counter += 1

image.save(new_filepath)
else:
new_filepath = full_filepath
if args.target_backend_id:
upload_to_s3(s3_client, bucket_name, image, new_filename)

Expand Down Expand Up @@ -342,15 +344,23 @@ def main():
if args.output_dir and not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
logger.info("Loading CogVLM model. This should only occur once.")
from transformers import AutoModelForCausalLM, LlamaTokenizer
from transformers import AutoModelForCausalLM, LlamaTokenizer, AutoTokenizer

tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
logger.info(f"Loading CogVLM in {args.precision} precision.")
if "cogvlm2" in args.model_path and torch.backends.mps.is_available():
logger.warning(
"Can not run CogVLM 2 on MPS because Triton is unavailable. Falling back to CogVLM 1.1"
)
args.model_path = "THUDM/cogvlm-chat-hf"
elif "cogvlm2" in args.model_path:
import sysconfig

print(sysconfig.get_paths()["include"])
tokenizer = AutoTokenizer.from_pretrained(
args.model_path, trust_remote_code=True
)
else:
tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")

model = AutoModelForCausalLM.from_pretrained(
args.model_path,
torch_dtype=torch_dtype,
Expand Down

0 comments on commit 5dcd649

Please sign in to comment.