Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance/support prompt for caption generation #191

Merged
merged 4 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ process:
caption_num: 1 # how many candidate captions to generate for each image
keep_candidate_mode: 'random_any' # retain strategy for the generated $caption_num$ candidates. should be in ["random_any", "similar_one_simhash", "all"].
keep_original_sample: true # whether to keep the original sample. If it's set to False, there will be only generated captions in the final datasets and the original captions will be removed. It's True in default.
prompt: null # a string prompt to guide the generation of blip2 model for all samples globally. It's None in default, which means no prompt provided.
prompt_key: null # the key name of fields in samples to store prompts for each sample. It's used for set different prompts for different samples. If it's none, use prompt in parameter "prompt". It's None in default.
- image_blur_mapper: # mapper to blur images.
p: 0.2 # probability of the image being blured
blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian']
Expand Down
30 changes: 30 additions & 0 deletions data_juicer/ops/mapper/generate_caption_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
from jsonargparse.typing import PositiveInt
from loguru import logger

from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import HashKeys
Expand Down Expand Up @@ -38,6 +39,8 @@ def __init__(self,
caption_num: PositiveInt = 1,
keep_candidate_mode: str = 'random_any',
keep_original_sample: bool = True,
prompt: str = None,
prompt_key: str = None,
*args,
**kwargs):
"""
Expand All @@ -64,6 +67,13 @@ def __init__(self,
it's set to False, there will be only generated captions in the
final datasets and the original captions will be removed. It's True
in default.
:param prompt: a string prompt to guide the generation of blip2 model
for all samples globally. It's None in default, which means no
prompt provided.
:param prompt_key: the key name of fields in samples to store prompts
for each sample. It's used for set different prompts for different
samples. If it's none, use prompt in parameter "prompt". It's None
in default.
:param args: extra args
:param kwargs: extra args
"""
Expand All @@ -87,6 +97,8 @@ def __init__(self,
self.caption_num = caption_num
self.keep_candidate_mode = keep_candidate_mode
self.keep_original_sample = keep_original_sample
self.prompt = prompt
self.prompt_key = prompt_key
self.extra_args = kwargs

if keep_candidate_mode in ['random_any', 'similar_one_simhash']:
Expand All @@ -96,6 +108,12 @@ def __init__(self,
else:
self.num_newly_generated_samples = 0

# report a warning when both prompt and prompt_key are set
if self.prompt and self.prompt_key:
logger.warning(
'Both the parameter `prompt` and `prompt_key` are '
'set. Data-Juicer will consider `prompt_key` first.')

def _process_single_sample(self, ori_sample):
"""

Expand Down Expand Up @@ -153,7 +171,19 @@ def _process_single_sample(self, ori_sample):
# generated_text_candidates_single_chunk[i][j] indicates
# the $i$-th generated candidate for the $j$-th image

# construct prompts
if self.prompt_key \
and isinstance(ori_sample[self.prompt_key], str):
# check prompt_key is not None, and it's a str in the sample
prompt_texts = [ori_sample[self.prompt_key]] * len(image_chunk)
elif self.prompt and isinstance(self.prompt, str):
HYLcool marked this conversation as resolved.
Show resolved Hide resolved
# check prompt is not None, and it's a str
prompt_texts = [self.prompt] * len(image_chunk)
else:
prompt_texts = None

inputs = self.img_processor_in_ctx(images=image_chunk,
text=prompt_texts,
return_tensors='pt')
for i in range(self.caption_num):
generated_ids = self.model_in_ctx.generate(**inputs,
Expand Down