Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/watermark filter #256

Merged
merged 7 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ process:
vertical_flip: false # flip image vertically (top to bottom).
reduce_mode: avg # reduce mode when one text corresponds to multiple images in a chunk, must be one of ['avg','max', 'min'].
any_or_all: any # keep this sample when any/all images meet the filter condition
- image_watermark_filter: # filter samples according to the predicted watermark probabilities of images in them
hf_watermark_model: amrul-hzz/watermark_detector # Huggingface model name for watermark classification
prob_threshold: 0.8 # the predicted watermark probability threshold for samples, range from 0 to 1
any_or_all: any # keep this sample when any/all images meet the filter condition
- language_id_score_filter: # filter text in specific language with language scores larger than a specific max value
lang: en # keep text in what language
min_score: 0.8 # the min language scores to filter text
Expand Down Expand Up @@ -370,6 +374,13 @@ process:
min_height: 480 # the min resolution of vertical resolution filter range (unit p)
max_height: 1080 # the max resolution of vertical resolution filter range (unit p)
any_or_all: any # keep this sample when any/all videos meet the filter condition
- video_watermark_filter: # filter samples according to the predicted watermark probabilities of videos in them
hf_watermark_model: amrul-hzz/watermark_detector # Huggingface model name for watermark classification
prob_threshold: 0.8 # the predicted watermark probability threshold for samples, range from 0 to 1
frame_sampling_method: all_keyframes # sampling method of extracting frame images from the videos. Should be one of ["all_keyframes", "uniform"]. The former one extracts all key frames and the latter one extract specified number of frames uniformly from the video. Default: "all_keyframes".
frame_num: 3 # the number of frames to be extracted uniformly from the video. Only works when frame_sampling_method is "uniform". If it's 1, only the middle frame will be extracted. If it's 2, only the first and the last frames will be extracted. If it's larger than 2, in addition to the first and the last frames, other frames will be extracted uniformly within the video duration.
reduce_mode: avg # reduce mode for multiple sampled video frames to compute final predicted watermark probabilities of videos, must be one of ['avg','max', 'min'].
any_or_all: any # keep this sample when any/all images meet the filter condition
- words_num_filter: # filter text with number of words out of specific range
lang: en # sample in which language
tokenization: false # whether to use model to tokenize documents
Expand Down
17 changes: 9 additions & 8 deletions data_juicer/ops/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
image_aspect_ratio_filter, image_face_ratio_filter,
image_nsfw_filter, image_shape_filter, image_size_filter,
image_text_matching_filter, image_text_similarity_filter,
language_id_score_filter, maximum_line_length_filter,
perplexity_filter, phrase_grounding_recall_filter,
special_characters_filter, specified_field_filter,
specified_numeric_field_filter, stopwords_filter, suffix_filter,
text_action_filter, text_entity_dependency_filter,
text_length_filter, token_num_filter, video_aesthetics_filter,
image_watermark_filter, language_id_score_filter,
maximum_line_length_filter, perplexity_filter,
phrase_grounding_recall_filter, special_characters_filter,
specified_field_filter, specified_numeric_field_filter,
stopwords_filter, suffix_filter, text_action_filter,
text_entity_dependency_filter, text_length_filter,
token_num_filter, video_aesthetics_filter,
video_aspect_ratio_filter, video_duration_filter,
video_frames_text_similarity_filter, video_motion_score_filter,
video_nsfw_filter, video_ocr_area_ratio_filter,
video_resolution_filter, word_num_filter,
word_repetition_filter)
video_resolution_filter, video_watermark_filter,
word_num_filter, word_repetition_filter)

# yapf: enable
2 changes: 1 addition & 1 deletion data_juicer/ops/filter/image_nsfw_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self,
"""
Initialization method.

:param hf_nsfw_model: nsfw detection model name on huggingfacet.
:param hf_nsfw_model: nsfw detection model name on huggingface.
:param score_threshold: the nsfw score threshold for samples.
range from 0 to 1.
:param any_or_all: keep this sample with 'any' or 'all' strategy of
Expand Down
101 changes: 101 additions & 0 deletions data_juicer/ops/filter/image_watermark_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import numpy as np
from jsonargparse.typing import ClosedUnitInterval

from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.mm_utils import load_data_with_context, load_image
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Filter
from ..op_fusion import LOADED_IMAGES

OP_NAME = 'image_watermark_filter'

with AvailabilityChecking(['torch', 'transformers'], OP_NAME):
import torch
import transformers # noqa: F401

# avoid hanging when calling watermark detection in multiprocessing
torch.set_num_threads(1)


@OPERATORS.register_module(OP_NAME)
@LOADED_IMAGES.register_module(OP_NAME)
class ImageWatermarkFilter(Filter):
"""
Filter to keep samples whose images have no watermark with high
probability.
"""

def __init__(self,
hf_watermark_model='amrul-hzz/watermark_detector',
prob_threshold: ClosedUnitInterval = 0.8,
any_or_all: str = 'any',
*args,
**kwargs):
"""
Initialization method.

:param hf_watermark_model: watermark detection model name on
huggingface.
:param prob_threshold: the predicted watermark probability threshold
for samples. range from 0 to 1.
:param any_or_all: keep this sample with 'any' or 'all' strategy of
all images. 'any': keep this sample if any images meet the
condition. 'all': keep this sample only if all images meet the
condition.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.prob_threshold = prob_threshold
if any_or_all not in ['any', 'all']:
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')
self.model_key = prepare_model(
model_type='huggingface',
pretrained_model_name_or_path=hf_watermark_model)
self._accelerator = 'cuda'

def compute_stats(self, sample, rank=None, context=False):
# check if it's computed already
if StatsKeys.image_watermark_prob in sample[Fields.stats]:
return sample

# there is no image in this sample
if self.image_key not in sample or not sample[self.image_key]:
sample[Fields.stats][StatsKeys.image_watermark_prob] = np.array(
[], dtype=np.float64)
return sample

# load images
loaded_image_keys = sample[self.image_key]
sample, images = load_data_with_context(sample, context,
loaded_image_keys, load_image)

model, processor = get_model(self.model_key, rank=rank)

images = [images[key] for key in images]
inputs = processor(images=images, return_tensors='pt').to(model.device)
outputs = model(**inputs)
logits = outputs.logits
watermark_probs = [probs[1] for probs in torch.softmax(logits, dim=-1)]

sample[Fields.stats][StatsKeys.image_watermark_prob] = watermark_probs

return sample

def process(self, sample, rank=None):
itm_probs = sample[Fields.stats][StatsKeys.image_watermark_prob]
if len(itm_probs) <= 0:
return True

keep_bools = np.array(
[itm_prob < self.prob_threshold for itm_prob in itm_probs])

# different strategies
if self.any:
return keep_bools.any()
else:
return keep_bools.all()
159 changes: 159 additions & 0 deletions data_juicer/ops/filter/video_watermark_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import numpy as np
from jsonargparse.typing import ClosedUnitInterval, PositiveInt

from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.mm_utils import (extract_key_frames,
extract_video_frames_uniformly,
load_data_with_context, load_video)
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Filter
from ..op_fusion import LOADED_VIDEOS

OP_NAME = 'video_watermark_filter'

with AvailabilityChecking(['torch', 'transformers'], OP_NAME):

import torch
import transformers # noqa: F401

# avoid hanging when calling watermark detection in multiprocessing
torch.set_num_threads(1)


@OPERATORS.register_module(OP_NAME)
@LOADED_VIDEOS.register_module(OP_NAME)
class VideoWatermarkFilter(Filter):
"""
Filter to keep samples whose videos have no watermark with high
probability.
"""

def __init__(self,
hf_watermark_model='amrul-hzz/watermark_detector',
prob_threshold: ClosedUnitInterval = 0.8,
frame_sampling_method: str = 'all_keyframes',
frame_num: PositiveInt = 3,
reduce_mode: str = 'avg',
any_or_all: str = 'any',
*args,
**kwargs):
"""
Initialization method.

:param hf_watermark_model: watermark detection model name on
huggingface.
:param prob_threshold: the predicted watermark probability threshold
for samples. range from 0 to 1.
:param frame_sampling_method: sampling method of extracting frame
images from the videos.
Should be one of ["all_keyframes", "uniform"].
The former one extracts all key frames (the number of which depends
on the duration of the video) and the latter one extract specified
number of frames uniformly from the video.
Default: "all_keyframes".
:param frame_num: the number of frames to be extracted uniformly from
the video. Only works when frame_sampling_method is "uniform". If
it's 1, only the middle frame will be extracted. If it's 2, only
the first and the last frames will be extracted. If it's larger
than 2, in addition to the first and the last frames, other frames
will be extracted uniformly within the video duration.
:param reduce_mode: reduce mode for multiple sampled video frames.
'avg': Take the average of multiple values
'max': Take the max of multiple values
'min': Take the min of multiple values
:param any_or_all: keep this sample with 'any' or 'all' strategy of
all videos. 'any': keep this sample if any videos meet the
condition. 'all': keep this sample only if all videos meet the
condition.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.prob_threshold = prob_threshold
if frame_sampling_method not in ['all_keyframes', 'uniform']:
raise ValueError(
f'Frame sampling method '
f'[{frame_sampling_method}] is not supported. '
f'Can only be one of ["all_keyframes", "uniform"].')
if reduce_mode not in ['avg', 'max', 'min']:
raise ValueError(f'Reduce mode [{reduce_mode}] is not supported. '
f'Can only be one of ["avg", "max", "min"].')
if any_or_all not in ['any', 'all']:
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')
self.model_key = prepare_model(
model_type='huggingface',
pretrained_model_name_or_path=hf_watermark_model)
self._accelerator = 'cuda'
self.reduce_mode = reduce_mode
self.frame_sampling_method = frame_sampling_method
self.frame_num = frame_num

def compute_stats(self, sample, rank=None, context=False):
# check if it's computed already
if StatsKeys.video_watermark_prob in sample[Fields.stats]:
return sample

# there is no videos in this sample
if self.video_key not in sample or not sample[self.video_key]:
sample[Fields.stats][StatsKeys.video_watermark_prob] = np.array(
[], dtype=np.float64)
return sample

# load videos
loaded_video_keys = sample[self.video_key]
sample, videos = load_data_with_context(sample, context,
loaded_video_keys, load_video)

watermark_probs = []
model, processor = get_model(self.model_key, rank=rank)

for video_key, video in videos.items():

# extract frame images
if self.frame_sampling_method == 'all_keyframes':
frames = extract_key_frames(video)
elif self.frame_sampling_method == 'uniform':
frames = extract_video_frames_uniformly(video, self.frame_num)
else:
frames = []

frame_images = [frame.to_image() for frame in frames]
inputs = processor(images=frame_images, return_tensors='pt')
inputs = inputs.to(model.device)
outputs = model(**inputs)
logits = outputs.logits
cur_probs = [probs[1] for probs in torch.softmax(logits, dim=-1)]
cur_probs = torch.Tensor(cur_probs)

if self.reduce_mode == 'avg':
watermark_probs.append(cur_probs.mean())
elif self.reduce_mode == 'max':
watermark_probs.append(cur_probs.max())
else:
watermark_probs.append(cur_probs.min())

sample[Fields.stats][StatsKeys.video_watermark_prob] = watermark_probs

if not context:
for vid_key in videos:
videos[vid_key].close()

return sample

def process(self, sample, rank=None):
itm_probs = sample[Fields.stats][StatsKeys.video_watermark_prob]
if len(itm_probs) <= 0:
return True

keep_bools = np.array(
[itm_prob < self.prob_threshold for itm_prob in itm_probs])

# different strategies
if self.any:
return keep_bools.any()
else:
return keep_bools.all()
2 changes: 2 additions & 0 deletions data_juicer/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class StatsKeysConstant(object):
face_detections = 'face_detections'
image_aesthetics_scores = 'image_aesthetics_scores'
image_nsfw_score = 'image_nsfw_score'
image_watermark_prob = 'image_watermark_prob'

# audios
audio_duration = 'audio_duration'
Expand All @@ -150,6 +151,7 @@ class StatsKeysConstant(object):
video_frames_aesthetics_score = 'video_frames_aesthetics_score'
video_motion_score = 'video_motion_score'
video_nsfw_score = 'video_nsfw_score'
video_watermark_prob = 'video_watermark_prob'

# multimodal
# image-text
Expand Down
Loading
Loading