Skip to content

Commit

Permalink
Fix caption_with_cogvlm.py captioning: fix caption_strategy == "text"…
Browse files Browse the repository at this point in the history
… where new_filepath was not defined if not outdir specified
  • Loading branch information
burgalon committed Sep 4, 2024
1 parent 7a7e889 commit 8899037
Showing 1 changed file with 26 additions and 26 deletions.
52 changes: 26 additions & 26 deletions toolkit/captioning/caption_with_cogvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ def load_filter_list(filter_list_path):


def eval_image(
image: Image.Image,
model,
tokenizer,
torch_dtype,
query: str,
image: Image.Image,
model,
tokenizer,
torch_dtype,
query: str,
):
inputs = model.build_conversation_input_ids(
tokenizer, query=query, history=[], images=[image]
Expand Down Expand Up @@ -183,18 +183,18 @@ def content_to_filename(content, filter_terms, disable_filename_cleaning: bool =


def process_directory(
args,
image_dir,
output_dir,
model,
tokenizer,
processed_files,
caption_strategy,
save_interval,
progress_file,
filter_terms,
torch_dtype,
query_str: str,
args,
image_dir,
output_dir,
model,
tokenizer,
processed_files,
caption_strategy,
save_interval,
progress_file,
filter_terms,
torch_dtype,
query_str: str,
):
processed_file_counter = 0
bucket_name = None
Expand Down Expand Up @@ -234,12 +234,12 @@ def process_directory(
)

for filename in tqdm(
os.listdir(image_dir),
desc=f"Processing directory {image_dir}",
unit="images",
leave=True,
position=0,
mininterval=0.5,
os.listdir(image_dir),
desc=f"Processing directory {image_dir}",
unit="images",
leave=True,
position=0,
mininterval=0.5,
):
full_filepath = os.path.join(image_dir, filename)
if os.path.isdir(full_filepath):
Expand Down Expand Up @@ -353,10 +353,10 @@ def main():
)
elif "cogvlm2" in args.model_path:
import sysconfig
print(sysconfig.get_paths()['include'])

print(sysconfig.get_paths()["include"])
tokenizer = AutoTokenizer.from_pretrained(
args.model_path,
trust_remote_code=True
args.model_path, trust_remote_code=True
)
else:
tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
Expand Down

0 comments on commit 8899037

Please sign in to comment.