From c8f2243c8b7988a0f49b440a338195fc1d4030ec Mon Sep 17 00:00:00 2001 From: "lielin.hyl" Date: Fri, 19 Jan 2024 14:50:29 +0800 Subject: [PATCH 1/4] * support prompt for generate_caption_mapper --- data_juicer/ops/mapper/generate_caption_mapper.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/data_juicer/ops/mapper/generate_caption_mapper.py b/data_juicer/ops/mapper/generate_caption_mapper.py index 2303a2b60..f306f40f6 100644 --- a/data_juicer/ops/mapper/generate_caption_mapper.py +++ b/data_juicer/ops/mapper/generate_caption_mapper.py @@ -38,6 +38,7 @@ def __init__(self, caption_num: PositiveInt = 1, keep_candidate_mode: str = 'random_any', keep_original_sample: bool = True, + prompt: str = None, *args, **kwargs): """ @@ -64,6 +65,8 @@ def __init__(self, it's set to False, there will be only generated captions in the final datasets and the original captions will be removed. It's True in default. + :param prompt: a string prompt to guide the generation of blip2 model. + It's None in default, which means no prompt provided. :param args: extra args :param kwargs: extra args """ @@ -87,6 +90,7 @@ def __init__(self, self.caption_num = caption_num self.keep_candidate_mode = keep_candidate_mode self.keep_original_sample = keep_original_sample + self.prompt = prompt self.extra_args = kwargs if keep_candidate_mode in ['random_any', 'similar_one_simhash']: @@ -154,6 +158,8 @@ def _process_single_sample(self, ori_sample): # the $i$-th generated candidate for the $j$-th image inputs = self.img_processor_in_ctx(images=image_chunk, + text=[self.prompt] * + len(image_chunk), return_tensors='pt') for i in range(self.caption_num): generated_ids = self.model_in_ctx.generate(**inputs, From 24ad8552d0fe47970c9acd88959647927e6dbd88 Mon Sep 17 00:00:00 2001 From: "lielin.hyl" Date: Fri, 19 Jan 2024 15:26:17 +0800 Subject: [PATCH 2/4] * update docs --- configs/config_all.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 4274e4cf5..a6b695a3a 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -55,6 +55,7 @@ process: caption_num: 1 # how many candidate captions to generate for each image keep_candidate_mode: 'random_any' # retain strategy for the generated $caption_num$ candidates. should be in ["random_any", "similar_one_simhash", "all"]. keep_original_sample: true # whether to keep the original sample. If it's set to False, there will be only generated captions in the final datasets and the original captions will be removed. It's True in default. + prompt: null # a string prompt to guide the generation of blip2 model. It's None in default, which means no prompt provided. - image_blur_mapper: # mapper to blur images. p: 0.2 # probability of the image being blured blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian'] From 4e459d6fc0c119d5d7de0bb969cf11f549334f61 Mon Sep 17 00:00:00 2001 From: "lielin.hyl" Date: Fri, 19 Jan 2024 17:06:32 +0800 Subject: [PATCH 3/4] * support prompt for generate_caption_mapper --- configs/config_all.yaml | 3 ++- .../ops/mapper/generate_caption_mapper.py | 25 ++++++++++++++++--- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index a6b695a3a..f04ca74ee 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -55,7 +55,8 @@ process: caption_num: 1 # how many candidate captions to generate for each image keep_candidate_mode: 'random_any' # retain strategy for the generated $caption_num$ candidates. should be in ["random_any", "similar_one_simhash", "all"]. keep_original_sample: true # whether to keep the original sample. If it's set to False, there will be only generated captions in the final datasets and the original captions will be removed. It's True in default. - prompt: null # a string prompt to guide the generation of blip2 model. It's None in default, which means no prompt provided. + prompt: null # a string prompt to guide the generation of blip2 model for all samples globally. It's None in default, which means no prompt provided. + prompt_key: null # the key name of fields in samples to store prompts for each sample. It's used for set different prompts for different samples. If it's none, use prompt in parameter "prompt". It's None in default. - image_blur_mapper: # mapper to blur images. p: 0.2 # probability of the image being blured blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian'] diff --git a/data_juicer/ops/mapper/generate_caption_mapper.py b/data_juicer/ops/mapper/generate_caption_mapper.py index f306f40f6..4abdcdb71 100644 --- a/data_juicer/ops/mapper/generate_caption_mapper.py +++ b/data_juicer/ops/mapper/generate_caption_mapper.py @@ -39,6 +39,7 @@ def __init__(self, keep_candidate_mode: str = 'random_any', keep_original_sample: bool = True, prompt: str = None, + prompt_key: str = None, *args, **kwargs): """ @@ -65,8 +66,13 @@ def __init__(self, it's set to False, there will be only generated captions in the final datasets and the original captions will be removed. It's True in default. - :param prompt: a string prompt to guide the generation of blip2 model. - It's None in default, which means no prompt provided. + :param prompt: a string prompt to guide the generation of blip2 model + for all samples globally. It's None in default, which means no + prompt provided. + :param prompt_key: the key name of fields in samples to store prompts + for each sample. It's used for set different prompts for different + samples. If it's none, use prompt in parameter "prompt". It's None + in default. :param args: extra args :param kwargs: extra args """ @@ -91,6 +97,7 @@ def __init__(self, self.keep_candidate_mode = keep_candidate_mode self.keep_original_sample = keep_original_sample self.prompt = prompt + self.prompt_key = prompt_key self.extra_args = kwargs if keep_candidate_mode in ['random_any', 'similar_one_simhash']: @@ -157,9 +164,19 @@ def _process_single_sample(self, ori_sample): # generated_text_candidates_single_chunk[i][j] indicates # the $i$-th generated candidate for the $j$-th image + # construct prompts + if self.prompt_key \ + and isinstance(ori_sample[self.prompt_key], str): + # check prompt_key is not None, and it's a str in the sample + prompt_texts = [ori_sample[self.prompt_key]] * len(image_chunk) + elif self.prompt and isinstance(self.prompt, str): + # check prompt is not None, and it's a str + prompt_texts = [self.prompt] * len(image_chunk) + else: + prompt_texts = None + inputs = self.img_processor_in_ctx(images=image_chunk, - text=[self.prompt] * - len(image_chunk), + text=prompt_texts, return_tensors='pt') for i in range(self.caption_num): generated_ids = self.model_in_ctx.generate(**inputs, From 37d6b5a9f38561cb21bc91617965d7af1cc9f0db Mon Sep 17 00:00:00 2001 From: "lielin.hyl" Date: Fri, 19 Jan 2024 17:29:12 +0800 Subject: [PATCH 4/4] + Add a warning when both prompt and prompt_key are set --- data_juicer/ops/mapper/generate_caption_mapper.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/data_juicer/ops/mapper/generate_caption_mapper.py b/data_juicer/ops/mapper/generate_caption_mapper.py index 4abdcdb71..056ebe20c 100644 --- a/data_juicer/ops/mapper/generate_caption_mapper.py +++ b/data_juicer/ops/mapper/generate_caption_mapper.py @@ -3,6 +3,7 @@ import numpy as np from jsonargparse.typing import PositiveInt +from loguru import logger from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import HashKeys @@ -107,6 +108,12 @@ def __init__(self, else: self.num_newly_generated_samples = 0 + # report a warning when both prompt and prompt_key are set + if self.prompt and self.prompt_key: + logger.warning( + 'Both the parameter `prompt` and `prompt_key` are ' + 'set. Data-Juicer will consider `prompt_key` first.') + def _process_single_sample(self, ori_sample): """