Skip to content

Commit

Permalink
Merge pull request #54 from ai-forever/dev
Browse files Browse the repository at this point in the history
release 1.0.0
  • Loading branch information
boomb0om authored May 29, 2024
2 parents 8d8406d + 441f49c commit 0f1d015
Show file tree
Hide file tree
Showing 24 changed files with 1,293 additions and 438 deletions.
17 changes: 16 additions & 1 deletion DPF/filters/images/aesthetic_improved_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,22 @@ def get_improved_aesthetic_model(cache_folder: str) -> MLP:

class ImprovedAestheticFilter(ImageFilter):
"""
ImprovedAestheticFilter class
Filter for improved aesthetic score calculating with LAION model. This repository is used:
https://github.com/christophschuhmann/improved-aesthetic-predictor
Parameters
----------
weights_folder: str
Path to the folder where the weights are located.
If there are no weights, they will be downloaded automatically
device: str = "cuda:0"
Device to use
workers: int = 16
Number of processes to use for reading data and calculating flow scores
batch_size: int = 64
Batch size for model
pbar: bool = True
Whether to use a progress bar
"""

def __init__(
Expand Down
11 changes: 10 additions & 1 deletion DPF/filters/images/hash_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,16 @@ def get_phash(pil_img: Image.Image, hash_size: int = 8, highfreq_factor: int = 4

class PHashFilter(ImageFilter):
"""
PHashFilter class
Filter for calculating PHash (perceptual hash) for images
Parameters
----------
sim_hash_size: int = 8
Hash size for PHash
workers: int = 16
Number of processes to use for reading data and calculating flow scores
pbar: bool = True
Whether to use a progress bar
"""

def __init__(
Expand Down
9 changes: 8 additions & 1 deletion DPF/filters/images/info_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,14 @@ def get_image_info(img_bytes: bytes, data: dict[str, Any], key_column: str) -> I

class ImageInfoFilter(ImageFilter):
"""
ImageInfoFilter class
Filter for gathering basic info about images (width, height, number of channels)
Parameters
----------
workers: int = 16
Number of parallel dataloader workers
pbar: bool = True
Whether to show progress bar
"""

def __init__(self, workers: int = 16, pbar: bool = True, _pbar_position: int = 0):
Expand Down
10 changes: 4 additions & 6 deletions DPF/filters/images/nsfw_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@


def load_safety_model(clip_model: str, cache_folder: str, device: Union[str, torch.device]) -> Any:
"""load the safety model"""

gpus = tf.config.list_physical_devices("GPU")
if gpus:
try:
Expand All @@ -38,7 +36,7 @@ def load_safety_model(clip_model: str, cache_folder: str, device: Union[str, tor
pass

if clip_model == "ViT-L/14":
model_dir = cache_folder + "/clip_autokeras_binary_nsfw"
model_dir = os.path.join(cache_folder, "clip_autokeras_binary_nsfw")
url_model = (
"https://raw.githubusercontent.com/LAION-AI/"
"CLIP-based-NSFW-Detector/main/clip_autokeras_binary_nsfw.zip"
Expand All @@ -48,12 +46,13 @@ def load_safety_model(clip_model: str, cache_folder: str, device: Union[str, tor

if not os.path.exists(model_dir):
os.makedirs(cache_folder, exist_ok=True)
path_to_zip_file = cache_folder + "/clip_autokeras_binary_nsfw.zip"
path_to_zip_file = os.path.join(cache_folder, "clip_autokeras_binary_nsfw.zip")
urlretrieve(url_model, path_to_zip_file)
with zipfile.ZipFile(path_to_zip_file, "r") as zip_ref:
zip_ref.extractall(cache_folder)

with tf.device(device):
print(model_dir)
loaded_model = load_model(model_dir, custom_objects=ak.CUSTOM_OBJECTS)

return loaded_model
Expand All @@ -72,7 +71,6 @@ class NSFWFilter(ImageFilter):

def __init__(
self,
clip_model: str,
weights_folder: str,
workers: int = 16,
batch_size: int = 64,
Expand All @@ -81,7 +79,7 @@ def __init__(
_pbar_position: int = 0
):
super().__init__(pbar, _pbar_position)

clip_model = "ViT-L/14"
self.num_workers = workers
self.batch_size = batch_size
self.device = device
Expand Down
3 changes: 2 additions & 1 deletion DPF/filters/images/ocr_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
self,
weights_path: str,
model_name: Optional[str] = None,
text_box_col: str = "text_boxes",
device: str = "cuda:0",
workers: int = 16,
pad: int = 5,
Expand Down Expand Up @@ -73,7 +74,7 @@ def __init__(

self.AlignCollate = AlignCollate(imgH=self.opt.imgH, imgW=self.opt.imgW, keep_ratio_with_pad=self.opt.PAD)
#
self.text_box_col = "text_boxes"
self.text_box_col = text_box_col
self.ocr_col = f"OCR_{self.model_name}"

@property
Expand Down
24 changes: 0 additions & 24 deletions DPF/filters/images/ocr_model/model.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,3 @@
"""
Copyright (c) 2019-present NAVER Corp.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import torch.nn as nn

from .modules.feature_extraction import (
Expand All @@ -34,14 +18,12 @@ def __init__(self, opt):
self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction,
'Seq': opt.SequenceModeling, 'Pred': opt.Prediction}

""" Transformation """
if opt.Transformation == 'TPS':
self.Transformation = TPS_SpatialTransformerNetwork(
F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel)
else:
print('No Transformation module specified')

""" FeatureExtraction """
if opt.FeatureExtraction == 'VGG':
self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel)
elif opt.FeatureExtraction == 'RCNN':
Expand All @@ -53,7 +35,6 @@ def __init__(self, opt):
self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512
self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1

""" Sequence modeling"""
if opt.SequenceModeling == 'BiLSTM':
self.SequenceModeling = nn.Sequential(
BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size),
Expand All @@ -63,7 +44,6 @@ def __init__(self, opt):
print('No SequenceModeling module specified')
self.SequenceModeling_output = self.FeatureExtraction_output

""" Prediction """
if opt.Prediction == 'CTC':
self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class)
elif opt.Prediction == 'Attn':
Expand All @@ -72,22 +52,18 @@ def __init__(self, opt):
raise Exception('Prediction is neither CTC or Attn')

def forward(self, input, text, is_train=True):
""" Transformation stage """
if not self.stages['Trans'] == "None":
input = self.Transformation(input)

""" Feature extraction stage """
visual_feature = self.FeatureExtraction(input)
visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h]
visual_feature = visual_feature.squeeze(3)

""" Sequence modeling stage """
if self.stages['Seq'] == 'BiLSTM':
contextual_feature = self.SequenceModeling(visual_feature)
else:
contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM

""" Prediction stage """
if self.stages['Pred'] == 'CTC':
prediction = self.Prediction(contextual_feature.contiguous())
else:
Expand Down
21 changes: 20 additions & 1 deletion DPF/filters/texts/google_translate_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,26 @@ def translate_batch(translator: BaseTranslator, batch: list[str], delimiter: str

class GoogleTranslateFilter(ColumnFilter):
"""
GoogleTranslateFilter class
Filter for translating texts with google translate api
Parameters
----------
text_column_name: str = "text"
Name of column with texts
source_lang: str = "auto"
Source language to translate from
target_lang: str = "en"
Language to translate to
max_symbols_in_batch: int = 3000
Maximum symbols in one request to API.
timeout: float = 1
Timeout between requests
timeout_on_error: float = 3
Timeout between requests if error occured
num_retries_per_batch: int = 1
Number of retries of errors occured
pbar: bool = True
Whether to use a progress bar
"""

def __init__(
Expand Down
11 changes: 10 additions & 1 deletion DPF/filters/texts/lang_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,16 @@

class LangFilter(ColumnFilter):
"""
LangFilter class
Filter for text language detection
Parameters
----------
text_column_name: str = "text"
Name of column with texts
workers: int = 16
Number of processes to use
pbar: bool = True
Whether to use a progress bar
"""

def __init__(
Expand Down
Loading

0 comments on commit 0f1d015

Please sign in to comment.