diff --git a/README.md b/README.md index 7c2288b25..9731845e6 100644 --- a/README.md +++ b/README.md @@ -1,39 +1,37 @@ -English | [**中文**](README_ZH.md) +[[中文主页]](README_ZH.md) | [[Docs]](README.md#documentation-index--文档索引-a-namedocumentationindex) | [[API]](https://alibaba.github.io/data-juicer) | [[*DJ-SORA*]](docs/DJ_SORA.md) # Data-Juicer: A One-Stop Data Processing System for Large Language Models -![Data-Juicer](https://img.alicdn.com/imgextra/i3/O1CN017Eq5kf27AlA2NUKef_!!6000000007757-0-tps-1280-720.jpg "Data-Juicer") + Data-Juicer -[![Paper](http://img.shields.io/badge/cs.LG-arXiv%3A2309.02033-B31B1B?logo=arxiv&logoColor=red)](https://arxiv.org/abs/2309.02033) ![](https://img.shields.io/badge/language-Python-214870.svg) ![](https://img.shields.io/badge/license-Apache--2.0-000000.svg) -[![Contributing](https://img.shields.io/badge/Contribution-welcome-brightgreen.svg)](docs/DeveloperGuide.md) - [![pypi version](https://img.shields.io/pypi/v/py-data-juicer?logo=pypi&color=026cad)](https://pypi.org/project/py-data-juicer) [![Docker version](https://img.shields.io/docker/v/datajuicer/data-juicer?logo=docker&label=Docker&color=498bdf)](https://hub.docker.com/r/datajuicer/data-juicer) -[![Document_List](https://img.shields.io/badge/Docs-English-blue?logo=Markdown)](README.md#documentation) -[![文档列表](https://img.shields.io/badge/文档-中文-blue?logo=Markdown)](README_ZH.md#documentation) + +[![DataModality](https://img.shields.io/badge/DataModality-Text,Image,Audio,Video-brightgreen.svg)](docs/DeveloperGuide_ZH.md) +[![Usage](https://img.shields.io/badge/Usage-Cleaning,Generation,Analysis-FFD21E.svg)](docs/DeveloperGuide_ZH.md) +[![ModelScope- Demos](https://img.shields.io/badge/ModelScope-Demos-4e29ff.svg?logo=)](https://modelscope.cn/studios?name=Data-Jiucer&page=1&sort=latest&type=1) +[![HuggingFace- Demos](https://img.shields.io/badge/🤗HuggingFace-Demos-4e29ff.svg)](https://huggingface.co/spaces?&search=datajuicer) + + + +[![Document_List](https://img.shields.io/badge/Docs-English-blue?logo=Markdown)](README.md#documentation-index--文档索引-a-namedocumentationindex) +[![文档列表](https://img.shields.io/badge/文档-中文-blue?logo=Markdown)](README_ZH.md#documentation-index--文档索引-a-namedocumentationindex) [![API Reference](https://img.shields.io/badge/Docs-API_Reference-blue?logo=Markdown)](https://alibaba.github.io/data-juicer/) +[![Paper](http://img.shields.io/badge/cs.LG-arXiv%3A2309.02033-B31B1B?logo=arxiv&logoColor=red)](https://arxiv.org/abs/2309.02033) -[![ModelScope-10+ Demos](https://img.shields.io/badge/ModelScope-10+_Demos-4e29ff.svg?logo=)](https://modelscope.cn/studios?name=Data-Jiucer&page=1&sort=latest&type=1) -[![ModelScope-20+_Refined_Datasets](https://img.shields.io/badge/ModelScope-20+_Refined_Datasets-4e29ff.svg?logo=)](https://modelscope.cn/datasets?organization=Data-Juicer&page=1) -[![ModelScope-Reference_Models](https://img.shields.io/badge/ModelScope-Reference_Models-4e29ff.svg?logo=)](https://modelscope.cn/models?organization=Data-Juicer&page=1) -[![HuggingFace-10+ Demos](https://img.shields.io/badge/🤗HuggingFace-10+_Demos-FFD21E.svg)](https://huggingface.co/spaces?&search=datajuicer) -[![HuggingFace-20+_Refined_Datasets](https://img.shields.io/badge/🤗HuggingFace-20+_Refined_Datasets-FFD21E.svg)](https://huggingface.co/datasets?&search=datajuicer) -[![HuggingFace-Reference_Models](https://img.shields.io/badge/🤗HuggingFace-Reference_Models-FFD21E.svg)](https://huggingface.co/models?&search=datajuicer) -[![QualityClassifier](https://img.shields.io/badge/Tools-Quality_Classifier-saddlebrown?logo=Markdown)](tools/quality_classifier/README.md) -[![AutoEvaluation](https://img.shields.io/badge/Tools-Auto_Evaluation-saddlebrown?logo=Markdown)](tools/evaluator/README.md) -Data-Juicer is a one-stop data processing system to make data higher-quality, +Data-Juicer is a one-stop **multimodal** data processing system to make data higher-quality, juicier, and more digestible for LLMs. -This project is being actively updated and maintained, and we will periodically enhance and add more features and data recipes. We welcome you to join us in promoting LLM data development and research! -If you find Data-Juicer useful for your research or development, please kindly -cite our [work](#references). +Data-Juicer (including [DJ-SORA](docs/DJ_SORA.md)) is being actively updated and maintained. We will periodically enhance and add more features, data recipes and datasets. +We welcome you to join us in promoting LLM data development and research! -Welcome to join our [Slack channel](https://join.slack.com/t/data-juicer/shared_invite/zt-23zxltg9d-Z4d3EJuhZbCLGwtnLWWUDg?spm=a2c22.12281976.0.0.7a8253f30mgpjw), [DingDing group](https://qr.dingtalk.com/action/joingroup?spm=a2c22.12281976.0.0.7a8253f30mgpjw&code=v1,k1,C0DI7CwRFrg7gJP5aMC95FUmsNuwuKJboT62BqP5DAk=&_dt_no_comment=1&origin=11), or WeChat group (scan the QR code below with WeChat) for discussion. +If you find Data-Juicer useful for your research or development, please kindly +cite our [work](#references). Welcome to join our [Slack channel](https://join.slack.com/t/data-juicer/shared_invite/zt-23zxltg9d-Z4d3EJuhZbCLGwtnLWWUDg?spm=a2c22.12281976.0.0.7a8253f30mgpjw), [DingDing group](https://qr.dingtalk.com/action/joingroup?spm=a2c22.12281976.0.0.7a8253f30mgpjw&code=v1,k1,C0DI7CwRFrg7gJP5aMC95FUmsNuwuKJboT62BqP5DAk=&_dt_no_comment=1&origin=11), or WeChat group (scan the QR code below with WeChat) for discussion. QR Code for WeChat group @@ -41,7 +39,9 @@ Welcome to join our [Slack channel](https://join.slack.com/t/data-juicer/shared_ ---- ## News -- ![new](https://img.alicdn.com/imgextra/i4/O1CN01kUiDtl1HVxN6G56vN_!!6000000000764-2-tps-43-19.png) [2024-02-20] We have actively maintained an awesome list of LLM-Data, welcome to [visit](docs/awesome_llm_data.md) and contribute! +- ![new](https://img.alicdn.com/imgextra/i4/O1CN01kUiDtl1HVxN6G56vN_!!6000000000764-2-tps-43-19.png) [2024-03-07] We release **Data-Juicer [v0.2.0](https://github.com/alibaba/data-juicer/releases/tag/v0.2.0)** now! +In this new version, we support more features for **multimodal data (including video now)**, and introduce **[DJ-SORA](docs/DJ_SORA.md)** to provide open large-scale, high-quality datasets for SORA-like models. +- ![new](https://img.alicdn.com/imgextra/i4/O1CN01kUiDtl1HVxN6G56vN_!!6000000000764-2-tps-43-19.png) [2024-02-20] We have actively maintained an *awesome list of LLM-Data*, welcome to [visit](docs/awesome_llm_data.md) and contribute! - ![new](https://img.alicdn.com/imgextra/i4/O1CN01kUiDtl1HVxN6G56vN_!!6000000000764-2-tps-43-19.png) [2024-02-05] Our paper has been accepted by SIGMOD'24 industrial track! - [2024-01-10] Discover new horizons in "Data Mixture"—Our second data-centric LLM competition has kicked off! Please visit the competition's [official website](https://tianchi.aliyun.com/competition/entrance/532174) for more information. - [2024-01-05] We release **Data-Juicer v0.1.3** now! @@ -56,9 +56,11 @@ Besides, our paper is also updated to [v3](https://arxiv.org/abs/2309.02033). Table of Contents ================= -* [Data-Juicer: A One-Stop Data Processing System for Large Language Models](#data-juicer--a-one-stop-data-processing-system-for-large-language-models) +* [Data-Juicer: A One-Stop Data Processing System for Large Language Models](#data-juicer-a-one-stop-data-processing-system-for-large-language-models) * [Table of Contents](#table-of-contents) * [Features](#features) + * [Documentation Index | 文档索引](#documentation-index--文档索引-a-namedocumentationindex) + * [Demos](#demos) * [Prerequisites](#prerequisites) * [Installation](#installation) * [From Source](#from-source) @@ -67,25 +69,25 @@ Table of Contents * [Installation check](#installation-check) * [Quick Start](#quick-start) * [Data Processing](#data-processing) + * [Distributed Data Processing](#distributed-data-processing) * [Data Analysis](#data-analysis) * [Data Visualization](#data-visualization) * [Build Up Config Files](#build-up-config-files) * [Preprocess raw data (Optional)](#preprocess-raw-data-optional) * [For Docker Users](#for-docker-users) - * [Documentation | 文档](#documentation) * [Data Recipes](#data-recipes) - * [Demos](#demos) * [License](#license) * [Contributing](#contributing) * [Acknowledgement](#acknowledgement) * [References](#references) + ## Features ![Overview](https://img.alicdn.com/imgextra/i2/O1CN01IMPeD11xYRUYLmXKO_!!6000000006455-2-tps-3620-1604.png) - **Systematic & Reusable**: - Empowering users with a systematic library of 20+ reusable [config recipes](configs), 50+ core [OPs](docs/Operators.md), and feature-rich + Empowering users with a systematic library of 80+ core [OPs](docs/Operators.md), 20+ reusable [config recipes](configs), and 20+ feature-rich dedicated [toolkits](#documentation), designed to function independently of specific LLM datasets and processing pipelines. @@ -95,7 +97,7 @@ Table of Contents - **Comprehensive Data Processing Recipes**: Offering tens of [pre-built data processing recipes](configs/data_juicer_recipes/README.md) for pre-training, fine-tuning, en, zh, and more scenarios. Validated on - reference LLaMA models. + reference LLaMA and LLaVA models. ![exp_llama](https://img.alicdn.com/imgextra/i2/O1CN019WtUPP1uhebnDlPR8_!!6000000006069-2-tps-2530-1005.png) - **Enhanced Efficiency**: Providing a speedy data processing pipeline @@ -107,9 +109,47 @@ Table of Contents - **User-Friendly Experience**: Designed for simplicity, with [comprehensive documentation](#documentation), [easy start guides](#quick-start) and [demo configs](configs/README.md), and intuitive configuration with simple adding/removing OPs from [existing configs](configs/config_all.yaml). + + +## Documentation Index | 文档索引 + +- [Overview](README.md) | [概览](README_ZH.md) +- [Operator Zoo](docs/Operators.md) | [算子库](docs/Operators_ZH.md) +- [Configs](configs/README.md) | [配置系统](configs/README_ZH.md) +- [Developer Guide](docs/DeveloperGuide.md) | [开发者指南](docs/DeveloperGuide_ZH.md) +- ["Bad" Data Exhibition](docs/BadDataExhibition.md) | [“坏”数据展览](docs/BadDataExhibition_ZH.md) +- Dedicated Toolkits | 专用工具箱 + - [Quality Classifier](tools/quality_classifier/README.md) | [质量分类器](tools/quality_classifier/README_ZH.md) + - [Auto Evaluation](tools/evaluator/README.md) | [自动评测](tools/evaluator/README_ZH.md) + - [Preprocess](tools/preprocess/README.md) | [前处理](tools/preprocess/README_ZH.md) + - [Postprocess](tools/postprocess/README.md) | [后处理](tools/postprocess/README_ZH.md) +- [Third-parties (LLM Ecosystems)](thirdparty/README.md) | [第三方库(大语言模型生态)](thirdparty/README_ZH.md) +- [API references](https://alibaba.github.io/data-juicer/) +- [Awesome LLM-Data](docs/awesome_llm_data.md) +- [DJ-SORA](docs/DJ_SORA.md) + + +## Demos +- Introduction to Data-Juicer [[ModelScope](https://modelscope.cn/studios/Data-Juicer/overview_scan/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/overview_scan)] +- Data Visualization: + - Basic Statistics [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_visulization_statistics/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/data_visualization_statistics)] + - Lexical Diversity [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_visulization_diversity/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/data_visualization_diversity)] + - Operator Insight (Single OP) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_visualization_op_insight/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/data_visualization_op_insight)] + - Operator Effect (Multiple OPs) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_visulization_op_effect/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/data_visualization_op_effect)] +- Data Processing: + - Scientific Literature (e.g. [arXiv](https://info.arxiv.org/help/bulk_data_s3.html)) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/process_sci_data/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/process_sci_data)] + - Programming Code (e.g. [TheStack](https://huggingface.co/datasets/bigcode/the-stack)) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/process_code_data/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/process_code_data)] + - Chinese Instruction Data (e.g. [Alpaca-CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/process_sft_zh_data/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/process_cft_zh_data)] +- Tool Pool: + - Dataset Splitting by Language [[ModelScope](https://modelscope.cn/studios/Data-Juicer/tool_dataset_splitting_by_language/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/tool_dataset_splitting_by_language)] + - Quality Classifier for CommonCrawl [[ModelScope](https://modelscope.cn/studios/Data-Juicer/tool_quality_classifier/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/tool_quality_classifier)] + - Auto Evaluation on [HELM](https://github.com/stanford-crfm/helm) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/auto_evaluation_helm/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/auto_evaluation_helm)] + - Data Sampling and Mixture [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_mixture/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/data_mixture)] +- Data Processing Loop [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_process_loop/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/data_process_loop)] + ## Prerequisites -- Recommend Python>=3.7,<=3.10 +- Recommend Python>=3.8,<=3.10 - gcc >= 5 (at least C++14 support) ## Installation @@ -207,6 +247,23 @@ export DATA_JUICER_MODELS_CACHE="/path/to/another/directory/models" export DATA_JUICER_ASSETS_CACHE="/path/to/another/directory/assets" ``` +### Distributed Data Processing + +We have now implemented multi-machine distributed data processing based on RAY. The corresponding demos can be run using the following commands: + +```shell +# Run text data processing +python tools/process_data.py --config ./demos/process_on_ray/configs/demo.yaml +# Run video data processing +python tools/process_data.py --config ./demos/process_video_on_ray/configs/demo.yaml +``` + +- To run multimodal data processing across multiple machines, it is necessary to ensure that all distributed nodes can access the corresponding data paths (for example, by mounting the respective data paths on a file-sharing system such as NAS). + +- Users can also opt not to use RAY and instead split the dataset to run on a cluster with Slurm/DLC. + + + ### Data Analysis - Run `analyze_data.py` tool or `dj-analyze` command line tool with your config as the argument to analyse your dataset. @@ -295,45 +352,14 @@ docker run -dit \ # run the container in the background docker exec -it bash ``` -## Documentation | 文档 - -- [Overview](README.md) | [概览](README_ZH.md) -- [Operator Zoo](docs/Operators.md) | [算子库](docs/Operators_ZH.md) -- [Configs](configs/README.md) | [配置系统](configs/README_ZH.md) -- [Developer Guide](docs/DeveloperGuide.md) | [开发者指南](docs/DeveloperGuide_ZH.md) -- ["Bad" Data Exhibition](docs/BadDataExhibition.md) | [“坏”数据展览](docs/BadDataExhibition_ZH.md) -- Dedicated Toolkits | 专用工具箱 - - [Quality Classifier](tools/quality_classifier/README.md) | [质量分类器](tools/quality_classifier/README_ZH.md) - - [Auto Evaluation](tools/evaluator/README.md) | [自动评测](tools/evaluator/README_ZH.md) - - [Preprocess](tools/preprocess/README.md) | [前处理](tools/preprocess/README_ZH.md) - - [Postprocess](tools/postprocess/README.md) | [后处理](tools/postprocess/README_ZH.md) -- [Third-parties (LLM Ecosystems)](thirdparty/README.md) | [第三方库(大语言模型生态)](thirdparty/README_ZH.md) -- [API references](https://alibaba.github.io/data-juicer/) -- [Awesome LLM-Data](docs/awesome_llm_data.md) - ## Data Recipes - [Recipes for data process in BLOOM](configs/reproduced_bloom/README.md) -- [Recipes for data process in RedPajama](configs/reproduced_redpajama/README.md) -- [Refined recipes for pre-training data](configs/data_juicer_recipes/README.md) -- [Refined recipes for fine-tuning data](configs/data_juicer_recipes/README.md#before-and-after-refining-for-alpaca-cot-dataset) +- [Recipes for data process in RedPajama](configs/redpajama/README.md) +- [Refined recipes for pre-training text data](configs/data_juicer_recipes/README.md) +- [Refined recipes for fine-tuning text data](configs/data_juicer_recipes/README.md#before-and-after-refining-for-alpaca-cot-dataset) +- [Refined recipes for pre-training multi-modal data](configs/data_juicer_recipes/README.md#before-and-after-refining-for-multimodal-dataset) + -## Demos -- Introduction to Data-Juicer [[ModelScope](https://modelscope.cn/studios/Data-Juicer/overview_scan/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/overview_scan)] -- Data Visualization: - - Basic Statistics [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_visulization_statistics/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/data_visualization_statistics)] - - Lexical Diversity [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_visulization_diversity/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/data_visualization_diversity)] - - Operator Effect [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_visulization_op_effect/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/data_visualization_op_effect)] -- Data Processing: - - Scientific Literature (e.g. [arXiv](https://info.arxiv.org/help/bulk_data_s3.html)) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/process_sci_data/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/process_sci_data)] - - Programming Code (e.g. [TheStack](https://huggingface.co/datasets/bigcode/the-stack)) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/process_code_data/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/process_code_data)] - - Chinese Instruction Data (e.g. [Alpaca-CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/process_sft_zh_data/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/process_cft_zh_data)] -- Tool Pool: - - Dataset Splitting by Language [[ModelScope](https://modelscope.cn/studios/Data-Juicer/tool_dataset_splitting_by_language/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/tool_dataset_splitting_by_language)] - - Quality Classifier for CommonCrawl [[ModelScope](https://modelscope.cn/studios/Data-Juicer/tool_quality_classifier/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/tool_quality_classifier)] - - Auto Evaluation on [HELM](https://github.com/stanford-crfm/helm) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/auto_evaluation_helm/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/auto_evaluation_helm)] - - Data Sampling and Mixture [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_mixture/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/data_mixture)] -- Data Processing Loop [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_process_loop/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/data_process_loop)] -- Data Processing HPO [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_process_hpo/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/data_process_hpo)] ## License Data-Juicer is released under Apache License 2.0. diff --git a/README_ZH.md b/README_ZH.md index 5d1148474..405adaed5 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -1,33 +1,28 @@ -[**English**](README.md) | 中文 +[[English Page]](README.md) | [[文档]](README_ZH.md#documentation-index--文档索引-a-namedocumentationindex) | [[API]](https://alibaba.github.io/data-juicer) | [[*DJ-SORA*]](docs/DJ_SORA_ZH.md) # Data-Juicer: 为大语言模型提供更高质量、更丰富、更易“消化”的数据 -![Data-Juicer](https://img.alicdn.com/imgextra/i3/O1CN017Eq5kf27AlA2NUKef_!!6000000007757-0-tps-1280-720.jpg "Data-Juicer") + Data-Juicer -[![Paper](http://img.shields.io/badge/cs.LG-arXiv%3A2309.02033-B31B1B?logo=arxiv&logoColor=red)](https://arxiv.org/abs/2309.02033) ![](https://img.shields.io/badge/language-Python-214870.svg) ![](https://img.shields.io/badge/license-Apache--2.0-000000.svg) -[![Contributing](https://img.shields.io/badge/Contribution-welcome-brightgreen.svg)](docs/DeveloperGuide_ZH.md) - [![pypi version](https://img.shields.io/pypi/v/py-data-juicer?logo=pypi&color=026cad)](https://pypi.org/project/py-data-juicer) [![Docker version](https://img.shields.io/docker/v/datajuicer/data-juicer?logo=docker&label=Docker&color=498bdf)](https://hub.docker.com/r/datajuicer/data-juicer) -[![Document_List](https://img.shields.io/badge/Docs-English-blue?logo=Markdown)](README.md#documentation) -[![文档列表](https://img.shields.io/badge/文档-中文-blue?logo=Markdown)](README_ZH.md#documentation) -[![API Reference](https://img.shields.io/badge/Docs-API_Reference-blue?logo=Markdown)](https://alibaba.github.io/data-juicer/) -[![ModelScope-10+ Demos](https://img.shields.io/badge/ModelScope-10+_Demos-4e29ff.svg?logo=)](https://modelscope.cn/studios?name=Data-Jiucer&page=1&sort=latest&type=1) -[![ModelScope-20+_Refined_Datasets](https://img.shields.io/badge/ModelScope-20+_Refined_Datasets-4e29ff.svg?logo=)](https://modelscope.cn/datasets?organization=Data-Juicer&page=1) -[![ModelScope-Reference_Models](https://img.shields.io/badge/ModelScope-Reference_Models-4e29ff.svg?logo=)](https://modelscope.cn/models?organization=Data-Juicer&page=1) +[![DataModality](https://img.shields.io/badge/DataModality-Text,Image,Audio,Video-brightgreen.svg)](docs/DeveloperGuide_ZH.md) +[![Usage](https://img.shields.io/badge/Usage-Cleaning,Generation,Analysis-FFD21E.svg)](docs/DeveloperGuide_ZH.md) +[![ModelScope- Demos](https://img.shields.io/badge/ModelScope-Demos-4e29ff.svg?logo=)](https://modelscope.cn/studios?name=Data-Jiucer&page=1&sort=latest&type=1) +[![HuggingFace- Demos](https://img.shields.io/badge/🤗HuggingFace-Demos-4e29ff.svg)](https://huggingface.co/spaces?&search=datajuicer) + +[![Document_List](https://img.shields.io/badge/Docs-English-blue?logo=Markdown)](README.md#documentation-index--文档索引-a-namedocumentationindex) +[![文档列表](https://img.shields.io/badge/文档-中文-blue?logo=Markdown)](README_ZH.md#documentation-index--文档索引-a-namedocumentationindex) +[![API Reference](https://img.shields.io/badge/Docs-API_Reference-blue?logo=Markdown)](https://alibaba.github.io/data-juicer/) +[![Paper](http://img.shields.io/badge/cs.LG-arXiv%3A2309.02033-B31B1B?logo=arxiv&logoColor=red)](https://arxiv.org/abs/2309.02033) -[![HuggingFace-10+ Demos](https://img.shields.io/badge/🤗HuggingFace-10+_Demos-FFD21E.svg)](https://huggingface.co/spaces?&search=datajuicer) -[![HuggingFace-20+_Refined_Datasets](https://img.shields.io/badge/🤗HuggingFace-20+_Refined_Datasets-FFD21E.svg)](https://huggingface.co/datasets?&search=datajuicer) -[![HuggingFace-Reference_Models](https://img.shields.io/badge/🤗HuggingFace-Reference_Models-FFD21E.svg)](https://huggingface.co/models?&search=datajuicer) -[![QualityClassifier](https://img.shields.io/badge/Tools-Quality_Classifier-saddlebrown?logo=Markdown)](tools/quality_classifier/README_ZH.md) -[![AutoEvaluation](https://img.shields.io/badge/Tools-Auto_Evaluation-saddlebrown?logo=Markdown)](tools/evaluator/README_ZH.md) +Data-Juicer 是一个一站式**多模态**数据处理系统,旨在为大语言模型 (LLM) 提供更高质量、更丰富、更易“消化”的数据。 -Data-Juicer 是一个一站式数据处理系统,旨在为大语言模型 (LLM) 提供更高质量、更丰富、更易“消化”的数据。 -本项目在积极更新和维护中,我们将定期强化和新增更多的功能和数据菜谱。欢迎您加入我们推进 LLM 数据的开发和研究工作! +Data-Juicer(包含[DJ-SORA](docs/DJ_SORA_ZH.md))正在积极更新和维护中,我们将定期强化和新增更多的功能和数据菜谱。热烈欢迎您加入我们,一起推进LLM数据的开发和研究! 如果Data-Juicer对您的研发有帮助,请引用我们的[工作](#参考文献) 。 @@ -39,7 +34,8 @@ Data-Juicer 是一个一站式数据处理系统,旨在为大语言模型 (LLM ---- ## 新消息 -- ![new](https://img.alicdn.com/imgextra/i4/O1CN01kUiDtl1HVxN6G56vN_!!6000000000764-2-tps-43-19.png) [2024-02-20] 我们在积极维护一份关于LLM-Data的精选列表,欢迎[访问](docs/awesome_llm_data.md)并参与贡献! +- ![new](https://img.alicdn.com/imgextra/i4/O1CN01kUiDtl1HVxN6G56vN_!!6000000000764-2-tps-43-19.png) [2024-03-07] 我们现在发布了 **Data-Juicer [v0.2.0](https://github.com/alibaba/data-juicer/releases/tag/v0.2.0)**! 在这个新版本中,我们支持了更多的 **多模态数据(包括视频)** 相关特性。我们还启动了 **[DJ-SORA](docs/DJ_SORA_ZH.md)** ,为SORA-like大模型构建开放的大规模高质量数据集! +- ![new](https://img.alicdn.com/imgextra/i4/O1CN01kUiDtl1HVxN6G56vN_!!6000000000764-2-tps-43-19.png) [2024-02-20] 我们在积极维护一份关于LLM-Data的*精选列表*,欢迎[访问](docs/awesome_llm_data.md)并参与贡献! - ![new](https://img.alicdn.com/imgextra/i4/O1CN01kUiDtl1HVxN6G56vN_!!6000000000764-2-tps-43-19.png) [2024-02-05] 我们的论文被SIGMOD'24 industrial track接收! - [2024-01-10] 开启“数据混合”新视界——第二届Data-Juicer大模型数据挑战赛已经正式启动!立即访问[竞赛官网](https://tianchi.aliyun.com/competition/entrance/532174),了解赛事详情。 @@ -54,10 +50,11 @@ Data-Juicer 是一个一站式数据处理系统,旨在为大语言模型 (LLM 目录 === - * [Data-Juicer: 为大语言模型提供更高质量、更丰富、更易“消化”的数据](#data-juicer-为大语言模型提供更高质量更丰富更易消化的数据) * [目录](#目录) * [特点](#特点) + * [Documentation Index | 文档索引](#documentation-index--文档索引-a-namedocumentationindex) + * [演示样例](#演示样例) * [前置条件](#前置条件) * [安装](#安装) * [从源码安装](#从源码安装) @@ -66,28 +63,28 @@ Data-Juicer 是一个一站式数据处理系统,旨在为大语言模型 (LLM * [安装校验](#安装校验) * [快速上手](#快速上手) * [数据处理](#数据处理) + * [分布式数据处理](#分布式数据处理) * [数据分析](#数据分析) * [数据可视化](#数据可视化) * [构建配置文件](#构建配置文件) * [预处理原始数据(可选)](#预处理原始数据可选) * [对于 Docker 用户](#对于-docker-用户) - * [Documentation | 文档](#documentation) * [数据处理菜谱](#数据处理菜谱) - * [演示样例](#演示样例) * [开源协议](#开源协议) * [贡献](#贡献) * [致谢](#致谢) * [参考文献](#参考文献) + ## 特点 ![Overview](https://img.alicdn.com/imgextra/i2/O1CN01IMPeD11xYRUYLmXKO_!!6000000006455-2-tps-3620-1604.png) -* **系统化 & 可复用**:为用户提供系统化且可复用的20+[配置菜谱](configs/README_ZH.md),50+核心[算子](docs/Operators_ZH.md)和专用[工具池](#documentation),旨在让数据处理独立于特定的大语言模型数据集和处理流水线。 +* **系统化 & 可复用**:为用户提供系统化且可复用的80+核心[算子](docs/Operators_ZH.md),20+[配置菜谱](configs/README_ZH.md)和20+专用[工具池](#documentation),旨在让数据处理独立于特定的大语言模型数据集和处理流水线。 * **数据反馈回路**:支持详细的数据分析,并提供自动报告生成功能,使您深入了解您的数据集。结合多维度自动评估功能,支持在 LLM 开发过程的多个阶段进行及时反馈循环。 ![Data-in-the-loop](https://img.alicdn.com/imgextra/i1/O1CN011E99C01ndLZ55iCUS_!!6000000005112-0-tps-2701-1050.jpg) -* **全面的数据处理菜谱**:为pre-training、fine-tuning、中英文等场景提供数十种[预构建的数据处理菜谱](configs/data_juicer_recipes/README_ZH.md)。 ![exp_llama](https://img.alicdn.com/imgextra/i2/O1CN019WtUPP1uhebnDlPR8_!!6000000006069-2-tps-2530-1005.png) +* **全面的数据处理菜谱**:为pre-training、fine-tuning、中英文等场景提供数十种[预构建的数据处理菜谱](configs/data_juicer_recipes/README_ZH.md)。 在LLaMA、LLaVA等模型上有效验证。 ![exp_llama](https://img.alicdn.com/imgextra/i2/O1CN019WtUPP1uhebnDlPR8_!!6000000006069-2-tps-2530-1005.png) * **效率增强**:提供高效的数据处理流水线,减少内存占用和CPU开销,提高生产力。 ![sys-perf](https://img.alicdn.com/imgextra/i4/O1CN01Sk0q2U1hdRxbnQXFg_!!6000000004300-0-tps-2438-709.jpg) @@ -96,9 +93,47 @@ Data-Juicer 是一个一站式数据处理系统,旨在为大语言模型 (LLM * **灵活 & 易扩展**:支持大多数数据格式(如jsonl、parquet、csv等),并允许灵活组合算子。支持[自定义算子](docs/DeveloperGuide_ZH.md#构建自己的算子),以执行定制化的数据处理。 +## Documentation Index | 文档索引 + +* [Overview](README.md) | [概览](README_ZH.md) +* [Operator Zoo](docs/Operators.md) | [算子库](docs/Operators_ZH.md) +* [Configs](configs/README.md) | [配置系统](configs/README_ZH.md) +* [Developer Guide](docs/DeveloperGuide.md) | [开发者指南](docs/DeveloperGuide_ZH.md) +* ["Bad" Data Exhibition](docs/BadDataExhibition.md) | [“坏”数据展览](docs/BadDataExhibition_ZH.md) +* Dedicated Toolkits | 专用工具箱 + * [Quality Classifier](tools/quality_classifier/README.md) | [质量分类器](tools/quality_classifier/README_ZH.md) + * [Auto Evaluation](tools/evaluator/README.md) | [自动评测](tools/evaluator/README_ZH.md) + * [Preprocess](tools/preprocess/README.md) | [前处理](tools/preprocess/README_ZH.md) + * [Postprocess](tools/postprocess/README.md) | [后处理](tools/postprocess/README_ZH.md) +* [Third-parties (LLM Ecosystems)](thirdparty/README.md) | [第三方库(大语言模型生态)](thirdparty/README_ZH.md) +* [API references](https://alibaba.github.io/data-juicer/) +* [Awesome LLM-Data](docs/awesome_llm_data.md) +* [DJ-SORA](docs/DJ_SORA_ZH.md) + + +## 演示样例 + +* Data-Juicer 介绍 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/overview_scan/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/overview_scan)] +* 数据可视化: + * 基础指标统计 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_visulization_statistics/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/data_visualization_statistics)] + * 词汇多样性 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_visulization_diversity/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/data_visualization_diversity)] + * 算子洞察(单OP) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_visualization_op_insight/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/data_visualization_op_insight)] + * 算子效果(多OP) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_visulization_op_effect/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/data_visualization_op_effect)] +* 数据处理: + * 科学文献 (例如 [arXiv](https://info.arxiv.org/help/bulk_data_s3.html)) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/process_sci_data/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/process_sci_data)] + * 编程代码 (例如 [TheStack](https://huggingface.co/datasets/bigcode/the-stack)) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/process_code_data/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/process_code_data)] + * 中文指令数据 (例如 [Alpaca-CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/process_sft_zh_data/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/process_cft_zh_data)] +* 工具池: + * 按语言分割数据集 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/tool_dataset_splitting_by_language/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/tool_dataset_splitting_by_language)] + * CommonCrawl 质量分类器 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/tool_quality_classifier/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/tool_quality_classifier)] + * 基于 [HELM](https://github.com/stanford-crfm/helm) 的自动评测 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/auto_evaluation_helm/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/auto_evaluation_helm)] + * 数据采样及混合 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_mixture/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/data_mixture)] +* 数据处理回路 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_process_loop/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/data_process_loop)] + + ## 前置条件 -* 推荐 Python>=3.7,<=3.10 +* 推荐 Python>=3.8,<=3.10 * gcc >= 5 (at least C++14 support) ## 安装 @@ -189,6 +224,25 @@ export DATA_JUICER_MODELS_CACHE="/path/to/another/directory/models" export DATA_JUICER_ASSETS_CACHE="/path/to/another/directory/assets" ``` +### 分布式数据处理 + +现在基于RAY对多机分布式的数据处理进行了实现。 +对应Demo可以通过如下命令运行: + +```shell + +# 运行文字数据处理 +python tools/process_data.py --config ./demos/process_on_ray/configs/demo.yaml + +# 运行视频数据处理 +python tools/process_data.py --config ./demos/process_video_on_ray/configs/demo.yaml + +``` + + - 如果需要在多机上使用RAY运行多模态数据处理,需要确保各分布式节点可以访问对应的数据路径,将对应的数据路径挂载在文件共享系统(如NAS)中 + + - 用户也可以不使用RAY,拆分数据集后使用Slurm/DLC在集群上运行 + ### 数据分析 - 以配置文件路径为参数运行 `analyze_data.py` 或者 `dj-analyze` 命令行工具来分析数据集。 @@ -273,48 +327,14 @@ docker run -dit \ # 在后台启动容器 docker exec -it bash ``` -## Documentation | 文档 - -* [Overview](README.md) | [概览](README_ZH.md) -* [Operator Zoo](docs/Operators.md) | [算子库](docs/Operators_ZH.md) -* [Configs](configs/README.md) | [配置系统](configs/README_ZH.md) -* [Developer Guide](docs/DeveloperGuide.md) | [开发者指南](docs/DeveloperGuide_ZH.md) -* ["Bad" Data Exhibition](docs/BadDataExhibition.md) | [“坏”数据展览](docs/BadDataExhibition_ZH.md) -* Dedicated Toolkits | 专用工具箱 - * [Quality Classifier](tools/quality_classifier/README.md) | [质量分类器](tools/quality_classifier/README_ZH.md) - * [Auto Evaluation](tools/evaluator/README.md) | [自动评测](tools/evaluator/README_ZH.md) - * [Preprocess](tools/preprocess/README.md) | [前处理](tools/preprocess/README_ZH.md) - * [Postprocess](tools/postprocess/README.md) | [后处理](tools/postprocess/README_ZH.md) -* [Third-parties (LLM Ecosystems)](thirdparty/README.md) | [第三方库(大语言模型生态)](thirdparty/README_ZH.md) -* [API references](https://alibaba.github.io/data-juicer/) -* [Awesome LLM-Data](docs/awesome_llm_data.md) - ## 数据处理菜谱 * [BLOOM 数据处理菜谱](configs/reproduced_bloom/README_ZH.md) * [RedPajama 数据处理菜谱](configs/reproduced_redpajama/README_ZH.md) -* [预训练数据增强菜谱](configs/data_juicer_recipes/README_ZH.md) -* [Fine-tuning数据增强菜谱](configs/data_juicer_recipes/README_ZH.md#完善前后的alpaca-cot数据集) - -## 演示样例 - -* Data-Juicer 介绍 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/overview_scan/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/overview_scan)] -* 数据可视化: - * 基础指标统计 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_visulization_statistics/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/data_visualization_statistics)] - * 词汇多样性 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_visulization_diversity/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/data_visualization_diversity)] - * 算子效果 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_visulization_op_effect/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/data_visualization_op_effect)] -* 数据处理: - * 科学文献 (例如 [arXiv](https://info.arxiv.org/help/bulk_data_s3.html)) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/process_sci_data/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/process_sci_data)] - * 编程代码 (例如 [TheStack](https://huggingface.co/datasets/bigcode/the-stack)) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/process_code_data/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/process_code_data)] - * 中文指令数据 (例如 [Alpaca-CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)) [[ModelScope](https://modelscope.cn/studios/Data-Juicer/process_sft_zh_data/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/process_cft_zh_data)] -* 工具池: - * 按语言分割数据集 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/tool_dataset_splitting_by_language/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/tool_dataset_splitting_by_language)] - * CommonCrawl 质量分类器 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/tool_quality_classifier/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/tool_quality_classifier)] - * 基于 [HELM](https://github.com/stanford-crfm/helm) 的自动评测 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/auto_evaluation_helm/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/auto_evaluation_helm)] - * 数据采样及混合 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_mixture/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/data_mixture)] -* 数据处理回路 [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_process_loop/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/data_process_loop)] -* 数据处理 HPO [[ModelScope](https://modelscope.cn/studios/Data-Juicer/data_process_hpo/summary)] [[HuggingFace](https://huggingface.co/spaces/datajuicer/data_process_hpo)] +* [预训练文本数据增强菜谱](configs/data_juicer_recipes/README_ZH.md) +* [Fine-tuning文本数据增强菜谱](configs/data_juicer_recipes/README_ZH.md#完善前后的alpaca-cot数据集) +* [预训练多模态数据增强菜谱](configs/data_juicer_recipes/README_ZH.md#before-and-after-refining-for-multimodal-dataset) ## 开源协议 diff --git a/configs/config_all.yaml b/configs/config_all.yaml index f2a2bf5e3..6d55de0ad 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -28,6 +28,8 @@ image_key: 'images' # key name of field image_special_token: '<__dj__image>' # the special token that represents an image in the text. In default, it's "<__dj__image>". You can specify your own special token according to your input dataset. audio_key: 'audios' # key name of field to store the list of sample audio paths. audio_special_token: '<__dj__audio>' # the special token that represents an audio in the text. In default, it's "<__dj__audio>". You can specify your own special token according to your input dataset. +video_key: 'videos' # key name of field to store the list of sample video paths. +video_special_token: '<__dj__video>' # the special token that represents a video in the text. In default, it's "<__dj__video>". You can specify your own special token according to your input dataset. eoc_special_token: '<|__dj__eoc|>' # the special token that represents the end of a chunk in the text. In default, it's "<|__dj__eoc|>". You can specify your own special token according to your input dataset. @@ -41,6 +43,7 @@ save_stats_in_one_file: false # whether to store a # process schedule: a list of several process operators with their arguments process: # Mapper ops. Most of these ops need no arguments. + - audio_ffmpeg_wrapped_mapper: # simple wrapper for FFmpeg audio filters - chinese_convert_mapper: # convert Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji. mode: 's2t' # choose the mode to convert Chinese: ['s2t', 't2s', 's2tw', 'tw2s', 's2hk', 'hk2s', 's2twp', 'tw2sp', 't2tw', 'tw2t', 'hk2t', 't2hk', 't2jp', 'jp2t'] - clean_email_mapper: # remove emails from text. @@ -50,14 +53,11 @@ process: - clean_copyright_mapper: # remove copyright comments. - expand_macro_mapper: # expand macro definitions in Latex text. - fix_unicode_mapper: # fix unicode errors in text. - - generate_caption_mapper: # generate captions for images to augment datasets - hf_blip2: 'Salesforce/blip2-opt-2.7b' # blip2 model name on huggingface to generate caption - caption_num: 1 # how many candidate captions to generate for each image - keep_candidate_mode: 'random_any' # retain strategy for the generated $caption_num$ candidates. should be in ["random_any", "similar_one_simhash", "all"]. - keep_original_sample: true # whether to keep the original sample. If it's set to False, there will be only generated captions in the final datasets and the original captions will be removed. It's True in default. - prompt: null # a string prompt to guide the generation of blip2 model for all samples globally. It's None in default, which means no prompt provided. - prompt_key: null # the key name of fields in samples to store prompts for each sample. It's used for set different prompts for different samples. If it's none, use prompt in parameter "prompt". It's None in default. - - gpt4v_generate_mapper: # generate samples whose texts are generated based on gpt-4-visison and the image + - image_blur_mapper: # mapper to blur images. + p: 0.2 # probability of the image being blured + blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian'] + radius: 2 # radius of blur kernel + - image_captioning_from_gpt4v_mapper: # generate samples whose texts are generated based on gpt-4-visison and the image mode: 'description' # mode of text generated from images, can be one of ['resoning', 'description', 'conversation', 'custom'] api_key: '' # the API key to authenticate the request max_token: 500 # the maximum number of tokens to generate. Default is 500. @@ -67,10 +67,13 @@ process: user_prompt_key: null # the key name of fields in samples to store prompts for each sample. It's used for set different prompts for different samples. If it's none, use prompt in parameter "prompt". It's None in default keep_original_sample: true # whether to keep the original sample. If it's set to False, there will be only generated text in the final datasets and the original text will be removed. It's True in default any_or_all: 'any' # keep this sample with 'any' or 'all' strategy of all images. 'any': keep this sample if any images meet the condition. 'all': keep this sample only if all images meet the condition - - image_blur_mapper: # mapper to blur images. - p: 0.2 # probability of the image being blured - blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian'] - radius: 2 # radius of blur kernel + - image_captioning_mapper: # generate captions for images to augment datasets + hf_img2seq: 'Salesforce/blip2-opt-2.7b' # model name on huggingface to generate caption + caption_num: 1 # how many candidate captions to generate for each image + keep_candidate_mode: 'random_any' # retain strategy for the generated $caption_num$ candidates. should be in ["random_any", "similar_one_simhash", "all"]. + keep_original_sample: true # whether to keep the original sample. If it's set to False, there will be only generated captions in the final datasets and the original captions will be removed. It's True in default. + prompt: null # a string prompt to guide the generation of blip2 model for all samples globally. It's None in default, which means no prompt provided. + prompt_key: null # the key name of fields in samples to store prompts for each sample. It's used for set different prompts for different samples. If it's none, use prompt in parameter "prompt". It's None in default. - image_diffusion_mapper: # generate images by diffusion model floating_point: 'fp32' # the floating point used to load the diffusion model. hf_diffusion: 'CompVis/stable-diffusion-v1-4' # stable diffusion model name on huggingface to generate image @@ -79,7 +82,7 @@ process: aug_num: 1 # the number of images to generate keep_original_sample: true # whether to keep the original sample. If it's set to False, there will be only generated images in the final datasets and the original images will be removed. It's True in default. caption_key: null # the key name of fields in samples to store captions for each images, the caption guide the diffusion model to produce what the image is - hf_blip2: 'Salesforce/blip2-opt-2.7b' # blip2 model name on huggingface to generate caption if caption_key is null + hf_img2seq: 'Salesforce/blip2-opt-2.7b' # model name on huggingface to generate caption if caption_key is null - nlpaug_en_mapper: # simply augment texts in English based on the nlpaug library sequential: false # whether combine all augmentation methods to a sequence. If it's True, a sample will be augmented by all opened augmentation methods sequentially. If it's False, each opened augmentation method would generate its augmented samples independently. aug_num: 1 # number of augmented samples to be generated. If `sequential` is True, there will be total aug_num augmented samples generated. If it's False, there will be (aug_num * #opened_aug_method) augmented samples generated. @@ -132,6 +135,47 @@ process: substrings: ['http', 'www', '.com', 'href', '//'] # incorrect substrings to remove - sentence_split_mapper: # split text to multiple sentences and join them with '\n' lang: 'en' # split text in what language + - video_captioning_from_audio_mapper: # caption a video according to its audio streams based on Qwen-Audio model + keep_original_sample: true # whether to keep the original sample. If it's set to False, there will be only captioned sample in the final datasets and the original sample will be removed. It's True in default. + - video_captioning_from_video_mapper: # generate captions by frame images extracted from video to augment datasets + hf_video_blip: 'kpyu/video-blip-opt-2.7b-ego4d' # video-blip model name on huggingface to generate caption + caption_num: 1 # how many candidate captions to generate for each video + keep_candidate_mode: 'random_any' # retain strategy for the generated $caption_num$ candidates. should be in ["random_any", "similar_one_simhash", "all"]. + keep_original_sample: true # whether to keep the original sample. If it's set to False, there will be only generated captions in the final datasets and the original captions will be removed. It's True in default. + prompt: null # a string prompt to guide the generation of video-blip model for all samples globally. It's None in default, which means no prompt provided. + prompt_key: null # the key name of fields in samples to store prompts for each sample. It's used for set different prompts for different samples. If it's none, use prompt in parameter "prompt". It's None in default. + frame_sampling_method: 'all_keyframes' # sampling method of extracting frame images from the videos. Should be one of ["all_keyframes", "uniform"]. The former one extracts all key frames and the latter one extract specified number of frames uniformly from the video. Default: "all_keyframes". + frame_num: 3 # the number of frames to be extracted uniformly from the video. Only works when frame_sampling_method is "uniform". If it's 1, only the middle frame will be extracted. If it's 2, only the first and the last frames will be extracted. If it's larger than 2, in addition to the first and the last frames, other frames will be extracted uniformly within the video duration. + horizontal_flip: false # flip frame image horizontally (left to right). + vertical_flip: false # flip frame image vertically (top to bottom). + - video_split_by_scene_mapper: # split videos into scene clips + detector: 'ContentDetector' # PySceneDetect scene detector. Should be one of ['ContentDetector', 'ThresholdDetector', 'AdaptiveDetector`] + threshold: 27.0 # threshold passed to the detector + min_scene_len: 15 # minimum length of any scene + show_progress: false # whether to show progress from scenedetect + - video_resize_aspect_ratio_mapper: # resize videos aspect ratios of videos (a fraction of width by height, r=w/h) to a specified range + min_ratio: 9/21 # the minimum aspect ratio to enforce videos with an aspect ratio below `min_ratio` will be resized to match this minimum ratio. The ratio should be provided as a string in the format "9:21" or "9/21". + max_ratio: 21/9 # the maximum aspect ratio to enforce videos with an aspect ratio above `max_ratio` will be resized to match this maximum ratio. The ratio should be provided as a string in the format "21:9" or "21/9". + strategy: increase # the resizing strategy to apply when adjusting the video dimensions. It can be either 'decrease' to reduce the dimension or 'increase' to enlarge it. Accepted values are ['decrease', 'increase']. + - video_resize_resolution_mapper: # map videos to ones with given resolution range + min_width: 640 # the min horizontal resolution (unit p), videos with width less than 'min_width' will be mapped to videos with equal or bigger width + max_width: 1280 # the max horizontal resolution (unit p), videos with width more than 'max_width' will be mapped to videos with equal of smaller width + min_height: 480 # the min vertical resolution (unit p), videos with height less than 'min_height' will be mapped to videos with equal or bigger height + max_height: 1080 # the max vertical resolution (unit p), videos with height more than 'max_height' will be mapped to videos with equal or smaller height + force_original_aspect_ratio: 'increase' # Enable decreasing or increasing output video width or height if necessary to keep the original aspect ratio + force_divisible_by: 4 # Ensures that both the output dimensions, width and height, are divisible by the given integer when used together with force_original_aspect_ratio + - video_ffmpeg_wrapped_mapper: # simple wrapper for FFmpeg video filters + - video_split_by_duration_mapper: # Mapper to split video by duration. + split_duration: 10 # duration of each video split in seconds. + min_last_split_duration: 0 # the minimum allowable duration in seconds for the last video split. If the duration of the last split is less than this value, it will be discarded. + keep_original_sample: true # whether to keep the original sample. If it's set to False, there will be only cut sample in the final datasets and the original sample will be removed. It's True in default + - video_split_by_key_frame_mapper: # Mapper to split video by key frame. + keep_original_sample: true # whether to keep the original sample. If it's set to False, there will be only cut sample in the final datasets and the original sample will be removed. It's True in default + - video_tagging_from_audio_mapper: # Mapper to generate video tags from audio streams extracted from the video. + hf_ast: 'MIT/ast-finetuned-audioset-10-10-0.4593' # Huggingface model name for the audio classification model. + - video_tagging_from_frames_mapper: # Mapper to generate video tags from frames extracted from the video. + frame_sampling_method: 'all_keyframes' # sampling method of extracting frame images from the videos. Should be one of ["all_keyframes", "uniform"]. The former one extracts all key frames and the latter one extract specified number of frames uniformly from the video. Default: "all_keyframes". + frame_num: 3 # the number of frames to be extracted uniformly from the video. Only works when frame_sampling_method is "uniform". If it's 1, only the middle frame will be extracted. If it's 2, only the first and the last frames will be extracted. If it's larger than 2, in addition to the first and the last frames, other frames will be extracted uniformly within the video duration. - whitespace_normalization_mapper: # normalize different kinds of whitespaces to English whitespace. # Filter ops @@ -175,6 +219,11 @@ process: min_ratio: 0.333 # the min aspect ratio of filter range max_ratio: 3.0 # the max aspect ratio of filter range any_or_all: any # keep this sample when any/all images meet the filter condition + - image_aesthetics_filter: # filter samples according to the aesthetics score of images. + hf_scorer_model: shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE # Huggingface model name for the aesthetics predictor + min_score: 0.3 # the min aesthetics score of filter range + max_score: 1.0 # the max aesthetics score of filter range + any_or_all: any # keep this sample when any/all images meet the filter condition - image_shape_filter: # filter samples according to the widths and heights of images in them min_width: 200 # the min width of width filter range max_width: 5000 # the max width of width filter range @@ -211,7 +260,7 @@ process: lang: en # compute perplexity in what language max_ppl: 1500 # the max perplexity score to filter text - phrase_grounding_recall_filter: # filter samples according to the locating recall of phrases extracted from text in the images. - hf_owlvit: openai/clip-vit-base-patch32 # name of used Hugging Face Owl-ViT + hf_owlvit: openai/clip-vit-base-patch32 # name of used Hugging Face Owl-ViT min_recall: 0.1 # the min phrase grounding recall of filter range max_recall: 1.0 # the max phrase grounding recall of filter range horizontal_flip: false # flip image horizontally (left to right). @@ -224,6 +273,13 @@ process: - special_characters_filter: # filter text with special-char ratio out of specific range min_ratio: 0.0 # the min ratio of filter range max_ratio: 0.25 # the max ratio of filter range + - specified_field_filter: # filter text with the specified field info out of specific range + field_key: '' # the target key corresponding to multi-level field information need to be separated by '.' + target_value: [] # the range of specified field information corresponding to the samples that need to be retained + - specified_numeric_field_filter: # filter text with the specified numeric field info out of specific range + field_key: '' # the target key corresponding to multi-level field information need to be separated by '.' + min_value: 0 # the min filter value in SpecifiedNumericField op + max_value: 10000 # the max filter value in SpecifiedNumericField op - stopwords_filter: # filter text with stopword ratio smaller than a specific min value lang: en # consider stopwords in what language tokenization: false # whether to use model to tokenize documents @@ -232,6 +288,8 @@ process: use_words_aug: false # whether to augment words, especially for Chinese and Vietnamese words_aug_group_sizes: [2] # the group size of words to augment words_aug_join_char: "" # the join char between words to augment + - suffix_filter: # filter to keep samples with specified suffix. + suffixes: [] # the suffix of text that will be keep. For example: '.txt', 'txt' or ['txt', '.pdf', 'docx'] - text_action_filter: # filter text according the number of action verb lang: en # consider the words in what language min_action_num: 1 # text will be filtered whose verbs less the min action number @@ -246,6 +304,49 @@ process: hf_tokenizer: EleutherAI/pythia-6.9b-deduped # name of used Hugging Face tokenizer min_num: 10 # the min number of filter range max_num: 10000 # the max number of filter range + - video_aesthetics_filter: # filter samples according to the aesthetics score of frame images extracted from videos. + hf_scorer_model: shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE # Huggingface model name for the aesthetics predictor + min_score: 0.3 # the min aesthetics score of filter range + max_score: 1.0 # the max aesthetics score of filter range + frame_sampling_method: 'uniform' # sampling method of extracting frame images from the videos. Should be one of ["all_keyframe", "uniform"]. The former one extracts all key frames and the latter one extract specified number of frames uniformly from the video. Default: "uniform" with frame_num=3, considering that the number of keyframes can be large while their difference is usually small in terms of their aesthetics. + frame_num: 3 # the number of frames to be extracted uniformly from the video. Only works when frame_sampling_method is "uniform". If it's 1, only the middle frame will be extracted. If it's 2, only the first and the last frames will be extracted. If it's larger than 2, in addition to the first and the last frames, other frames will be extracted uniformly within the video duration. + reduce_mode: avg # reduce mode to the all frames extracted from videos, must be one of ['avg','max', 'min']. + any_or_all: any # keep this sample when any/all images meet the filter condition + - video_aspect_ratio_filter: # filter samples according to the aspect ratios of videos (a fraction of width by height, r=w/h) in them + min_ratio: 9/21 # the minimum aspect ratio to keep samples, supported format is a string, such as "9:21" or "9/21". + max_ratio: 21/9 # the maximum aspect ratio to keep samples, supported format is a string, such as "21:9" or "21/9". + any_or_all: any # keep this sample when any/all videos meet the filter condition + - video_duration_filter: # Keep data samples whose videos' durations are within a specified range. + min_duration: 0 # the min video duration of filter range (in seconds) + max_duration: 10 # the max video duration of filter range (in seconds) + any_or_all: any # keep this sample when any/all videos meet the filter condition + - video_frames_text_similarity_filter: # keep samples those similarities between sampled video frame images and text within a specific range. + hf_clip: 'openai/clip-vit-base-patch32' # clip model name on huggingface to compute the similarity between frame image and text. It's kind of language-related. For example, for Chinese datasets, ChineseCLIP might be a better choice. + min_score: 0.1 # the min similarity to keep samples. + max_score: 1.0 # the max similarity to keep samples. + frame_sampling_method: 'all_keyframes' # sampling method of extracting frame images from the videos. Should be one of ["all_keyframes", "uniform"]. The former one extracts all key frames and the latter one extract specified number of frames uniformly from the video. Default: "all_keyframes". + frame_num: 3 # the number of frames to be extracted uniformly from the video. Only works when frame_sampling_method is "uniform". If it's 1, only the middle frame will be extracted. If it's 2, only the first and the last frames will be extracted. If it's larger than 2, in addition to the first and the last frames, other frames will be extracted uniformly within the video duration. + horizontal_flip: false # flip frame image horizontally (left to right). + vertical_flip: false # flip frame image vertically (top to bottom). + reduce_mode: avg # reduce mode when one text corresponds to multiple videos in a chunk, must be one of ['avg','max', 'min']. + any_or_all: any # keep this sample when any/all videos meet the filter condition + - video_motion_score_filter: # Keep samples with video motion scores within a specific range. + min_score: 0.25 # the minimum motion score to keep samples + max_score: 10000.0 # the maximum motion score to keep samples + sampling_fps: 2 # the samplig rate of frames_per_second to compute optical flow + any_or_all: any # keep this sample when any/all videos meet the filter condition + - video_ocr_area_ratio_filter: # Keep data samples whose detected text area ratios for specified frames in the video are within a specified range. + min_area_ratio: 0 # the min ocr area ratio to keep samples. It's 0 by default. + max_area_ratio: 1.0 # the max ocr area ratio to keep samples. It's 1.0 by default. + frame_sample_num: 3 # the number of sampled frames to calculate the ocr area ratio. If it's 1, only middle frame will be selected. If it's 2, only the first and the last frames will be selected. If it's larger than 2, in addition to the first and the last frames, other frames will be sampled evenly within the video duration. + languages_to_detect: ['ch_sim', 'en'] # texts in which languages should be detected. Default: ['ch_sim', 'en']. Full language list can be found here: https://www.jaided.ai/easyocr/. + any_or_all: any # keep this sample with 'any' or 'all' strategy of all videos. 'any': keep this sample if any videos meet the condition. 'all': keep this sample only if all videos meet the condition. + - video_resolution_filter: # filter samples according to the resolution of videos in them + min_width: 1280 # the min resolution of horizontal resolution filter range (unit p) + max_width: 4096 # the max resolution of horizontal resolution filter range (unit p) + min_height: 480 # the min resolution of vertical resolution filter range (unit p) + max_height: 1080 # the max resolution of vertical resolution filter range (unit p) + any_or_all: any # keep this sample when any/all videos meet the filter condition - words_num_filter: # filter text with number of words out of specific range lang: en # sample in which language tokenization: false # whether to use model to tokenize documents @@ -257,15 +358,6 @@ process: rep_len: 10 # repetition length for word-level n-gram min_ratio: 0.0 # the min ratio of filter range max_ratio: 0.5 # the max ratio of filter range - - suffix_filter: # filter to keep samples with specified suffix. - suffixes: [] # the suffix of text that will be keep. For example: '.txt', 'txt' or ['txt', '.pdf', 'docx'] - - specified_field_filter: # filter text with the specified field info out of specific range - field_key: '' # the target key corresponding to multi-level field information need to be separated by '.' - target_value: [] # the range of specified field information corresponding to the samples that need to be retained - - specified_numeric_field_filter: # filter text with the specified numeric field info out of specific range - field_key: '' # the target key corresponding to multi-level field information need to be separated by '.' - min_value: 0 # the min filter value in SpecifiedNumericField op - max_value: 10000 # the max filter value in SpecifiedNumericField op # Deduplicator ops - document_deduplicator: # deduplicate text samples using md5 hashing exact matching method @@ -289,6 +381,7 @@ process: ignore_pattern: null # whether to ignore sub-strings with specific pattern when computing simhash. - image_deduplicator: # deduplicator to deduplicate samples at document-level using exact matching of images between documents. method: phash # hash method for image. One of [phash, dhash, whash, ahash] + - video_deduplicator: # deduplicator to deduplicate samples at document-level using exact matching of videos between documents. # Selector ops - topk_specified_field_selector: # selector to select top samples based on the sorted specified field diff --git a/configs/data_juicer_recipes/README.md b/configs/data_juicer_recipes/README.md index f04894f91..1f6ad757c 100644 --- a/configs/data_juicer_recipes/README.md +++ b/configs/data_juicer_recipes/README.md @@ -4,7 +4,7 @@ We found that there are still some "bad" samples in existing processed datasets We use simple 3-σ rule to set the hyperparameters for ops in each recipe. -## Before and after refining for Pretraining Dataset +## Before and after refining for Pretraining Text Dataset | subset | #samples before | #samples after | keep ratio | config link | data link | source | |----------------------|:---------------------------:|:--------------:|:----------:|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------| @@ -35,3 +35,17 @@ We use simple 3-σ rule to set the hyperparameters for ops in each recipe. |------------------|:-------------------------:|:--------------------------------------:|:----------:|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------| | Alpaca-Cot EN | 136,219,879 | 72,855,345 | 54.48% | [alpaca-cot-en-refine.yaml](alpaca_cot/alpaca-cot-en-refine.yaml) | [Aliyun](https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/CFT/alpaca-cot-en-refine_result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/alpaca-cot-en-refined-by-data-juicer/summary)
[HuggingFace](https://huggingface.co/datasets/datajuicer/alpaca-cot-en-refined-by-data-juicer) | [39 Subsets of Alpaca-CoT](alpaca_cot/README.md#refined-alpaca-cot-dataset-meta-info) | | Alpaca-Cot ZH | 21,197,246 | 9,873,214 | 46.58% | [alpaca-cot-zh-refine.yaml](alpaca_cot/alpaca-cot-zh-refine.yaml) | [Aliyun](https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/CFT/alpaca-cot-zh-refine_result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/alpaca-cot-zh-refined-by-data-juicer/summary)
[HuggingFace](https://huggingface.co/datasets/datajuicer/alpaca-cot-zh-refined-by-data-juicer) | [28 Subsets of Alpaca-CoT](alpaca_cot/README.md#refined-alpaca-cot-dataset-meta-info) | + +## Before and after refining for Multimodal Dataset + +| subset | #samples before | #samples after | keep ratio | config link | data link | source | +|---------------------------|:---------------------------:|:--------------:|:----------:|--------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------| +| LLaVA pretrain (LCS-558k) | 558,128 | 500,380 | 89.65% | [llava-pretrain-refine.yaml](llava-pretrain-refine.yaml) | [Aliyun](https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/MM_data/our_refined_data/LLaVA-1.5/public/llava-pretrain-refine-result.json)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/llava-pretrain-refined-by-data-juicer/summary)
[HuggingFace](https://huggingface.co/datasets/datajuicer/llava-pretrain-refined-by-data-juicer) | [LLaVA-1.5](https://github.com/haotian-liu/LLaVA) | + +### Evaluation Results +- LLaVA pretrain (LCS-558k): models **pretrained with refined dataset** and fine-tuned with the original instruct dataset outperforms the baseline (LLaVA-1.5-13B) on 10 out of 12 benchmarks. + +| model | VQAv2 | GQA | VizWiz | SQA | TextVQA | POPE | MME | MM-Bench | MM-Bench-CN | SEED | LLaVA-Bench-Wild | MM-Vet | +|-------------------------------|-------| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| LLaVA-1.5-13B
(baseline) | **80.0** | 63.3 | 53.6 | 71.6 | **61.3** | 85.9 | 1531.3 | 67.7 | 63.6 | 61.6 | 72.5 | 36.1 | +| LLaVA-1.5-13B
(refined pretrain dataset) | 79.94 | **63.5** | **54.09** | **74.20** | 60.82 | **86.67** | **1565.53** | **68.2** | **63.9** | **61.8** | **75.9** | **37.4** | diff --git a/configs/data_juicer_recipes/README_ZH.md b/configs/data_juicer_recipes/README_ZH.md index af8d1d697..d7dd848d7 100644 --- a/configs/data_juicer_recipes/README_ZH.md +++ b/configs/data_juicer_recipes/README_ZH.md @@ -35,3 +35,17 @@ |-------------------|:------------------------:|:----------------------------------:|:---------:|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------| | Alpaca-Cot EN | 136,219,879 | 72,855,345 | 54.48% | [alpaca-cot-en-refine.yaml](alpaca_cot/alpaca-cot-en-refine.yaml) | [Aliyun](https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/CFT/alpaca-cot-en-refine_result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/alpaca-cot-en-refined-by-data-juicer/summary)
[HuggingFace](https://huggingface.co/datasets/datajuicer/alpaca-cot-en-refined-by-data-juicer) | [来自Alpaca-CoT的39个子集](alpaca_cot/README_ZH.md#完善的-alpaca-cot-数据集元信息) | | Alpaca-Cot ZH | 21,197,246 | 9,873,214 | 46.58% | [alpaca-cot-zh-refine.yaml](alpaca_cot/alpaca-cot-zh-refine.yaml) | [Aliyun](https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/LLM_data/our_refined_datasets/CFT/alpaca-cot-zh-refine_result.jsonl)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/alpaca-cot-zh-refined-by-data-juicer/summary)
[HuggingFace](https://huggingface.co/datasets/datajuicer/alpaca-cot-zh-refined-by-data-juicer) | [来自Alpaca-CoT的28个子集](alpaca_cot/README_ZH.md#完善的-alpaca-cot-数据集元信息) | + +## 完善前后的多模态数据集 + +| 数据子集 | 完善前的样本数目 | 完善后的样本数目 | 样本保留率 | 配置链接 | 数据链接 | 来源 | +|---------------------------|:---------------------------:|:--------------:|:----------:|--------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------| +| LLaVA pretrain (LCS-558k) | 558,128 | 500,380 | 89.65% | [llava-pretrain-refine.yaml](llava-pretrain-refine.yaml) | [Aliyun](https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/MM_data/our_refined_data/LLaVA-1.5/public/llava-pretrain-refine-result.json)
[ModelScope](https://modelscope.cn/datasets/Data-Juicer/llava-pretrain-refined-by-data-juicer/summary)
[HuggingFace](https://huggingface.co/datasets/datajuicer/llava-pretrain-refined-by-data-juicer) | [LLaVA-1.5](https://github.com/haotian-liu/LLaVA) | + +### 评测结果 +- LLaVA pretrain (LCS-558k): 使用**完善后的预训练数据集**预训练并使用原始的指令数据集微调后的模型在12个评测集上有10个超过了基线模型LLaVA-1.5-13B。 + +| 模型 | VQAv2 | GQA | VizWiz | SQA | TextVQA | POPE | MME | MM-Bench | MM-Bench-CN | SEED | LLaVA-Bench-Wild | MM-Vet | +|---------------------------------|-------| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| LLaVA-1.5-13B
(基线) | **80.0** | 63.3 | 53.6 | 71.6 | **61.3** | 85.9 | 1531.3 | 67.7 | 63.6 | 61.6 | 72.5 | 36.1 | +| LLaVA-1.5-13B
(完善后的预训练数据集) | 79.94 | **63.5** | **54.09** | **74.20** | 60.82 | **86.67** | **1565.53** | **68.2** | **63.9** | **61.8** | **75.9** | **37.4** | diff --git a/configs/data_juicer_recipes/llava-pretrain-refine.yaml b/configs/data_juicer_recipes/llava-pretrain-refine.yaml new file mode 100644 index 000000000..03a1bf23c --- /dev/null +++ b/configs/data_juicer_recipes/llava-pretrain-refine.yaml @@ -0,0 +1,60 @@ +project_name: 'llava-1.5-pretrain-dataset-refine-recipe' +dataset_path: 'blip_laion_cc_sbu_558k_dj_fmt_only_caption.jsonl' # converted LLaVA pretrain dataset in Data-Juicer format with only_keep_caption is True. See tools/multimodal/source_format_to_data_juicer_format/llava_to_dj.py +export_path: 'blip_laion_cc_sbu_558k_dj_fmt_only_caption_refined.jsonl' + +np: 42 # number of subprocess to process your dataset +text_keys: 'text' # the key name of field where the sample texts to be processed, e.g., `text`, `instruction`, `output`, ... + +# for multimodal data processing +image_key: 'images' # Key name of field to store the list of sample image paths. +image_special_token: '' # The special token that represents an image in the text. For LLaVA, it's "". Should be aligned with the args when running conversion tools. +eoc_special_token: '<|__dj__eoc|>' # The special token that represents the end of a chunk in the text. In default, it's "<|__dj__eoc|>". You can specify your own special token according to your input dataset. Should be aligned with the args when running conversion tools. + +open_tracer: true + +# process schedule: a list of several process operators with their arguments +process: + - fix_unicode_mapper: # fix unicode errors in text. + - punctuation_normalization_mapper: # normalize unicode punctuations to English punctuations. + + # 558128 + # Filter ops + - alphanumeric_filter: #558087 # filter text with alphabet/numeric ratio out of specific range. + tokenization: false # Whether to count the ratio of alphanumeric to the total number of tokens. + min_ratio: 0.60 # the min ratio of filter range + - character_repetition_filter: #546105 # filter text with the character repetition ratio out of specific range + rep_len: 10 # repetition length for char-level n-gram + max_ratio: 0.09373663 # the max ratio of filter range + - flagged_words_filter: #543960 # filter text with the flagged-word ratio larger than a specific max value + lang: en # consider flagged words in what language + tokenization: false # whether to use model to tokenize documents + max_ratio: 0.0 # the max ratio to filter text + - perplexity_filter: #532029 # filter text with perplexity score out of specific range + lang: en # compute perplexity in what language + max_ppl: 14435.5806 # the max perplexity score to filter text + - special_characters_filter: #531968 # filter text with special-char ratio out of specific range + min_ratio: 0.16534802 # the min ratio of filter range + max_ratio: 0.42023757 # the max ratio of filter range + - word_repetition_filter: # 530773 # filter text with the word repetition ratio out of specific range + lang: en # sample in which language + tokenization: false # whether to use model to tokenize documents + rep_len: 10 # repetition length for word-level n-gram + max_ratio: 0.03085751 # the max ratio of filter range + + - image_aspect_ratio_filter: #542389 # filter samples according to the aspect ratios of images (a fraction of width by height, r=w/h) in them + min_ratio: 0.333 # the min aspect ratio of filter range + max_ratio: 3.0 # the max aspect ratio of filter range + any_or_all: any # keep this sample when any/all images meet the filter condition + - image_shape_filter: #533966 # filter samples according to the widths and heights of images in them + max_width: 727.8798422276 # the max width of width filter range + max_height: 606.2421072264 # the max height of height filter range + any_or_all: any # keep this sample when any/all images meet the filter condition + - image_size_filter: # 533966 # filter samples according to the size of images (in bytes) within them + max_size: "124KB" # the max size of filter range + any_or_all: any # keep this sample when any/all images meet the filter condition + - image_text_similarity_filter: #544202 # filter samples according to the similarity between text and images. + hf_clip: openai/clip-vit-base-patch32 # name of used Hugging Face clip + min_score: 0.20315419 # the min similarity of filter range + - image_text_matching_filter: # filter samples according to the matching score between image and text. + hf_blip: Salesforce/blip-itm-base-coco # name of used Hugging Face blip + min_score: 0.44930778 # the min matching score of filter range diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index c728dfa66..f75c4bef3 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -149,6 +149,18 @@ def init_configs(args=None): help='The special token that represents an audio in the text. In ' 'default, it\'s "<__dj__audio>". You can specify your own special' ' token according to your input dataset.') + parser.add_argument( + '--video_key', + type=str, + default='videos', + help='Key name of field to store the list of sample video paths.') + parser.add_argument( + '--video_special_token', + type=str, + default=SpecialTokens.video, + help='The special token that represents a video in the text. In ' + 'default, it\'s "<__dj__video>". You can specify your own special' + ' token according to your input dataset.') parser.add_argument( '--eoc_special_token', type=str, diff --git a/data_juicer/core/ray_executor.py b/data_juicer/core/ray_executor.py index 8587f6a2e..da8eac70a 100644 --- a/data_juicer/core/ray_executor.py +++ b/data_juicer/core/ray_executor.py @@ -1,3 +1,8 @@ +import os +from functools import partial + +import pandas as pd +import pyarrow as pa from loguru import logger from data_juicer.config import init_configs @@ -10,6 +15,48 @@ import ray.data as rd +def is_valid_path(item, dataset_dir): + full_path = os.path.abspath(os.path.join(dataset_dir, item)) + return os.path.exists(full_path) + + +def convert_to_absolute_paths(dict_with_paths, dataset_dir): + for key, value in dict_with_paths.items(): + if isinstance(value, list): + dict_with_paths[key] = [ + os.path.abspath(os.path.join(dataset_dir, item)) + if isinstance(item, str) and is_valid_path(dataset_dir, item) + else item for item in value + ] + elif isinstance(value, str): + dict_with_paths[key] = os.path.abspath( + os.path.join( + dataset_dir, + value)) if isinstance(value, str) and is_valid_path( + value, dataset_dir) else value + return dict_with_paths + + +def set_dataset_to_absolute_path(dataset, dataset_path): + """ + Set all the path in input data to absolute path. + Checks dataset_dir and project_dir for valid paths. + """ + dataset_dir = os.path.dirname(dataset_path) + dataset = dataset.map( + lambda item: convert_to_absolute_paths(item, dataset_dir)) + print(f"transfer {dataset.count()} sample's paths") + return dataset + + +def ray_batch_mapper_wrapper(samples, fn): + samples = samples.to_pandas() + res = fn(samples) + if not isinstance(res, pd.DataFrame): + res = pd.DataFrame(res) + return pa.Table.from_pandas(res) + + class RayExecutor: """ Executor based on Ray [Experimental]. @@ -47,6 +94,10 @@ def run(self, load_data_np=None): logger.info('Loading dataset with Ray...') dataset = rd.read_json(self.cfg.dataset_path) + # convert all the path in dataset to absolute path + dataset = set_dataset_to_absolute_path(dataset, self.cfg.dataset_path) + for items in dataset.iter_rows(): + print('item is:', items) # 2. extract processes logger.info('Preparing process operators...') self.process_list, self.ops = load_ops(self.cfg.process, @@ -57,14 +108,27 @@ def run(self, load_data_np=None): # - If checkpoint is open, clean the cache files after each process if Fields.stats not in dataset.columns(fetch_if_missing=False): logger.info(f'columns {dataset.columns(fetch_if_missing=False)}') - dataset = dataset.add_column( - Fields.stats, lambda df: [{} for _ in range(len(df))]) + + def process_batch_arrow(table: pa.Table) -> pa.Table: + new_column_data = [{} for _ in range(len(table))] + new_talbe = table.append_column(Fields.stats, + [new_column_data]) + return new_talbe + + dataset = dataset.map_batches(process_batch_arrow, + batch_format='pyarrow') + logger.info('Processing data...') for op_cfg, op in zip(self.process_list, self.ops): op_name, _ = list(op_cfg.items())[0] try: if isinstance(op, Mapper): - dataset = dataset.map(op.process) + if op.is_batched_op(): + dataset = dataset.map_batches(partial( + ray_batch_mapper_wrapper, fn=op.process), + batch_format='pyarrow') + else: + dataset = dataset.map(op.process) elif isinstance(op, Filter): dataset = dataset.map(op.compute_stats) dataset = dataset.filter(op.process) @@ -79,10 +143,6 @@ def run(self, load_data_np=None): traceback.print_exc() exit(1) - # clean up cache files and record processed ops - logger.info(f'Op [{op_name}] Done. Left ' - f'{dataset.count()} samples.') - # 4. data export logger.info('Exporting dataset to disk...') dataset.write_json(self.cfg.export_path, force_ascii=False) diff --git a/data_juicer/format/formatter.py b/data_juicer/format/formatter.py index 653a6b251..987a3667b 100644 --- a/data_juicer/format/formatter.py +++ b/data_juicer/format/formatter.py @@ -218,12 +218,15 @@ def non_empty_text(sample, target_keys): ds_dir = global_cfg.dataset_dir image_key = global_cfg.image_key audio_key = global_cfg.audio_key + video_key = global_cfg.video_key data_path_keys = [] if image_key in dataset.features: data_path_keys.append(image_key) if audio_key in dataset.features: data_path_keys.append(audio_key) + if video_key in dataset.features: + data_path_keys.append(video_key) if len(data_path_keys) == 0: # no image/audios path list in dataset, no need to convert return dataset diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 6b494e5ec..d60d34f1c 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -1,3 +1,5 @@ +import copy + from data_juicer.utils.registry import Registry OPERATORS = Registry('Operators') @@ -10,6 +12,7 @@ def __init__( text_key: str = None, image_key: str = None, audio_key: str = None, + video_key: str = None, ): """ Base class of operators. @@ -20,6 +23,8 @@ def __init__( to be processed :param audio_key: the key name of field that stores sample audio list to be processed + :param video_key: the key name of field that stores sample video list + to be processed """ # init data keys if text_key is None: @@ -31,6 +36,9 @@ def __init__( if audio_key is None: audio_key = 'audios' self.audio_key = audio_key + if video_key is None: + video_key = 'videos' + self.video_key = video_key self._accelerator = 'cpu' from data_juicer.core.data import wrap_func_with_nested_access @@ -39,15 +47,42 @@ def __init__( def process(self, *args, **kwargs): raise NotImplementedError + def remove_extra_parameters(self, param_dict, keys=None): + """ + at the begining of the init of the mapper op, call + self.remove_extra_parameters(locals()) + to get the init parameter dict of the op for convenience + + """ + if keys is None: + param_dict = { + k: v + for k, v in param_dict.items() if not k.startswith('_') + } + param_dict.pop('self', None) + else: + param_dict = {k: v for k, v in param_dict.items() if k not in keys} + return param_dict + + def add_parameters(self, init_parameter_dict, **extra_param_dict): + """ + add parameters for each sample, need to keep extra_param_dict + and init_parameter_dict unchanged. + """ + related_parameters = copy.deepcopy(init_parameter_dict) + related_parameters.update(extra_param_dict) + return related_parameters + class Mapper(OP): - def __init__( - self, - text_key: str = None, - image_key: str = None, - audio_key: str = None, - ): + def __init__(self, + text_key: str = None, + image_key: str = None, + audio_key: str = None, + video_key: str = None, + *args, + **kwargs): """ Base class that conducts data editing. @@ -57,8 +92,10 @@ def __init__( to be processed :param audio_key: the key name of field that stores sample audio list to be processed + :param video_key: the key name of field that stores sample video list + to be processed """ - super(Mapper, self).__init__(text_key, image_key, audio_key) + super(Mapper, self).__init__(text_key, image_key, audio_key, video_key) # In default, it's a normal OP instead of batched OP self._batched_op = False @@ -78,12 +115,13 @@ def is_batched_op(self): class Filter(OP): - def __init__( - self, - text_key: str = None, - image_key: str = None, - audio_key: str = None, - ): + def __init__(self, + text_key: str = None, + image_key: str = None, + audio_key: str = None, + video_key: str = None, + *args, + **kwargs): """ Base class that removes specific info. @@ -93,8 +131,10 @@ def __init__( to be processed :param audio_key: the key name of field that stores sample audio list to be processed + :param video_key: the key name of field that stores sample video list + to be processed """ - super(Filter, self).__init__(text_key, image_key, audio_key) + super(Filter, self).__init__(text_key, image_key, audio_key, video_key) from data_juicer.core.data import wrap_func_with_nested_access self.compute_stats = wrap_func_with_nested_access(self.compute_stats) @@ -123,12 +163,13 @@ def process(self, sample): class Deduplicator(OP): - def __init__( - self, - text_key: str = None, - image_key: str = None, - audio_key: str = None, - ): + def __init__(self, + text_key: str = None, + image_key: str = None, + audio_key: str = None, + video_key: str = None, + *args, + **kwargs): """ Base class that conducts deduplication. @@ -138,8 +179,11 @@ def __init__( to be processed :param audio_key: the key name of field that stores sample audio list to be processed + :param video_key: the key name of field that stores sample video list + to be processed """ - super(Deduplicator, self).__init__(text_key, image_key, audio_key) + super(Deduplicator, self).__init__(text_key, image_key, audio_key, + video_key) from data_juicer.core.data import wrap_func_with_nested_access self.compute_hash = wrap_func_with_nested_access(self.compute_hash) @@ -167,12 +211,13 @@ def process(self, dataset, show_num=0): class Selector(OP): - def __init__( - self, - text_key: str = None, - image_key: str = None, - audio_key: str = None, - ): + def __init__(self, + text_key: str = None, + image_key: str = None, + audio_key: str = None, + video_key: str = None, + *args, + **kwargs): """ Base class that conducts selection in dataset-level. @@ -182,8 +227,11 @@ def __init__( to be processed :param audio_key: the key name of field that stores sample audio list to be processed + :param video_key: the key name of field that stores sample video list + to be processed """ - super(Selector, self).__init__(text_key, image_key, audio_key) + super(Selector, self).__init__(text_key, image_key, audio_key, + video_key) def process(self, dataset): """ diff --git a/data_juicer/ops/deduplicator/__init__.py b/data_juicer/ops/deduplicator/__init__.py index 7cee3ebb7..65dfb05c2 100644 --- a/data_juicer/ops/deduplicator/__init__.py +++ b/data_juicer/ops/deduplicator/__init__.py @@ -1,2 +1,3 @@ from . import (document_deduplicator, document_minhash_deduplicator, - document_simhash_deduplicator, image_deduplicator) + document_simhash_deduplicator, image_deduplicator, + video_deduplicator) diff --git a/data_juicer/ops/deduplicator/image_deduplicator.py b/data_juicer/ops/deduplicator/image_deduplicator.py index 0553d5c8c..2ca191c66 100644 --- a/data_juicer/ops/deduplicator/image_deduplicator.py +++ b/data_juicer/ops/deduplicator/image_deduplicator.py @@ -4,8 +4,8 @@ import numpy as np from data_juicer.utils.availability_utils import AvailabilityChecking -from data_juicer.utils.constant import Fields, HashKeys -from data_juicer.utils.mm_utils import load_image +from data_juicer.utils.constant import HashKeys +from data_juicer.utils.mm_utils import load_data_with_context, load_image from ..base_op import OPERATORS, Deduplicator from ..op_fusion import LOADED_IMAGES @@ -57,20 +57,8 @@ def compute_hash(self, sample, context=False): # load images loaded_image_keys = sample[self.image_key] - images = {} - for loaded_image_key in loaded_image_keys: - if context and loaded_image_key in sample[Fields.context]: - # load from context - images[loaded_image_key] = sample[ - Fields.context][loaded_image_key] - else: - if loaded_image_key not in images: - # avoid load the same images - image = load_image(loaded_image_key) - images[loaded_image_key] = image - if context: - # store the image data into context - sample[Fields.context][loaded_image_key] = image + sample, images = load_data_with_context(sample, context, + loaded_image_keys, load_image) # compute hash for key in images: diff --git a/data_juicer/ops/deduplicator/video_deduplicator.py b/data_juicer/ops/deduplicator/video_deduplicator.py new file mode 100644 index 000000000..d06de9f69 --- /dev/null +++ b/data_juicer/ops/deduplicator/video_deduplicator.py @@ -0,0 +1,104 @@ +import hashlib +from collections import defaultdict +from typing import Dict, Set + +from data_juicer.utils.constant import HashKeys +from data_juicer.utils.mm_utils import load_data_with_context, load_video + +from ..base_op import OPERATORS, Deduplicator +from ..op_fusion import LOADED_VIDEOS + +OP_NAME = 'video_deduplicator' + + +@OPERATORS.register_module(OP_NAME) +@LOADED_VIDEOS.register_module(OP_NAME) +class VideoDeduplicator(Deduplicator): + """ + Deduplicator to deduplicate samples at document-level using exact matching + of videos between documents. + """ + + def __init__(self, *args, **kwargs): + """ + Initialization. + + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + + def compute_hash(self, sample, context=False): + # check if it's computed already + if HashKeys.videohash in sample: + return sample + + # there is no video in this sample + sample[HashKeys.videohash] = '' + if self.video_key not in sample or not sample[self.video_key]: + return sample + + # load videos + loaded_video_keys = sample[self.video_key] + sample, videos = load_data_with_context(sample, context, + loaded_video_keys, load_video) + + # compute hash + md5_hash = hashlib.md5() + for key in videos: + # consider the multi stream of video in one container + for packet in videos[key].demux(): + if packet.stream.type == 'video': + md5_hash.update(packet.to_bytes()) + + sample[HashKeys.videohash] = md5_hash.hexdigest() + return sample + + def process(self, dataset, show_num=0): + """ + For doc-level, dataset --> dataset. + + :param dataset: input dataset + :param show_num: number of traced samples used when tracer is + open. + :return: deduplicated dataset and the sampled duplicate pairs. + """ + # no need to deduplicate because too few samples + if len(dataset) <= 1: + return dataset, {} + + dup_hashes = None + if show_num > 0: + # sample duplicate pairs + hash2ids: Dict[int, Set[int]] = defaultdict(set) + for sid, hash_val in enumerate(dataset[HashKeys.videohash]): + if hash_val: + hash2ids[hash_val].add(sid) + dup_samples = sorted(list(hash2ids.items()), + key=lambda x: len(x[1]), + reverse=True) + dup_hashes = set([ + item[0] for item in dup_samples if len(item[1]) > 1 + ][:show_num]) + + def _filter_dup_helper(sample, hashes): + hash = sample[HashKeys.videohash] + if not hash: + return True + if show_num > 0 and hash in dup_hashes \ + and len(dup_pairs[hash]) < 2: + # tracer is open and not enough duplicate sample pairs + dup_pairs[hash].append(sample) + if hash in hashes: + return False + else: + hashes.add(hash) + return True + + hashes = set() + dup_pairs = {hash_v: [] for hash_v in dup_hashes} if dup_hashes else {} + dataset = dataset.filter( + _filter_dup_helper, + fn_kwargs=dict(hashes=hashes), + load_from_cache_file=False if show_num > 0 else True) # num_proc=1 + return dataset, dup_pairs diff --git a/data_juicer/ops/filter/__init__.py b/data_juicer/ops/filter/__init__.py index d5233eabd..5c94a7acd 100644 --- a/data_juicer/ops/filter/__init__.py +++ b/data_juicer/ops/filter/__init__.py @@ -2,7 +2,7 @@ from . import (alphanumeric_filter, audio_duration_filter, audio_nmf_snr_filter, audio_size_filter, average_line_length_filter, character_repetition_filter, - face_area_filter, flagged_words_filter, + face_area_filter, flagged_words_filter, image_aesthetics_filter, image_aspect_ratio_filter, image_shape_filter, image_size_filter, image_text_matching_filter, image_text_similarity_filter, language_id_score_filter, @@ -11,6 +11,10 @@ specified_field_filter, specified_numeric_field_filter, stopwords_filter, suffix_filter, text_action_filter, text_entity_dependency_filter, text_length_filter, - token_num_filter, word_num_filter, word_repetition_filter) + token_num_filter, video_aesthetics_filter, + video_aspect_ratio_filter, video_duration_filter, + video_frames_text_similarity_filter, video_motion_score_filter, + video_ocr_area_ratio_filter, video_resolution_filter, + word_num_filter, word_repetition_filter) # yapf: enable diff --git a/data_juicer/ops/filter/image_aesthetics_filter.py b/data_juicer/ops/filter/image_aesthetics_filter.py new file mode 100644 index 000000000..1f69ef681 --- /dev/null +++ b/data_juicer/ops/filter/image_aesthetics_filter.py @@ -0,0 +1,127 @@ +import numpy as np +from jsonargparse.typing import ClosedUnitInterval +from loguru import logger + +from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.constant import Fields, StatsKeys +from data_juicer.utils.mm_utils import load_data_with_context, load_image + +from ...utils.model_utils import get_model, prepare_model +from ..base_op import OPERATORS, Filter +from ..op_fusion import LOADED_IMAGES + +OP_NAME = 'image_aesthetics_filter' +CHECK_PKGs = ['torch', 'transformers', 'simple-aesthetics-predictor'] + +with AvailabilityChecking(CHECK_PKGs, OP_NAME): + + import aesthetics_predictor # noqa: F401 + import torch + import transformers # noqa: F401 + + # avoid hanging when calling clip in multiprocessing + torch.set_num_threads(1) + + +@OPERATORS.register_module(OP_NAME) +@LOADED_IMAGES.register_module(OP_NAME) +class ImageAestheticsFilter(Filter): + """Filter to keep samples with aesthetics scores within a specific range. + """ + + def __init__(self, + hf_scorer_model='', + min_score: ClosedUnitInterval = 0.5, + max_score: ClosedUnitInterval = 1.0, + any_or_all: str = 'any', + *args, + **kwargs): + """ + Initialization method. + + :param hf_scorer_model: Huggingface model name for the aesthetics + predictor. By default, we will use + 'shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE', + refer to pypi.org/project/simple-aesthetics-predictor + :param min_score: Min score for the predicted aesthetics in an image. + :param max_score: Max score for the predicted aesthetics in an image. + :param any_or_all: Keep this sample with 'any' or 'all' strategy of + all images. 'any': keep this sample if any images meet the + condition. 'all': keep this sample only if all images meet the + condition. + :param args: Extra positional arguments. + :param kwargs: Extra keyword arguments. + """ + + super().__init__(*args, **kwargs) + if hf_scorer_model == '': + hf_scorer_model = \ + 'shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE' + self.min_score = min_score + self.max_score = max_score + + if any_or_all not in ['any', 'all']: + raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' + f'Can only be one of ["any", "all"].') + self.any = (any_or_all == 'any') + + self.model_key = prepare_model( + model_type='simple_aesthetics', + pretrained_model_name_or_path=hf_scorer_model) + # the original score predicted by laion-ai's scorer is within [0, 10] + self.need_normalized_by_ten = ('shunk031/aesthetics-predictor' + in hf_scorer_model) + self._accelerator = 'cuda' + + def compute_stats(self, sample, rank=None, context=False): + # check if it's computed already + if StatsKeys.image_aesthetics_scores in sample[Fields.stats]: + return sample + + # there is no image in this sample + if self.image_key not in sample or not sample[self.image_key]: + sample[Fields.stats][StatsKeys.image_aesthetics_scores] = np.array( + [], dtype=np.float64) + return sample + + # load images + loaded_image_keys = sample[self.image_key] + sample, images = load_data_with_context(sample, context, + loaded_image_keys, load_image) + + # compute aesthetics_scores + model, processor = get_model(self.model_key, rank=rank) + inputs = processor(images=list(images.values()), return_tensors='pt') + with torch.no_grad(): + outputs = model(**inputs) + if self.need_normalized_by_ten: + aesthetics_scores = (outputs.logits / 10.0).detach().cpu() + else: + aesthetics_scores = outputs.logits.detach().cpu() + + aesthetics_scores = [ + aesthetics_score.item() for aesthetics_score in aesthetics_scores + ] + + logger.debug(f'aesthetics_scores: {aesthetics_scores}') + + sample[Fields.stats][StatsKeys.image_aesthetics_scores] =\ + aesthetics_scores + return sample + + def process(self, sample): + aesthetics_scores = ( + sample)[Fields.stats][StatsKeys.image_aesthetics_scores] + if len(aesthetics_scores) <= 0: + return True + + keep_bools = np.array([ + self.min_score <= aesthetics_score <= self.max_score + for aesthetics_score in aesthetics_scores + ]) + + # different strategies + if self.any: + return keep_bools.any() + else: + return keep_bools.all() diff --git a/data_juicer/ops/filter/image_text_similarity_filter.py b/data_juicer/ops/filter/image_text_similarity_filter.py index 5efe2f107..a129d8abb 100644 --- a/data_juicer/ops/filter/image_text_similarity_filter.py +++ b/data_juicer/ops/filter/image_text_similarity_filter.py @@ -25,7 +25,7 @@ @OPERATORS.register_module(OP_NAME) @LOADED_IMAGES.register_module(OP_NAME) class ImageTextSimilarityFilter(Filter): - """Filter to keep samples those similarity between image and text + """Filter to keep samples those similarities between image and text within a specific range.""" def __init__(self, diff --git a/data_juicer/ops/filter/video_aesthetics_filter.py b/data_juicer/ops/filter/video_aesthetics_filter.py new file mode 100644 index 000000000..cf5505dfa --- /dev/null +++ b/data_juicer/ops/filter/video_aesthetics_filter.py @@ -0,0 +1,190 @@ +import numpy as np +from jsonargparse.typing import ClosedUnitInterval, PositiveInt +from loguru import logger + +from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.constant import Fields, StatsKeys +from data_juicer.utils.mm_utils import (extract_key_frames, + extract_video_frames_uniformly, + load_data_with_context, load_video) + +from ...utils.model_utils import get_model, prepare_model +from ..base_op import OPERATORS, Filter +from ..op_fusion import LOADED_VIDEOS + +OP_NAME = 'video_aesthetics_filter' +CHECK_PKGS = ['torch', 'transformers', 'simple-aesthetics-predictor'] + +with AvailabilityChecking(CHECK_PKGS, OP_NAME): + + import aesthetics_predictor # noqa: F401 + import torch + import transformers # noqa: F401 + + # avoid hanging when calling clip in multiprocessing + torch.set_num_threads(1) + + +@OPERATORS.register_module(OP_NAME) +@LOADED_VIDEOS.register_module(OP_NAME) +class VideoAestheticsFilter(Filter): + """Filter to keep data samples with aesthetics scores for specified frames + in the videos within a specific range. + """ + + def __init__(self, + hf_scorer_model='', + min_score: ClosedUnitInterval = 0.4, + max_score: ClosedUnitInterval = 1.0, + frame_sampling_method: str = 'uniform', + frame_num: PositiveInt = 3, + any_or_all: str = 'any', + reduce_mode: str = 'avg', + *args, + **kwargs): + """ + Initialization method. + + :param hf_scorer_model: Huggingface model name for the aesthetics + predictor. By default, we will use + 'shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE', + refer to pypi.org/project/simple-aesthetics-predictor + :param min_score: Min score for the predicted aesthetics in a video. + :param max_score: Max score for the predicted aesthetics in a video. + :param frame_sampling_method: sampling method of extracting frame + images from the videos. + Should be one of ["all_keyframes", "uniform"]. + The former one extracts all key frames and the latter one extract + specified number of frames uniformly from the video. + Default: "uniform" with frame_num=3, considering that the number of + keyframes can be large while their difference is usually small + in terms of their aesthetics. + :param frame_num: the number of frames to be extracted uniformly from + the video. Only works when frame_sampling_method is "uniform". If + it's 1, only the middle frame will be extracted. If it's 2, only + the first and the last frames will be extracted. If it's larger + than 2, in addition to the first and the last frames, other frames + will be extracted uniformly within the video duration. + :param any_or_all: Keep this sample with 'any' or 'all' strategy of + all videos. 'any': keep this sample if any videos meet the + condition. 'all': keep this sample only if all videos meet the + condition. + :param reduce_mode: reduce mode when one sample corresponds to + multiple frames, must be one of ['avg','max', 'min']. + 'avg': Take the average of multiple values + 'max': Take the max of multiple values + 'min': Take the min of multiple values + :param args: Extra positional arguments. + :param kwargs: Extra keyword arguments. + """ + + super().__init__(*args, **kwargs) + if hf_scorer_model == '': + hf_scorer_model = \ + 'shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE' + self.min_score = min_score + self.max_score = max_score + + if frame_sampling_method not in ['all_keyframes', 'uniform']: + raise ValueError( + f'Frame sampling method ' + f'[{frame_sampling_method}] is not supported. ' + f'Can only be one of ["all_keyframes", "uniform"].') + + if any_or_all not in ['any', 'all']: + raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' + f'Can only be one of ["any", "all"].') + if reduce_mode not in ['avg', 'max', 'min']: + raise ValueError(f'Reduce mode [{reduce_mode}] is not supported. ' + f'Can only be one of ["avg", "max", "min"].') + self.any = (any_or_all == 'any') + self.reduce_mode = reduce_mode + + self.model_key = prepare_model( + model_type='simple_aesthetics', + pretrained_model_name_or_path=hf_scorer_model) + # the original score predicted by laion-ai's scorer is within [0, 10] + self.need_normalized_by_ten = ('shunk031/aesthetics-predictor' + in hf_scorer_model) + self._accelerator = 'cuda' + self.frame_sampling_method = frame_sampling_method + self.frame_num = frame_num + + def compute_stats(self, sample, rank=None, context=False): + # check if it's computed already + if StatsKeys.video_frames_aesthetics_score in sample[Fields.stats]: + return sample + + # there is no video in this sample + if self.video_key not in sample or not sample[self.video_key]: + sample[Fields.stats][StatsKeys.video_frames_aesthetics_score] = ( + np.array([], dtype=np.float64)) + return sample + + # load videos + loaded_video_keys = sample[self.video_key] + sample, videos = load_data_with_context(sample, context, + loaded_video_keys, load_video) + + aesthetics_scores = [] + for video in videos: + all_frames = [] + if video is None: + continue + else: + # extract frame images + if self.frame_sampling_method == 'all_keyframes': + frames = extract_key_frames(video) + elif self.frame_sampling_method == 'uniform': + frames = extract_video_frames_uniformly( + video, self.frame_num) + else: + frames = [] + all_frames.extend(frames) + frame_images = [frame.to_image() for frame in all_frames] + + # compute aesthetics_scores + model, processor = get_model(self.model_key, rank=rank) + inputs = processor(images=frame_images, return_tensors='pt') + with torch.no_grad(): + outputs = model(**inputs) + if self.need_normalized_by_ten: + aesthetics_score = (outputs.logits / 10.0).detach().cpu() + else: + aesthetics_score = outputs.logits.detach().cpu() + + if self.reduce_mode == 'avg': + aesthetics_score = float(aesthetics_score.mean()) + elif self.reduce_mode == 'max': + aesthetics_score = float(aesthetics_score.max()) + else: + aesthetics_score = float(aesthetics_score.min()) + aesthetics_scores.append(aesthetics_score) + + logger.debug(f'aesthetics_score: {aesthetics_scores}') + + sample[Fields.stats][StatsKeys.video_frames_aesthetics_score] = ( + aesthetics_scores) + + if not context: + for vid_key in videos: + videos[vid_key].close() + + return sample + + def process(self, sample): + aesthetics_scores = ( + sample)[Fields.stats][StatsKeys.video_frames_aesthetics_score] + if len(aesthetics_scores) <= 0: + return True + + keep_bools = np.array([ + self.min_score <= aesthetics_score <= self.max_score + for aesthetics_score in aesthetics_scores + ]) + + # different strategies + if self.any: + return keep_bools.any() + else: + return keep_bools.all() diff --git a/data_juicer/ops/filter/video_aspect_ratio_filter.py b/data_juicer/ops/filter/video_aspect_ratio_filter.py new file mode 100644 index 000000000..4bea08827 --- /dev/null +++ b/data_juicer/ops/filter/video_aspect_ratio_filter.py @@ -0,0 +1,94 @@ +from fractions import Fraction + +import numpy as np + +from data_juicer.utils.constant import Fields, StatsKeys +from data_juicer.utils.mm_utils import load_data_with_context, load_video + +from ..base_op import OPERATORS, Filter +from ..op_fusion import LOADED_VIDEOS + + +@OPERATORS.register_module('video_aspect_ratio_filter') +@LOADED_VIDEOS.register_module('video_aspect_ratio_filter') +class VideoAspectRatioFilter(Filter): + """Filter to keep samples with video aspect ratio within a specific range. + AspectRatio = W / H. + """ + + def __init__(self, + min_ratio: str = '9/21', + max_ratio: str = '21/9', + any_or_all: str = 'any', + *args, + **kwargs): + """ + Initialization method. + + :param min_ratio: The minimum aspect ratio to keep samples, + supported format is a string, such as "9:21" or "9/21". + :param max_ratio: The maximum aspect ratio to keep samples, + supported format is a string, such as "21:9" or "21/9". + :param any_or_all: keep this sample with 'any' or 'all' strategy of + all videos. 'any': keep this sample if any videos meet the + condition. 'all': keep this sample only if all videos meet the + condition. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.min_ratio = Fraction(str(min_ratio).replace(':', '/')) + self.max_ratio = Fraction(str(max_ratio).replace(':', '/')) + if any_or_all not in ['any', 'all']: + raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' + f'Can only be one of ["any", "all"].') + self.any = (any_or_all == 'any') + + def compute_stats(self, sample, context=False): + # check if it's computed already + if StatsKeys.video_aspect_ratios in sample[Fields.stats]: + return sample + + # there is no video in this sample + if self.video_key not in sample or not sample[self.video_key]: + sample[Fields.stats][StatsKeys.video_aspect_ratios] = np.array( + [], dtype=np.float64) + return sample + + # load videos + loaded_video_keys = sample[self.video_key] + sample, videos = load_data_with_context(sample, context, + loaded_video_keys, load_video) + + # compute aspect ratios for each video with W/H + video_aspect_ratios = {} + for key, video in videos.items(): + stream = video.streams.video[0] + video_aspect_ratios[key] = str( + Fraction(stream.codec_context.width, + stream.codec_context.height)) + if not context: + video.close() + + sample[Fields.stats][StatsKeys.video_aspect_ratios] = [ + video_aspect_ratios[key] for key in loaded_video_keys + ] + + return sample + + def process(self, sample): + video_aspect_ratios = sample[Fields.stats][ + StatsKeys.video_aspect_ratios] + + keep_bools = np.array([ + self.min_ratio <= Fraction(aspect_ratio) <= self.max_ratio + for aspect_ratio in video_aspect_ratios + ]) + if len(keep_bools) <= 0: + return True + + # different strategies + if self.any: + return keep_bools.any() + else: + return keep_bools.all() diff --git a/data_juicer/ops/filter/video_duration_filter.py b/data_juicer/ops/filter/video_duration_filter.py new file mode 100644 index 000000000..e65a05c65 --- /dev/null +++ b/data_juicer/ops/filter/video_duration_filter.py @@ -0,0 +1,93 @@ +import sys + +import numpy as np +from jsonargparse.typing import NonNegativeInt + +from data_juicer.utils.constant import Fields, StatsKeys +from data_juicer.utils.mm_utils import load_data_with_context, load_video + +from ..base_op import OPERATORS, Filter +from ..op_fusion import LOADED_VIDEOS + +OP_NAME = 'video_duration_filter' + + +@OPERATORS.register_module(OP_NAME) +@LOADED_VIDEOS.register_module(OP_NAME) +class VideoDurationFilter(Filter): + """Keep data samples whose videos' durations are within a specified range. + """ + + def __init__(self, + min_duration: NonNegativeInt = 0, + max_duration: NonNegativeInt = sys.maxsize, + any_or_all: str = 'any', + *args, + **kwargs): + """ + Initialization method. + + :param min_duration: The min video duration to keep samples in seconds. + It's 0 by default. + :param max_duration: The max video duration to keep samples in seconds. + It's sys.maxsize by default. + :param any_or_all: keep this sample with 'any' or 'all' strategy of + all videos. 'any': keep this sample if any videos meet the + condition. 'all': keep this sample only if all videos meet the + condition. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.min_duration = min_duration + self.max_duration = max_duration + if any_or_all not in ['any', 'all']: + raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' + f'Can only be one of ["any", "all"].') + self.any = (any_or_all == 'any') + + def compute_stats(self, sample, context=False): + # check if it's computed already + if StatsKeys.video_duration in sample[Fields.stats]: + return sample + + # there is no video in this sample + if self.video_key not in sample or not sample[self.video_key]: + sample[Fields.stats][StatsKeys.video_duration] = np.array( + [], dtype=np.float64) + return sample + + # load videos + loaded_video_keys = sample[self.video_key] + sample, videos = load_data_with_context(sample, context, + loaded_video_keys, load_video) + + video_durations = {} + for video_key, video in videos.items(): + stream = video.streams.video[0] + video_durations[video_key] = round(stream.duration * + stream.time_base) + if not context: + video.close() + + # get video durations + sample[Fields.stats][StatsKeys.video_duration] = [ + video_durations[video_key] for video_key in sample[self.video_key] + ] + + return sample + + def process(self, sample): + video_durations = sample[Fields.stats][StatsKeys.video_duration] + keep_bools = np.array([ + self.min_duration <= duration <= self.max_duration + for duration in video_durations + ]) + if len(keep_bools) <= 0: + return True + + # different strategies + if self.any: + return keep_bools.any() + else: + return keep_bools.all() diff --git a/data_juicer/ops/filter/video_frames_text_similarity_filter.py b/data_juicer/ops/filter/video_frames_text_similarity_filter.py new file mode 100644 index 000000000..7f1ddf23b --- /dev/null +++ b/data_juicer/ops/filter/video_frames_text_similarity_filter.py @@ -0,0 +1,200 @@ +import numpy as np +from jsonargparse.typing import ClosedUnitInterval, PositiveInt +from PIL import ImageOps + +from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.constant import Fields, StatsKeys +from data_juicer.utils.mm_utils import (SpecialTokens, extract_key_frames, + extract_video_frames_uniformly, + load_data_with_context, load_video, + remove_special_tokens) +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, Filter +from ..op_fusion import LOADED_VIDEOS + +OP_NAME = 'video_frames_text_similarity_filter' + +with AvailabilityChecking(['torch', 'transformers'], OP_NAME): + + import torch + import transformers # noqa: F401 + + # avoid hanging when calling clip in multiprocessing + torch.set_num_threads(1) + + +@OPERATORS.register_module(OP_NAME) +@LOADED_VIDEOS.register_module(OP_NAME) +class VideoFramesTextSimilarityFilter(Filter): + """Filter to keep samples those similarities between sampled video frame + images and text within a specific range.""" + + def __init__(self, + hf_clip='openai/clip-vit-base-patch32', + min_score: ClosedUnitInterval = 0.1, + max_score: ClosedUnitInterval = 1.0, + frame_sampling_method: str = 'all_keyframes', + frame_num: PositiveInt = 3, + horizontal_flip: bool = False, + vertical_flip: bool = False, + any_or_all: str = 'any', + reduce_mode: str = 'avg', + *args, + **kwargs): + """ + Initialization method. + + :param hf_clip: clip model name on huggingface to compute + the similarity between frame image and text. It's kind of + language-related. For example, for Chinese datasets, ChineseCLIP + might be a better choice. + :param min_score: the min similarity to keep samples. + :param max_score: the max similarity to keep samples. + :param frame_sampling_method: sampling method of extracting frame + images from the videos. + Should be one of ["all_keyframes", "uniform"]. + The former one extracts all key frames (the number of which depends + on the duration of the video) and the latter one extract specified + number of frames uniformly from the video. + Default: "all_keyframes". + :param frame_num: the number of frames to be extracted uniformly from + the video. Only works when frame_sampling_method is "uniform". If + it's 1, only the middle frame will be extracted. If it's 2, only + the first and the last frames will be extracted. If it's larger + than 2, in addition to the first and the last frames, other frames + will be extracted uniformly within the video duration. + :param horizontal_flip: flip frame image horizontally (left to right). + :param vertical_flip: flip frame image vertically (top to bottom). + :param any_or_all: keep this sample with 'any' or 'all' strategy of + all videos. 'any': keep this sample if any videos meet the + condition. 'all': keep this sample only if all videos meet the + condition. + :param reduce_mode: reduce mode when one text corresponds to + multiple video frame images in a chunk. + 'avg': Take the average of multiple values + 'max': Take the max of multiple values + 'min': Take the min of multiple values + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.min_score = min_score + self.max_score = max_score + if frame_sampling_method not in ['all_keyframes', 'uniform']: + raise ValueError( + f'Frame sampling method ' + f'[{frame_sampling_method}] is not supported. ' + f'Can only be one of ["all_keyframes", "uniform"].') + if reduce_mode not in ['avg', 'max', 'min']: + raise ValueError(f'Reduce mode [{reduce_mode}] is not supported. ' + f'Can only be one of ["avg", "max", "min"].') + if any_or_all not in ['any', 'all']: + raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' + f'Can only be one of ["any", "all"].') + self.any = (any_or_all == 'any') + self.model_key = prepare_model(model_type='huggingface', + pretrained_model_name_or_path=hf_clip) + self._accelerator = 'cuda' + self.reduce_mode = reduce_mode + self.horizontal_flip = horizontal_flip + self.vertical_flip = vertical_flip + self.frame_sampling_method = frame_sampling_method + self.frame_num = frame_num + + def compute_stats(self, sample, rank=None, context=False): + # check if it's computed already + if StatsKeys.video_frames_text_matching_score in sample[Fields.stats]: + return sample + + # there is no videos in this sample + if self.video_key not in sample or not sample[self.video_key]: + sample[Fields.stats][ + StatsKeys.video_frames_text_matching_score] = np.array( + [], dtype=np.float64) + return sample + + # load videos + loaded_video_keys = sample[self.video_key] + sample, videos = load_data_with_context(sample, context, + loaded_video_keys, load_video) + + text = sample[self.text_key] + offset = 0 + similarity = [] + model, processor = get_model(self.model_key, rank=rank) + + for chunk in text.split(SpecialTokens.eoc): + count = chunk.count(SpecialTokens.video) + + # no video or no text + if count == 0 or len(chunk) == 0: + continue + else: + text_chunk = remove_special_tokens(chunk) + video_frame_images_chunk = [] + for video_key in loaded_video_keys[offset:offset + count]: + video = videos[video_key] + + # extract frame images + if self.frame_sampling_method == 'all_keyframes': + frames = extract_key_frames(video) + elif self.frame_sampling_method == 'uniform': + frames = extract_video_frames_uniformly( + video, self.frame_num) + else: + frames = [] + + frame_images = [frame.to_image() for frame in frames] + for image in frame_images: + if self.horizontal_flip: + image = ImageOps.mirror(image) + if self.vertical_flip: + image = ImageOps.flip(image) + video_frame_images_chunk.append(image) + + inputs = processor(text=text_chunk, + images=video_frame_images_chunk, + return_tensors='pt', + truncation=True, + max_length=model.config.text_config. + max_position_embeddings, + padding=True).to(model.device) + + outputs = model(**inputs) + chunk_logits = outputs.logits_per_text.detach().cpu() / 100.0 + + if self.reduce_mode == 'avg': + chunk_similarity = chunk_logits.mean() + elif self.reduce_mode == 'max': + chunk_similarity = chunk_logits.max() + else: + chunk_similarity = chunk_logits.min() + + similarity.append(float(chunk_similarity)) + offset += count + sample[Fields.stats][ + StatsKeys.video_frames_text_matching_score] = similarity + + if not context: + for vid_key in videos: + videos[vid_key].close() + + return sample + + def process(self, sample, rank=None): + similarity = sample[Fields.stats][ + StatsKeys.video_frames_text_matching_score] + if len(similarity) <= 0: + return True + + keep_bools = np.array([ + self.min_score <= sim_value <= self.max_score + for sim_value in similarity + ]) + + # different strategies + if self.any: + return keep_bools.any() + else: + return keep_bools.all() diff --git a/data_juicer/ops/filter/video_motion_score_filter.py b/data_juicer/ops/filter/video_motion_score_filter.py new file mode 100644 index 000000000..fe94d2485 --- /dev/null +++ b/data_juicer/ops/filter/video_motion_score_filter.py @@ -0,0 +1,162 @@ +import sys +from contextlib import contextmanager + +import numpy as np +from jsonargparse.typing import PositiveInt + +from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.constant import Fields, StatsKeys + +from ..base_op import OPERATORS, Filter +from ..op_fusion import LOADED_VIDEOS + +OP_NAME = 'video_motion_score_filter' + +with AvailabilityChecking(['opencv-python'], OP_NAME): + import cv2 + + +@contextmanager +def VideoCapture(*args, **kwargs): + cap = cv2.VideoCapture(*args, **kwargs) + try: + yield cap + finally: + cap.release() + + +@OPERATORS.register_module(OP_NAME) +@LOADED_VIDEOS.register_module(OP_NAME) +class VideoMotionScoreFilter(Filter): + """Filter to keep samples with video motion scores within a specific range. The + Farneback's algorith from OpenCV is used to compute dense optical flow. + """ + + _default_kwargs = { + 'pyr_scale': 0.5, + 'levels': 3, + 'winsize': 15, + 'iterations': 3, + 'poly_n': 5, + 'poly_sigma': 1.2, + 'flags': 0 + } + + def __init__(self, + min_score: float = 0.25, + max_score: float = sys.float_info.max, + sampling_fps: PositiveInt = 2, + any_or_all: str = 'any', + *args, + **kwargs): + """ + Initialization method. + + :param min_score: The minimum motion score to keep samples. + :param max_score: The maximum motion score to keep samples. + :param sampling_fps: The samplig rate of frames_per_second to + compute optical flow. + :param any_or_all: keep this sample with 'any' or 'all' strategy of + all videos. 'any': keep this sample if any videos meet the + condition. 'all': keep this sample only if all videos meet the + condition. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.min_score = min_score + self.max_score = max_score + self.sampling_fps = sampling_fps + + self.extra_kwargs = { + k: kwargs.get(k, v) + for k, v in self._default_kwargs.items() + } + + if any_or_all not in ['any', 'all']: + raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' + f'Can only be one of ["any", "all"].') + self.any = (any_or_all == 'any') + + def compute_stats(self, sample, context=False): + # check if it's computed already + if StatsKeys.video_motion_score in sample[Fields.stats]: + return sample + + # there is no video in this sample + if self.video_key not in sample or not sample[self.video_key]: + sample[Fields.stats][StatsKeys.video_motion_score] = np.array( + [], dtype=np.float64) + return sample + + # load videos + loaded_video_keys = sample[self.video_key] + unique_motion_scores = {} + for video_key in loaded_video_keys: + # skip duplicate videos + if video_key in unique_motion_scores: + continue + + video_motion_scores = [] + with VideoCapture(video_key) as cap: + fps = cap.get(cv2.CAP_PROP_FPS) + valid_fps = min(max(self.sampling_fps, 1), fps) + frame_interval = int(fps / valid_fps) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + # if cannot get the second frame, use the last one + frame_interval = min(frame_interval, total_frames - 1) + + prev_frame = None + frame_count = -1 + while cap.isOpened(): + ret, frame = cap.read() + frame_count += 1 + + if not ret: + # If the frame can't be read, it could be due to + # a corrupt frame or reaching the end of the video. + break + + # skip middle frames + if frame_count % frame_interval != 0: + continue + + gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + if prev_frame is None: + prev_frame = gray_frame + continue + + flow = cv2.calcOpticalFlowFarneback( + prev_frame, gray_frame, None, **self.extra_kwargs) + mag, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1]) + frame_motion_score = np.mean(mag) + video_motion_scores.append(frame_motion_score) + prev_frame = gray_frame + + # may due to frame corruption + if not video_motion_scores: + unique_motion_scores[video_key] = -1 + else: + unique_motion_scores[video_key] = np.mean(video_motion_scores) + + sample[Fields.stats][StatsKeys.video_motion_score] = [ + unique_motion_scores[key] for key in loaded_video_keys + ] + return sample + + def process(self, sample): + video_motion_scores = sample[Fields.stats][ + StatsKeys.video_motion_score] + + keep_bools = np.array([ + self.min_score <= motion_score <= self.max_score + for motion_score in video_motion_scores + ]) + if len(keep_bools) <= 0: + return True + + # different strategies + if self.any: + return keep_bools.any() + else: + return keep_bools.all() diff --git a/data_juicer/ops/filter/video_ocr_area_ratio_filter.py b/data_juicer/ops/filter/video_ocr_area_ratio_filter.py new file mode 100644 index 000000000..f61303e46 --- /dev/null +++ b/data_juicer/ops/filter/video_ocr_area_ratio_filter.py @@ -0,0 +1,175 @@ +from typing import List, Union + +import numpy as np +from jsonargparse.typing import ClosedUnitInterval, PositiveInt + +from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.constant import Fields, StatsKeys +from data_juicer.utils.mm_utils import (extract_video_frames_uniformly, + load_data_with_context, load_video) + +from ..base_op import OPERATORS, Filter +from ..op_fusion import LOADED_VIDEOS + +OP_NAME = 'video_ocr_area_ratio_filter' + +with AvailabilityChecking(['easyocr'], OP_NAME): + import easyocr + + +def triangle_area(p1, p2, p3): + """ + Compute the triangle area according to its coordinates. + """ + x1, y1 = p1 + x2, y2 = p2 + x3, y3 = p3 + tri_area = 0.5 * np.abs(x1 * y2 + x2 * y3 + x3 * y1 - x2 * y1 - x3 * y2 - + x1 * y3) + return tri_area + + +@OPERATORS.register_module(OP_NAME) +@LOADED_VIDEOS.register_module(OP_NAME) +class VideoOcrAreaRatioFilter(Filter): + """Keep data samples whose detected text area ratios for specified frames + in the video are within a specified range. + """ + + def __init__(self, + min_area_ratio: ClosedUnitInterval = 0, + max_area_ratio: ClosedUnitInterval = 1.0, + frame_sample_num: PositiveInt = 3, + languages_to_detect: Union[str, List[str]] = ['ch_sim', 'en'], + any_or_all: str = 'any', + *args, + **kwargs): + """ + Initialization method. + + :param min_area_ratio: The min ocr area ratio to keep samples. It's 0 + by default. + :param max_area_ratio: The max ocr area ratio to keep samples. It's 1.0 + by default. + :param frame_sample_num: The number of sampled frames to calculate the + ocr area ratio. If it's 1, only middle frame will be selected. If + it's 2, only the first and the last frames will be selected. If + it's larger than 2, in addition to the first and the last frames, + other frames will be sampled evenly within the video duration. + :param languages_to_detect: texts in which languages should be + detected. Default: ['ch_sim', 'en']. Full language list can be + found here: https://www.jaided.ai/easyocr/. + :param any_or_all: keep this sample with 'any' or 'all' strategy of + all videos. 'any': keep this sample if any videos meet the + condition. 'all': keep this sample only if all videos meet the + condition. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.min_area_ratio = min_area_ratio + self.max_area_ratio = max_area_ratio + self.frame_sample_num = frame_sample_num + if any_or_all not in ['any', 'all']: + raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' + f'Can only be one of ["any", "all"].') + self.any = (any_or_all == 'any') + + # initialize easyocr reader + if isinstance(languages_to_detect, str): + languages_to_detect = [languages_to_detect] + self.reader = easyocr.Reader( + lang_list=languages_to_detect, + recognizer=False, + verbose=False, + ) + + def compute_stats(self, sample, context=False): + # check if it's computed already + if StatsKeys.video_ocr_area_ratio in sample[Fields.stats]: + return sample + + # there is no video in this sample + if self.video_key not in sample or not sample[self.video_key]: + sample[Fields.stats][StatsKeys.video_ocr_area_ratio] = np.array( + [], dtype=np.float64) + return sample + + # load videos + loaded_video_keys = sample[self.video_key] + sample, videos = load_data_with_context(sample, context, + loaded_video_keys, load_video) + + # compute ocr area ratios + video_ocr_area_ratios = {} + for video_key, container in videos.items(): + sampled_frames = extract_video_frames_uniformly( + container, self.frame_sample_num) + images = [f.to_image() for f in sampled_frames] + # collect ocr results for each image + frame_ocr_area_ratios = [] + for idx, image in enumerate(images): + # return horizontal detected results and free-form detected + # results + horizontal_list, free_list = self.reader.detect( + np.asarray(image)) + total_area = image.width * image.height + # rectangles + rect_area = 0 + for xmin, xmax, ymin, ymax in horizontal_list[0]: + if xmax < xmin or ymax < ymin: + continue + rect_area += (xmax - xmin) * (ymax - ymin) + # free-form + quad_area = 0 + for points in free_list[0]: + triangle1 = points[:3] + quad_area += triangle_area(*triangle1) + triangle2 = points[3:] + [points[0]] + quad_area += triangle_area(*triangle2) + text_area = rect_area + quad_area + frame_ocr_area_ratios.append(text_area / total_area) + + # for debug + # if False: + # from PIL import ImageDraw + # draw = ImageDraw.Draw(image) + # for xmin, xmax, ymin, ymax in horizontal_list[0]: + # if xmax < xmin or ymax < ymin: + # continue + # draw.rectangle((xmin, ymin, xmax, ymax), + # outline='red', + # width=1) + # for points in free_list[0]: + # points = [(int(item[0]), int(item[1])) + # for item in points] + # draw.polygon(points, outline='blue', width=1) + # image.save(f'{video_key}-{idx}.jpg') + video_ocr_area_ratios[video_key] = np.mean(frame_ocr_area_ratios) + + if not context: + container.close() + + # get video durations + sample[Fields.stats][StatsKeys.video_ocr_area_ratio] = [ + video_ocr_area_ratios[video_key] + for video_key in sample[self.video_key] + ] + + return sample + + def process(self, sample): + video_ocr_area_ratios = sample[Fields.stats][ + StatsKeys.video_ocr_area_ratio] + keep_bools = np.array([ + self.min_area_ratio <= ocr_area_ratio <= self.max_area_ratio + for ocr_area_ratio in video_ocr_area_ratios + ]) + if len(keep_bools) <= 0: + return True + + # different strategies + if self.any: + return keep_bools.any() + else: + return keep_bools.all() diff --git a/data_juicer/ops/filter/video_resolution_filter.py b/data_juicer/ops/filter/video_resolution_filter.py new file mode 100644 index 000000000..efd0b9147 --- /dev/null +++ b/data_juicer/ops/filter/video_resolution_filter.py @@ -0,0 +1,112 @@ +import sys + +import numpy as np +from jsonargparse.typing import PositiveInt + +from data_juicer.utils.constant import Fields, StatsKeys +from data_juicer.utils.mm_utils import load_data_with_context, load_video + +from ..base_op import OPERATORS, Filter +from ..op_fusion import LOADED_VIDEOS + +OP_NAME = 'video_resolution_filter' + + +@OPERATORS.register_module(OP_NAME) +@LOADED_VIDEOS.register_module(OP_NAME) +class VideoResolutionFilter(Filter): + """Keep data samples whose videos' resolutions are within a specified range. + """ + + def __init__(self, + min_width: PositiveInt = 1, + max_width: PositiveInt = sys.maxsize, + min_height: PositiveInt = 1, + max_height: PositiveInt = sys.maxsize, + any_or_all: str = 'any', + *args, + **kwargs): + """ + Initialization method. + + :param min_width: The min horizontal resolution. + :param max_width: The max horizontal resolution. + :param min_height: The min vertical resolution. + :param max_height: The max vertical resolution. + :param any_or_all: keep this sample with 'any' or 'all' strategy of + all videos. 'any': keep this sample if any videos meet the + condition. 'all': keep this sample only if all videos meet the + condition. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.min_width = min_width + self.max_width = max_width + self.min_height = min_height + self.max_height = max_height + if any_or_all not in ['any', 'all']: + raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' + f'Can only be one of ["any", "all"].') + self.any = (any_or_all == 'any') + + def compute_stats(self, sample, context=False): + # check if it's computed already + if StatsKeys.video_width in sample[Fields.stats] \ + and StatsKeys.video_height in sample[Fields.stats]: + return sample + + # there is no video in this sample + if self.video_key not in sample or not sample[self.video_key]: + sample[Fields.stats][StatsKeys.video_width] = np.array( + [], dtype=np.int64) + sample[Fields.stats][StatsKeys.video_height] = np.array( + [], dtype=np.int64) + return sample + + # load videos + loaded_video_keys = sample[self.video_key] + sample, videos = load_data_with_context(sample, context, + loaded_video_keys, load_video) + + video_width, video_height = dict(), dict() + for video_key, video in videos.items(): + # default to load the first stream + video_stream = video.streams.video[0] + + # fail in loading video + if video_stream is None: + return sample + + if not context: + video.close() + + video_width[video_key] = video_stream.codec_context.width + video_height[video_key] = video_stream.codec_context.height + + # get video resolutions + sample[Fields.stats][StatsKeys.video_width] = [ + video_width[video_key] for video_key in sample[self.video_key] + ] + sample[Fields.stats][StatsKeys.video_height] = [ + video_height[video_key] for video_key in sample[self.video_key] + ] + + return sample + + def process(self, sample): + ws = sample[Fields.stats][StatsKeys.video_width] + hs = sample[Fields.stats][StatsKeys.video_height] + keep_bools = np.array([ + self.min_width <= w <= self.max_width + and self.min_height <= h <= self.max_height + for w, h in zip(ws, hs) + ]) + if len(keep_bools) <= 0: + return True + + # different strategies + if self.any: + return keep_bools.any() + else: + return keep_bools.all() diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 5f842cf3b..1c5cc2855 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -1,17 +1,24 @@ # yapf: disable -from . import (chinese_convert_mapper, clean_copyright_mapper, - clean_email_mapper, clean_html_mapper, clean_ip_mapper, - clean_links_mapper, expand_macro_mapper, fix_unicode_mapper, - generate_caption_mapper, gpt4v_generate_mapper, - image_blur_mapper, image_diffusion_mapper, nlpaug_en_mapper, - nlpcda_zh_mapper, punctuation_normalization_mapper, - remove_bibliography_mapper, remove_comments_mapper, - remove_header_mapper, remove_long_words_mapper, - remove_non_chinese_character_mapper, +from . import (audio_ffmpeg_wrapped_mapper, chinese_convert_mapper, + clean_copyright_mapper, clean_email_mapper, clean_html_mapper, + clean_ip_mapper, clean_links_mapper, expand_macro_mapper, + fix_unicode_mapper, image_blur_mapper, + image_captioning_from_gpt4v_mapper, image_captioning_mapper, + image_diffusion_mapper, nlpaug_en_mapper, nlpcda_zh_mapper, + punctuation_normalization_mapper, remove_bibliography_mapper, + remove_comments_mapper, remove_header_mapper, + remove_long_words_mapper, remove_non_chinese_character_mapper, remove_repeat_sentences_mapper, remove_specific_chars_mapper, remove_table_text_mapper, remove_words_with_incorrect_substrings_mapper, replace_content_mapper, sentence_split_mapper, + video_captioning_from_audio_mapper, + video_captioning_from_video_mapper, video_ffmpeg_wrapped_mapper, + video_resize_aspect_ratio_mapper, + video_resize_resolution_mapper, video_split_by_duration_mapper, + video_split_by_key_frame_mapper, video_split_by_scene_mapper, + video_tagging_from_audio_mapper, + video_tagging_from_frames_mapper, whitespace_normalization_mapper) # yapf: enable diff --git a/data_juicer/ops/mapper/audio_ffmpeg_wrapped_mapper.py b/data_juicer/ops/mapper/audio_ffmpeg_wrapped_mapper.py new file mode 100644 index 000000000..462a78fcc --- /dev/null +++ b/data_juicer/ops/mapper/audio_ffmpeg_wrapped_mapper.py @@ -0,0 +1,75 @@ +from typing import Dict, List, Optional + +from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.file_utils import transfer_filename +from data_juicer.utils.logger_utils import HiddenPrints + +from ..base_op import OPERATORS, Mapper + +OP_NAME = 'audio_ffmpeg_wrapped_mapper' + +with AvailabilityChecking(['ffmpeg-python'], OP_NAME), HiddenPrints(): + import ffmpeg + + +@OPERATORS.register_module(OP_NAME) +class AudioFFmpegWrappedMapper(Mapper): + """Simple wrapper for FFmpeg audio filters. + """ + + def __init__( + self, + filter_name: Optional[str] = None, + filter_kwargs: Optional[Dict] = None, + global_args: Optional[List[str]] = None, + capture_stderr: bool = True, + overwrite_output: bool = True, + *args, + **kwargs, + ): + """ + Initialization method. + + :param filter_name: ffmpeg audio filter name. + :param filter_kwargs: keyword-arguments passed to ffmpeg filter. + :param global_args: list-arguments passed to ffmpeg command-line. + :param capture_stderr: whether to capture stderr. + :param overwrite_output: whether to overwrite output file. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self._init_parameters = self.remove_extra_parameters(locals()) + + self.filter_name = filter_name + self.filter_kwargs = filter_kwargs + self.global_args = global_args + self.capture_stderr = capture_stderr + self.overwrite_output = overwrite_output + + def process(self, sample): + # there is no audio in this sample + if self.audio_key not in sample or not sample[self.audio_key]: + return sample + + if self.filter_name is None: + return sample + + loaded_audio_keys = sample[self.audio_key] + proceessed = {} + for audio_key in loaded_audio_keys: + if audio_key in proceessed: + continue + + output_key = transfer_filename(audio_key, OP_NAME, + **self._init_parameters) + stream = (ffmpeg.input(audio_key).filter( + self.filter_name, **self.filter_kwargs).output(output_key)) + if self.global_args is not None: + stream = stream.global_args(*self.global_args) + stream.run(capture_stderr=self.capture_stderr, + overwrite_output=self.overwrite_output) + proceessed[audio_key] = output_key + + sample[self.audio_key] = [proceessed[key] for key in loaded_audio_keys] + return sample diff --git a/data_juicer/ops/mapper/image_blur_mapper.py b/data_juicer/ops/mapper/image_blur_mapper.py index 3edb22c20..880823d6c 100644 --- a/data_juicer/ops/mapper/image_blur_mapper.py +++ b/data_juicer/ops/mapper/image_blur_mapper.py @@ -3,14 +3,17 @@ import numpy as np from data_juicer.utils.constant import Fields +from data_juicer.utils.file_utils import transfer_filename from data_juicer.utils.mm_utils import load_data_with_context, load_image from ..base_op import OPERATORS, Mapper from ..op_fusion import LOADED_IMAGES +OP_NAME = 'image_blur_mapper' -@OPERATORS.register_module('image_blur_mapper') -@LOADED_IMAGES.register_module('image_blur_mapper') + +@OPERATORS.register_module(OP_NAME) +@LOADED_IMAGES.register_module(OP_NAME) class ImageBlurMapper(Mapper): """Mapper to blur images. """ @@ -32,6 +35,7 @@ def __init__(self, :param kwargs: extra args """ super().__init__(*args, **kwargs) + self._init_parameters = self.remove_extra_parameters(locals()) if blur_type not in ['mean', 'box', 'gaussian']: raise ValueError( f'Blur_type [{blur_type}] is not supported. ' @@ -63,9 +67,8 @@ def process(self, sample, context=False): if self.p < np.random.rand(): continue else: - blured_image_key = os.path.join( - os.path.dirname(value), - '_blured.'.join(os.path.basename(value).split('.'))) + blured_image_key = transfer_filename(value, OP_NAME, + **self._init_parameters) if not os.path.exists( blured_image_key) or blured_image_key not in images: blured_image = images[value].convert('RGB').filter( diff --git a/data_juicer/ops/mapper/image_captioning_from_gpt4v_mapper.py b/data_juicer/ops/mapper/image_captioning_from_gpt4v_mapper.py new file mode 100644 index 000000000..8b58f1e3a --- /dev/null +++ b/data_juicer/ops/mapper/image_captioning_from_gpt4v_mapper.py @@ -0,0 +1,268 @@ +import copy + +import requests +from jsonargparse.typing import ClosedUnitInterval +from loguru import logger + +from data_juicer.utils.mm_utils import (SpecialTokens, image_byte_to_base64, + insert_texts_after_placeholders, + load_image_byte, + remove_non_special_tokens, + remove_special_tokens) + +from ..base_op import OPERATORS, Mapper +from ..op_fusion import LOADED_IMAGES + +SYSTEM_PROMPTS = { + 'resoning': + "You are an AI visual assistant that can analyze a single image. The task is to use the provided image, create a plausible question about the image, and provide the answer in detail.\n\nYou can create complex questions beyond describing the scene. Make the question challenging by not including the visual content details in the question so that the user needs to reason about that first.\n\nTo answer such questions, you should require first understanding the visual content, then based on the background knowledge or reasoning, either explain why the things are happening that way, or provide guides and help to user's request. \n\nPlease give the Q&A content directly and separate questions and answers with Q and A.", # noqa: E501 + 'description': + 'You are an AI visual assistant that can analyze a single image. The task is to use the provided image, create a reasonable question that describes the content of the image, and provide the answer in detail.\n\nPlease give the Q&A content directly and separate questions and answers with Q and A.', # noqa: E501 + 'conversation': + 'You are an AI visual assistant, and you are seeing a single image.\n\nDesign a conversation between you and a person asking about this image. The answers should be in a tone that a visual AI assistant is seeing the image and answering the question. Ask diverse questions and give corresponding answers.\n\nInclude questions asking about the visual content of the image, including the object types, counting the objects, object actions, object locations, relative positions between objects, etc. Only include questions that have definite answers:\n(1) one can see the content in the image that the question asks about and can answer confidently;\n(2) one can determine confidently from the image that it is not in the image.\nDo not ask any question that cannot be answered confidently.\n\nConversation also include complex questions that are relevant to the content in the image, for example, asking about background knowledge of the objects in the image, asking to discuss about events happening in the image, etc. Again, do not ask about uncertain details.\nProvide detailed answers when answering complex questions. For example, give detailed examples or reasoning steps to make the content more convincing and well-organized. Please give the content of the conversation directly and separate questions and answers with Q and A' # noqa: E501 +} + + +def call_gpt_vision_api(api_key, + system_prompt, + user_prompt, + base64_image, + max_tokens=500, + temperature=1.0, + model='gpt-4-vision-preview'): + api_url = 'https://api.openai.com/v1/chat/completions' + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {api_key}' + } + data = { + 'model': + model, + 'messages': [{ + 'role': 'system', + 'content': system_prompt + }, { + 'role': + 'user', + 'content': [{ + 'type': 'text', + 'text': user_prompt + }, { + 'type': 'image_url', + 'image_url': { + 'url': f'data:image/jpeg;base64,{base64_image}', + 'detail': 'low' + } + }] + }], + 'max_tokens': + max_tokens, + 'temperature': + temperature + } + try: + response = requests.post(api_url, headers=headers, json=data) + response.raise_for_status() + result = response.json() + + if 'choices' in result and result['choices']: + return result['choices'][0]['text'] + else: + logger.warning('No results returned from the API, return None.') + return None + + except requests.exceptions.HTTPError as errh: + if errh.response.status_code == 401: + logger.warning('Invalid API key provided.') + elif errh.response.status_code == 429: + logger.warning( + 'API request limit has been reached. Please try again later.') + else: + logger.warning(f'HTTP error occurred: {errh}') + except requests.exceptions.ConnectionError: + logger.warning('Network error occurred. Please check your connection.') + except requests.exceptions.Timeout: + logger.warning('The request timed out. Please try again later.') + except requests.exceptions.RequestException as err: + logger.warningt(f'An error occurred: {err}') + except Exception as e: + logger.warning(f'An unexpected error occurred: {e}') + + logger.warning('API request failed, return None.') + return None + + +@OPERATORS.register_module('image_captioning_from_gpt4v_mapper') +@LOADED_IMAGES.register_module('image_captioning_from_gpt4v_mapper') +class ImageCaptioningFromGPT4VMapper(Mapper): + """Mapper to generate samples whose texts are generated based on + gpt-4-visison and the image.""" + + def __init__(self, + mode: str = 'description', + api_key: str = '', + max_token: int = 500, + temperature: ClosedUnitInterval = 1.0, + system_prompt: str = '', + user_prompt: str = '', + user_prompt_key: str = None, + keep_original_sample: bool = True, + any_or_all: str = 'any', + *args, + **kwargs): + """ + Initialization method. + + :param mode: mode of text generated from images, can be one of + ['resoning', 'description', 'conversation', 'custom'] + :param api_key: the API key to authenticate the request. + :param max_token: the maximum number of tokens to generate. + Default is 500. + :param temperature: controls the randomness of the output (range + from 0 to 1). Default is 0. + :param system_prompt: a string prompt used to set the context of a + conversation and provide global guidance or rules for the + gpt4-vision so that it can generate responses in the expected way. + If `mode` set to `custom`, the parameter will be used. + :param user_prompt: a string prompt to guide the generation of + gpt4-vision for each samples. It's "" in default, which means no + prompt provided. + :param uers_prompt_key: the key name of fields in samples to store + prompts for each sample. It's used for set different prompts for + different samples. If it's none, use prompt in parameter "prompt". + It's None in default. + :param keep_original_sample: whether to keep the original sample. If + it's set to False, there will be only generated text in the + final datasets and the original text will be removed. It's True + in default. + :param any_or_all: keep this sample with 'any' or 'all' strategy of + all images. 'any': keep this sample if any images meet the + condition. 'all': keep this sample only if all images meet the + condition. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self._batched_op = True + if mode not in ['resoning', 'description', 'conversation', 'custom']: + raise ValueError( + f'Mode [{mode}] is not supported. ' + f'Can only be one of ' + f'["resoning", "description", "conversation", "custom"].') + + if mode == 'custom': + self.system_prompt = system_prompt + logger.info('The parameter `mode` set to `[custom]`. Data-Juicer ' + 'will use `system_prompt` to generate text.') + else: + self.system_prompt = SYSTEM_PROMPTS[mode] + logger.info( + f'The parameter `mode` set to [{mode}]. Data-Juicer will ' + f'use default prompt to generate text.') + + self.mode = mode + self.api_key = api_key + self.max_token = max_token + self.temperature = temperature + self.user_prompt = user_prompt + self.user_prompt_key = user_prompt_key + self.keep_original_sample = keep_original_sample + self.any_or_all = any_or_all + self.extra_args = kwargs + + # report a warning when both user_prompt and user_prompt_key are set + if self.user_prompt and self.user_prompt_key: + logger.warning( + 'Both the parameter `user_prompt` and `user_prompt_key` are ' + 'set. Data-Juicer will consider `user_prompt_key` first.') + + def _process_single_sample(self, sample): + # there is no image in this sample + if self.image_key not in sample or not sample[self.image_key]: + return [] + + # the generated results + generated_sample = copy.deepcopy(sample) + generated_sample[self.text_key] = '' + + # load all image(s) + loaded_image_keys = sample[self.image_key] + images = {} + for loaded_image_key in loaded_image_keys: + if loaded_image_key not in images: + # avoid loading the same images + image = load_image_byte(loaded_image_key) + images[loaded_image_key] = image + + # construct user prompts + if self.user_prompt_key and isinstance(sample[self.user_prompt_key], + str): + # check user_prompt_key is not None, and it's a str in the sample + prompt_texts = sample[self.user_prompt_key] + elif self.user_prompt and isinstance(self.user_prompt, str): + # check prompt is not None, and it's a str + prompt_texts = self.user_prompt + else: + prompt_texts = '' + + offset = 0 + # do generation for each image chunk by chunk + for chunk in sample[self.text_key].split(SpecialTokens.eoc): + # skip empty chunks or contents after the last eoc token + if not chunk.strip(): + continue + + else: + img_count = chunk.count(SpecialTokens.image) + text_with_only_special_tokens = remove_non_special_tokens( + chunk) + generated_text_single_chunk = [] + for image_key in loaded_image_keys[offset:offset + img_count]: + image = images[image_key] + res = call_gpt_vision_api(self.api_key, self.system_prompt, + prompt_texts, + image_byte_to_base64(image), + self.max_token, self.temperature) + generated_text_single_chunk.append(res) + if self.any_or_all == 'all' and not all( + generated_text_single_chunk): + return [] + + # insert the generated text according to given mode + place_holders = [SpecialTokens.image] * img_count + new_generated_text_per_chunk = insert_texts_after_placeholders( + original_string=text_with_only_special_tokens, + placeholders=place_holders, + new_texts=generated_text_single_chunk) + generated_sample[ + self. + text_key] += f'{new_generated_text_per_chunk}{SpecialTokens.eoc}' # noqa: E501 + offset += img_count + if self.any_or_all == 'any' and not remove_special_tokens( + generated_sample[self.text_key]): + return [] + + return [generated_sample] + + def process(self, samples): + # reconstruct samples from "dict of lists" to "list of dicts" + reconstructed_samples = [] + for i in range(len(samples[self.text_key])): + reconstructed_samples.append( + {key: samples[key][i] + for key in samples}) + samples_after_generation = [] + # do generation for each sample within the batch + for ori_sample in reconstructed_samples: + if self.keep_original_sample: + samples_after_generation.append(ori_sample) + generated_samples = self._process_single_sample(ori_sample) + if len(generated_samples) != 0: + samples_after_generation.extend(generated_samples) + # reconstruct samples from "list of dicts" to "dict of lists" + keys = samples_after_generation[0].keys() + res_samples = {} + for key in keys: + res_samples[key] = [s[key] for s in samples_after_generation] + + return res_samples diff --git a/data_juicer/ops/mapper/image_captioning_mapper.py b/data_juicer/ops/mapper/image_captioning_mapper.py new file mode 100644 index 000000000..65332f543 --- /dev/null +++ b/data_juicer/ops/mapper/image_captioning_mapper.py @@ -0,0 +1,300 @@ +import copy +import random + +import numpy as np +from jsonargparse.typing import PositiveInt +from loguru import logger + +from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.constant import HashKeys +from data_juicer.utils.mm_utils import (SpecialTokens, + insert_texts_after_placeholders, + load_image, remove_non_special_tokens, + remove_special_tokens) +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, Mapper +from ..op_fusion import LOADED_IMAGES + +OP_NAME = 'image_captioning_mapper' + +with AvailabilityChecking(['torch', 'transformers', 'simhash-pybind'], + OP_NAME): + import simhash # noqa: F401 + import torch + import transformers # noqa: F401 + + # avoid hanging when calling model in multiprocessing + torch.set_num_threads(1) + + +@OPERATORS.register_module(OP_NAME) +@LOADED_IMAGES.register_module(OP_NAME) +class ImageCaptioningMapper(Mapper): + """Mapper to generate samples whose captions are generated based on + another model and the figure.""" + + def __init__(self, + hf_img2seq='Salesforce/blip2-opt-2.7b', + caption_num: PositiveInt = 1, + keep_candidate_mode: str = 'random_any', + keep_original_sample: bool = True, + prompt: str = None, + prompt_key: str = None, + *args, + **kwargs): + """ + Initialization method. + + :param hf_img2seq: model name on huggingface to generate caption + :param caption_num: how many candidate captions to generate + for each image + :param keep_candidate_mode: retain strategy for the generated + $caption_num$ candidates. + 'random_any': Retain the random one from generated captions + 'similar_one_simhash': Retain the generated one that is most + similar to the original caption + 'all': Retain all generated captions by concatenation + Note: This is a batched_OP, whose input and output type are + both list. Suppose there are $N$ list of input samples, whose batch + size is $b$, and denote caption_num as $M$. + The number of total samples after generation is $2Nb$ when + keep_original_sample is True and $Nb$ when keep_original_sample is + False. For 'random_any' and 'similar_one_simhash' mode, + it's $(1+M)Nb$ for 'all' mode when keep_original_sample is True + and $MNb$ when keep_original_sample is False. + :param keep_original_sample: whether to keep the original sample. If + it's set to False, there will be only generated captions in the + final datasets and the original captions will be removed. It's True + in default. + :param prompt: a string prompt to guide the generation of blip2 model + for all samples globally. It's None in default, which means no + prompt provided. + :param prompt_key: the key name of fields in samples to store prompts + for each sample. It's used for set different prompts for different + samples. If it's none, use prompt in parameter "prompt". It's None + in default. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self._batched_op = True + if keep_candidate_mode not in [ + 'random_any', 'similar_one_simhash', 'all' + ]: + raise ValueError( + f'Keep strategy [{keep_candidate_mode}] is not supported. ' + f'Can only be one of ' + f'["random_any", "similar_one_simhash", "all"].') + + self.model_key = prepare_model( + model_type='huggingface', pretrained_model_name_or_path=hf_img2seq) + self._accelerator = 'cuda' + self.caption_num = caption_num + self.keep_candidate_mode = keep_candidate_mode + self.keep_original_sample = keep_original_sample + self.prompt = prompt + self.prompt_key = prompt_key + self.extra_args = kwargs + + if keep_candidate_mode in ['random_any', 'similar_one_simhash']: + self.num_newly_generated_samples = 1 + elif keep_candidate_mode in ['all']: + self.num_newly_generated_samples = self.caption_num + else: + self.num_newly_generated_samples = 0 + + # report a warning when both prompt and prompt_key are set + if self.prompt and self.prompt_key: + logger.warning( + 'Both the parameter `prompt` and `prompt_key` are ' + 'set. Data-Juicer will consider `prompt_key` first.') + + def _process_single_sample(self, ori_sample, rank=None): + """ + + :param ori_sample: a single data sample before applying generation + :return: batched results after generation + """ + # there is no image in this sample + if self.image_key not in ori_sample or \ + not ori_sample[self.image_key]: + return [] + + # the generated results + generated_samples = [ + copy.deepcopy(ori_sample) + for _ in range(self.num_newly_generated_samples) + ] + for generated_sample in generated_samples: + generated_sample[self.text_key] = '' + + # 1. load all image(s) + loaded_image_keys = ori_sample[self.image_key] + images = {} + for loaded_image_key in loaded_image_keys: + if loaded_image_key not in images: + # avoid loading the same images + image = load_image(loaded_image_key) + images[loaded_image_key] = image + + offset = 0 + + # we follow such assumption: + # all text/img/video/audio data within a chunk are correlated. + # As a result, + # the original text will be removed, + # the generated text will be placed following each SpecialTokens.img + # and the original special tokens are kept in an order-preserving way. + + model, processor = get_model(self.model_key, rank=rank) + + # do generation for each image chunk by chunk + for chunk in ori_sample[self.text_key].split(SpecialTokens.eoc): + # skip empty chunks or contents after the last eoc token + if not chunk.strip(): + continue + + img_count = chunk.count(SpecialTokens.image) + text_with_only_special_tokens = remove_non_special_tokens(chunk) + image_chunk = [] + for image_key in loaded_image_keys[offset:offset + img_count]: + image = images[image_key] + image_chunk.append(image) + + # 2. generate candidate caption(s) in batch manner + generated_text_candidates_single_chunk = \ + [[] for _ in range(self.caption_num)] + # an assistant 2-D array, + # generated_text_candidates_single_chunk[i][j] indicates + # the $i$-th generated candidate for the $j$-th image + + # construct prompts + if self.prompt_key \ + and isinstance(ori_sample[self.prompt_key], str): + # check prompt_key is not None, and it's a str in the sample + prompt_texts = [ori_sample[self.prompt_key]] * len(image_chunk) + elif self.prompt and isinstance(self.prompt, str): + # check prompt is not None, and it's a str + prompt_texts = [self.prompt] * len(image_chunk) + else: + prompt_texts = None + + inputs = processor(images=image_chunk, + text=prompt_texts, + return_tensors='pt').to(model.device) + for i in range(self.caption_num): + generated_ids = model.generate(**inputs, + do_sample=True).to(model.device) + generated_text = processor.batch_decode( + generated_ids, skip_special_tokens=True) + generated_text_candidates_single_chunk[i] = generated_text + + # 3. insert a list of generated captions into the positions of + # subsequent placeholders in the original string + new_generated_text_all_images = \ + [[] for _ in range(self.num_newly_generated_samples)] + # new_generated_text_all_images is a helper array, element [i][j] + # denotes the reduced $i$-th result for the $j$-th image + + # reduce the captions according to given mode image by image + for j in range(img_count): + new_generated_text_per_image = self._reduce_captions_per_image( + chunk, [ + captions[j] + for captions in generated_text_candidates_single_chunk + ]) + assert self.num_newly_generated_samples == \ + len(new_generated_text_per_image) + for i in range(len(new_generated_text_per_image)): + new_generated_text_all_images[i].append( + new_generated_text_per_image[i]) + + # insert the captions according to given mode + place_holders = [SpecialTokens.image] * img_count + for i in range(self.num_newly_generated_samples): + new_generated_text_per_chunk = insert_texts_after_placeholders( + original_string=text_with_only_special_tokens, + placeholders=place_holders, + new_texts=new_generated_text_all_images[i]) + generated_samples[i][self.text_key] += \ + f'{new_generated_text_per_chunk}{SpecialTokens.eoc}' + + offset += img_count + + return generated_samples + + def _reduce_captions_per_image(self, chunk, + generated_text_candidates_single_chunk): + new_generated_text_per_chunk = [] + if self.keep_candidate_mode == 'random_any': + new_generated_text_per_chunk.append( + random.choice(generated_text_candidates_single_chunk)) + elif self.keep_candidate_mode == 'all': + new_generated_text_per_chunk.extend( + generated_text_candidates_single_chunk) + elif self.keep_candidate_mode == 'similar_one_simhash': + from simhash import num_differing_bits + + from ..deduplicator.document_simhash_deduplicator import \ + DocumentSimhashDeduplicator + ori_normal_text = remove_special_tokens(chunk) + # using a simhash OP to calculate their similarity + # NOTE: simhash is just one method to calculate the similarities + # between texts, but not the most accurate one. More methods (e.g. + # embedding-based, ...) will be added. + op_simhash = DocumentSimhashDeduplicator(window_size=2, + **self.extra_args) + ori_text_hash = np.uint64( + op_simhash.compute_hash({op_simhash.text_key: + ori_normal_text})[HashKeys.simhash]) + generated_text_hashes = [ + np.uint64( + op_simhash.compute_hash( + {op_simhash.text_key: + candidate_text})[HashKeys.simhash]) + for candidate_text in generated_text_candidates_single_chunk + ] + hamming_distances = [ + num_differing_bits(ori_text_hash, generated_text_hash) + for generated_text_hash in generated_text_hashes + ] + max_index = min(range(len(hamming_distances)), + key=hamming_distances.__getitem__) + new_generated_text_per_chunk.append( + generated_text_candidates_single_chunk[max_index]) + return new_generated_text_per_chunk + + def process(self, samples, rank=None): + """ + Note: This is a batched_OP, whose input and output type are + both list. Suppose there are $N$ input sample list with batch + size as $b$, and denote caption_num as $M$. + the number of total samples after generation is $2Nb$ + for 'random_any' and 'similar_one' mode, + and $(1+M)Nb$ for 'all' mode. + :param samples: + :return: + """ + # reconstruct samples from "dict of lists" to "list of dicts" + reconstructed_samples = [] + for i in range(len(samples[self.text_key])): + reconstructed_samples.append( + {key: samples[key][i] + for key in samples}) + samples_after_generation = [] + # do generation for each sample within the batch + for ori_sample in reconstructed_samples: + if self.keep_original_sample: + samples_after_generation.append(ori_sample) + generated_samples = self._process_single_sample(ori_sample, + rank=rank) + if len(generated_samples) != 0: + samples_after_generation.extend(generated_samples) + # reconstruct samples from "list of dicts" to "dict of lists" + keys = samples_after_generation[0].keys() + res_samples = {} + for key in keys: + res_samples[key] = [s[key] for s in samples_after_generation] + + return res_samples diff --git a/data_juicer/ops/mapper/image_diffusion_mapper.py b/data_juicer/ops/mapper/image_diffusion_mapper.py index 7941d3c1d..925fdc011 100644 --- a/data_juicer/ops/mapper/image_diffusion_mapper.py +++ b/data_juicer/ops/mapper/image_diffusion_mapper.py @@ -5,6 +5,7 @@ from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields +from data_juicer.utils.file_utils import transfer_filename from data_juicer.utils.mm_utils import (SpecialTokens, load_data_with_context, load_image, remove_special_tokens) from data_juicer.utils.model_utils import get_model, prepare_model @@ -40,7 +41,7 @@ def __init__(self, aug_num: int = 1, keep_original_sample: bool = True, caption_key: str = None, - hf_blip2='Salesforce/blip2-opt-2.7b', + hf_img2seq='Salesforce/blip2-opt-2.7b', *args, **kwargs): """ @@ -81,10 +82,11 @@ def __init__(self, for each images. It can be a string if there is only one image in each sample. Otherwise, it should be a list. If it's none, ImageDiffusionMapper will produce captions for each images. - :param hf_blip2: blip2 model name on huggingface to generate caption if + :param hf_img2seq: model name on huggingface to generate caption if caption_key is None. """ super().__init__(*args, **kwargs) + self._init_parameters = self.remove_extra_parameters(locals()) self._batched_op = True self._accelerator = 'cuda' self.strength = strength @@ -94,9 +96,9 @@ def __init__(self, self.caption_key = caption_key self.prompt = 'A photo of a ' if not self.caption_key: - from .generate_caption_mapper import GenerateCaptionMapper - self.op_generate_caption = GenerateCaptionMapper( - hf_blip2=hf_blip2, + from .image_captioning_mapper import ImageCaptioningMapper + self.op_generate_caption = ImageCaptioningMapper( + hf_img2seq=hf_img2seq, keep_original_sample=False, prompt=self.prompt) @@ -175,9 +177,10 @@ def _process_single_sample(self, ori_sample, rank=None, context=False): for aug_id in range(self.aug_num): diffusion_image_keys = [] for index, value in enumerate(loaded_image_keys): - diffusion_image_key = os.path.join( - os.path.dirname(value), (f'_diffusion_{aug_id}.').join( - os.path.basename(value).split('.'))) + related_parameters = self.add_parameters( + self._init_parameters, caption=captions[index]) + diffusion_image_key = transfer_filename( + value, OP_NAME, **related_parameters) diffusion_image_keys.append(diffusion_image_key) # TODO: duplicated generation if image is reused if not os.path.exists(diffusion_image_key diff --git a/data_juicer/ops/mapper/video_captioning_from_audio_mapper.py b/data_juicer/ops/mapper/video_captioning_from_audio_mapper.py new file mode 100644 index 000000000..7b479b778 --- /dev/null +++ b/data_juicer/ops/mapper/video_captioning_from_audio_mapper.py @@ -0,0 +1,147 @@ +import copy +import os + +import regex as re + +from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.mm_utils import SpecialTokens, extract_audio_from_video +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, Mapper + +NAME = 'video_captioning_from_audio_mapper' +CHECK_PKGS = [ + 'transformers', 'transformers_stream_generator', 'einops', 'accelerate', + 'tiktoken' +] + +with AvailabilityChecking(CHECK_PKGS, NAME): + import accelerate # noqa: F401 + import einops # noqa: F401 + import tiktoken # noqa: F401 + import transformers # noqa: F401 + import transformers_stream_generator # noqa: F401 + + +@OPERATORS.register_module(NAME) +class VideoCaptioningFromAudioMapper(Mapper): + """Mapper to caption a video according to its audio streams based on + Qwen-Audio model. + """ + + def __init__(self, keep_original_sample: bool = True, *args, **kwargs): + """ + Initialization method. + + :param keep_original_sample: whether to keep the original sample. If + it's set to False, there will be only captioned sample in the + final datasets and the original sample will be removed. It's True + in default. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self._batched_op = True + self.keep_original_sample = keep_original_sample + self.extra_args = kwargs + + self._accelerator = 'cuda' + + self._hf_qwen_audio = 'Qwen-Audio' + self.model_key = prepare_model( + model_type='huggingface', + pretrained_model_name_or_path=self._hf_qwen_audio, + trust_remote_code=True, + ) + self.prompt = '<|startoftranscription|><|unkown|><|caption|>' \ + '<|unkown|><|notimestamps|><|wo_itn|>' + self.response_remove_pattern = re.compile(r'<\|.*?\|>') + + def _process_single_sample(self, sample, rank=None): + # there is no video in this sample + if self.video_key not in sample or not sample[self.video_key]: + return [] + + # get paths of all video(s) + loaded_video_keys = sample[self.video_key] + + # get models + model, processor = get_model(self.model_key, rank=rank) + + offset = 0 + captioned_sample = copy.deepcopy(sample) + # generate for each video chunk by chunk + captioned_texts = '' + left_video_keys = [] + for chunk in sample[self.text_key].split(SpecialTokens.eoc): + # skip empty chunks + if not chunk.strip(): + continue + + vid_count = chunk.count(SpecialTokens.video) + + captioned_text_list = [] + for video in loaded_video_keys[offset:offset + vid_count]: + # only extract audio for index 0 for now + _, _, valid_indexes = extract_audio_from_video( + video, video + '.mp3', stream_indexes=[0]) + if len(valid_indexes) == 0: + # there is no valid audio streams. Skip! + continue + extracted_audio_path = video + '_0.mp3' + query = f'{self.prompt}' + + # start to inference + audio_info = processor.process_audio(query) + inputs = processor(query, + return_tensors='pt', + audio_info=audio_info).to(model.device) + outputs = model.generate(**inputs, audio_info=audio_info) + response = processor.decode(outputs.cpu()[0], + skip_special_tokens=True, + audio_info=audio_info) + # remove audio path + response = response.replace(extracted_audio_path, '').replace( + '', '') + response = self.response_remove_pattern.sub('', + response).strip() + if response == '': + # generate failure. Skip! + continue + captioned_text_list.append(f'{SpecialTokens.video} {response}') + left_video_keys.append(video) + # remove extracted audio files + os.remove(extracted_audio_path) + offset += vid_count + captioned_text = ''.join(captioned_text_list) + + # add special tokens + captioned_texts += f'{captioned_text}{SpecialTokens.eoc}' + + captioned_sample[self.text_key] = captioned_texts + captioned_sample[self.video_key] = left_video_keys + return [captioned_sample] + + def process(self, samples, rank=None): + # reconstruct samples from "dict of lists" to "list of dicts" + reconstructed_samples = [] + for i in range(len(samples[self.text_key])): + reconstructed_samples.append( + {key: samples[key][i] + for key in samples}) + samples_after_split = [] + # do split for each sample within the batch + for ori_sample in reconstructed_samples: + if self.keep_original_sample: + samples_after_split.append(ori_sample) + generated_samples = self._process_single_sample(ori_sample, + rank=rank) + if len(generated_samples) != 0: + samples_after_split.extend(generated_samples) + # reconstruct samples from "list of dicts" to "dict of lists" + keys = samples_after_split[0].keys() + res_samples = {} + for key in keys: + res_samples[key] = [s[key] for s in samples_after_split] + + return res_samples diff --git a/data_juicer/ops/mapper/video_captioning_from_video_mapper.py b/data_juicer/ops/mapper/video_captioning_from_video_mapper.py new file mode 100644 index 000000000..7c80d4a0f --- /dev/null +++ b/data_juicer/ops/mapper/video_captioning_from_video_mapper.py @@ -0,0 +1,362 @@ +import copy +import random + +import numpy as np +from jsonargparse.typing import PositiveInt +from loguru import logger +from PIL import ImageOps + +from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.constant import HashKeys +from data_juicer.utils.mm_utils import (SpecialTokens, extract_key_frames, + extract_video_frames_uniformly, + insert_texts_after_placeholders, + load_data_with_context, load_video, + remove_non_special_tokens, + remove_special_tokens) +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, Mapper +from ..op_fusion import LOADED_VIDEOS + +OP_NAME = 'video_captioning_from_video_mapper' + +with AvailabilityChecking(['torch', 'transformers', 'simhash-pybind'], + OP_NAME): + + import simhash # noqa: F401 + import torch + import transformers # noqa: F401 + + # avoid hanging when calling clip in multiprocessing + torch.set_num_threads(1) + + +@OPERATORS.register_module(OP_NAME) +@LOADED_VIDEOS.register_module(OP_NAME) +class VideoCaptioningFromVideoMapper(Mapper): + """Mapper to generate samples whose captions are generated based on + another model and sampled video frame.""" + + def __init__( + self, + hf_video_blip='kpyu/video-blip-opt-2.7b-ego4d', + caption_num: PositiveInt = 1, + keep_candidate_mode: str = 'random_any', + keep_original_sample: bool = True, + prompt: str = None, + prompt_key: str = None, + frame_sampling_method: str = 'all_keyframes', + frame_num: PositiveInt = 3, + horizontal_flip: bool = False, + vertical_flip: bool = False, + *args, + **kwargs, + ): + """ + Initialization method. + + :param hf_video_blip: video-blip model name on huggingface + to generate caption + :param caption_num: how many candidate captions to generate + for each video + :param keep_candidate_mode: retain strategy for the generated + $caption_num$ candidates. + 'random_any': Retain the random one from generated captions + 'similar_one_simhash': Retain the generated one that is most + similar to the original caption + 'all': Retain all generated captions by concatenation + Note: This is a batched_OP, whose input and output type are + both list. Suppose there are $N$ list of input samples, whose batch + size is $b$, and denote caption_num as $M$. + The number of total samples after generation is $2Nb$ when + keep_original_sample is True and $Nb$ when keep_original_sample is + False. For 'random_any' and 'similar_one_simhash' mode, + it's $(1+M)Nb$ for 'all' mode when keep_original_sample is True + and $MNb$ when keep_original_sample is False. + :param keep_original_sample: whether to keep the original sample. If + it's set to False, there will be only generated captions in the + final datasets and the original captions will be removed. It's True + in default. + :param prompt: a string prompt to guide the generation of video-blip + model for all samples globally. It's None in default, which means + no prompt provided. + :param prompt_key: the key name of fields in samples to store prompts + for each sample. It's used for set different prompts for different + samples. If it's none, use prompt in parameter "prompt". It's None + in default. + :param frame_sampling_method: sampling method of extracting frame + videos from the videos. Should be one of ["all_keyframes", + "uniform"]. The former one extracts all key frames (the number + of which depends on the duration of the video) and the latter + one extract specified number of frames uniformly from the video. + Default: "all_keyframes". + :param frame_num: the number of frames to be extracted uniformly from + the video. Only works when frame_sampling_method is "uniform". If + it's 1, only the middle frame will be extracted. If it's 2, only + the first and the last frames will be extracted. If it's larger + than 2, in addition to the first and the last frames, other frames + will be extracted uniformly within the video duration. + :param horizontal_flip: flip frame video horizontally (left to right). + :param vertical_flip: flip frame video vertically (top to bottom). + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + + self._batched_op = True + self._accelerator = 'cuda' + + if keep_candidate_mode not in [ + 'random_any', 'similar_one_simhash', 'all' + ]: + raise ValueError( + f'Keep strategy [{keep_candidate_mode}] is not supported. ' + f'Can only be one of ' + f'["random_any", "similar_one_simhash", "all"].') + + if keep_candidate_mode in ['random_any', 'similar_one_simhash']: + self.num_newly_generated_samples = 1 + elif keep_candidate_mode in ['all']: + self.num_newly_generated_samples = caption_num + else: + self.num_newly_generated_samples = 0 + + # report a warning when both prompt and prompt_key are set + if prompt and prompt_key: + logger.warning( + 'Both the parameter `prompt` and `prompt_key` are ' + 'set. Data-Juicer will consider `prompt_key` first.') + + self.caption_num = caption_num + self.keep_candidate_mode = keep_candidate_mode + self.keep_original_sample = keep_original_sample + self.prompt = prompt + self.prompt_key = prompt_key + self.extra_args = kwargs + + if frame_sampling_method not in ['all_keyframes', 'uniform']: + raise ValueError( + f'Frame sampling method ' + f'[{frame_sampling_method}] is not supported. ' + f'Can only be one of ["all_keyframes", "uniform"].') + + self.horizontal_flip = horizontal_flip + self.vertical_flip = vertical_flip + self.frame_sampling_method = frame_sampling_method + self.frame_num = frame_num + + self.model_key = prepare_model( + model_type='video_blip', + pretrained_model_name_or_path=hf_video_blip, + ) + + def _process_single_sample(self, ori_sample, rank=None, context=False): + + # there is no videos in this sample + if self.video_key not in ori_sample or not ori_sample[self.video_key]: + return [] + + # the generated results + generated_samples = [ + copy.deepcopy(ori_sample) + for _ in range(self.num_newly_generated_samples) + ] + for generated_sample in generated_samples: + generated_sample[self.text_key] = '' + + # load videos + loaded_video_keys = ori_sample[self.video_key] + sample, videos = load_data_with_context(ori_sample, context, + loaded_video_keys, load_video) + + text = sample[self.text_key] + offset = 0 + model, processor = get_model(self.model_key, rank=rank) + + for chunk in text.split(SpecialTokens.eoc): + + video_count = chunk.count(SpecialTokens.video) + + # no video or no text + if video_count == 0 or len(chunk) == 0: + continue + else: + text_with_only_special_tokens = remove_non_special_tokens( + chunk) + # generate candidate caption(s) in batch manner + generated_text_candidates_single_chunk = [ + [] for _ in range(self.caption_num) + ] + for video_key in loaded_video_keys[offset:offset + + video_count]: + video = videos[video_key] + video_frame_videos_chunk = [] + # extract frame videos + if self.frame_sampling_method == 'all_keyframes': + frames = extract_key_frames(video) + elif self.frame_sampling_method == 'uniform': + frames = extract_video_frames_uniformly( + video, self.frame_num) + else: + frames = [] + frame_videos = [frame.to_image() for frame in frames] + for video in frame_videos: + if self.horizontal_flip: + video = ImageOps.mirror(video) + if self.vertical_flip: + video = ImageOps.flip(video) + video_frame_videos_chunk.append(video) + + # construct prompts + if self.prompt_key and isinstance( + ori_sample[self.prompt_key], str): + # check prompt_key is not None, and it's a str + # in the sample + prompt_texts = [ori_sample[self.prompt_key]] + elif self.prompt and isinstance(self.prompt, str): + # check prompt is not None, and it's a str + prompt_texts = [self.prompt] + else: + prompt_texts = None + + inputs = processor( + text=prompt_texts, + images=video_frame_videos_chunk, + return_tensors='pt', + truncation=True, + max_length=model.config.text_config. + max_position_embeddings, + padding=True, + ).to(model.device) + # tchw to bcthw + inputs['pixel_values'] = inputs.pixel_values.unsqueeze( + 0).permute(0, 2, 1, 3, 4) + for i in range(self.caption_num): + generated_ids = model.generate(**inputs, + do_sample=True).to( + model.device) + generated_text = processor.batch_decode( + generated_ids, skip_special_tokens=True) + generated_text_candidates_single_chunk[ + i] += generated_text + + # 3. insert a list of generated captions into the positions of + # subsequent placeholders in the original string + new_generated_text_all_videos = [ + [] for _ in range(self.num_newly_generated_samples) + ] + # new_generated_text_all_videos is a helper array, + # element [i][j] + # denotes the reduced $i$-th result for the $j$-th video + + # reduce the captions according to given mode video by video + for j in range(video_count): + new_generated_text_per_video = self._reduce_captions( + chunk, + [ + captions[j] for captions in + generated_text_candidates_single_chunk + ], + ) + assert self.num_newly_generated_samples == len( + new_generated_text_per_video) + for i in range(len(new_generated_text_per_video)): + new_generated_text_all_videos[i].append( + new_generated_text_per_video[i]) + + # insert the captions according to given mode + place_holders = [SpecialTokens.video] * video_count + for i in range(self.num_newly_generated_samples): + generated_text_per_chunk = insert_texts_after_placeholders( + original_string=text_with_only_special_tokens, + placeholders=place_holders, + new_texts=new_generated_text_all_videos[i], + ) + generated_samples[i][ + self. + text_key] += f'{generated_text_per_chunk}' \ + f'{SpecialTokens.eoc}' + + offset += video_count + + if not context: + for vid_key in videos: + videos[vid_key].close() + return generated_samples + + def _reduce_captions(self, chunk, generated_text_candidates_single_chunk): + generated_text_per_chunk = [] + if self.keep_candidate_mode == 'random_any': + generated_text_per_chunk.append( + random.choice(generated_text_candidates_single_chunk)) + elif self.keep_candidate_mode == 'all': + generated_text_per_chunk.extend( + generated_text_candidates_single_chunk) + elif self.keep_candidate_mode == 'similar_one_simhash': + from simhash import num_differing_bits + + from ..deduplicator.document_simhash_deduplicator import \ + DocumentSimhashDeduplicator + + ori_normal_text = remove_special_tokens(chunk) + # using a simhash OP to calculate their similarity + # NOTE: simhash is just one method to calculate the similarities + # between texts, but not the most accurate one. More methods (e.g. + # embedding-based, ...) will be added. + op_simhash = DocumentSimhashDeduplicator(window_size=2, + **self.extra_args) + ori_text_hash = np.uint64( + op_simhash.compute_hash({op_simhash.text_key: + ori_normal_text})[HashKeys.simhash]) + generated_text_hashes = [ + np.uint64( + op_simhash.compute_hash( + {op_simhash.text_key: + candidate_text})[HashKeys.simhash]) + for candidate_text in generated_text_candidates_single_chunk + ] + hamming_distances = [ + num_differing_bits(ori_text_hash, generated_text_hash) + for generated_text_hash in generated_text_hashes + ] + max_index = min(range(len(hamming_distances)), + key=hamming_distances.__getitem__) + generated_text_per_chunk.append( + generated_text_candidates_single_chunk[max_index]) + return generated_text_per_chunk + + def process(self, samples, rank=None, context=False): + """ + Note: This is a batched_OP, whose the input and output type are + both list. Suppose there are $N$ input sample list with batch + size as $b$, and denote caption_num as $M$. + the number of total samples after generation is $2Nb$ + for 'random_any' and 'similar_one' mode, + and $(1+M)Nb$ for 'all' mode. + :param samples: + :return: + """ + # reconstruct samples from "dict of lists" to "list of dicts" + reconstructed_samples = [] + for i in range(len(samples[self.text_key])): + reconstructed_samples.append( + {key: samples[key][i] + for key in samples}) + samples_after_generation = [] + # do generation for each sample within the batch + for ori_sample in reconstructed_samples: + if self.keep_original_sample: + samples_after_generation.append(ori_sample) + generated_samples = self._process_single_sample(ori_sample, + rank=rank, + context=context) + if len(generated_samples) != 0: + samples_after_generation.extend(generated_samples) + # reconstruct samples from "list of dicts" to "dict of lists" + keys = samples_after_generation[0].keys() + res_samples = {} + for key in keys: + res_samples[key] = [s[key] for s in samples_after_generation] + + return res_samples diff --git a/data_juicer/ops/mapper/video_ffmpeg_wrapped_mapper.py b/data_juicer/ops/mapper/video_ffmpeg_wrapped_mapper.py new file mode 100644 index 000000000..447f13318 --- /dev/null +++ b/data_juicer/ops/mapper/video_ffmpeg_wrapped_mapper.py @@ -0,0 +1,75 @@ +from typing import Dict, List, Optional + +from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.file_utils import transfer_filename +from data_juicer.utils.logger_utils import HiddenPrints + +from ..base_op import OPERATORS, Mapper + +OP_NAME = 'video_ffmpeg_wrapped_mapper' + +with AvailabilityChecking(['ffmpeg-python'], OP_NAME), HiddenPrints(): + import ffmpeg + + +@OPERATORS.register_module(OP_NAME) +class VideoFFmpegWrappedMapper(Mapper): + """Simple wrapper for FFmpeg video filters. + """ + + def __init__( + self, + filter_name: Optional[str] = None, + filter_kwargs: Optional[Dict] = None, + global_args: Optional[List[str]] = None, + capture_stderr: bool = True, + overwrite_output: bool = True, + *args, + **kwargs, + ): + """ + Initialization method. + + :param filter_name: ffmpeg video filter name. + :param filter_kwargs: keyword-arguments passed to ffmpeg filter. + :param global_args: list-arguments passed to ffmpeg command-line. + :param capture_stderr: whether to capture stderr. + :param overwrite_output: whether to overwrite output file. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self._init_parameters = self.remove_extra_parameters(locals()) + + self.filter_name = filter_name + self.filter_kwargs = filter_kwargs + self.global_args = global_args + self.capture_stderr = capture_stderr + self.overwrite_output = overwrite_output + + def process(self, sample): + # there is no video in this sample + if self.video_key not in sample or not sample[self.video_key]: + return sample + + if self.filter_name is None: + return sample + + loaded_video_keys = sample[self.video_key] + proceessed = {} + for video_key in loaded_video_keys: + if video_key in proceessed: + continue + + output_key = transfer_filename(video_key, OP_NAME, + **self._init_parameters) + stream = (ffmpeg.input(video_key).filter( + self.filter_name, **self.filter_kwargs).output(output_key)) + if self.global_args is not None: + stream = stream.global_args(*self.global_args) + stream.run(capture_stderr=self.capture_stderr, + overwrite_output=self.overwrite_output) + proceessed[video_key] = output_key + + sample[self.video_key] = [proceessed[key] for key in loaded_video_keys] + return sample diff --git a/data_juicer/ops/mapper/video_resize_aspect_ratio_mapper.py b/data_juicer/ops/mapper/video_resize_aspect_ratio_mapper.py new file mode 100644 index 000000000..fa1de22d6 --- /dev/null +++ b/data_juicer/ops/mapper/video_resize_aspect_ratio_mapper.py @@ -0,0 +1,143 @@ +import math +import os +from fractions import Fraction + +from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.file_utils import transfer_filename +from data_juicer.utils.logger_utils import HiddenPrints +from data_juicer.utils.mm_utils import load_video + +from ..base_op import OPERATORS, Mapper + +OP_NAME = 'video_resize_aspect_ratio_mapper' + +with AvailabilityChecking(['ffmpeg-python'], OP_NAME), HiddenPrints(): + import ffmpeg + + +def rescale(width, height, ori_ratio, min_ratio, max_ratio, strategy): + + scaled_width = width + scaled_height = height + ori_ratio = Fraction(ori_ratio) + min_ratio = Fraction(min_ratio) + max_ratio = Fraction(max_ratio) + if ori_ratio < min_ratio: + if strategy == 'increase': + # increase width to meet the min ratio + scaled_width = math.ceil(height * min_ratio) + scaled_width += scaled_width % 2 + elif strategy == 'decrease': + # decrease height to meet the min ratio + scaled_height = math.floor(width / min_ratio) + scaled_height -= scaled_height % 2 + + elif ori_ratio > max_ratio: + if strategy == 'increase': + # increase height to meet the max ratio + scaled_height = math.ceil(width / max_ratio) + scaled_height += scaled_height % 2 + + elif strategy == 'decrease': + # decrease width to meet the max ratio + scaled_width = math.floor(height * max_ratio) + scaled_width -= scaled_width % 2 + + assert Fraction(scaled_width, scaled_height) >= min_ratio + assert Fraction(scaled_width, scaled_height) <= max_ratio + + scaled_width = max(2, scaled_width) + scaled_height = max(2, scaled_height) + + return scaled_width, scaled_height + + +@OPERATORS.register_module(OP_NAME) +class VideoResizeAspectRatioMapper(Mapper): + """Mapper to resize videos by aspect ratio. + AspectRatio = W / H. + """ + + STRATEGY = ['decrease', 'increase'] + + def __init__( + self, + min_ratio: str = '9/21', + max_ratio: str = '21/9', + strategy: str = 'increase', + *args, + **kwargs, + ): + """ + Initialization method. + + :param min_ratio: The minimum aspect ratio to enforce videos with + an aspect ratio below `min_ratio` will be resized to match + this minimum ratio. The ratio should be provided as a string + in the format "9:21" or "9/21". + :param max_ratio: The maximum aspect ratio to enforce videos with + an aspect ratio above `max_ratio` will be resized to match + this maximum ratio. The ratio should be provided as a string + in the format "21:9" or "21/9". + :param strategy: The resizing strategy to apply when adjusting the + video dimensions. It can be either 'decrease' to reduce the + dimension or 'increase' to enlarge it. Accepted values are + ['decrease', 'increase']. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self._init_parameters = self.remove_extra_parameters(locals()) + + strategy = strategy.lower() + if strategy not in self.STRATEGY: + raise ValueError( + f'force_original_aspect_ratio [{strategy}] is not supported. ' + f'Can only be one of {self.STRATEGY}. ') + + self.min_ratio = Fraction(str(min_ratio).replace(':', '/')) + self.max_ratio = Fraction(str(max_ratio).replace(':', '/')) + self.strategy = strategy + + def process(self, sample): + # there is no video in this sample + if self.video_key not in sample or not sample[self.video_key]: + return sample + + loaded_video_keys = sample[self.video_key] + for index, video_key in enumerate(loaded_video_keys): + + container = load_video(video_key) + video = container.streams.video[0] + original_width = video.codec_context.width + original_height = video.codec_context.height + original_aspect_ratio = Fraction(original_width, original_height) + container.close() + + if (original_aspect_ratio >= self.min_ratio + and original_aspect_ratio <= self.max_ratio): + continue + + scaled_width, scaled_height = rescale( + original_width, + original_height, + original_aspect_ratio, + self.min_ratio, + self.max_ratio, + self.strategy, + ) + resized_video_key = transfer_filename(video_key, OP_NAME, + **self._init_parameters) + if (not os.path.exists(resized_video_key) + or resized_video_key not in loaded_video_keys): + args = ['-nostdin', '-v', 'quiet', '-y'] + stream = ffmpeg.input(video_key) + stream = stream.filter('scale', + width=scaled_width, + height=scaled_height) + stream = stream.output(resized_video_key).global_args(*args) + stream.run() + loaded_video_keys[index] = resized_video_key + + sample[self.video_key] = loaded_video_keys + return sample diff --git a/data_juicer/ops/mapper/video_resize_resolution_mapper.py b/data_juicer/ops/mapper/video_resize_resolution_mapper.py new file mode 100644 index 000000000..5d026f8ae --- /dev/null +++ b/data_juicer/ops/mapper/video_resize_resolution_mapper.py @@ -0,0 +1,167 @@ +import math +import os +import sys + +from jsonargparse.typing import PositiveInt + +from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.file_utils import transfer_filename +from data_juicer.utils.logger_utils import HiddenPrints +from data_juicer.utils.mm_utils import load_video + +from ..base_op import OPERATORS, Mapper +from ..op_fusion import LOADED_VIDEOS + +OP_NAME = 'video_resize_resolution_mapper' + +with AvailabilityChecking(['ffmpeg-python'], OP_NAME), HiddenPrints(): + import ffmpeg + + +@OPERATORS.register_module(OP_NAME) +@LOADED_VIDEOS.register_module(OP_NAME) +class VideoResizeResolutionMapper(Mapper): + """ + Mapper to resize videos resolution. We leave the super resolution + with deep learning for future works. + """ + + def __init__(self, + min_width: PositiveInt = 1, + max_width: PositiveInt = sys.maxsize, + min_height: PositiveInt = 1, + max_height: PositiveInt = sys.maxsize, + force_original_aspect_ratio: str = 'disable', + force_divisible_by: PositiveInt = 2, + *args, + **kwargs): + """ + Initialization method. + + :param min_width: Videos with width less than 'min_width' will be + mapped to videos with equal or bigger width. + :param max_width: Videos with width more than 'max_width' will be + mapped to videos with equal of smaller width. + :param min_height: Videos with height less than 'min_height' will be + mapped to videos with equal or bigger height. + :param max_height: Videos with height more than 'max_height' will be + mapped to videos with equal or smaller height. + :param force_original_aspect_ratio: Enable decreasing or \ + increasing output video width or height if necessary \ + to keep the original aspect ratio, including ['disable', \ + 'decrease', 'increase']. + :param force_divisible_by: Ensures that both the output dimensions, \ + width and height, are divisible by the given integer when used \ + together with force_original_aspect_ratio, must be a positive \ + even number. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self._init_parameters = self.remove_extra_parameters(locals()) + + force_original_aspect_ratio = force_original_aspect_ratio.lower() + + if force_original_aspect_ratio not in [ + 'disable', 'decrease', 'increase' + ]: + raise ValueError( + f'force_original_aspect_ratio [{force_original_aspect_ratio}]' + f' is not supported. ' + f"Can only be one of ['disable', 'decrease', 'increase']. ") + if (force_divisible_by <= 1 or force_divisible_by % 2 + == 1) and force_original_aspect_ratio != 'disable': + raise ValueError( + f'force_divisible_by [{force_divisible_by}] must be a positive' + f' even number. ') + + self.min_width = min_width + self.max_width = max_width + self.min_height = min_height + self.max_height = max_height + self.scale_method = 'scale' + self.force_original_aspect_ratio = force_original_aspect_ratio + self.force_divisible_by = force_divisible_by + + def process(self, sample, context=False): + # there is no video in this sample + if self.video_key not in sample or not sample[self.video_key]: + return sample + + loaded_video_keys = sample[self.video_key] + + for index, video_key in enumerate(loaded_video_keys): + + container = load_video(video_key) + video = container.streams.video[0] + width = video.codec_context.width + height = video.codec_context.height + origin_ratio = width / height + container.close() + + if width >= self.min_width and width <= self.max_width and \ + height >= self.min_height and height <= self.max_height: + continue + + # keep the original aspect ratio as possible + if width < self.min_width: + height = self.min_width / origin_ratio + width = self.min_width + if width > self.max_width: + height = self.max_width / origin_ratio + width = self.max_width + if height < self.min_height: + width = self.min_height * origin_ratio + height = self.min_height + if height > self.max_height: + width = self.max_height * origin_ratio + height = self.max_height + + # the width and height of a video must be divisible by 2. + if self.force_original_aspect_ratio == 'disable': + force_divisible_by = 2 + else: + force_divisible_by = self.force_divisible_by + + # make sure in the range if possible + width = int(max(width, self.min_width)) + width = math.ceil(width / force_divisible_by) * force_divisible_by + width = int(min(width, self.max_width)) + width = int(width / force_divisible_by) * force_divisible_by + height = int(max(height, self.min_height)) + height = math.ceil( + height / force_divisible_by) * force_divisible_by + height = int(min(height, self.max_height)) + height = int(height / force_divisible_by) * force_divisible_by + + # keep the origin aspect ratio + if self.force_original_aspect_ratio == 'increase': + if width / height < origin_ratio: + width = height * origin_ratio + elif width / height > origin_ratio: + height = width / origin_ratio + elif self.force_original_aspect_ratio == 'decrease': + if width / height < origin_ratio: + height = width / origin_ratio + elif width / height > origin_ratio: + width = height * origin_ratio + width = int(round(width / force_divisible_by)) * force_divisible_by + height = int(round( + height / force_divisible_by)) * force_divisible_by + + # resize + resized_video_key = transfer_filename(video_key, OP_NAME, + **self._init_parameters) + if (not os.path.exists(resized_video_key) + or resized_video_key not in loaded_video_keys): + args = ['-nostdin', '-v', 'quiet', + '-y'] # close the ffmpeg log + stream = ffmpeg.input(video_key) + stream = stream.filter('scale', width=width, height=height) + stream = stream.output(resized_video_key).global_args(*args) + stream.run() + + loaded_video_keys[index] = resized_video_key + + sample[self.video_key] = loaded_video_keys + return sample diff --git a/data_juicer/ops/mapper/video_split_by_duration_mapper.py b/data_juicer/ops/mapper/video_split_by_duration_mapper.py new file mode 100644 index 000000000..c018506a4 --- /dev/null +++ b/data_juicer/ops/mapper/video_split_by_duration_mapper.py @@ -0,0 +1,156 @@ +import copy +import re + +import numpy as np + +from data_juicer.utils.file_utils import (add_suffix_to_filename, + transfer_filename) +from data_juicer.utils.mm_utils import (SpecialTokens, cut_video_by_seconds, + get_video_duration, load_video) + +from ..base_op import OPERATORS, Mapper +from ..op_fusion import LOADED_VIDEOS + + +def create_replacer(replacements): + + def replacer(match): + return replacements.pop(0) + + return replacer + + +OP_NAME = 'video_split_by_duration_mapper' + + +@OPERATORS.register_module(OP_NAME) +@LOADED_VIDEOS.register_module(OP_NAME) +class VideoSplitByDurationMapper(Mapper): + """Mapper to split video by duration. + """ + + def __init__(self, + split_duration: float = 10, + min_last_split_duration: float = 0, + keep_original_sample: bool = True, + *args, + **kwargs): + """ + Initialization method. + :param split_duration: duration of each video split in seconds. + :param min_last_split_duration: The minimum allowable duration in + seconds for the last video split. If the duration of the last + split is less than this value, it will be discarded. + :param keep_original_sample: whether to keep the original sample. If + it's set to False, there will be only cut sample in the + final datasets and the original sample will be removed. It's True + in default. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self._init_parameters = self.remove_extra_parameters(locals()) + self._batched_op = True + self.split_duration = split_duration + self.min_last_split_duration = min_last_split_duration + self.keep_original_sample = keep_original_sample + self.extra_args = kwargs + + def split_videos_by_duration(self, video_key, container): + video_duration = get_video_duration(container) + timestamps = np.arange(0, video_duration, self.split_duration).tolist() + count = 0 + split_video_keys = [] + for i in range(1, len(timestamps)): + split_video_key = transfer_filename(video_key, OP_NAME, + **self._init_parameters) + suffix = '_split-by-duration-' + str(count) + split_video_key = add_suffix_to_filename(split_video_key, suffix) + cut_video_by_seconds(container, split_video_key, timestamps[i - 1], + timestamps[i]) + split_video_keys.append(split_video_key) + count += 1 + + if video_duration - timestamps[-1] >= self.min_last_split_duration: + split_video_key = transfer_filename(video_key, OP_NAME, + **self._init_parameters) + suffix = '_split-by-duration-' + str(count) + split_video_key = add_suffix_to_filename(split_video_key, suffix) + cut_video_by_seconds(container, split_video_key, timestamps[-1]) + split_video_keys.append(split_video_key) + return split_video_keys + + def _process_single_sample(self, sample): + # there is no video in this sample + if self.video_key not in sample \ + or sample[self.video_key] is None \ + or len(sample[self.video_key]) == 0: + return [] + + # the split results + split_sample = copy.deepcopy(sample) + split_sample[self.text_key] = '' + + # load all video(s) + loaded_video_keys = sample[self.video_key] + videos = {} + for loaded_video_key in loaded_video_keys: + if loaded_video_key not in videos: + # avoid loading the same videos + video = load_video(loaded_video_key) + videos[loaded_video_key] = video + + split_video_keys = [] + offset = 0 + # split each video chunk by chunk + for chunk in sample[self.text_key].split(SpecialTokens.eoc): + # skip empty chunks or contents after the last eoc token + if not chunk.strip(): + continue + else: + video_count = chunk.count(SpecialTokens.video) + place_holders = [] + for video_key in loaded_video_keys[offset:offset + + video_count]: + video = videos[video_key] + new_video_keys = self.split_videos_by_duration( + video_key, video) + video.close() + split_video_keys.extend(new_video_keys) + place_holders.append(SpecialTokens.video * + len(new_video_keys)) + + # insert the generated text according to given mode + replacer_function = create_replacer(place_holders) + new_split_text_per_chunk = re.sub(SpecialTokens.video, + replacer_function, chunk) + split_sample[ + self. + text_key] += f'{new_split_text_per_chunk}{SpecialTokens.eoc}' # noqa: E501 + offset += video_count + + split_sample[self.video_key] = split_video_keys + return [split_sample] + + def process(self, samples): + # reconstruct samples from "dict of lists" to "list of dicts" + reconstructed_samples = [] + for i in range(len(samples[self.text_key])): + reconstructed_samples.append( + {key: samples[key][i] + for key in samples}) + samples_after_split = [] + # do split for each sample within the batch + for ori_sample in reconstructed_samples: + if self.keep_original_sample: + samples_after_split.append(ori_sample) + generated_samples = self._process_single_sample(ori_sample) + if len(generated_samples) != 0: + samples_after_split.extend(generated_samples) + # reconstruct samples from "list of dicts" to "dict of lists" + keys = samples_after_split[0].keys() + res_samples = {} + for key in keys: + res_samples[key] = [s[key] for s in samples_after_split] + + return res_samples diff --git a/data_juicer/ops/mapper/video_split_by_key_frame_mapper.py b/data_juicer/ops/mapper/video_split_by_key_frame_mapper.py new file mode 100644 index 000000000..f33dfd20d --- /dev/null +++ b/data_juicer/ops/mapper/video_split_by_key_frame_mapper.py @@ -0,0 +1,142 @@ +import copy +import re + +from data_juicer.utils.file_utils import (add_suffix_to_filename, + transfer_filename) +from data_juicer.utils.mm_utils import (SpecialTokens, cut_video_by_seconds, + get_key_frame_seconds, load_video) + +from ..base_op import OPERATORS, Mapper +from ..op_fusion import LOADED_VIDEOS + + +def create_replacer(replacements): + + def replacer(match): + return replacements.pop(0) + + return replacer + + +OP_NAME = 'video_split_by_key_frame_mapper' + + +@OPERATORS.register_module(OP_NAME) +@LOADED_VIDEOS.register_module(OP_NAME) +class VideoSplitByKeyFrameMapper(Mapper): + """Mapper to split video by key frame. + """ + + def __init__(self, keep_original_sample: bool = True, *args, **kwargs): + """ + Initialization method. + + :param keep_original_sample: whether to keep the original sample. If + it's set to False, there will be only split sample in the + final datasets and the original sample will be removed. It's True + in default. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self._init_parameters = self.remove_extra_parameters(locals()) + self._batched_op = True + self.keep_original_sample = keep_original_sample + self.extra_args = kwargs + + def get_split_key_frame(self, video_key, container): + timestamps = get_key_frame_seconds(container) + + count = 0 + split_video_keys = [] + for i in range(1, len(timestamps)): + split_video_key = transfer_filename(video_key, OP_NAME, + **self._init_parameters) + suffix = '_split-by-key-frame-' + str(count) + split_video_key = add_suffix_to_filename(split_video_key, suffix) + cut_video_by_seconds(container, split_video_key, timestamps[i - 1], + timestamps[i]) + split_video_keys.append(split_video_key) + count += 1 + + split_video_key = transfer_filename(video_key, OP_NAME, + **self._init_parameters) + suffix = '_split-by-key-frame-' + str(count) + split_video_key = add_suffix_to_filename(split_video_key, suffix) + cut_video_by_seconds(container, split_video_key, timestamps[-1]) + split_video_keys.append(split_video_key) + return split_video_keys + + def _process_single_sample(self, sample): + # there is no video in this sample + if self.video_key not in sample \ + or sample[self.video_key] is None \ + or len(sample[self.video_key]) == 0: + return [] + + # the split results + split_sample = copy.deepcopy(sample) + split_sample[self.text_key] = '' + + # load all video(s) + loaded_video_keys = sample[self.video_key] + videos = {} + for loaded_video_key in loaded_video_keys: + if loaded_video_key not in videos: + # avoid loading the same videos + video = load_video(loaded_video_key) + videos[loaded_video_key] = video + + split_video_keys = [] + offset = 0 + # split each video chunk by chunk + for chunk in sample[self.text_key].split(SpecialTokens.eoc): + # skip empty chunks or contents after the last eoc token + if not chunk.strip(): + continue + else: + video_count = chunk.count(SpecialTokens.video) + place_holders = [] + for video_key in loaded_video_keys[offset:offset + + video_count]: + video = videos[video_key] + new_video_keys = self.get_split_key_frame(video_key, video) + video.close() + split_video_keys.extend(new_video_keys) + place_holders.append(SpecialTokens.video * + len(new_video_keys)) + + # insert the generated text according to given mode + replacer_function = create_replacer(place_holders) + new_split_text_per_chunk = re.sub(SpecialTokens.video, + replacer_function, chunk) + split_sample[ + self. + text_key] += f'{new_split_text_per_chunk}{SpecialTokens.eoc}' # noqa: E501 + offset += video_count + + split_sample[self.video_key] = split_video_keys + return [split_sample] + + def process(self, samples): + # reconstruct samples from "dict of lists" to "list of dicts" + reconstructed_samples = [] + for i in range(len(samples[self.text_key])): + reconstructed_samples.append( + {key: samples[key][i] + for key in samples}) + samples_after_split = [] + # do split for each sample within the batch + for ori_sample in reconstructed_samples: + if self.keep_original_sample: + samples_after_split.append(ori_sample) + generated_samples = self._process_single_sample(ori_sample) + if len(generated_samples) != 0: + samples_after_split.extend(generated_samples) + # reconstruct samples from "list of dicts" to "dict of lists" + keys = samples_after_split[0].keys() + res_samples = {} + for key in keys: + res_samples[key] = [s[key] for s in samples_after_split] + + return res_samples diff --git a/data_juicer/ops/mapper/video_split_by_scene_mapper.py b/data_juicer/ops/mapper/video_split_by_scene_mapper.py new file mode 100644 index 000000000..fd8327197 --- /dev/null +++ b/data_juicer/ops/mapper/video_split_by_scene_mapper.py @@ -0,0 +1,143 @@ +import math +import re +from itertools import chain + +from jsonargparse.typing import NonNegativeFloat, NonNegativeInt + +from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.file_utils import (add_suffix_to_filename, + transfer_filename) +from data_juicer.utils.mm_utils import SpecialTokens + +from ..base_op import OPERATORS, Mapper + +OP_NAME = 'video_split_by_scene_mapper' + +with AvailabilityChecking(['scenedetect[opencv]'], OP_NAME): + import scenedetect.detectors + from scenedetect import detect, split_video_ffmpeg + + +def replace_func(match, scene_counts_iter): + try: + count = next(scene_counts_iter) + return SpecialTokens.video * count + except StopIteration: + return match.group(0) + + +@OPERATORS.register_module(OP_NAME) +class VideoSplitBySceneMapper(Mapper): + """Mapper to cut videos into scene clips. + """ + + # Define shared detector keys and their properties + avaliable_detectors = { + 'ContentDetector': ['weights', 'luma_only', 'kernel_size'], + 'AdaptiveDetector': [ + 'window_width', 'min_content_val', 'weights', 'luma_only', + 'kernel_size', 'video_manager', 'min_delta_hsv' + ], + 'ThresholdDetector': + ['fade_bias', 'add_final_scene', 'method', 'block_size'] + } + + def __init__(self, + detector: str = 'ContentDetector', + threshold: NonNegativeFloat = 27.0, + min_scene_len: NonNegativeInt = 15, + show_progress: bool = False, + *args, + **kwargs): + """ + Initialization method. + + :param detector: Algorithm from `scenedetect.detectors`. Should be one + of ['ContentDetector', 'ThresholdDetector', 'AdaptiveDetector`]. + :param threshold: Threshold passed to the detector. + :param min_scene_len: Minimum length of any scene. + :param show_progress: Whether to show progress from scenedetect. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self._init_parameters = self.remove_extra_parameters(locals()) + + if detector not in self.avaliable_detectors: + raise ValueError( + f'Scene detector {detector} is not supported. ' + f'Can only be one of {list(self.avaliable_detectors.keys())}') + + self.detector = detector + self.threshold = threshold + self.min_scene_len = min_scene_len + self.show_progress = show_progress + + # prepare detector args + avaliable_kwargs = self.avaliable_detectors[self.detector] + self.detector_class = getattr(scenedetect.detectors, self.detector) + self.detector_kwargs = { + key: kwargs[key] + for key in avaliable_kwargs if key in kwargs + } + + def process(self, sample, context=False): + # there is no video in this sample + if self.video_key not in sample or not sample[self.video_key]: + return sample + + # load videos + loaded_video_keys = sample[self.video_key] + output_video_keys = {} + scene_counts = {} + + for video_key in loaded_video_keys: + + # skip duplicate + if video_key in output_video_keys: + continue + + redirected_video_key = transfer_filename(video_key, OP_NAME, + **self._init_parameters) + output_template = add_suffix_to_filename(redirected_video_key, + '_Scene-$SCENE_NUMBER') + + # detect scenes + detector = self.detector_class(self.threshold, self.min_scene_len, + **self.detector_kwargs) + scene_list = detect(video_key, + detector, + show_progress=self.show_progress, + start_in_scene=True) + scene_counts[video_key] = len(scene_list) + + if len(scene_list) > 1: + # sync with split_video_ffmpeg internal + scene_num_format = f'%0{max(3, math.floor(math.log(len(scene_list), 10)) + 1)}d' # noqa: E501 + output_video_keys[video_key] = [ + output_template.replace('$SCENE_NUMBER', + scene_num_format % (i + 1)) + for i in range(len(scene_list)) + ] + # split video into clips + split_video_ffmpeg(video_key, + scene_list, + output_template, + show_progress=self.show_progress) + else: + output_video_keys[video_key] = [video_key] + + # replace splited video tokens + if self.text_key in sample: + scene_counts_iter = iter( + [scene_counts[key] for key in loaded_video_keys]) + updated_text = re.sub( + re.escape(SpecialTokens.video), + lambda match: replace_func(match, scene_counts_iter), + sample[self.text_key]) + sample[self.text_key] = updated_text + + sample[self.video_key] = list( + chain.from_iterable( + [output_video_keys[key] for key in loaded_video_keys])) + return sample diff --git a/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py b/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py new file mode 100644 index 000000000..6a9636160 --- /dev/null +++ b/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py @@ -0,0 +1,85 @@ +import librosa + +from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.constant import Fields +from data_juicer.utils.mm_utils import extract_audio_from_video +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, Mapper + +OP_NAME = 'video_tagging_from_audio_mapper' + +with AvailabilityChecking(['torch', 'transformers', 'torchaudio'], OP_NAME): + import torch + import torchaudio # noqa: F401 + import transformers # noqa: F401 + + # avoid hanging when calling recognizeAnything in multiprocessing + torch.set_num_threads(1) + + +@OPERATORS.register_module(OP_NAME) +class VideoTaggingFromAudioMapper(Mapper): + """Mapper to generate video tags from audio streams extracted by video + using the Audio Spectrogram Transformer. + """ + + def __init__(self, + hf_ast='MIT/ast-finetuned-audioset-10-10-0.4593', + *args, + **kwargs): + """ + Initialization method. + + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.model_key = prepare_model(model_type='huggingface', + pretrained_model_name_or_path=hf_ast) + self._model_sampling_rate = 16000 + self._no_audio_label = 'EMPTY' + + def process(self, sample, rank=None): + # check if it's generated already + if Fields.video_audio_tags in sample: + return sample + + # there is no video in this sample + if self.video_key not in sample or not sample[self.video_key]: + sample[Fields.video_audio_tags] = [] + return sample + + # load video paths + loaded_video_keys = sample[self.video_key] + + model, feature_extractor = get_model(self.model_key, rank=rank) + video_audio_tags = [] + for video_path in loaded_video_keys: + # only extract audio data and sr for index 0 for now + ys, srs, valid_indexes = extract_audio_from_video( + video_path, stream_indexes=[0]) + if len(valid_indexes) == 0: + # there is no valid audio streams. Skip! + video_audio_tags.append(self._no_audio_label) + continue + + # inference + y = ys[0] + sr = srs[0] + # check if it meets the sampling rate condition of the model + if sr != self._model_sampling_rate: + y = librosa.resample(y, + orig_sr=sr, + target_sr=self._model_sampling_rate) + sr = self._model_sampling_rate + inputs = feature_extractor(y, + sampling_rate=sr, + return_tensors='pt') + with torch.no_grad(): + logits = model(**inputs).logits + predicted_tag_id = torch.argmax(logits, dim=-1).item() + predicted_tag = model.config.id2label[predicted_tag_id] + video_audio_tags.append(predicted_tag) + sample[Fields.video_audio_tags] = video_audio_tags + return sample diff --git a/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py b/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py new file mode 100644 index 000000000..0c69461a3 --- /dev/null +++ b/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py @@ -0,0 +1,111 @@ +from collections import Counter + +from jsonargparse.typing import PositiveInt + +from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.constant import Fields +from data_juicer.utils.mm_utils import (extract_key_frames, + extract_video_frames_uniformly, + load_data_with_context, load_video) +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, Mapper +from ..op_fusion import LOADED_VIDEOS + +OP_NAME = 'video_tagging_from_frames_mapper' + +with AvailabilityChecking( + ['torch', 'git+https://github.com/xinyu1205/recognize-anything.git'], + OP_NAME): + import ram # noqa: F401 + import torch + + # avoid hanging when calling recognizeAnything in multiprocessing + torch.set_num_threads(1) + + +@OPERATORS.register_module(OP_NAME) +@LOADED_VIDEOS.register_module(OP_NAME) +class VideoTaggingFromFramesMapper(Mapper): + """Mapper to generate video tags from frames extract by video. + """ + + def __init__(self, + frame_sampling_method: str = 'all_keyframes', + frame_num: PositiveInt = 3, + *args, + **kwargs): + """ + Initialization method. + + :param frame_sampling_method: sampling method of extracting frame + images from the videos. Should be one of + ["all_keyframes", "uniform"]. + The former one extracts all key frames (the number of which depends + on the duration of the video) and the latter one extract specified + number of frames uniformly from the video. + Default: "all_keyframes". + :param frame_num: the number of frames to be extracted uniformly from + the video. Only works when frame_sampling_method is "uniform". If + it's 1, only the middle frame will be extracted. If it's 2, only + the first and the last frames will be extracted. If it's larger + than 2, in addition to the first and the last frames, other frames + will be extracted uniformly within the video duration. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + if frame_sampling_method not in ['all_keyframes', 'uniform']: + raise ValueError( + f'Frame sampling method [{frame_sampling_method}] is not ' + f'supported. Can only be one of ["all_keyframes", "uniform"].') + self.model_key = prepare_model( + model_type='recognizeAnything', + pretrained_model_name_or_path='ram_plus_swin_large_14m.pth', + input_size=384) + self.frame_sampling_method = frame_sampling_method + self.frame_num = frame_num + from ram import get_transform + self.transform = get_transform(image_size=384) + + def process(self, sample, rank=None, context=False): + # check if it's generated already + if Fields.video_frame_tags in sample: + return sample + + # there is no video in this sample + if self.video_key not in sample or not sample[self.video_key]: + sample[Fields.video_frame_tags] = [] + return sample + + # load videos + loaded_video_keys = sample[self.video_key] + sample, videos = load_data_with_context(sample, context, + loaded_video_keys, load_video) + + model = get_model(self.model_key, rank=rank) + video_tags = [] + for _, value in enumerate(loaded_video_keys): + video = videos[value] + + # extract frame images + if self.frame_sampling_method == 'all_keyframes': + frames = extract_key_frames(video) + elif self.frame_sampling_method == 'uniform': + frames = extract_video_frames_uniformly(video, self.frame_num) + else: + video_tags.append([]) + frames = [] + + frame_tensor = torch.stack([ + self.transform(frame.to_image()) for frame in frames + ]).to(next(model.parameters()).device) + with torch.no_grad(): + tags, _ = model.generate_tag(frame_tensor) + + words = [word.strip() for tag in tags for word in tag.split('|')] + word_count = Counter(words) + sorted_word_list = [item for item, _ in word_count.most_common()] + video_tags.append(sorted_word_list) + sample[Fields.video_frame_tags] = video_tags + return sample diff --git a/data_juicer/ops/op_fusion.py b/data_juicer/ops/op_fusion.py index 69f072dd6..8312b362c 100644 --- a/data_juicer/ops/op_fusion.py +++ b/data_juicer/ops/op_fusion.py @@ -18,8 +18,11 @@ # audios LOADED_AUDIOS = Registry(InterVars.loaded_audios) +# videos +LOADED_VIDEOS = Registry(InterVars.loaded_videos) + # all -ALL_INTER_VARS = [INTER_LINES, INTER_WORDS, LOADED_IMAGES] +ALL_INTER_VARS = [INTER_LINES, INTER_WORDS, LOADED_IMAGES, LOADED_VIDEOS] def fuse_operators(process_list, ops): @@ -135,12 +138,19 @@ def __init__(self, fused_filters: List): self.fused_filters = fused_filters def compute_stats(self, sample): + import av + # context for the intermediate vars sample[Fields.context] = {} for op in self.fused_filters: # open the context for these fused ops sample = op.compute_stats(sample, context=True) # clean up the contexts after processing + # check if there are containers that need to be closed + for context_key in sample[Fields.context]: + if isinstance(sample[Fields.context][context_key], + av.container.InputContainer): + sample[Fields.context][context_key].close() _ = sample.pop(Fields.context) return sample diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 31049d62f..75e8e3914 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -15,6 +15,10 @@ class Fields(object): context = DEFAULT_PREFIX + 'context__' suffix = DEFAULT_PREFIX + 'suffix__' + # video_frame_tags + video_frame_tags = DEFAULT_PREFIX + 'video_frame_tags__' + video_audio_tags = DEFAULT_PREFIX + 'video_audio_tags__' + class StatsKeysMeta(type): """ @@ -125,18 +129,32 @@ class StatsKeysConstant(object): image_sizes = 'image_sizes' face_ratios = 'face_ratios' face_detections = 'face_detections' + image_aesthetics_scores = 'image_aesthetics_scores' # audios audio_duration = 'audio_duration' audio_nmf_snr = 'audio_nmf_snr' audio_sizes = 'audio_sizes' + # videos + video_duration = 'video_duration' + video_aspect_ratios = 'video_aspect_ratios' + video_width = 'video_width' + video_height = 'video_height' + video_ocr_area_ratio = 'video_ocr_area_ratio' + video_aesthetic_score = 'video_aesthetic_score' + video_frames_aesthetics_score = 'video_frames_aesthetics_score' + video_motion_score = 'video_motion_score' + # multimodal # image-text image_text_similarity = 'image_text_similarity' image_text_matching_score = 'image_text_matching_score' phrase_grounding_recall = 'phrase_grounding_recall' + # video-text + video_frames_text_matching_score = 'video_frames_text_matching_score' + class StatsKeys(object, metaclass=StatsKeysMeta): _constants_class = StatsKeysConstant @@ -150,6 +168,9 @@ class HashKeys(object): # image imagehash = DEFAULT_PREFIX + 'imagehash' + # video + videohash = DEFAULT_PREFIX + 'videohash' + class InterVars(object): # text @@ -162,3 +183,6 @@ class InterVars(object): # audios loaded_audios = DEFAULT_PREFIX + 'loaded_audios' # (data, sampling_rate) + + # videos + loaded_videos = DEFAULT_PREFIX + 'loaded_videos' # InputContainer from av diff --git a/data_juicer/utils/file_utils.py b/data_juicer/utils/file_utils.py index a78572a53..d8f224872 100644 --- a/data_juicer/utils/file_utils.py +++ b/data_juicer/utils/file_utils.py @@ -1,8 +1,15 @@ +import copy +import hashlib +import os +import re +from datetime import datetime, timezone from pathlib import Path from typing import List, Tuple, Union from datasets.utils.extract import ZstdExtractor as Extractor +from data_juicer.utils.constant import DEFAULT_PREFIX + def find_files_with_suffix( path: Union[str, Path], @@ -67,3 +74,105 @@ def is_absolute_path(path: Union[str, Path]) -> bool: path is a relative path. """ return Path(path).is_absolute() + + +def add_suffix_to_filename(filename, suffix): + """ + Add a suffix to the filename. Only regard the content after the last dot + as the file extension. + E.g. + 1. abc.jpg + "_resized" --> abc_resized.jpg + 2. edf.xyz.csv + "_processed" --> edf.xyz_processed.csv + 3. /path/to/file.json + "_suf" --> /path/to/file_suf.json + 4. ds.tar.gz + "_whoops" --> ds.tar_whoops.gz (maybe unexpected) + + :param filename: input filename + :param suffix: suffix string to be added + """ + name, ext = os.path.splitext(filename) + new_name = f'{name}{suffix}{ext}' + return new_name + + +def dict_to_hash(input_dict, hash_length=None): + """ + hash a dict to a string with length hash_length + + :param input_dict: the given dict + """ + sorted_items = sorted(input_dict.items()) + dict_string = str(sorted_items).encode() + hasher = hashlib.sha256() + hasher.update(dict_string) + hash_value = hasher.hexdigest() + if hash_length: + hash_value = hash_value[:hash_length] + return hash_value + + +def create_directory_if_not_exists(directory_path): + """ + create a directory if not exists, this function is process safe + + :param directory_path: directory path to be create + """ + try: + os.makedirs(directory_path, exist_ok=True) + except FileExistsError: + # We ignore the except from multi processes or threads. + # Just make sure the directory exists. + pass + + +def transfer_filename(original_filepath: Union[str, Path], op_name, + **op_kwargs): + """ + According to the op and hashing its parameters 'op_kwargs' addition + to the process id and current time as the 'hash_val', map the + original_filepath to another unique file path. + E.g. + 1. abc.jpg --> + {op_name}/abc__dj_hash_#{hash_val}#.jpg + 2. ./abc.jpg --> + ./{op_name}/abc__dj_hash_#{hash_val}#.jpg + 3. /path/to/abc.jpg --> + /path/to/{op_name}/abc__dj_hash_#{hash_val}#.jpg + 4. /path/to/{op_name}/abc.jpg --> + /path/to/{op_name}/abc__dj_hash_#{hash_val}#.jpg + 5. /path/to/{op_name}/abc__dj_hash_#{hash_val1}#.jpg --> + /path/to/{op_name}/abc__dj_hash_#{hash_val2}#.jpg + """ + # produce the directory + original_dir = os.path.dirname(original_filepath) + parent_dir = os.path.basename(original_dir) + if parent_dir == op_name: + new_dir = original_dir + else: + new_dir = os.path.join(original_dir, f'{op_name}') + create_directory_if_not_exists(new_dir) + + # produce the unique hash code + unique_parameters = copy.deepcopy(op_kwargs) + unique_parameters[f'{DEFAULT_PREFIX}pid'] = os.getpid() + unique_parameters[f'{DEFAULT_PREFIX}timestamp'] = str( + datetime.now(timezone.utc)) + unique_hash = dict_to_hash(unique_parameters) + + # if the input data is produced by data-juicer, replace the hash code + # else append hash value to filename + def add_hash_value(text, new_hash_value): + pattern = r'__dj_hash_#(.*?)#' + + match = re.search(pattern, text) + # draw the string produced by data-juicer + if match: + text = text[:match.start()] + + return f'{text}__dj_hash_#{new_hash_value}#' + + original_filename = os.path.basename(original_filepath) + name, ext = os.path.splitext(original_filename) + new_name = add_hash_value(name, unique_hash) + new_filepath = os.path.join(new_dir, f'{new_name}{ext}') + + return new_filepath diff --git a/data_juicer/utils/mm_utils.py b/data_juicer/utils/mm_utils.py index 04e08d695..b6c1db5b4 100644 --- a/data_juicer/utils/mm_utils.py +++ b/data_juicer/utils/mm_utils.py @@ -1,10 +1,19 @@ import base64 +import datetime +import os import re +from typing import List, Union +import av import numpy as np from datasets import Audio, Image +from loguru import logger from data_juicer.utils.constant import DEFAULT_PREFIX, Fields +from data_juicer.utils.file_utils import add_suffix_to_filename + +# suppress most warnings from av +av.logging.set_level(av.logging.PANIC) # A class to keep special tokens for multimodal information in the texts @@ -13,11 +22,22 @@ class SpecialTokens(object): # modality image = f'<{DEFAULT_PREFIX}image>' audio = f'<{DEFAULT_PREFIX}audio>' + video = f'<{DEFAULT_PREFIX}video>' # others eoc = f'<|{DEFAULT_PREFIX}eoc|>' +AV_STREAM_THREAD_TYPE = 'AUTO' +""" +av stream thread type support "SLICE", "FRAME", "AUTO". + "SLICE": Decode more than one part of a single frame at once + "FRAME": Decode more than one frame at once + "AUTO": Using both "FRAME" and "SLICE" + AUTO is faster when there are no video latency. +""" + + def get_special_tokens(): special_token_dict = { key: value @@ -62,7 +82,7 @@ def load_data_with_context(sample, context, loaded_data_keys, load_func): return sample, data -# Image +# Images def load_images(paths): return [load_image(path) for path in paths] @@ -121,7 +141,7 @@ def iou(box1, box2): return 1.0 * intersection / union -# Audio +# Audios def load_audios(paths): return [load_audio(path) for path in paths] @@ -132,6 +152,432 @@ def load_audio(path, sampling_rate=None): return aud['array'], aud['sampling_rate'] +# Videos +def load_videos(paths): + return [load_video(path) for path in paths] + + +def load_video(path): + """ + Load a video using its path. + + :param path: the path to this video. + :return: a container object form PyAv library, which contains all streams + in this video (video/audio/...) and can be used to decode these streams + to frames. + """ + container = av.open(path) + return container + + +def get_video_duration(input_video: Union[str, av.container.InputContainer], + video_stream_index=0): + """ + Get the video's duration from the container + :param input_video: the container object form PyAv library, which + contains all streams in this video (video/audio/...) and can be used + to decode these streams to frames. + :param video_stream_index: the video stream index to decode, + default set to 0. + :return: duration of the video in second + """ + if isinstance(input_video, str): + container = av.open(input_video) + elif isinstance(input_video, av.container.InputContainer): + container = input_video + else: + raise ValueError(f'Unsupported type of input_video. Should be one of ' + f'[str, av.container.InputContainer], but given ' + f'[{type(input_video)}].') + + input_video_stream = container.streams.video[video_stream_index] + duration = input_video_stream.duration * input_video_stream.time_base + return float(duration) + + +def get_decoded_frames_from_video( + input_video: Union[str, av.container.InputContainer], + video_stream_index=0): + """ + Get the video's frames from the container + :param input_video: the container object form PyAv library, which + contains all streams in this video (video/audio/...) and can be used + to decode these streams to frames. + :param video_stream_index: the video stream index to decode, + default set to 0. + :return: an iterator of all the frames of the video + """ + if isinstance(input_video, str): + container = av.open(input_video) + elif isinstance(input_video, av.container.InputContainer): + container = input_video + stream = container.streams.video[video_stream_index] + # use "AUTO" thread_type for faster decode + stream.thread_type = AV_STREAM_THREAD_TYPE + return container.decode(stream) + + +def cut_video_by_seconds( + input_video: Union[str, av.container.InputContainer], + output_video: str, + start_seconds: float, + end_seconds: float = None, +): + """ + Cut a video into several segments by times in second. + + :param input_video: the path to input video or the video container. + :param output_video: the path to output video. + :param start_seconds: the start time in second. + :param end_seconds: the end time in second. If it's None, this function + will cut the video from the start_seconds to the end of the video. + """ + # open the original video + if isinstance(input_video, str): + container = av.open(input_video) + else: + container = input_video + + # create the output video + output_container = av.open(output_video, 'w') + + # add the video stream into the output video according to input video + input_video_stream = container.streams.video[0] + codec_name = input_video_stream.codec_context.name + fps = input_video_stream.base_rate + output_video_stream = output_container.add_stream(codec_name, + rate=str(fps)) + output_video_stream.width = input_video_stream.codec_context.width + output_video_stream.height = input_video_stream.codec_context.height + output_video_stream.pix_fmt = input_video_stream.codec_context.pix_fmt + + # add the audio stream into the output video with template of input audio + if len(container.streams.audio) == 0: + input_audio_stream = None + else: + input_audio_stream = container.streams.audio[0] + output_container.add_stream(template=input_audio_stream) + + # seek to the start time, time must be in microsecond if no + # stream is specified + container.seek(int(start_seconds * 1000000), + any_frame=False, + backward=True) + + # copy the video and audio streams until the end time + # NOTICE: for different streams, the time have to be converted to be + # in the corresponding time base. + video_at_the_end = False + # compute the start/end pts for video/audio streams + video_start_pts = int(start_seconds / input_video_stream.time_base) + video_end_pts = (end_seconds / input_video_stream.time_base + if end_seconds else input_video_stream.duration) + if input_audio_stream is not None: + audio_start_pts = int(start_seconds / input_audio_stream.time_base) + audio_end_pts = (end_seconds / input_audio_stream.time_base + if end_seconds else input_audio_stream.duration) + for packet in container.demux(input_video_stream, input_audio_stream): + if packet.stream.type == 'video': + for frame in packet.decode(): + if frame.pts < video_start_pts: + continue + if frame.pts > video_end_pts: + # continue to check until the next P/I frame + if frame.pict_type in {'P', 'I'}: + video_at_the_end = True + break + continue + frame.pts -= video_start_pts # timestamp alignment + for inter_packet in output_video_stream.encode(frame): + output_container.mux(inter_packet) + elif packet.stream.type == 'audio': + if packet.pts is None or packet.dts is None: + continue + if packet.pts < audio_start_pts or packet.pts > audio_end_pts: + continue + packet.pts -= audio_start_pts + packet.dts -= audio_start_pts + output_container.mux(packet) + if video_at_the_end: + break + + # flush all packets + for packet in output_video_stream.encode(): + output_container.mux(packet) + + # close the output videos + if isinstance(input_video, str): + container.close() + output_container.close() + + +def extract_key_frames(input_video: Union[str, av.container.InputContainer]): + """ + Extract key frames from the input video. + + :param input_video: input video path or container. + :return: a list of key frames. + """ + # load the input video + if isinstance(input_video, str): + container = av.open(input_video) + elif isinstance(input_video, av.container.InputContainer): + container = input_video + else: + raise ValueError(f'Unsupported type of input_video. Should be one of ' + f'[str, av.container.InputContainer], but given ' + f'[{type(input_video)}].') + + key_frames = [] + input_video_stream = container.streams.video[0] + ori_skip_method = input_video_stream.codec_context.skip_frame + input_video_stream.codec_context.skip_frame = 'NONKEY' + # restore to the beginning of the video + container.seek(0, backward=False, any_frame=False) + for frame in container.decode(input_video_stream): + key_frames.append(frame) + # restore to the original skip_type + input_video_stream.codec_context.skip_frame = ori_skip_method + + if isinstance(input_video, str): + container.close() + return key_frames + + +def get_key_frame_seconds(input_video: Union[str, + av.container.InputContainer]): + """ + Get seconds of key frames in the input video. + """ + key_frames = extract_key_frames(input_video) + ts = [float(f.pts * f.time_base) for f in key_frames] + ts.sort() + return ts + + +def extract_video_frames_uniformly( + input_video: Union[str, av.container.InputContainer], + frame_num: int, +): + """ + Extract a number of video frames uniformly within the video duration. + + :param input_video: input video path or container. + :param frame_num: The number of frames to be extracted. If it's 1, only the + middle frame will be extracted. If it's 2, only the first and the last + frames will be extracted. If it's larger than 2, in addition to the + first and the last frames, other frames will be extracted uniformly + within the video duration. + :return: a list of extracted frames. + """ + # load the input video + if isinstance(input_video, str): + container = av.open(input_video) + elif isinstance(input_video, av.container.InputContainer): + container = input_video + else: + raise ValueError(f'Unsupported type of input_video. Should be one of ' + f'[str, av.container.InputContainer], but given ' + f'[{type(input_video)}].') + + input_video_stream = container.streams.video[0] + total_frame_num = input_video_stream.frames + if total_frame_num < frame_num: + logger.warning('Number of frames to be extracted is larger than the ' + 'total number of frames in this video. Set it to the ' + 'total number of frames.') + frame_num = total_frame_num + # calculate the frame seconds to be extracted + duration = input_video_stream.duration * input_video_stream.time_base + if frame_num == 1: + extract_seconds = [duration / 2] + else: + step = duration / (frame_num - 1) + extract_seconds = [step * i for i in range(0, frame_num)] + + # group durations according to the seconds of key frames + key_frame_seconds = get_key_frame_seconds(container) + if 0.0 not in key_frame_seconds: + key_frame_seconds = [0.0] + key_frame_seconds + if len(key_frame_seconds) == 1: + second_groups = [extract_seconds] + else: + second_groups = [] + idx = 0 + group_id = 0 + curr_group = [] + curr_upper_bound_ts = key_frame_seconds[group_id + 1] + while idx < len(extract_seconds): + curr_ts = extract_seconds[idx] + if curr_ts < curr_upper_bound_ts: + curr_group.append(curr_ts) + idx += 1 + else: + second_groups.append(curr_group) + group_id += 1 + curr_group = [] + if group_id >= len(key_frame_seconds) - 1: + break + curr_upper_bound_ts = key_frame_seconds[group_id + 1] + if len(curr_group) > 0: + second_groups.append(curr_group) + if idx < len(extract_seconds): + second_groups.append(extract_seconds[idx:]) + + # extract frames by their group's key frames + extracted_frames = [] + time_base = input_video_stream.time_base + for i, second_group in enumerate(second_groups): + key_frame_second = key_frame_seconds[i] + if len(second_group) == 0: + continue + if key_frame_second == 0.0: + # search from the beginning + container.seek(0, backward=False, any_frame=True) + search_idx = 0 + curr_pts = second_group[search_idx] / time_base + for frame in container.decode(input_video_stream): + if frame.pts >= curr_pts: + extracted_frames.append(frame) + search_idx += 1 + if search_idx >= len(second_group): + break + curr_pts = second_group[search_idx] / time_base + else: + # search from a key frame + container.seek(int(key_frame_second * 1e6)) + search_idx = 0 + curr_pts = second_group[search_idx] / time_base + find_all = False + for packet in container.demux(input_video_stream): + for frame in packet.decode(): + if frame.pts >= curr_pts: + extracted_frames.append(frame) + search_idx += 1 + if search_idx >= len(second_group): + find_all = True + break + curr_pts = second_group[search_idx] / time_base + if find_all: + break + if not find_all and frame is not None: + # add the last frame + extracted_frames.append(frame) + + # if the container is opened in this function, close it + if isinstance(input_video, str): + container.close() + return extracted_frames + + +def extract_audio_from_video( + input_video: Union[str, av.container.InputContainer], + output_audio: str = None, + start_seconds: int = 0, + end_seconds: int = None, + stream_indexes: Union[int, List[int]] = None, +): + """ + Extract audio data for the given video. + + :param input_video: input video. Can be a video path or an + av.container.InputContainer. + :param output_audio: output audio path. If it's None, the audio data won't + be written to file. If stream_indexes is not None, it will output + multiple audio files with original filename and the stream indexes. + Default: None. + :param start_seconds: the start seconds to extract audio data. Default: 0, + which means extract from the start of the video. + :param end_seconds: the end seconds to stop extracting audio data. If it's + None, the extraction won't stop until the end of the video. Default: + None. + :param stream_indexes: there might be multiple audio streams in the video, + so we need to decide which audio streams with stream_indexes will be + extracted. It can be a single index or a list of indexes. If it's None, + all audio streams will be extracted. Default: None. + """ + if isinstance(input_video, str): + input_container = av.open(input_video) + elif isinstance(input_video, av.container.InputContainer): + input_container = input_video + else: + raise ValueError(f'Unsupported type of input_video. Should be one of ' + f'[str, av.container.InputContainer], but given ' + f'[{type(input_video)}].') + + if output_audio and not output_audio.endswith('mp3'): + raise ValueError(f'Now we only support export the audios into `mp3` ' + f'format, but given ' + f'[{os.path.splitext(output_audio)[1]}') + + # no audios in the video + num_audio_streams = len(input_container.streams.audio) + if stream_indexes is None: + valid_stream_indexes = list(range(num_audio_streams)) + elif isinstance(stream_indexes, int): + valid_stream_indexes = [stream_indexes] + else: + # remove indexes that are larger than the total number of audio streams + valid_stream_indexes = [ + idx for idx in stream_indexes if idx < num_audio_streams + ] + # no valid expected audio streams + if len(valid_stream_indexes) == 0: + return [], [], valid_stream_indexes + + audio_data_list = [] + audio_sampling_rate_list = [] + for idx in valid_stream_indexes: + # read the current audio stream + input_audio_stream = input_container.streams.audio[idx] + # get the sampling rate + audio_sampling_rate_list.append(float(1 / + input_audio_stream.time_base)) + + if output_audio: + # if the output_audio is not None, prepare the output audio file + this_output_audio = add_suffix_to_filename(output_audio, f'_{idx}') + output_container = av.open(this_output_audio, 'w') + output_stream = output_container.add_stream('mp3') + + # get the start/end pts + start_pts = int(start_seconds / input_audio_stream.time_base) + end_pts = (end_seconds / + input_audio_stream.time_base if end_seconds else None) + + audio_data = [] + for frame in input_container.decode(input_audio_stream): + if frame.pts is None or frame.dts is None: + continue + if frame.pts < start_pts: + continue + if end_pts and frame.pts > end_pts: + break + # get frame data + array = frame.to_ndarray()[0] + audio_data.append(array) + + if output_audio: + # compute the right pts when writing an audio file + frame.pts -= start_pts + frame.dts -= start_pts + for packet in output_stream.encode(frame): + output_container.mux(packet) + + # flush + if output_audio: + for packet in output_stream.encode(None): + output_container.mux(packet) + + if isinstance(input_video, str): + input_container.close() + if output_audio: + output_container.close() + audio_data_list.append(np.concatenate(audio_data)) + + return audio_data_list, audio_sampling_rate_list, valid_stream_indexes + + # Others def size_to_bytes(size): alphabets_list = [char for char in size if char.isalpha()] @@ -196,3 +642,18 @@ def insert_texts_after_placeholders(original_string, modified_string[index + len(placeholder):] return modified_string + + +def timecode_string_to_seconds(timecode: str): + """ + Convert a timecode string to the float seconds. + + :param timecode: the input timecode string. Must in "HH:MM:SS.fff(fff)" + format. + """ + # parse the timecode string + dt = datetime.datetime.strptime(timecode, '%H:%M:%S.%f') + + # compute the start/end time in second + pts = dt.hour * 3600 + dt.minute * 60 + dt.second + dt.microsecond / 1e6 + return pts diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index c92f7fb47..9e61e6089 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -1,6 +1,7 @@ import fnmatch import os from functools import partial +from typing import Optional, Union import multiprocess as mp import wget @@ -31,7 +32,7 @@ # sentence split model from nltk punkt 'punkt.*.pickle': 'https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/' - 'data_juicer/models/' + 'data_juicer/models/', } @@ -176,6 +177,183 @@ def prepare_nltk_model(lang, name_pattern='punkt.{}.pickle'): return nltk_model +def prepare_video_blip_model(pretrained_model_name_or_path, + return_model=True, + trust_remote_code=False): + """ + Prepare and load a video-clip model with the correspoding processor. + + :param pretrained_model_name_or_path: model name or path + :param return_model: return model or not + :param trust_remote_code: passed to transformers + :return: a tuple (model, input processor) if `return_model` is True; + otherwise, only the processor is returned. + """ + import torch + import torch.nn as nn + from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM, + Blip2Config, Blip2ForConditionalGeneration, + Blip2QFormerModel, Blip2VisionModel) + from transformers.modeling_outputs import BaseModelOutputWithPooling + + class VideoBlipVisionModel(Blip2VisionModel): + """A simple, augmented version of Blip2VisionModel to handle + videos.""" + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutputWithPooling]: + """Flatten `pixel_values` along the batch and time dimension, + pass it through the original vision model, + then unflatten it back. + + :param pixel_values: a tensor of shape + (batch, channel, time, height, width) + + :returns: + last_hidden_state: a tensor of shape + (batch, time * seq_len, hidden_size) + pooler_output: a tensor of shape + (batch, time, hidden_size) + hidden_states: + a tuple of tensors of shape + (batch, time * seq_len, hidden_size), + one for the output of the embeddings + + one for each layer + attentions: + a tuple of tensors of shape + (batch, time, num_heads, seq_len, seq_len), + one for each layer + """ + if pixel_values is None: + raise ValueError('You have to specify pixel_values') + + batch, _, time, _, _ = pixel_values.size() + + # flatten along the batch and time dimension to create a + # tensor of shape + # (batch * time, channel, height, width) + flat_pixel_values = pixel_values.permute(0, 2, 1, 3, + 4).flatten(end_dim=1) + + vision_outputs: BaseModelOutputWithPooling = super().forward( + pixel_values=flat_pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + # now restore the original dimensions + # vision_outputs.last_hidden_state is of shape + # (batch * time, seq_len, hidden_size) + seq_len = vision_outputs.last_hidden_state.size(1) + last_hidden_state = vision_outputs.last_hidden_state.view( + batch, time * seq_len, -1) + # vision_outputs.pooler_output is of shape + # (batch * time, hidden_size) + pooler_output = vision_outputs.pooler_output.view(batch, time, -1) + # hidden_states is a tuple of tensors of shape + # (batch * time, seq_len, hidden_size) + hidden_states = (tuple( + hidden.view(batch, time * seq_len, -1) + for hidden in vision_outputs.hidden_states) + if vision_outputs.hidden_states is not None else + None) + # attentions is a tuple of tensors of shape + # (batch * time, num_heads, seq_len, seq_len) + attentions = (tuple( + hidden.view(batch, time, -1, seq_len, seq_len) + for hidden in vision_outputs.attentions) + if vision_outputs.attentions is not None else None) + if return_dict: + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooler_output, + hidden_states=hidden_states, + attentions=attentions, + ) + return (last_hidden_state, pooler_output, hidden_states, + attentions) + + class VideoBlipForConditionalGeneration(Blip2ForConditionalGeneration): + + def __init__(self, config: Blip2Config) -> None: + # HACK: we call the grandparent super().__init__() to bypass + # Blip2ForConditionalGeneration.__init__() so we can replace + # self.vision_model + super(Blip2ForConditionalGeneration, self).__init__(config) + + self.vision_model = VideoBlipVisionModel(config.vision_config) + + self.query_tokens = nn.Parameter( + torch.zeros(1, config.num_query_tokens, + config.qformer_config.hidden_size)) + self.qformer = Blip2QFormerModel(config.qformer_config) + + self.language_projection = nn.Linear( + config.qformer_config.hidden_size, + config.text_config.hidden_size) + if config.use_decoder_only_language_model: + language_model = AutoModelForCausalLM.from_config( + config.text_config) + else: + language_model = AutoModelForSeq2SeqLM.from_config( + config.text_config) + self.language_model = language_model + + # Initialize weights and apply final processing + self.post_init() + + from transformers import AutoProcessor + processor = AutoProcessor.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code) + if return_model: + model_class = VideoBlipForConditionalGeneration + model = model_class.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code) + return (model, processor) if return_model else processor + + +def prepare_simple_aesthetics_model(pretrained_model_name_or_path, + return_model=True): + """ + Prepare and load a simple aesthetics model. + + :param pretrained_model_name_or_path: model name or path + :param return_model: return model or not + :return: a tuple (model, input processor) if `return_model` is True; + otherwise, only the processor is returned. + """ + from aesthetics_predictor import (AestheticsPredictorV1, + AestheticsPredictorV2Linear, + AestheticsPredictorV2ReLU) + from transformers import CLIPProcessor + + processor = CLIPProcessor.from_pretrained(pretrained_model_name_or_path) + if not return_model: + return processor + else: + if 'v1' in pretrained_model_name_or_path: + model = AestheticsPredictorV1.from_pretrained( + pretrained_model_name_or_path) + elif ('v2' in pretrained_model_name_or_path + and 'linear' in pretrained_model_name_or_path): + model = AestheticsPredictorV2Linear.from_pretrained( + pretrained_model_name_or_path) + elif ('v2' in pretrained_model_name_or_path + and 'relu' in pretrained_model_name_or_path): + model = AestheticsPredictorV2ReLU.from_pretrained( + pretrained_model_name_or_path) + else: + raise ValueError( + 'Not support {}'.format(pretrained_model_name_or_path)) + return (model, processor) + + def prepare_huggingface_model(pretrained_model_name_or_path, return_model=True, trust_remote_code=False): @@ -189,34 +367,26 @@ def prepare_huggingface_model(pretrained_model_name_or_path, otherwise, only the processor is returned. """ import transformers - from transformers import (AutoConfig, AutoImageProcessor, AutoProcessor, - AutoTokenizer) - from transformers.models.auto.image_processing_auto import \ - IMAGE_PROCESSOR_MAPPING_NAMES - from transformers.models.auto.processing_auto import \ - PROCESSOR_MAPPING_NAMES - from transformers.models.auto.tokenization_auto import \ - TOKENIZER_MAPPING_NAMES - - config = AutoConfig.from_pretrained(pretrained_model_name_or_path) - # TODO: What happens when there are more than one? - arch = config.architectures[0] - model_class = getattr(transformers, arch) - model_type = config.model_type - if model_type in PROCESSOR_MAPPING_NAMES: - processor = AutoProcessor.from_pretrained( - pretrained_model_name_or_path, trust_remote_code=trust_remote_code) - elif model_type in IMAGE_PROCESSOR_MAPPING_NAMES: - processor = AutoImageProcessor.from_pretrained( + from transformers import AutoConfig, AutoProcessor + + processor = AutoProcessor.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code) + + if return_model: + config = AutoConfig.from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code) - elif model_type in TOKENIZER_MAPPING_NAMES: - processor = AutoTokenizer.from_pretrained( + if hasattr(config, 'auto_map'): + class_name = next( + (k for k in config.auto_map if k.startswith('AutoModel')), + 'AutoModel') + else: + # TODO: What happens if more than one + class_name = config.architectures[0] + + model_class = getattr(transformers, class_name) + model = model_class.from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code) - else: - processor = None - if return_model: - model = model_class.from_pretrained(pretrained_model_name_or_path) return (model, processor) if return_model else processor @@ -308,14 +478,41 @@ def prepare_diffusion_model(pretrained_model_name_or_path, return model +def prepare_recognizeAnything_model( + pretrained_model_name_or_path='ram_plus_swin_large_14m.pth', + input_size=384): + """ + Prepare and load recognizeAnything model. + + :param model_name: input model name. + :param input_size: the input size of the model. + """ + from ram.models import ram_plus + logger.info('Loading recognizeAnything model...') + try: + model = ram_plus(pretrained=check_model(pretrained_model_name_or_path), + image_size=input_size, + vit='swin_l') + except: # noqa: E722 + model = ram_plus(pretrained=check_model(pretrained_model_name_or_path, + force=True), + image_size=input_size, + vit='swin_l') + model.eval() + return model + + MODEL_FUNCTION_MAPPING = { 'fasttext': prepare_fasttext_model, 'sentencepiece': prepare_sentencepiece_model, 'kenlm': prepare_kenlm_model, 'nltk': prepare_nltk_model, 'huggingface': prepare_huggingface_model, + 'simple_aesthetics': prepare_simple_aesthetics_model, 'spacy': prepare_spacy_model, - 'diffusion': prepare_diffusion_model + 'diffusion': prepare_diffusion_model, + 'video_blip': prepare_video_blip_model, + 'recognizeAnything': prepare_recognizeAnything_model } @@ -342,9 +539,6 @@ def move_to_cuda(model, rank): logger.info( f'Moving {module.__class__.__name__} to CUDA device {rank}') module.to(f'cuda:{rank}') - # Optionally, verify the device assignment - logger.debug( - f'{module.__class__.__name__} is on device {module.device}') def get_model(model_key=None, rank=None): diff --git a/data_juicer/utils/unittest_utils.py b/data_juicer/utils/unittest_utils.py index 8b912c494..b9d18dbf1 100644 --- a/data_juicer/utils/unittest_utils.py +++ b/data_juicer/utils/unittest_utils.py @@ -9,6 +9,12 @@ class DataJuicerTestCaseBase(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Set maxDiff for all test cases based on an environment variable + max_diff = os.getenv('TEST_MAX_DIFF', 'None') + cls.maxDiff = None if max_diff == 'None' else int(max_diff) + @classmethod def tearDownClass(cls, hf_model_name=None) -> None: # clean the huggingface model cache files diff --git a/demos/data_visualization_op_insight/app.css b/demos/data_visualization_op_insight/app.css new file mode 100644 index 000000000..e3c9c2d85 --- /dev/null +++ b/demos/data_visualization_op_insight/app.css @@ -0,0 +1,171 @@ +/* code highlight: https://python-markdown.github.io/extensions/code_hilite/ */ +.codehilite .hll { background-color: #ffffcc } +.codehilite { background: #f8f8f8; } +.codehilite .c { color: #408080; font-style: italic } /* Comment */ +.codehilite .err { border: 1px solid #FF0000 } /* Error */ +.codehilite .k { color: #008000; font-weight: bold } /* Keyword */ +.codehilite .o { color: #666666 } /* Operator */ +.codehilite .ch { color: #408080; font-style: italic } /* Comment.Hashbang */ +.codehilite .cm { color: #408080; font-style: italic } /* Comment.Multiline */ +.codehilite .cp { color: #BC7A00 } /* Comment.Preproc */ +.codehilite .cpf { color: #408080; font-style: italic } /* Comment.PreprocFile */ +.codehilite .c1 { color: #408080; font-style: italic } /* Comment.Single */ +.codehilite .cs { color: #408080; font-style: italic } /* Comment.Special */ +.codehilite .gd { color: #A00000 } /* Generic.Deleted */ +.codehilite .ge { font-style: italic } /* Generic.Emph */ +.codehilite .gr { color: #FF0000 } /* Generic.Error */ +.codehilite .gh { color: #000080; font-weight: bold } /* Generic.Heading */ +.codehilite .gi { color: #00A000 } /* Generic.Inserted */ +.codehilite .go { color: #888888 } /* Generic.Output */ +.codehilite .gp { color: #000080; font-weight: bold } /* Generic.Prompt */ +.codehilite .gs { font-weight: bold } /* Generic.Strong */ +.codehilite .gu { color: #800080; font-weight: bold } /* Generic.Subheading */ +.codehilite .gt { color: #0044DD } /* Generic.Traceback */ +.codehilite .kc { color: #008000; font-weight: bold } /* Keyword.Constant */ +.codehilite .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */ +.codehilite .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */ +.codehilite .kp { color: #008000 } /* Keyword.Pseudo */ +.codehilite .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */ +.codehilite .kt { color: #B00040 } /* Keyword.Type */ +.codehilite .m { color: #666666 } /* Literal.Number */ +.codehilite .s { color: #BA2121 } /* Literal.String */ +.codehilite .na { color: #7D9029 } /* Name.Attribute */ +.codehilite .nb { color: #008000 } /* Name.Builtin */ +.codehilite .nc { color: #0000FF; font-weight: bold } /* Name.Class */ +.codehilite .no { color: #880000 } /* Name.Constant */ +.codehilite .nd { color: #AA22FF } /* Name.Decorator */ +.codehilite .ni { color: #999999; font-weight: bold } /* Name.Entity */ +.codehilite .ne { color: #D2413A; font-weight: bold } /* Name.Exception */ +.codehilite .nf { color: #0000FF } /* Name.Function */ +.codehilite .nl { color: #A0A000 } /* Name.Label */ +.codehilite .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */ +.codehilite .nt { color: #008000; font-weight: bold } /* Name.Tag */ +.codehilite .nv { color: #19177C } /* Name.Variable */ +.codehilite .ow { color: #AA22FF; font-weight: bold } /* Operator.Word */ +.codehilite .w { color: #bbbbbb } /* Text.Whitespace */ +.codehilite .mb { color: #666666 } /* Literal.Number.Bin */ +.codehilite .mf { color: #666666 } /* Literal.Number.Float */ +.codehilite .mh { color: #666666 } /* Literal.Number.Hex */ +.codehilite .mi { color: #666666 } /* Literal.Number.Integer */ +.codehilite .mo { color: #666666 } /* Literal.Number.Oct */ +.codehilite .sa { color: #BA2121 } /* Literal.String.Affix */ +.codehilite .sb { color: #BA2121 } /* Literal.String.Backtick */ +.codehilite .sc { color: #BA2121 } /* Literal.String.Char */ +.codehilite .dl { color: #BA2121 } /* Literal.String.Delimiter */ +.codehilite .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */ +.codehilite .s2 { color: #BA2121 } /* Literal.String.Double */ +.codehilite .se { color: #BB6622; font-weight: bold } /* Literal.String.Escape */ +.codehilite .sh { color: #BA2121 } /* Literal.String.Heredoc */ +.codehilite .si { color: #BB6688; font-weight: bold } /* Literal.String.Interpol */ +.codehilite .sx { color: #008000 } /* Literal.String.Other */ +.codehilite .sr { color: #BB6688 } /* Literal.String.Regex */ +.codehilite .s1 { color: #BA2121 } /* Literal.String.Single */ +.codehilite .ss { color: #19177C } /* Literal.String.Symbol */ +.codehilite .bp { color: #008000 } /* Name.Builtin.Pseudo */ +.codehilite .fm { color: #0000FF } /* Name.Function.Magic */ +.codehilite .vc { color: #19177C } /* Name.Variable.Class */ +.codehilite .vg { color: #19177C } /* Name.Variable.Global */ +.codehilite .vi { color: #19177C } /* Name.Variable.Instance */ +.codehilite .vm { color: #19177C } /* Name.Variable.Magic */ +.codehilite .il { color: #666666 } /* Literal.Number.Integer.Long */ + + +.project_cover { + display: flex; + flex-direction: column; + justify-content: center; + align-items: center; + min-height: 650px; + border: 1px solid rgba(229, 231, 235, 0.6); /* 在边框中添加一点透明度 */ + border-radius: 16px; /* 增加边框圆角 */ + padding: 40px; /* 增加内部间距 */ + background-color: #ffffff; /* 添加背景颜色 */ + box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); /* 添加轻微的阴影效果 */ +} + +.project_img { + overflow: hidden; + position: center; + display: flex; + justify-content: center; + align-items: center; + margin-bottom: auto; + /* box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15), 0 6px 20px rgba(0, 0, 0, 0.5); */ +} + +.project_img img { + width: 80%; + height: 80%; +} + +.project_label { + font-size: 18px; /* 标题字体大小 */ + color: #333; /* 字体颜色,这里使用深灰色 */ + font-weight: bold; /* 字体加粗 */ + text-shadow: 1px 1px 2px rgba(0, 0, 0, 0.1); /* 文字阴影 */ + transition: all 0.3s ease; /* 平滑过渡动画 */ + padding: 10px; /* 内填充 */ + margin-bottom: 20px; /* 底部外边距 */ + border-bottom: 2px solid #ddd; /* 底部边框样式 */ +} + +.project_name { + font-size: 30px; /* 调整字体大小 */ + color: #333333; /* 字体颜色更深,增加对比度 */ + margin-top: 20px; /* 调整名称顶部的间距 */ + /* font-weight: bold; 字体加粗 */ + /* text-transform: uppercase; 文字大写 */ + align-items: center; + justify-content: center; + text-align: center; /* 文字居中 */ + letter-spacing: 1.5px; /* 增加字母间距 */ + transition: all 0.3s ease; /* 平滑过渡动画 */ +} + +.project_desc { + color: #444444; /* 字体颜色更深 */ + font-size: 18px; /* 增加字体大小 */ + margin: 20px 0; /* 增加上下间距 */ + text-align: center; /* 文字居中 */ + line-height: 1.5; /* 增加行高,提升可读性 */ + transition: all 0.3s ease; /* 平滑过渡动画 */ +} + +.markdown-body .message { + white-space: pre-wrap; +} + +.markdown-body details { + white-space: nowrap; +} +.markdown-body .bot details:not(:last-child) { + margin-bottom: 1px; +} +.markdown-body summary { + background-color: #4b5563; + color: #eee; + padding: 0 4px; + border-radius: 4px; + font-size: 0.9em; +} + + +.project_intro { + display: grid; + place-items: center; /* 完美居中 */ + height: 100px; /* 高度 */ + width: + font-size: 15px; /* 正文字体大小 */ + /* text-align: center; 文字居中 */ + color: #555; /* 正文字体颜色,这里使用较浅的灰色 */ + border-radius: 8px; /* 边框圆角 */ + transition: transform 0.3s ease; /* 平滑过渡动画 */ +} + +/* 鼠标悬停时的动画效果 */ +.project_desc:hover, +.project_name:hover, +.project_label:hover, +.project_intro:hover { + transform: translateY(-5px); /* 向上移动 */ +} \ No newline at end of file diff --git a/demos/data_visualization_op_insight/app.py b/demos/data_visualization_op_insight/app.py new file mode 100644 index 000000000..d2df57c0d --- /dev/null +++ b/demos/data_visualization_op_insight/app.py @@ -0,0 +1,351 @@ +import os +import inspect +import base64 +import yaml +import copy +import shutil +import gradio as gr +from data_juicer.ops.base_op import OPERATORS +from data_juicer.utils.constant import Fields +demo_path = os.path.dirname(os.path.abspath(__file__)) +project_path = os.path.dirname(os.path.dirname(demo_path)) + + +# 图片本地路径转换为 base64 格式 +def covert_image_to_base64(image_path): + # 获得文件后缀名 + ext = image_path.split(".")[-1] + if ext not in ["gif", "jpeg", "png"]: + ext = "jpeg" + + with open(image_path, "rb") as image_file: + # Read the file + encoded_string = base64.b64encode(image_file.read()) + + # Convert bytes to string + base64_data = encoded_string.decode("utf-8") + + # 生成base64编码的地址 + base64_url = f"data:image/{ext};base64,{base64_data}" + return base64_url + + +def format_cover_html(project_img_path): + readme_link = 'https://github.com/alibaba/data-juicer' + config = { + 'name': "Data-Juicer", + 'label': "Op Insight", + 'description': f'A One-Stop Data Processing System for Large Language Models, see more details in
GitHub', + 'introduction': + "This project is being actively updated and maintained, and we will periodically enhance and add more features and data recipes.
" + "We welcome you to join us in promoting LLM data development and research!
", + 'demo':"You can experience the effect of the operators of Data-Juicer" + } + # image_src = covert_image_to_base64(project_img_path) + #
+ #
+ return f""" +
+
{config.get("name", "")}
+
{config.get("description", "")}
+
{config.get("introduction", "")}
+
{config.get("demo", "")}
+
+""" +op_text = '' +docs_file = os.path.join(project_path, 'docs/Operators.md') +if os.path.exists(docs_file): + with open(os.path.join(project_path, 'docs/Operators.md'), 'r') as f: + op_text = f.read() + +def extract_op_desc(markdown_text, header): + start_index = markdown_text.find(header) + end_index = markdown_text.find("\n##", start_index + len(header)) + return markdown_text[start_index+ len(header):end_index].strip() + +op_desc = f"
{extract_op_desc(op_text, '## Overview').split('All the specific ')[0].strip()}
" +op_list_desc = { + 'mapper':extract_op_desc(op_text, '## Mapper '), + 'filter':extract_op_desc(op_text, '## Filter '), + 'deduplicator':extract_op_desc(op_text, '## Deduplicator '), + 'selector':extract_op_desc(op_text, '## Selector '), +} + +op_types = ['mapper', 'filter',]# 'deduplicator'] , 'selector'] +local_ops_dict = {op_type:[] for op_type in op_types} +multimodal = os.getenv('MULTI_MODAL', False) +multimodal = True +text_key = 'text' +image_key = 'images' +audio_key = 'audios' +video_key = 'videos' +def get_op_lists(op_type): + use_local_op = os.getenv('USE_LOCAL_OP', False) + if not use_local_op: + all_ops = list(OPERATORS.modules.keys()) + options = [ + name for name in all_ops if name.endswith(op_type) + ] + else: + options = local_ops_dict.get(op_type, []) + + for exclude in ['image', 'video', 'audio']: + options = [name for name in options if multimodal or exclude not in name] + return options + +def show_code(op_name): + op_class = OPERATORS.modules[op_name] + text = inspect.getsourcelines(op_class) + return ''.join(text[0]) + +def decode_sample(output_sample): + output_text = output_sample[text_key] + output_image = output_sample[image_key][0] if output_sample[image_key] else None + output_video = output_sample[video_key][0] if output_sample[video_key] else None + output_audio = output_sample[audio_key][0] if output_sample[audio_key] else None + def copy_func(file): + filename = None + if file: + filename= os.path.basename(file) + shutil.copyfile(file, filename) + return filename + + image_file = copy_func(output_image) + video_file = copy_func(output_video) + audio_file = copy_func(output_audio) + return output_text, image_file, video_file, audio_file + +def create_mapper_tab(op_type, op_tab): + with op_tab: + options = get_op_lists(op_type) + label = f'Select a {op_type} to show details' + with gr.Row(): + op_selector = gr.Dropdown(value=options[0], label=label, choices=options, interactive=True) + run_button = gr.Button(value="🚀Run") + show_code_button = gr.Button(value="🔍Show Code") + gr.Markdown(" **Op Parameters**") + op_params = gr.Code(label="Yaml",language='yaml', interactive=True) + with gr.Column(): + with gr.Group('Inputs'): + gr.Markdown(" **Inputs**") + with gr.Row(): + # img = '/private/var/folders/7b/p5l9gykj1k7_tylkvwjv_sl00000gp/T/gradio/f24972121fd4d4f95f42f1cd70f859bb03839e76/image_blur_mapper/喜欢的书__dj_hash_#14a7b2e1b96410fbe63ea16a70422180db53d644661630938b2773d8efa18dde#.png' + + input_text = gr.TextArea(label="Text",interactive=True,) + input_image = gr.Image(label='Image', type='filepath', visible=multimodal) + input_video = gr.Video(label='Video', visible=multimodal) + input_audio = gr.Audio(label='Audio', type='filepath', visible=multimodal) + with gr.Group('Outputs'): + gr.Markdown(" **Outputs**") + with gr.Row(): + output_text = gr.TextArea(label="Text",interactive=False,) + output_image = gr.Image(label='Image', visible=multimodal) + output_video = gr.Video(label='Video', visible=multimodal) + output_audio = gr.Audio(label='Audio', visible=multimodal) + code = gr.Code(label='Source', language='python') + def run_op(op_name, op_params, input_text, input_image, input_video, input_audio): + op_class = OPERATORS.modules[op_name] + try: + params = yaml.safe_load(op_params) + except: + params = {} + if params is None: + params = {} + op = op_class(**params) + sample = dict() + + sample[text_key] = input_text + sample[image_key] = [input_image] + sample[video_key] = [input_video] + sample[audio_key] = [input_audio] + + output_sample = op.process(copy.deepcopy(sample)) + + return decode_sample(output_sample) + + inputs = [op_selector, op_params, input_text, input_image, input_video, input_audio] + outputs = [output_text, output_image, output_video, output_audio] + run_button.click(run_op, inputs=inputs, outputs=outputs) + show_code_button.click(show_code, inputs=[op_selector], outputs=[code]) + +def create_filter_tab(op_type, op_tab): + with op_tab: + + options = get_op_lists(op_type) + label = f'Select a {op_type} to show details' + with gr.Row(): + op_selector = gr.Dropdown(value=options[0], label=label, choices=options, interactive=True) + run_button = gr.Button(value="🚀Run") + show_code_button = gr.Button(value="🔍Show Code") + gr.Markdown(" **Op Parameters**") + op_params = gr.Code(label="Yaml",language='yaml', interactive=True) + with gr.Column(): + with gr.Group('Inputs'): + gr.Markdown(" **Inputs**") + with gr.Row(): + input_text = gr.TextArea(label="Text",interactive=True,) + input_image = gr.Image(label='Image', type='filepath', visible=multimodal) + input_video = gr.Video(label='Video', visible=multimodal) + input_audio = gr.Audio(label='Audio', type='filepath', visible=multimodal) + input_stats = gr.Json(label='Stats') + + with gr.Group('Outputs'): + gr.Markdown(" **Outputs**") + with gr.Row(): + output_text = gr.TextArea(label="Text",interactive=False,) + output_image = gr.Image(label='Image', type='filepath', visible=multimodal) + output_video = gr.Video(label='Video', visible=multimodal) + output_audio = gr.Audio(label='Audio', type='filepath', visible=multimodal) + output_stats = gr.Json(label='Stats') + + code = gr.Code(label='Source', language='python') + def run_op(op_name, op_params, input_text, input_image, input_video, input_audio): + op_class = OPERATORS.modules[op_name] + try: + params = yaml.safe_load(op_params) + except: + params = {} + if params is None: + params = {} + op = op_class(**params) + sample = dict() + sample[Fields.stats] = dict() + sample[text_key] = input_text + sample[image_key] = [input_image] + sample[video_key] = [input_video] + sample[audio_key] = [input_audio] + input_stats = sample[Fields.stats] + output_sample = op.compute_stats(copy.deepcopy(sample)) + output_stats = output_sample[Fields.stats] + return *decode_sample(output_sample), input_stats, output_stats + + inputs = [op_selector, op_params, input_text, input_image, input_video, input_audio] + outputs = [output_text, output_image, output_video, output_audio, input_stats, output_stats] + run_button.click(run_op, inputs=inputs, outputs=outputs) + show_code_button.click(show_code, inputs=[op_selector], outputs=[code]) + +def create_deduplicator_tab(op_type, op_tab): + with op_tab: + options = get_op_lists(op_type) + label = f'Select a {op_type} to show details' + with gr.Row(): + op_selector = gr.Dropdown(value=options[0], label=label, choices=options, interactive=True) + run_button = gr.Button(value="🚀Run") + show_code_button = gr.Button(value="🔍Show Code") + gr.Markdown(" **Op Parameters**") + op_params = gr.Code(label="Yaml",language='yaml', interactive=True) + with gr.Column(): + with gr.Group('Inputs'): + gr.Markdown(" **Inputs**") + with gr.Row(): + input_text = gr.TextArea(label="Text",interactive=True,) + input_image = gr.Image(label='Image', type='filepath', visible=multimodal) + input_video = gr.Video(label='Video', visible=multimodal) + input_audio = gr.Audio(label='Audio', type='filepath', visible=multimodal) + + with gr.Group('Outputs'): + gr.Markdown(" **Outputs**") + with gr.Row(): + output_text = gr.TextArea(label="Text",interactive=False,) + output_image = gr.Image(label='Image', type='filepath', visible=multimodal) + output_video = gr.Video(label='Video', visible=multimodal) + output_audio = gr.Audio(label='Audio', type='filepath', visible=multimodal) + + code = gr.Code(label='Source', language='python') + def run_op(op_name, op_params, input_text, input_images, input_video, input_audio): + op_class = OPERATORS.modules[op_name] + try: + params = yaml.safe_load(op_params) + except: + params = {} + if params is None: + params = {} + op = op_class(**params) + sample = dict() + sample[text_key] = input_text + sample[image_key] = input_images + sample[video_key] = [input_video] + sample[audio_key] = [input_audio] + + output_sample = sample #op.compute_hash(copy.deepcopy(sample)) + return decode_sample(output_sample) + + inputs = [op_selector, op_params, input_text, input_image, input_video, input_audio] + outputs = [output_text, output_image, output_video, output_audio] + run_button.click(run_op, inputs=inputs, outputs=outputs) + show_code_button.click(show_code, inputs=[op_selector], outputs=[code]) + +def create_selector_tab(op_type, op_tab): + with op_tab: + options = get_op_lists(op_type) + label = f'Select a {op_type} to show details' + with gr.Row(): + op_selector = gr.Dropdown(value=options[0], label=label, choices=options, interactive=True) + run_button = gr.Button(value="🚀Run") + show_code_button = gr.Button(value="🔍Show Code") + gr.Markdown(" **Op Parameters**") + op_params = gr.Code(label="Yaml",language='yaml', interactive=True) + with gr.Column(): + with gr.Group('Inputs'): + gr.Markdown(" **Inputs**") + with gr.Row(): + input_text = gr.TextArea(label="Text",interactive=True,) + input_image = gr.Image(label='Image', type='filepath', visible=multimodal) + input_video = gr.Video(label='Video', visible=multimodal) + input_audio = gr.Audio(label='Audio', type='filepath', visible=multimodal) + input_stats = gr.Json(label='Stats') + + with gr.Group('Outputs'): + gr.Markdown(" **Outputs**") + with gr.Row(): + output_text = gr.TextArea(label="Text",interactive=False,) + output_image = gr.Image(label='Image', type='filepath', visible=multimodal) + output_video = gr.Video(label='Video', visible=multimodal) + output_audio = gr.Audio(label='Audio', type='filepath', visible=multimodal) + output_stats = gr.Json(label='Stats') + + code = gr.Code(label='Source', language='python') + def run_op(op_name, op_params, input_text, input_image, input_video, input_audio): + op_class = OPERATORS.modules[op_name] + try: + params = yaml.safe_load(op_params) + except: + params = {} + if params is None: + params = {} + op = op_class(**params) + sample = dict() + sample[Fields.stats] = dict() + sample[text_key] = input_text + sample[image_key] = [input_image] + sample[video_key] = [input_video] + sample[audio_key] = [input_audio] + input_stats = sample[Fields.stats] + output_sample = op.compute_stats(copy.deepcopy(sample)) + output_stats = output_sample[Fields.stats] + + return *decode_sample(output_sample), input_stats, output_stats + + inputs = [op_selector, op_params, input_text, input_image, input_video, input_audio] + outputs = [output_text, output_image, output_video, output_audio, input_stats, output_stats] + run_button.click(run_op, inputs=inputs, outputs=outputs) + show_code_button.click(show_code, inputs=[op_selector], outputs=[code]) + +with gr.Blocks(css="./app.css") as demo: + + dj_image = os.path.join(project_path, 'docs/imgs/data-juicer.jpg') + gr.HTML(format_cover_html(dj_image)) + + with gr.Accordion(label='Op Insight',open=True): + tabs = gr.Tabs() + + with tabs: + op_tabs = {op_type: gr.Tab(label=op_type.capitalize() + 's') for op_type in op_types} + for op_type, op_tab in op_tabs.items(): + create_op_tab_func = globals().get(f'create_{op_type}_tab', None) + if callable(create_op_tab_func): + create_op_tab_func(op_type, op_tab) + else: + gr.Error(f'{op_type} not callable') + + demo.launch() diff --git a/demos/process_on_ray/configs/demo.yaml b/demos/process_on_ray/configs/demo.yaml index 0fefc8d39..1e3e4a55a 100644 --- a/demos/process_on_ray/configs/demo.yaml +++ b/demos/process_on_ray/configs/demo.yaml @@ -3,7 +3,7 @@ # global parameters project_name: 'ray-demo' executor_type: 'ray' -dataset_path: './demos/process_on_ray/data/demo-dataset.json' # path to your dataset directory or file +dataset_path: './demos/process_on_ray/data/demo-dataset.jsonl' # path to your dataset directory or file ray_address: 'auto' # change to your ray cluster address, e.g., ray://: export_path: './outputs/demo/demo-processed' diff --git a/demos/process_on_ray/data/demo-dataset.jsonl b/demos/process_on_ray/data/demo-dataset.jsonl new file mode 100644 index 000000000..a212d42f4 --- /dev/null +++ b/demos/process_on_ray/data/demo-dataset.jsonl @@ -0,0 +1,11 @@ +{"text":"What’s one thing you wish everyone knew about the brain?\nibble\nWhat’s one thing you wish everyone knew about the brain?\nThe place to have real conversations and understand each other better. Join a community or build and grow your own with groups, threads, and conversations.\nSee this content immediately after install\nGet The App\n"} +{"text":"JavaScript must be enabled to use the system\n"} +{"text":"中国企业又建成一座海外三峡工程!-科技-高清完整正版视频在线观看-优酷\n"} +{"text":"Skip to content\nPOLIDEPORTES\nPeriodismo especialzado en deportes\nPrimary Menu\nPOLIDEPORTES\nPolideportes\n¿Quiénes somos?\nNoticia\nEntrevistas\nReportaje\nEquipos de Época\nOpinión\nEspeciales\nCopa Poli\nBuscar:\nSteven Villegas Ceballos patinador\nShare this...\nFacebook\nTwitter\nLinkedin\nWhatsapp\nEmail\nSeguir leyendo\nAnterior El imparable campeón Steven Villegas\nTe pueden interesar\nDeportes\nNoticia\nPiezas filatélicas llegan al Museo Olímpico Colombiano\nmarzo 17, 2023"} +{"text":"Redirect Notice\nRedirect Notice\nThe previous page is sending you to http:\/\/sieuthikhoavantay.vn\/chi-tiet\/khoa-van-tay-dessmann-s710fp-duc.\nIf you do not want to visit that page, you can return to the previous page.\n"} +{"text": "Do you need a cup of coffee?"} +{"text": ".cv域名是因特网域名管理机构ICANN为佛得角共和国(The Republic of Cape Verde República de Cabo Verde)国家及地区分配的顶级域(ccTLD),作为其国家及地区因特网顶级域名。- 奇典网络\n专业的互联网服务提供商 登录 注册 控制中心 新闻中心 客户支持 交费方式 联系我们\n首页\n手机AI建站\n建站\n推广\n域名\n主机\n安全\n企业服务\n加盟\nICANN与CNNIC双认证顶级注册商 在中国,奇典网络是域名服务提供商\n.cv\n.cv域名是ICANN为佛得角共和国国家及地区分配的顶级域名,注册期限1年到10年不等。\n价格: 845 元\/1年\n注册要求: 无要求\n.cv\/.com.cv注册要求\n更多国别域名\n更多NewG域名\n相关资质\n1.什么是 .cv\/.com.cv域名?有什么优势?\n.cv域名是因特网域名管理机构ICANN为佛得角共和国(The Republic of Cape Verde República de Cabo Verde)国家及地区分配的顶级域(ccTLD),作为其国家及地区因特网顶级域名。\n2.cv\/.com.cv域名长度为多少?有什么注册规则?"} +{"text": "Sur la plateforme MT4, plusieurs manières d'accéder à ces fonctionnalités sont conçues simultanément."} +{"text": "欢迎来到阿里巴巴!"} +{"text": "This paper proposed a novel method on LLM pretraining."} +{"text":"世界十大网投平台_2022年卡塔尔世界杯官网\n177-8228-4819\n网站首页\n关于我们\n产品展示\n广告牌制作 广告灯箱制作 标识牌制作 楼宇亮化工程 门头店招制作 不锈钢金属字制作 LED发光字制作 形象墙Logo墙背景墙制作 LED显示屏制作 装饰装潢工程 铜字铜牌制作 户外广告 亚克力制品 各类广告设计 建筑工地广告制作 楼顶大字制作|楼顶发光字制作 霓虹灯制作 三维扣板|3D扣板|广告扣板 房地产广告制作设计 精神堡垒|立牌|指示牌制作 大型商业喷绘写真 展览展示 印刷服务\n合作伙伴\n新闻资讯\n公司新闻 行业新闻 制作知识 设计知识\n成功案例\n技术园地\n联系方式\n"} diff --git a/demos/process_video_on_ray/configs/demo.yaml b/demos/process_video_on_ray/configs/demo.yaml new file mode 100644 index 000000000..27236c08a --- /dev/null +++ b/demos/process_video_on_ray/configs/demo.yaml @@ -0,0 +1,39 @@ +# Process config example for dataset + +# global parameters +project_name: 'ray-demo' +executor_type: 'ray' +dataset_path: './demos/process_video_on_ray/data/demo-dataset.jsonl' # path to your dataset directory or file +ray_address: 'auto' # change to your ray cluster address, e.g., ray://: +export_path: './outputs/demo/demo-processed-ray-videos' + +# process schedule +# a list of several process operators with their arguments + +# single node passed, multi node still under develop +process: + # Filter ops + - video_duration_filter: + min_duration: 20 + max_duration: 60 + # Mapper ops + - video_split_by_duration_mapper: # Mapper to split video by duration. + split_duration: 10 # duration of each video split in seconds. + min_last_split_duration: 0 # the minimum allowable duration in seconds for the last video split. If the duration of the last split is less than this value, it will be discarded. + keep_original_sample: true + - video_resize_aspect_ratio_mapper: + min_ratio: 1 + max_ratio: 1.1 + strategy: increase + - video_split_by_key_frame_mapper: # Mapper to split video by key frame. + keep_original_sample: true # whether to keep the original sample. If it's set to False, there will be only cut sample in the final datasets and the original sample will be removed. It's True in default + - video_split_by_duration_mapper: # Mapper to split video by duration. + split_duration: 10 # duration of each video split in seconds. + min_last_split_duration: 0 # the minimum allowable duration in seconds for the last video split. If the duration of the last split is less than this value, it will be discarded. + keep_original_sample: true + - video_resolution_filter: # filter samples according to the resolution of videos in them + min_width: 1280 # the min resolution of horizontal resolution filter range (unit p) + max_width: 4096 # the max resolution of horizontal resolution filter range (unit p) + min_height: 480 # the min resolution of vertical resolution filter range (unit p) + max_height: 1080 # the max resolution of vertical resolution filter range (unit p) + any_or_all: any diff --git a/demos/process_video_on_ray/data/Note.md b/demos/process_video_on_ray/data/Note.md new file mode 100644 index 000000000..bf3dfece3 --- /dev/null +++ b/demos/process_video_on_ray/data/Note.md @@ -0,0 +1,7 @@ +# Note for dataset path + +The videos/images path here support both absolute path and relative path. +Please use an address that can be accessed on all nodes (such as an address within a NAS file-sharing system). +For relative paths, these should be relative to the directory where the dataset file is located (the dataset_path parameter in the config). + - if the dataset_path parameter is a directory, then it's relative to dataset_path + - if the dataset_path parameter is a file, then it's relative to data_path parameter's corresponding dirname diff --git a/demos/process_video_on_ray/data/demo-dataset.jsonl b/demos/process_video_on_ray/data/demo-dataset.jsonl new file mode 100644 index 000000000..1c9c006b0 --- /dev/null +++ b/demos/process_video_on_ray/data/demo-dataset.jsonl @@ -0,0 +1,3 @@ +{"videos": ["./videos/video1.mp4"], "text": "<__dj__video> 10s videos <|__dj__eoc|>'}"} +{"videos": ["./videos/video2.mp4"], "text": "<__dj__video> 23s videos <|__dj__eoc|>'}"} +{"videos": ["./videos/video3.mp4"], "text": "<__dj__video> 46s videos <|__dj__eoc|>'}"} \ No newline at end of file diff --git a/demos/process_video_on_ray/data/videos/video1.mp4 b/demos/process_video_on_ray/data/videos/video1.mp4 new file mode 100644 index 000000000..5b0cad49f Binary files /dev/null and b/demos/process_video_on_ray/data/videos/video1.mp4 differ diff --git a/demos/process_video_on_ray/data/videos/video2.mp4 b/demos/process_video_on_ray/data/videos/video2.mp4 new file mode 100644 index 000000000..28acb927f Binary files /dev/null and b/demos/process_video_on_ray/data/videos/video2.mp4 differ diff --git a/demos/process_video_on_ray/data/videos/video3.mp4 b/demos/process_video_on_ray/data/videos/video3.mp4 new file mode 100644 index 000000000..45db64a51 Binary files /dev/null and b/demos/process_video_on_ray/data/videos/video3.mp4 differ diff --git a/docs/DJ_SORA.md b/docs/DJ_SORA.md new file mode 100644 index 000000000..1dce43860 --- /dev/null +++ b/docs/DJ_SORA.md @@ -0,0 +1,106 @@ +English | [中文页面](DJ_SORA_ZH.md) + +--- + +Data is the key to the unprecedented development of large multi-modal models such as SORA. How to obtain and process data efficiently and scientifically faces new challenges! DJ-SORA aims to create a series of large-scale, high-quality open source multi-modal data sets to assist the open source community in data understanding and model training. + +DJ-SORA is based on Data-Juicer (including hundreds of dedicated video, image, audio, text and other multi-modal data processing [operators](Operators_ZH.md) and tools) to form a series of systematic and reusable Multimodal "data recipes" for analyzing, cleaning, and generating large-scale, high-quality multimodal data. + +This project is being actively updated and maintained. We eagerly invite you to participate and jointly create a more open and higher-quality multi-modal data ecosystem to unleash the unlimited potential of large models! + +# Motivation +- SORA only briefly mentions using DALLE-3 to generate captions and can handle varying durations, resolutions and aspect ratios. +- High-quality large-scale fine-grained data helps to densify data points, aiding models to better learn the conditional mapping of "text -> spacetime token", and solve a series of existing challenges in text-to-video models: + - Smoothness of visual flow, with some generated videos exhibiting dropped frames and static states. + - Text comprehension and fine-grained detail, where the produced results have a low match with the given prompts. + - Generated content showing distortions and violations of physical laws, especially when entities are in motion. + - Short video content, mostly around ~10 seconds, with little to no significant changes in scenes or backdrops. + +# Roadmap +## Overview +* [Support high-performance loading and processing of video data](#Support high-performance loading and processing of video data) +* [Basic Operators (video spatio-temporal dimension)](#Basic operator video spatio-temporal dimension) +* [Advanced Operators (fine-grained modal matching and data generation)](#Advanced operators fine-grained modal matching and data generation) +* [Advanced Operators (Video Content)](#Advanced Operator Video Content) +* [DJ-SORA Data Recipes and Datasets](#DJ-SORA Data Recipes and Datasets) +* [DJ-SORA Data Validation and Model Training](#DJ-SORA Data Validation and Model Training) + + +## Support high-performance loading and processing of video data +- [✅] Parallelize data loading and storing: + - [✅] lazy load with pyAV and ffmpeg + - [✅] Multi-modal data path signature +- [✅] Parallelization operator processing: + - [✅] Support single machine multicore running + - [✅] GPU utilization + - [✅] Ray based multi-machine distributed running +- [ ] [WIP] Distributed scheduling optimization (OP-aware, automated load balancing) --> Aliyun PAI-DLC +- [ ] [WIP] Distributed storage optimization + +## Basic Operators (video spatio-temporal dimension) +- Towards Data Quality + - [✅] video_resolution_filter (targeted resolution) + - [✅] video_aspect_ratio_filter (targeted aspect ratio) + - [✅] video_duration_filter (targeted) duration) + - [✅] video_motion_score_filter (video continuity dimension, calculating optical flow and removing statics and extreme dynamics) + - [✅] video_ocr_area_ratio_filter (remove samples with text areas that are too large) +- Towards Data Diversity & Quantity + - [✅] video_resize_resolution_mapper (enhancement in resolution dimension) + - [✅] video_resize_aspect_ratio_mapper (enhancement in aspect ratio dimension) + - [✅] video_split_by_duration_mapper (enhancement in time dimension) + - [✅] video_split_by_key_frame_mapper (enhancement in time dimension with key information focus) + - [✅] video_split_by_scene_mapper (enhancement in time dimension with scene continuity focus) + +## Advanced Operators (fine-grained modal matching and data generation) +- Towards Data Quality + - [✅] video_frames_text_similarity_filter (enhancement in the spatiotemporal consistency dimension, calculating the matching score of key/specified frames and text) +- Towards Diversity & Quantity + - [✅] video_tagging_from_frames_mapper (with lightweight image-to-text models, spatial summary information from dense frames) + - [ ] [WIP] video_captioning_from_frames_mapper (heavier image-to-text models, generating more detailed spatial information from fewer frames) + - [✅] video_tagging_from_audio_mapper (introducing audio classification/category and other meta information) + - [✅] video_captioning_from_audio_mapper (incorporating voice/dialogue information; AudioCaption for environmental and global context) + - [✅] video_captioning_from_video_mapper (video-to-text model, generating spacetime information from continuous frames) + - [ ] [WIP] video_captioning_from_summarizer_mapper (combining the above sub-abilities, using pure text large models for denoising and summarizing different types of caption information) + - [ ] [WIP] video_interleaved_mapper (enhancement in ICL, temporal, and cross-modal dimensions), `interleaved_modes` include: + - text_image_interleaved (placing captions and frames of the same video in temporal order) + - text_audio_interleaved (placing ASR text and frames of the same video in temporal order) + - text_image_audio_interleaved (alternating stitching of the above two types) +## Advanced Operators (Video Content) +- [✅] video_deduplicator (comparing hash values to deduplicate at the file sample level) +- [✅] video_aesthetic_filter (performing aesthetic scoring filters after frame decomposition) +- [✅] Compatibility with existing ffmpeg video commands + - audio_ffmpeg_wrapped_mapper + - video_ffmpeg_wrapped_mapper +- [WIP] Video content compliance and privacy protection operators (image, text, audio): + - [✅] Mosaic + - [ ] Copyright watermark + - [ ] Face blurring + - [ ] Violence and Adult Content +- [ ] [TODO] (Beyond Interpolation) Enhancing data authenticity and density + - Collisions, lighting, gravity, 3D, scene and phase transitions, depth of field, etc. + - [ ] Filter-type operators: whether captions describe authenticity, relevance scoring/correctness of that description + - [ ] Mapper-type operators: enhance textual descriptions of physical phenomena in video data + - [ ] ... +## DJ-SORA Data Recipes and Datasets +- Support for unified loading and conversion of representative datasets (other-data <-> dj-data), facilitating DJ operator processing and dataset expansion. + - [✅] **Video-ChatGPT**: 100k video-instruction data: `{}` + - [✅] **Youku-mPLUG-CN**: 36TB video-caption data: `{}` + - [✅] **InternVid**: 234M data sample: `{}` + - [ ] VideoInstruct-100K, Panda70M, MSR-VTT, ...... + - [ ] ModelScope's datasets integration +- [ ] Large-scale high-quality DJ-SORA dataset + - [ ] [WIP] Continuous expansion of data sources: open-datasets, Youku, web, ... + - [ ] [WIP] (Data sandbox) Building and optimizing multimodal data recipes with DJ-video operators (which are also being continuously extended and improved). + - [ ] [WIP] Large-scale analysis and cleaning of high-quality multimodal datasets based on DJ recipes + - [ ] [WIP] Large-scale generation of high-quality multimodal datasets based on DJ recipes. + - ... + +## DJ-SORA Data Validation and Model Training + - [ ] [WIP] Exploring and refining multimodal data evaluation metrics and techniques, establishing benchmarks and insights. + - [ ] [WIP] Integration of SORA-like model training pipelines + - VideoDIT + - VQVAE + - ... + - [ ] [WIP] (Model-Data sandbox) With relatively small models and the DJ-SORA dataset, exploring low-cost, transferable, and instructive data-model co-design, configurations and checkpoints. + - [ ] Training SORA-like models with DJ-SORA data on larger scales and in more scenarios to improve model performance. + - ... diff --git a/docs/DJ_SORA_ZH.md b/docs/DJ_SORA_ZH.md new file mode 100644 index 000000000..4ccdd8866 --- /dev/null +++ b/docs/DJ_SORA_ZH.md @@ -0,0 +1,111 @@ +中文 | [English Page](DJ_SORA.md) + +--- + +数据是SORA等前沿大模型的关键,如何高效科学地获取和处理数据面临新的挑战!DJ-SORA旨在创建一系列大规模高质量开源多模态数据集,助力开源社区数据理解和模型训练。 + +DJ-SORA将基于Data-Juicer(包含上百个专用的视频、图像、音频、文本等多模态数据处理[算子](Operators_ZH.md)及工具),形成一系列系统化可复用的多模态“数据菜谱”,用于分析、清洗及生成大规模高质量多模态数据。 + +本项目正在积极更新和维护中,我们热切地邀请您参与,共同打造一个更开放、更高质的多模态数据生态系统,激发大模型无限潜能! + +# 动机 +- SORA仅简略提及使用了DALLE-3来生成高质量caption,且模型输入数据有变化的时长、分辨率和宽高比。 +- 高质量大规模细粒度数据有助于稠密化数据点,帮助模型学好“文本 -> spacetime token”的条件映射,解决text-2-video模型的一系列现有挑战: + - 画面流畅性和一致性,部分生成的视频有丢帧及静止状态 + - 文本理解能力和细粒度,生成出的结果和prompt匹配度较低 + - 视频内容较短,大多只有~10s,且场景画面不会有大的改变 + - 生成内容存在变形扭曲和物理规则违背情况,特别是在实体做出动作时 + +# 路线图 +## 概览 +* [支持视频数据的高性能加载和处理](#支持视频数据的高性能加载和处理) +* [基础算子(视频时空维度)](#基础算子视频时空维度) +* [进阶算子(细粒度模态间匹配及生成)](#进阶算子细粒度模态间匹配及生成) +* [进阶算子(视频内容)](#进阶算子视频内容) +* [DJ-SORA数据菜谱及数据集](#DJ-SORA数据菜谱及数据集) +* [DJ-SORA数据验证及模型训练](#DJ-SORA数据验证及模型训练) + +## 支持视频数据的高性能加载和处理 +- [✅] 并行化数据加载存储: + - [✅] lazy load with pyAV and ffmpeg + - [✅] 多模态数据路径签名 +- [✅] 并行化算子处理: + - [✅] 支持单机多核 + - [✅] GPU调用 + - [✅] Ray多机分布式 +- [ ] [WIP] 分布式调度优化(OP-aware、自动化负载均衡)--> Aliyun PAI-DLC +- [ ] [WIP] 分布式存储优化 + +## 基础算子(视频时空维度) +- 面向数据质量 + - [✅] video_resolution_filter (在分辨率维度进行过滤) + - [✅] video_aspect_ratio_filter (在宽高比维度进行过滤) + - [✅] video_duration_filter (在时间维度进行过滤) + - [✅] video_motion_score_filter(在视频连续性维度过滤,计算光流,去除静态和极端动态) + - [✅] video_ocr_area_ratio_filter (移除文本区域过大的样本) +- 面向数据多样性及数量 + - [✅] video_resize_resolution_mapper(在分辨率维度进行增强) + - [✅] video_resize_aspect_ratio_mapper(在宽高比维度进行增强) + - [✅] video_split_by_key_frame_mapper(基于关键帧进行切割) + - [✅] video_split_by_duration_mapper(在时间维度进行切割) + - [✅] video_split_by_scene_mapper (基于场景连续性进行切割) + +## 进阶算子(细粒度模态间匹配及生成) +- 面向数据质量 + - [✅] video_frames_text_similarity_filter(在时空一致性维度过滤,计算关键/指定帧 和文本的匹配分) +- 面向数据多样性及数量 + - [✅] video_tagging_from_frames_mapper (轻量图生文模型,密集帧生成空间 概要信息) + - [ ] [WIP] video_captioning_from_frames_mapper(更重的图生文模型,少量帧生 成更详细空间信息) + - [✅] video_tagging_from_audio_mapper (引入audio classification/category等meta信息) + - [✅] video_captioning_from_audio_mapper(引入人声/对话等信息; AudioCaption环境、场景等全局信息) + - [✅] video_captioning_from_video_mapper(视频生文模型,连续帧生成时序信息) + - [ ] [WIP] video_captioning_from_summarizer_mapper(基于上述子能力的组合,使用纯文本大模型对不同种caption信息去噪、摘要) + - [ ] [WIP] video_interleaved_mapper(在ICL、时间和跨模态维度增强),`interleaved_modes` include + - text_image_interleaved(按时序交叉放置同一视频的的caption和frames) + - text_audio_interleaved(按时序交叉放置同一视频的的ASR文本和frames) + - text_image_audio_interleaved(交替拼接上述两种) + +## 进阶算子(视频内容) +- [✅] video_deduplicator (比较MD5哈希值在文件样本级别去重) +- [✅] video_aesthetic_filter(拆帧后,进行美学度打分过滤) +- [✅]兼容ffmpeg已有的video commands + - audio_ffmpeg_wrapped_mapper + - video_ffmpeg_wrapped_mapper +- [WIP] 视频内容合规和隐私保护算子(图像、文字、音频): + - [✅] 马赛克 + - [ ] 版权水印 + - [ ] 人脸模糊 + - [ ] 黄暴恐 +- [ ] [TODO] (Beyond Interpolation) 增强数据真实性和稠密性 + - 碰撞、光影、重力、3D、场景切换(phase tranisition)、景深等 + - [ ] Filter类算子: caption是否描述真实性,该描述的相关性得分/正确性得分 + - [ ] Mapper类算子:增强video数据中对物理现象的文本描述 + - [ ] ... + + + +## DJ-SORA数据菜谱及数据集 +- 支持代表性数据的统一加载和转换(other-data <-> dj-data),方便DJ算子处理及扩展数据集 + - [✅] **Video-ChatGPT**: 100k video-instruction data:`{}` + - [✅] **Youku-mPLUG-CN**: 36TB video-caption data:`{}` + - [✅] **InternVid**: 234M data sample:`{}` + - [ ] VideoInstruct-100K, Panda70M, MSR-VTT, ...... + - [ ] ModelScope数据集集成 +- [ ] 大规模高质量DJ-SORA数据集 + - [ ] [WIP] 数据源持续扩充:open-datasets, youku, web, ... + - [ ] [WIP] (Data sandbox) 基于DJ-video算子构建和优化多模态数据菜谱 (算子同期持续完善) + - [ ] [WIP] 基于DJ菜谱规模化分析、清洗高质量多模态数据集 + - [ ] [WIP] 基于DJ菜谱规模化生成高质量多模态数据集 + - ... + +## DJ-SORA数据验证及模型训练 + - [ ] [WIP] 探索及完善多模态数据的评估指标和评估技术,形成benchmark和insights + - [ ] [WIP] 类SORA模型训练pipeline集成 + - VideoDIT + - VQVAE + - ... + - [ ] [WIP] (Model-Data sandbox) 在相对小的模型和DJ-SORA数据集上,探索形成低开销、可迁移、有指导性的data-model co-design、配置及检查点 + - [ ] 更大规模、更多场景使用DJ-SORA数据训练类SORA模型,提高模型性能 + - ... + + diff --git a/docs/DeveloperGuide.md b/docs/DeveloperGuide.md index a658b5e7c..bf248aa82 100644 --- a/docs/DeveloperGuide.md +++ b/docs/DeveloperGuide.md @@ -147,6 +147,52 @@ class StatsKeys(object): # ... (same as above) ``` + - In a mapper operator, to avoid process conflicts and data coverage, we offer an interface to make a saving path for produced extra datas. The format of the saving path is `{ORIGINAL_DATAPATH}/{OP_NAME}/{ORIGINAL_FILENAME}__dj_hash_#{HASH_VALUE}#.{EXT}`, where the `HASH_VALUE` is hashed from the init parameters of the operator, the related parameters in each sample, the process ID, and the timestamp. For convenience, we can call `self.remove_extra_parameters(locals())` at the beginning of the initiation to get the init parameters. At the same time, we can call `self.add_parameters` to add related parameters with the produced extra datas from each sample. Take the operator which enhances the images with diffusion models as example: + ```python + # ... (import some library) + OP_NAME = 'image_diffusion_mapper' + @OPERATORS.register_module(OP_NAME) + @LOADED_IMAGES.register_module(OP_NAME) + class ImageDiffusionMapper(Mapper): + def __init__(self, + # ... (OP parameters) + *args, + **kwargs): + super().__init__(*args, **kwargs) + self._init_parameters = self.remove_extra_parameters(locals()) + + def process(self, sample, rank=None): + # ... (some codes) + # captions[index] is the prompt for diffusion model + related_parameters = self.add_parameters( + self._init_parameters, caption=captions[index]) + new_image_path = transfer_filename( + origin_image_path, OP_NAME, **related_parameters) + # ... (some codes) + ``` + For the mapper to produce multi extra datas base on one origin data, we can add suffix at the saving path. Take the operator which splits videos according to their key frames as example: + ```python + # ... (import some library) + OP_NAME = 'video_split_by_key_frame_mapper' + @OPERATORS.register_module(OP_NAME) + @LOADED_VIDEOS.register_module(OP_NAME) + class VideoSplitByKeyFrameMapper(Mapper): + def __init__(self, + # ... (OP parameters) + *args, + **kwargs): + super().__init__(*args, **kwargs) + self._init_parameters = self.remove_extra_parameters(locals()) + + def process(self, sample, rank=None): + # ... (some codes) + split_video_path = transfer_filename( + original_video_path, OP_NAME, **self._init_parameters) + suffix = '_split-by-key-frame-' + str(count) + split_video_path = add_suffix_to_filename(split_video_path, suffix) + # ... (some codes) + ``` + 3. After implemention, add it to the OP dictionary in the `__init__.py` file in `data_juicer/ops/filter/` directory. ```python @@ -172,8 +218,9 @@ process: ```python import unittest from data_juicer.ops.filter.text_length_filter import TextLengthFilter +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class TextLengthFilterTest(unittest.TestCase): +class TextLengthFilterTest(DataJuicerTestCaseBase): def test_func1(self): pass @@ -183,6 +230,9 @@ class TextLengthFilterTest(unittest.TestCase): def test_func3(self): pass + +if __name__ == '__main__': + unittest.main() ``` 6. (Strongly Recommend) In order to facilitate the use of other users, we also need to update this new OP information to diff --git a/docs/DeveloperGuide_ZH.md b/docs/DeveloperGuide_ZH.md index f188aecbc..3c6bb2411 100644 --- a/docs/DeveloperGuide_ZH.md +++ b/docs/DeveloperGuide_ZH.md @@ -142,6 +142,52 @@ class StatsKeys(object): # ... (same as above) ``` + - 在mapper算子中,我们提供了产生额外数据的存储路径生成接口,避免出现进程冲突和数据覆盖的情况。生成的存储路径格式为`{ORIGINAL_DATAPATH}/{OP_NAME}/{ORIGINAL_FILENAME}__dj_hash_#{HASH_VALUE}#.{EXT}`,其中`HASH_VALUE`是算子初始化参数、每个样本中相关参数、进程ID和时间戳的哈希值。为了方便,可以在OP类初始化开头调用`self.remove_extra_parameters(locals())`获取算子初始化参数,同时可以调用`self.add_parameters`添加每个样本与生成额外数据相关的参数。例如,利用diffusion模型对图像进行增强的算子: + ```python + # ... (import some library) + OP_NAME = 'image_diffusion_mapper' + @OPERATORS.register_module(OP_NAME) + @LOADED_IMAGES.register_module(OP_NAME) + class ImageDiffusionMapper(Mapper): + def __init__(self, + # ... (OP parameters) + *args, + **kwargs): + super().__init__(*args, **kwargs) + self._init_parameters = self.remove_extra_parameters(locals()) + + def process(self, sample, rank=None): + # ... (some codes) + # captions[index] is the prompt for diffusion model + related_parameters = self.add_parameters( + self._init_parameters, caption=captions[index]) + new_image_path = transfer_filename( + origin_image_path, OP_NAME, **related_parameters) + # ... (some codes) + ``` + 针对一个数据源衍生出多个额外数据的情况,我们允许在生成的存储路径后面再加后缀。比如,根据关键帧将视频拆分成多个视频: + ```python + # ... (import some library) + OP_NAME = 'video_split_by_key_frame_mapper' + @OPERATORS.register_module(OP_NAME) + @LOADED_VIDEOS.register_module(OP_NAME) + class VideoSplitByKeyFrameMapper(Mapper): + def __init__(self, + # ... (OP parameters) + *args, + **kwargs): + super().__init__(*args, **kwargs) + self._init_parameters = self.remove_extra_parameters(locals()) + + def process(self, sample, rank=None): + # ... (some codes) + split_video_path = transfer_filename( + original_video_path, OP_NAME, **self._init_parameters) + suffix = '_split-by-key-frame-' + str(count) + split_video_path = add_suffix_to_filename(split_video_path, suffix) + # ... (some codes) + ``` + 3. 实现后,将其添加到 `data_juicer/ops/filter` 目录下 `__init__.py` 文件中的算子字典中: ```python @@ -168,8 +214,10 @@ process: ```python import unittest from data_juicer.ops.filter.text_length_filter import TextLengthFilter +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + -class TextLengthFilterTest(unittest.TestCase): +class TextLengthFilterTest(DataJuicerTestCaseBase): def test_func1(self): pass @@ -179,6 +227,9 @@ class TextLengthFilterTest(unittest.TestCase): def test_func3(self): pass + +if __name__ == '__main__': + unittest.main() ``` 6. (强烈推荐)为了方便其他用户使用,我们还需要将新增的算子信息更新到相应的文档中,具体包括如下文档: diff --git a/docs/Operators.md b/docs/Operators.md index 28fdb7306..9409449b3 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -2,6 +2,7 @@ Operators are a collection of basic processes that assist in data modification, cleaning, filtering, deduplication, etc. We support a wide range of data sources and file formats, and allow for flexible extension to custom datasets. +This page offers a basic description of the operators (OPs) in Data-Juicer. Users can refer to the [API documentation](https://alibaba.github.io/data-juicer/) for the specific parameters of each operator. Users can refer to and run the unit tests for [examples of operator-wise usage](../tests/ops) as well as the effects of each operator when applied to built-in test data samples. ## Overview @@ -10,9 +11,9 @@ The operators in Data-Juicer are categorized into 5 types. | Type | Number | Description | |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data | -| [ Mapper ]( #mapper ) | 26 | Edits and transforms samples | -| [ Filter ]( #filter ) | 28 | Filters out low-quality samples | -| [ Deduplicator ]( #deduplicator ) | 4 | Detects and removes duplicate samples | +| [ Mapper ]( #mapper ) | 38 | Edits and transforms samples | +| [ Filter ]( #filter ) | 36 | Filters out low-quality samples | +| [ Deduplicator ]( #deduplicator ) | 5 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 2 | Selects top samples based on ranking | @@ -25,6 +26,7 @@ All the specific operators are listed below, each featured with several capabili - Financial: closely related to financial sector - Image: specific to images or multimodal - Audio: specific to audios or multimodal + - Video: specific to videos or multimodal - Multimodal: specific to multimodal * Language Tags - en: English @@ -46,69 +48,88 @@ All the specific operators are listed below, each featured with several capabili ## Mapper -| Operator | Domain | Lang | Description | -|-----------------------------------------------------|--------------------|--------|----------------------------------------------------------------------------------------------------------------| -| chinese_convert_mapper | General | zh | Converts Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji (by [opencc](https://github.com/BYVoid/OpenCC)) | -| clean_copyright_mapper | Code | en, zh | Removes copyright notice at the beginning of code files (:warning: must contain the word *copyright*) | -| clean_email_mapper | General | en, zh | Removes email information | -| clean_html_mapper | General | en, zh | Removes HTML tags and returns plain text of all the nodes | -| clean_ip_mapper | General | en, zh | Removes IP addresses | -| clean_links_mapper | General, Code | en, zh | Removes links, such as those starting with http or ftp | -| expand_macro_mapper | LaTeX | en, zh | Expands macros usually defined at the top of TeX documents | -| fix_unicode_mapper | General | en, zh | Fixes broken Unicodes (by [ftfy](https://ftfy.readthedocs.io/)) | -| generate_caption_mapper | Multimodal | - | generate samples whose captions are generated based on another model (such as blip2) and the figure within the original sample | -| gpt4v_generate_mapper | Multimodal | - | generate samples whose texts are generated based on gpt-4-visison and the image | +| Operator | Domain | Lang | Description | +|-----------------------------------------------------|--------------------|--------|---------------------------------------------------------------------------------------------------------------| +| audio_ffmpeg_wrapped_mapper | Audio | - | Simple wrapper to run a FFmpeg audio filter | +| chinese_convert_mapper | General | zh | Converts Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji (by [opencc](https://github.com/BYVoid/OpenCC)) | +| clean_copyright_mapper | Code | en, zh | Removes copyright notice at the beginning of code files (:warning: must contain the word *copyright*) | +| clean_email_mapper | General | en, zh | Removes email information | +| clean_html_mapper | General | en, zh | Removes HTML tags and returns plain text of all the nodes | +| clean_ip_mapper | General | en, zh | Removes IP addresses | +| clean_links_mapper | General, Code | en, zh | Removes links, such as those starting with http or ftp | +| expand_macro_mapper | LaTeX | en, zh | Expands macros usually defined at the top of TeX documents | +| fix_unicode_mapper | General | en, zh | Fixes broken Unicodes (by [ftfy](https://ftfy.readthedocs.io/)) | | image_blur_mapper | Multimodal | - | Blur images | +| image_captioning_from_gpt4v_mapper | Multimodal | - | generate samples whose texts are generated based on gpt-4-visison and the image | +| image_captioning_mapper | Multimodal | - | generate samples whose captions are generated based on another model (such as blip2) and the figure within the original sample | | image_diffusion_mapper | Multimodal | - | Generate and augment images by stable diffusion model | -| nlpaug_en_mapper | General | en | Simply augments texts in English based on the `nlpaug` library | -| nlpcda_zh_mapper | General | zh | Simply augments texts in Chinese based on the `nlpcda` library | -| punctuation_normalization_mapper | General | en, zh | Normalizes various Unicode punctuations to their ASCII equivalents | -| remove_bibliography_mapper | LaTeX | en, zh | Removes the bibliography of TeX documents | -| remove_comments_mapper | LaTeX | en, zh | Removes the comments of TeX documents | -| remove_header_mapper | LaTeX | en, zh | Removes the running headers of TeX documents, e.g., titles, chapter or section numbers/names | -| remove_long_words_mapper | General | en, zh | Removes words with length outside the specified range | +| nlpaug_en_mapper | General | en | Simply augments texts in English based on the `nlpaug` library | +| nlpcda_zh_mapper | General | zh | Simply augments texts in Chinese based on the `nlpcda` library | +| punctuation_normalization_mapper | General | en, zh | Normalizes various Unicode punctuations to their ASCII equivalents | +| remove_bibliography_mapper | LaTeX | en, zh | Removes the bibliography of TeX documents | +| remove_comments_mapper | LaTeX | en, zh | Removes the comments of TeX documents | +| remove_header_mapper | LaTeX | en, zh | Removes the running headers of TeX documents, e.g., titles, chapter or section numbers/names | +| remove_long_words_mapper | General | en, zh | Removes words with length outside the specified range | | remove_non_chinese_character_mapper | General | en, zh | Remove non Chinese character in text samples. | | remove_repeat_sentences_mapper | General | en, zh | Remove repeat sentences in text samples. | -| remove_specific_chars_mapper | General | en, zh | Removes any user-specified characters or substrings | +| remove_specific_chars_mapper | General | en, zh | Removes any user-specified characters or substrings | | remove_table_text_mapper | General, Financial | en | Detects and removes possible table contents (:warning: relies on regular expression matching and thus fragile) | -| remove_words_with_incorrect_
substrings_mapper | General | en, zh | Removes words containing specified substrings | -| replace_content_mapper | General | en, zh | Replace all content in the text that matches a specific regular expression pattern with a designated replacement string. | -| sentence_split_mapper | General | en | Splits and reorganizes sentences according to semantics | -| whitespace_normalization_mapper | General | en, zh | Normalizes various Unicode whitespaces to the normal ASCII space (U+0020) | +| remove_words_with_incorrect_
substrings_mapper | General | en, zh | Removes words containing specified substrings | +| replace_content_mapper | General | en, zh | Replace all content in the text that matches a specific regular expression pattern with a designated replacement string | +| sentence_split_mapper | General | en | Splits and reorganizes sentences according to semantics | +| video_captioning_from_audio_mapper | Multimodal | - | Caption a video according to its audio streams based on Qwen-Audio model | +| video_captioning_from_video_mapper | Multimodal | - | generate samples whose captions are generated based on another model (video-blip) and sampled video frame within the original sample | +| video_ffmpeg_wrapped_mapper | Video | - | Simple wrapper to run a FFmpeg video filter | +| video_resize_aspect_ratio_mapper | Video | - | Resize video aspect ratio to a specified range | +| video_resize_resolution_mapper | Video | - | Map videos to ones with given resolution range | +| video_split_by_duration_mapper | Multimodal | - | Mapper to split video by duration. | +| video_spit_by_key_frame_mapper | Multimodal | - | Mapper to split video by key frame. | +| video_split_by_scene_mapper | Multimodal | - | Split videos into scene clips | +| video_tagging_from_audio_mapper | Multimodal | - | Mapper to generate video tags from audio streams extracted from the video. | +| video_tagging_from_frames_mapper | Multimodal | - | Mapper to generate video tags from frames extracted from the video. | +| whitespace_normalization_mapper | General | en, zh | Normalizes various Unicode whitespaces to the normal ASCII space (U+0020) | ## Filter
-| Operator | Domain | Lang | Description | -|--------------------------------|------------|--------|-------------------------------------------------------------------------------------------------------------------------------------------------------| -| alphanumeric_filter | General | en, zh | Keeps samples with alphanumeric ratio within the specified range | -| audio_duration_filter | Audio | - | Keep data samples whose audios' durations are within a specified range | +| Operator | Domain | Lang | Description | +|--------------------------------|------------|--------|-----------------------------------------------------------------------------------------------------------------------------------------------------| +| alphanumeric_filter | General | en, zh | Keeps samples with alphanumeric ratio within the specified range | +| audio_duration_filter | Audio | - | Keep data samples whose audios' durations are within a specified range | | audio_nmf_snr_filter | Audio | - | Keep data samples whose audios' Signal-to-Noise Ratios (SNRs, computed based on Non-Negative Matrix Factorization, NMF) are within a specified range. | -| audio_size_filter | Audio | - | Keep data samples whose audios' sizes are within a specified range | -| average_line_length_filter | Code | en, zh | Keeps samples with average line length within the specified range | -| character_repetition_filter | General | en, zh | Keeps samples with char-level n-gram repetition ratio within the specified range | -| face_area_filter | Image | - | Keeps samples containing images with face area ratios within the specified range | -| flagged_words_filter | General | en, zh | Keeps samples with flagged-word ratio below the specified threshold | -| image_aspect_ratio_filter | Image | - | Keeps samples containing images with aspect ratios within the specified range | -| image_shape_filter | Image | - | Keeps samples containing images with widths and heights within the specified range | -| image_size_filter | Image | - | Keeps samples containing images whose size in bytes are within the specified range | -| image_text_matching_filter | Multimodal | - | Keeps samples with image-text classification matching score within the specified range based on a BLIP model | -| image_text_similarity_filter | Multimodal | - | Keeps samples with image-text feature cosine similarity within the specified range based on a CLIP model | -| language_id_score_filter | General | en, zh | Keeps samples of the specified language, judged by a predicted confidence score | -| maximum_line_length_filter | Code | en, zh | Keeps samples with maximum line length within the specified range | -| perplexity_filter | General | en, zh | Keeps samples with perplexity score below the specified threshold | -| phrase_grounding_recall_filter | Multimodal | - | Keeps samples whose locating recalls of phrases extracted from text in the images are within a specified range | -| special_characters_filter | General | en, zh | Keeps samples with special-char ratio within the specified range | -| specified_field_filter | General | en, zh | Filters samples based on field, with value lies in the specified targets | -| specified_numeric_field_filter | General | en, zh | Filters samples based on field, with value lies in the specified range (for numeric types) | -| stopwords_filter | General | en, zh | Keeps samples with stopword ratio above the specified threshold | -| suffix_filter | General | en, zh | Keeps samples with specified suffixes | -| text_action_filter | General | en, zh | Keeps samples containing action verbs in their texts | -| text_entity_dependency_filter | General | en, zh | Keeps samples containing entity nouns related to other tokens in the dependency tree of the texts | -| text_length_filter | General | en, zh | Keeps samples with total text length within the specified range | -| token_num_filter | General | en, zh | Keeps samples with token count within the specified range | -| word_num_filter | General | en, zh | Keeps samples with word count within the specified range | -| word_repetition_filter | General | en, zh | Keeps samples with word-level n-gram repetition ratio within the specified range | +| audio_size_filter | Audio | - | Keep data samples whose audios' sizes are within a specified range | +| average_line_length_filter | Code | en, zh | Keeps samples with average line length within the specified range | +| character_repetition_filter | General | en, zh | Keeps samples with char-level n-gram repetition ratio within the specified range | +| face_area_filter | Image | - | Keeps samples containing images with face area ratios within the specified range | +| flagged_words_filter | General | en, zh | Keeps samples with flagged-word ratio below the specified threshold | +| image_aspect_ratio_filter | Image | - | Keeps samples containing images with aspect ratios within the specified range | +| image_shape_filter | Image | - | Keeps samples containing images with widths and heights within the specified range | +| image_size_filter | Image | - | Keeps samples containing images whose size in bytes are within the specified range | +| image_aesthetics_filter | Image | - | Keeps samples containing images whose aesthetics scores are within the specified range | +| image_text_matching_filter | Multimodal | - | Keeps samples with image-text classification matching score within the specified range based on a BLIP model | +| image_text_similarity_filter | Multimodal | - | Keeps samples with image-text feature cosine similarity within the specified range based on a CLIP model | +| language_id_score_filter | General | en, zh | Keeps samples of the specified language, judged by a predicted confidence score | +| maximum_line_length_filter | Code | en, zh | Keeps samples with maximum line length within the specified range | +| perplexity_filter | General | en, zh | Keeps samples with perplexity score below the specified threshold | +| phrase_grounding_recall_filter | Multimodal | - | Keeps samples whose locating recalls of phrases extracted from text in the images are within a specified range | +| special_characters_filter | General | en, zh | Keeps samples with special-char ratio within the specified range | +| specified_field_filter | General | en, zh | Filters samples based on field, with value lies in the specified targets | +| specified_numeric_field_filter | General | en, zh | Filters samples based on field, with value lies in the specified range (for numeric types) | +| stopwords_filter | General | en, zh | Keeps samples with stopword ratio above the specified threshold | +| suffix_filter | General | en, zh | Keeps samples with specified suffixes | +| text_action_filter | General | en, zh | Keeps samples containing action verbs in their texts | +| text_entity_dependency_filter | General | en, zh | Keeps samples containing entity nouns related to other tokens in the dependency tree of the texts | +| text_length_filter | General | en, zh | Keeps samples with total text length within the specified range | +| token_num_filter | General | en, zh | Keeps samples with token count within the specified range | +| video_aspect_ratio_filter | Video | - | Keeps samples containing videos with aspect ratios within the specified range | +| video_duration_filter | Video | - | Keep data samples whose videos' durations are within a specified range | +| video_aesthetics_filter | Video | - | Keeps samples whose specified frames have aesthetics scores within the specified range | +| video_frames_text_similarity_filter | Multimodal | - | Keep data samples whose similarities between sampled video frame images and text are within a specific range | +| video_motion_score_filter | Video | - | Keep samples with video motion scores within a specific range | +| video_ocr_area_ratio_filter | Video | - | Keep data samples whose detected text area ratios for specified frames in the video are within a specified range | +| video_resolution_filter | Video | - | Keeps samples containing videos with horizontal and vertical resolutions within the specified range | +| word_num_filter | General | en, zh | Keeps samples with word count within the specified range | +| word_repetition_filter | General | en, zh | Keeps samples with word-level n-gram repetition ratio within the specified range | ## Deduplicator @@ -119,6 +140,7 @@ All the specific operators are listed below, each featured with several capabili | document_minhash_deduplicator | General | en, zh | Deduplicates samples at document-level using MinHashLSH | | document_simhash_deduplicator | General | en, zh | Deduplicates samples at document-level using SimHash | | image_deduplicator | Image | - | Deduplicates samples at document-level using exact matching of images between documents | +| video_deduplicator | Video | - | Deduplicates samples at document-level using exact matching of videos between documents | ## Selector diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index f3df33d89..4517c614c 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -2,6 +2,8 @@ 算子 (Operator) 是协助数据修改、清理、过滤、去重等基本流程的集合。我们支持广泛的数据来源和文件格式,并支持对自定义数据集的灵活扩展。 +这个页面提供了OP的基本描述,用户可以参考[API文档](https://alibaba.github.io/data-juicer/)更细致了解每个OP的具体参数,并且可以查看、运行单元测试,来体验[各OP的用法示例](../tests/ops)以及每个OP作用于内置测试数据样本时的效果。 + ## 概览 Data-Juicer 中的算子分为以下 5 种类型。 @@ -9,9 +11,9 @@ Data-Juicer 中的算子分为以下 5 种类型。 | 类型 | 数量 | 描述 | |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 | -| [ Mapper ]( #mapper ) | 26 | 对数据样本进行编辑和转换 | -| [ Filter ]( #filter ) | 28 | 过滤低质量样本 | -| [ Deduplicator ]( #deduplicator ) | 4 | 识别、删除重复样本 | +| [ Mapper ]( #mapper ) | 38 | 对数据样本进行编辑和转换 | +| [ Filter ]( #filter ) | 36 | 过滤低质量样本 | +| [ Deduplicator ]( #deduplicator ) | 5 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 2 | 基于排序选取高质量样本 | 下面列出所有具体算子,每种算子都通过多个标签来注明其主要功能。 @@ -23,6 +25,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 - Financial: 与金融领域相关 - Image: 专用于图像或多模态 - Audio: 专用于音频或多模态 + - Video: 专用于视频或多模态 - Multimodal: 专用于多模态 * Language 标签 @@ -46,6 +49,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | 算子 | 场景 | 语言 | 描述 | |-----------------------------------------------------|-----------------------|-----------|--------------------------------------------------------| +| audio_ffmpeg_wrapped_mapper | Audio | - | 运行 FFmpeg 语音过滤器的简单封装 | | chinese_convert_mapper | General | zh | 用于在繁体中文、简体中文和日文汉字之间进行转换(借助 [opencc](https://github.com/BYVoid/OpenCC)) | | clean_copyright_mapper | Code | en, zh | 删除代码文件开头的版权声明 (:warning: 必须包含单词 *copyright*) | | clean_email_mapper | General | en, zh | 删除邮箱信息 | @@ -54,9 +58,9 @@ Data-Juicer 中的算子分为以下 5 种类型。 | clean_links_mapper | General, Code | en, zh | 删除链接,例如以 http 或 ftp 开头的 | | expand_macro_mapper | LaTeX | en, zh | 扩展通常在 TeX 文档顶部定义的宏 | | fix_unicode_mapper | General | en, zh | 修复损坏的 Unicode(借助 [ftfy](https://ftfy.readthedocs.io/)) | -| generate_caption_mapper | Multimodal | - | 生成样本,其标题是根据另一个辅助模型(例如 blip2)和原始样本中的图形生成的。 | -| gpt4v_generate_mapper | Multimodal | - | 基于gpt-4-vision和图像生成文本 | -| image_blur_mapper | Multimodal | - | 对图像进行模糊处理 | +| image_blur_mapper | Multimodal | - | 对图像进行模糊处理 | +| image_captioning_from_gpt4v_mapper | Multimodal | - | 基于gpt-4-vision和图像生成文本 | +| image_captioning_mapper | Multimodal | - | 生成样本,其标题是根据另一个辅助模型(例如 blip2)和原始样本中的图形生成的。 | | image_diffusion_mapper | Multimodal | - | 用stable diffusion生成图像,对图像进行增强 | | nlpaug_en_mapper | General | en | 使用`nlpaug`库对英语文本进行简单增强 | | nlpcda_zh_mapper | General | zh | 使用`nlpcda`库对中文文本进行简单增强 | @@ -70,8 +74,18 @@ Data-Juicer 中的算子分为以下 5 种类型。 | remove_specific_chars_mapper | General | en, zh | 删除任何用户指定的字符或子字符串 | | remove_table_text_mapper | General, Financial | en | 检测并删除可能的表格内容(:warning: 依赖正则表达式匹配,因此很脆弱) | | remove_words_with_incorrect_
substrings_mapper | General | en, zh | 删除包含指定子字符串的单词 | -| replace_content_mapper | General | en, zh | 使用一个指定的替换字符串替换文本中满足特定正则表达式模版的所有内容 | +| replace_content_mapper | General | en, zh | 使用一个指定的替换字符串替换文本中满足特定正则表达式模版的所有内容 | | sentence_split_mapper | General | en | 根据语义拆分和重组句子 | +| video_captioning_from_audio_mapper | Multimodal | - | 基于 Qwen-Audio 模型根据视频的音频流为视频生成新的标题描述 | +| video_captioning_from_video_mapper | Multimodal | - | 生成样本,其标题是根据另一个辅助模型(video-blip)和原始样本中的视频中指定帧的图像。 | +| video_ffmpeg_wrapped_mapper | Video | - | 运行 FFmpeg 视频过滤器的简单封装 | +| video_resize_aspect_ratio_mapper | Video | - | 将视频的宽高比调整到指定范围内 | +| video_resize_resolution_mapper | Video | - | 将视频映射到给定的分辨率区间 | +| video_split_by_duration_mapper | Multimodal | - | 根据时长将视频切分为多个片段 | +| video_split_by_key_frame_mapper | Multimodal | - | 根据关键帧切分视频 | +| video_split_by_scene_mapper | Multimodal | - | 将视频切分为场景片段 | +| video_tagging_from_audio_mapper | Multimodal | - | 从视频提取的音频中生成视频标签 | +| video_tagging_from_frames_mapper | Multimodal | - | 从视频提取的帧中生成视频标签 | | whitespace_normalization_mapper | General | en, zh | 将各种 Unicode 空白标准化为常规 ASCII 空格 (U+0020) | ## Filter
@@ -80,7 +94,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 |--------------------------------|------------|--------|---------------------------------------------| | alphanumeric_filter | General | en, zh | 保留字母数字比例在指定范围内的样本 | | audio_duration_filter | Audio | - | 保留样本中包含的音频的时长在指定范围内的样本 | -| audio_nmf_snr_filter | Audio | - | 保留样本中包含的音频信噪比SNR(基于非负矩阵分解方法NMF计算)在指定范围内的样本. | +| audio_nmf_snr_filter | Audio | - | 保留样本中包含的音频信噪比SNR(基于非负矩阵分解方法NMF计算)在指定范围内的样本 | | audio_size_filter | Audio | - | 保留样本中包含的音频的大小(bytes)在指定范围内的样本 | | average_line_length_filter | Code | en, zh | 保留平均行长度在指定范围内的样本 | | character_repetition_filter | General | en, zh | 保留 char-level n-gram 重复比率在指定范围内的样本 | @@ -89,6 +103,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | image_aspect_ratio_filter | Image | - | 保留样本中包含的图片的宽高比在指定范围内的样本 | | image_shape_filter | Image | - | 保留样本中包含的图片的形状(即宽和高)在指定范围内的样本 | | image_size_filter | Image | - | 保留样本中包含的图片的大小(bytes)在指定范围内的样本 | +| image_aesthetics_filter | Image | - | 保留包含美学分数在指定范围内的图像的样本 | | image_text_matching_filter | Multimodal | - | 保留图像-文本的分类匹配分(基于BLIP模型)在指定范围内的样本 | | image_text_similarity_filter | Multimodal | - | 保留图像-文本的特征余弦相似度(基于CLIP模型)在指定范围内的样本 | | language_id_score_filter | General | en, zh | 保留特定语言的样本,通过预测的置信度得分来判断 | @@ -104,6 +119,13 @@ Data-Juicer 中的算子分为以下 5 种类型。 | text_entity_dependency_filter | General | en, zh | 保留文本部分的依存树中具有非独立实体的样本 | | text_length_filter | General | en, zh | 保留总文本长度在指定范围内的样本 | | token_num_filter | General | en, zh | 保留token数在指定范围内的样本 | +| video_aspect_ratio_filter | Video | - | 保留样本中包含的视频的宽高比在指定范围内的样本 | +| video_duration_filter | Video | - | 保留样本中包含的视频的时长在指定范围内的样本 | +| video_aesthetics_filter | Video | - | 保留指定帧的美学分数在指定范围内的样本| +| video_frames_text_similarity_filter | Multimodal | - | 保留视频中指定帧的图像-文本的特征余弦相似度(基于CLIP模型)在指定范围内的样本 | +| video_motion_score_filter | Video | - | 保留样本中包含的视频的运动份(基于稠密光流)在指定范围内的样本 | +| video_ocr_area_ratio_filter | Video | - | 保留样本中包含的视频的特定帧中检测出的文本的面积占比在指定范围内的样本 | +| video_resolution_filter | Video | - | 保留样本中包含的视频的分辨率(包括横向分辨率和纵向分辨率)在指定范围内的样本 | | word_num_filter | General | en, zh | 保留字数在指定范围内的样本 | | word_repetition_filter | General | en, zh | 保留 word-level n-gram 重复比率在指定范围内的样本 | @@ -115,6 +137,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | document_minhash_deduplicator | General | en, zh | 使用 MinHashLSH 在文档级别对样本去重 | | document_simhash_deduplicator | General | en, zh | 使用 SimHash 在文档级别对样本去重 | | image_deduplicator | Image | - | 使用文档之间图像的精确匹配在文档级别删除重复样本 | +| video_deduplicator | Video | - | 使用文档之间视频的精确匹配在文档级别删除重复样本 | ## Selector diff --git a/environments/dist_requires.txt b/environments/dist_requires.txt index e02756318..0edc5aa35 100644 --- a/environments/dist_requires.txt +++ b/environments/dist_requires.txt @@ -1 +1 @@ -ray +ray==2.9.2 diff --git a/environments/minimal_requires.txt b/environments/minimal_requires.txt index 31ef270e8..d7696c75b 100644 --- a/environments/minimal_requires.txt +++ b/environments/minimal_requires.txt @@ -1,7 +1,8 @@ -fsspec==2023.3.0 +fsspec==2023.5.0 pyarrow<=12.0.0 pandas==2.0.0 datasets==2.11.0 +av soundfile librosa loguru diff --git a/environments/science_requires.txt b/environments/science_requires.txt index 5116c8026..4421aad0d 100644 --- a/environments/science_requires.txt +++ b/environments/science_requires.txt @@ -1,3 +1,4 @@ +easyocr fasttext-wheel kenlm sentencepiece @@ -8,10 +9,18 @@ selectolax nlpaug nlpcda nltk -transformers +transformers>=4.37 +transformers_stream_generator +einops +accelerate +tiktoken opencc==1.1.6 imagededup torch +torchaudio dlib spacy-pkuseg==0.0.32 diffusers +simple-aesthetics-predictor +scenedetect[opencv] +ffmpeg-python diff --git a/scripts/README.md b/scripts/README.md new file mode 100644 index 000000000..e20de6fab --- /dev/null +++ b/scripts/README.md @@ -0,0 +1,14 @@ +# Scripts for Running on Multi Nodes + + +#### Running Using DLC(Deep Learing Containers) + +Internally we use [DLC](https://www.alibabacloud.com/help/zh/pai/user-guide/container-training/) from [PAI](https://www.alibabacloud.com/zh/product/machine-learning) to process data on multiple nodes. + +The scripts to run are in ./dlc folder. + +#### Running Using Slurm + + - [ ] We will provide scripts to support running on slurm. + +You can also manually partition the data according to specific circumstances and then use Slurm to run it on multiple machines by yourself. \ No newline at end of file diff --git a/scripts/dlc/partition_data_dlc.py b/scripts/dlc/partition_data_dlc.py new file mode 100644 index 000000000..b0f5bbbfc --- /dev/null +++ b/scripts/dlc/partition_data_dlc.py @@ -0,0 +1,49 @@ +import argparse +import json +import os +from collections import defaultdict +from typing import List + + +def partition_data(json_file_path: str, hostnames: List[str]): + with open(json_file_path, 'r') as f: + data = [json.loads(line) for line in f] + video_to_entries_map = defaultdict(list) + for entry in data: + video_path = entry['videos'][0] + video_to_entries_map[video_path].append(entry) + nodes_data = defaultdict(list) + nodes_video_size = {k: 0 for k in hostnames} + + # distribute videos to nodes based on the total size of videos + video_sizes = { + video: os.path.getsize(video) + for video in video_to_entries_map.keys() + } + + sorted_videos = sorted(video_sizes, key=video_sizes.get, reverse=True) + for video in sorted_videos: + min_node = min(nodes_video_size, key=nodes_video_size.get) + nodes_data[min_node].extend(video_to_entries_map[video]) + nodes_video_size[min_node] += video_sizes[video] + + for hostname in hostnames: + host_file_path = f"{json_file_path.rsplit('.', 1)[0]}_{hostname}.json" + with open(host_file_path, 'w') as f: + for entry in nodes_data[hostname]: + f.write(json.dumps(entry) + '\n') + + +# Usage +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Partition data across hostnames.') + + parser.add_argument('file_path', + type=str, + help='Path of the file to distribute.') + parser.add_argument('hostnames', nargs='+', help='The list of hostnames') + + args = parser.parse_args() + + partition_data(args.file_path, args.hostnames) diff --git a/scripts/dlc/run_on_dlc.sh b/scripts/dlc/run_on_dlc.sh new file mode 100644 index 000000000..8ed356e99 --- /dev/null +++ b/scripts/dlc/run_on_dlc.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# paremeters +datajuicer_path= # path to data-juicer +config_path= # path to config file + +# hostname +hostname=$(hostname) + +# into datajuicer_path +cd "$datajuicer_path" || { echo "Could not change directory to $datajuicer_path"; exit 1; } + +# copy and generate new config file for current host + +config_basename=$(basename "$config_path") +config_dirname=$(dirname "$config_path") +config_extension="${config_basename##*.}" +config_basename="${config_basename%.*}" + +new_config_file="${config_dirname}/${config_basename}_$hostname.$config_extension" +cp "$config_path" "$new_config_file" || { echo "Could not copy config file"; exit 1; } + +echo "$new_config_file" + +if [[ "$OSTYPE" == "darwin"* ]]; then + SED_I_SUFFIX=".bak" +else + SED_I_SUFFIX="" +fi + +if grep -q "dataset_path: .*\.json" "$new_config_file"; then + # .json data_path + sed -i$SED_I_SUFFIX "s|\(dataset_path: \)\(.*\)\(/[^/]*\)\(.json\)|\1\2\3_$hostname\4|" "$new_config_file" +else + # dir dataset_path + sed -i$SED_I_SUFFIX "s|\(dataset_path: '\)\(.*\)'\(.*\)|\1\2_$hostname'\3|" "$new_config_file" +fi + +# run to process data +python tools/process_data.py --config "$new_config_file" diff --git a/tests/config/test_config_funcs.py b/tests/config/test_config_funcs.py index 9e66e6b66..a2748ed7f 100644 --- a/tests/config/test_config_funcs.py +++ b/tests/config/test_config_funcs.py @@ -7,12 +7,13 @@ from data_juicer.config import init_configs from data_juicer.ops import load_ops +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase test_yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'demo_4_test.yaml') -class ConfigTest(unittest.TestCase): +class ConfigTest(DataJuicerTestCaseBase): def test_help_info(self): out = StringIO() @@ -35,12 +36,13 @@ def test_yaml_cfg_file(self): self.assertIsInstance(cfg, Namespace) self.assertEqual(cfg.project_name, 'test_demo') self.assertDictEqual( - cfg.process[0], - {'whitespace_normalization_mapper': { - 'text_key': 'text', - 'image_key': 'images', - 'audio_key': 'audios', - }}, 'nested dict load fail, for nonparametric op') + cfg.process[0], { + 'whitespace_normalization_mapper': { + 'text_key': 'text', + 'image_key': 'images', + 'audio_key': 'audios', + } + }, 'nested dict load fail, for nonparametric op') self.assertDictEqual( cfg.process[1], { 'language_id_score_filter': { diff --git a/tests/format/data/structured/demo-dataset.jsonl b/tests/format/data/structured/demo-dataset.jsonl index 77a0a1d88..116bf29e8 100644 --- a/tests/format/data/structured/demo-dataset.jsonl +++ b/tests/format/data/structured/demo-dataset.jsonl @@ -3,4 +3,4 @@ {"text": "Today is Sunday and it's a happy day!", "meta": {"src": "Arxiv", "date": "2023-04-27", "version": "1.0"}} {"text": "Do you need a cup of coffee?", "meta": {"src": "code", "author": "xxx"}} {"text": "Today is Sunday and it's a happy day!", "meta": {"src": "Arxiv", "date": "2023-04-27", "version": "1.0"}} -{"text": "Do you need a cup of coffee?", "meta": {"src": "code", "author": "xxx"}} \ No newline at end of file +{"text": "Do you need a cup of coffee?", "meta": {"src": "code", "author": "xxx"}} diff --git a/tests/format/test_csv_formatter.py b/tests/format/test_csv_formatter.py index 9db1ad343..591ccd61a 100644 --- a/tests/format/test_csv_formatter.py +++ b/tests/format/test_csv_formatter.py @@ -2,9 +2,10 @@ import unittest from data_juicer.format.csv_formatter import CsvFormatter +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class CsvFormatterTest(unittest.TestCase): +class CsvFormatterTest(DataJuicerTestCaseBase): def setUp(self): self._path = os.path.join(os.path.dirname(os.path.realpath(__file__)), diff --git a/tests/format/test_mixture_formatter.py b/tests/format/test_mixture_formatter.py index fc16dcbe1..a4d339695 100644 --- a/tests/format/test_mixture_formatter.py +++ b/tests/format/test_mixture_formatter.py @@ -2,9 +2,10 @@ import unittest from data_juicer.format.mixture_formatter import MixtureFormatter +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class MixtureFormatterTest(unittest.TestCase): +class MixtureFormatterTest(DataJuicerTestCaseBase): def setUp(self): self._path = os.path.join(os.path.dirname(os.path.realpath(__file__)), @@ -33,7 +34,8 @@ def test_sample_number(self): def test_sample_number_weight(self): max_samples = 2 - formatter = MixtureFormatter('0.5 ' + self._file, max_samples=max_samples) + formatter = MixtureFormatter('0.5 ' + self._file, + max_samples=max_samples) ds = formatter.load_dataset() self.assertEqual(len(ds), max_samples) self.assertEqual(list(ds.features.keys()), ['text', 'meta']) @@ -45,13 +47,6 @@ def test_multi_datasets_without_weight(self): self.assertEqual(len(ds), 12) self.assertEqual(list(ds.features.keys()), ['text', 'meta']) - def test_multi_datasets_with_weight(self): - data_path = self._file + ' ' + self._file2 - formatter = MixtureFormatter(data_path) - ds = formatter.load_dataset() - self.assertEqual(len(ds), 12) - self.assertEqual(list(ds.features.keys()), ['text', 'meta']) - def test_multi_datasets_with_one_weight(self): data_path = '0.5 ' + self._file + ' ' + self._file2 formatter = MixtureFormatter(data_path) @@ -74,5 +69,6 @@ def test_multi_datasets_with_sample(self): self.assertEqual(len(ds), max_samples) self.assertEqual(list(ds.features.keys()), ['text', 'meta']) + if __name__ == '__main__': unittest.main() diff --git a/tests/format/test_parquet_formatter.py b/tests/format/test_parquet_formatter.py index 107ea870c..6df093368 100644 --- a/tests/format/test_parquet_formatter.py +++ b/tests/format/test_parquet_formatter.py @@ -2,9 +2,10 @@ import unittest from data_juicer.format.parquet_formatter import ParquetFormatter +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class CsvFormatterTest(unittest.TestCase): +class CsvFormatterTest(DataJuicerTestCaseBase): def setUp(self): self._path = os.path.join(os.path.dirname(os.path.realpath(__file__)), diff --git a/tests/format/test_tsv_formatter.py b/tests/format/test_tsv_formatter.py index cde6bed85..46f1fad4d 100644 --- a/tests/format/test_tsv_formatter.py +++ b/tests/format/test_tsv_formatter.py @@ -2,9 +2,10 @@ import unittest from data_juicer.format.tsv_formatter import TsvFormatter +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class TsvFormatterTest(unittest.TestCase): +class TsvFormatterTest(DataJuicerTestCaseBase): def setUp(self): self._path = os.path.join(os.path.dirname(os.path.realpath(__file__)), diff --git a/tests/format/test_unify_format.py b/tests/format/test_unify_format.py index c9b41d19d..52b87493d 100644 --- a/tests/format/test_unify_format.py +++ b/tests/format/test_unify_format.py @@ -5,9 +5,10 @@ from data_juicer.format.formatter import load_dataset, unify_format from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class UnifyFormatTest(unittest.TestCase): +class UnifyFormatTest(DataJuicerTestCaseBase): def run_test(self, sample, args=None): if args is None: @@ -347,26 +348,26 @@ def test_hetero_meta(self): file_path = os.path.join(cur_dir, 'demo-dataset.jsonl') ds = load_dataset('json', data_files=file_path) ds = unify_format(ds) - import datetime - # the 'None' fields are missing fields after merging - sample = [{ - 'text': "Today is Sunday and it's a happy day!", - 'meta': { - 'src': 'Arxiv', - 'date': datetime.datetime(2023, 4, 27, 0, 0), - 'version': '1.0', - 'author': None - } - }, { - 'text': 'Do you need a cup of coffee?', - 'meta': { - 'src': 'code', - 'date': None, - 'version': None, - 'author': 'xxx' - } - }] + # import datetime + # the 'None' fields are missing fields after merging + # sample = [{ + # 'text': "Today is Sunday and it's a happy day!", + # 'meta': { + # 'src': 'Arxiv', + # 'date': datetime.datetime(2023, 4, 27, 0, 0), + # 'version': '1.0', + # 'author': None + # } + # }, { + # 'text': 'Do you need a cup of coffee?', + # 'meta': { + # 'src': 'code', + # 'date': None, + # 'version': None, + # 'author': 'xxx' + # } + # }] # test nested and missing field for the following cases: # 1. first row, then column unified_sample_first = ds[0] diff --git a/tests/ops/data/video1.mp4 b/tests/ops/data/video1.mp4 new file mode 100644 index 000000000..5b0cad49f Binary files /dev/null and b/tests/ops/data/video1.mp4 differ diff --git a/tests/ops/data/video2.mp4 b/tests/ops/data/video2.mp4 new file mode 100644 index 000000000..28acb927f Binary files /dev/null and b/tests/ops/data/video2.mp4 differ diff --git a/tests/ops/data/video3-no-audio.mp4 b/tests/ops/data/video3-no-audio.mp4 new file mode 100644 index 000000000..ad30ec95b Binary files /dev/null and b/tests/ops/data/video3-no-audio.mp4 differ diff --git a/tests/ops/data/video3.mp4 b/tests/ops/data/video3.mp4 new file mode 100644 index 000000000..45db64a51 Binary files /dev/null and b/tests/ops/data/video3.mp4 differ diff --git a/tests/ops/data/video4.mp4 b/tests/ops/data/video4.mp4 new file mode 100644 index 000000000..8bf5fe0ea Binary files /dev/null and b/tests/ops/data/video4.mp4 differ diff --git a/tests/ops/data/video5.mp4 b/tests/ops/data/video5.mp4 new file mode 100644 index 000000000..46a52855e Binary files /dev/null and b/tests/ops/data/video5.mp4 differ diff --git a/tests/ops/deduplicator/test_document_deduplicator.py b/tests/ops/deduplicator/test_document_deduplicator.py index 740caae18..5a37a2e91 100644 --- a/tests/ops/deduplicator/test_document_deduplicator.py +++ b/tests/ops/deduplicator/test_document_deduplicator.py @@ -4,9 +4,10 @@ from data_juicer.ops.deduplicator.document_deduplicator import \ DocumentDeduplicator +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class DocumentDeduplicatorTest(unittest.TestCase): +class DocumentDeduplicatorTest(DataJuicerTestCaseBase): def _run_doc_dedup(self, dataset: Dataset, target_list, op): dataset = dataset.map(op.compute_hash) diff --git a/tests/ops/deduplicator/test_document_minhash_deduplicator.py b/tests/ops/deduplicator/test_document_minhash_deduplicator.py index b60209e8b..5190ed1e4 100644 --- a/tests/ops/deduplicator/test_document_minhash_deduplicator.py +++ b/tests/ops/deduplicator/test_document_minhash_deduplicator.py @@ -4,9 +4,10 @@ from data_juicer.ops.deduplicator.document_minhash_deduplicator import \ DocumentMinhashDeduplicator +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class DocumentMinhashDeduplicatorTest(unittest.TestCase): +class DocumentMinhashDeduplicatorTest(DataJuicerTestCaseBase): def _run_minhash_dedup(self, dataset: Dataset, target_list, op): dataset = dataset.map(op.compute_hash) diff --git a/tests/ops/deduplicator/test_document_simhash_deduplicator.py b/tests/ops/deduplicator/test_document_simhash_deduplicator.py index d021423c0..ddde50e82 100644 --- a/tests/ops/deduplicator/test_document_simhash_deduplicator.py +++ b/tests/ops/deduplicator/test_document_simhash_deduplicator.py @@ -4,9 +4,10 @@ from data_juicer.ops.deduplicator.document_simhash_deduplicator import \ DocumentSimhashDeduplicator +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class DocumentSimhashDeduplicatorTest(unittest.TestCase): +class DocumentSimhashDeduplicatorTest(DataJuicerTestCaseBase): def _run_simhash_dedup(self, dataset: Dataset, target_list, op): dataset = dataset.map(op.compute_hash) diff --git a/tests/ops/deduplicator/test_image_deduplicator.py b/tests/ops/deduplicator/test_image_deduplicator.py index 3ac131506..a643b55be 100644 --- a/tests/ops/deduplicator/test_image_deduplicator.py +++ b/tests/ops/deduplicator/test_image_deduplicator.py @@ -3,30 +3,31 @@ from datasets import Dataset -from data_juicer.ops.deduplicator.image_deduplicator import \ - ImageDeduplicator +from data_juicer.ops.deduplicator.image_deduplicator import ImageDeduplicator +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class ImageDeduplicatorTest(unittest.TestCase): +class ImageDeduplicatorTest(DataJuicerTestCaseBase): - data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), - '..', 'data') + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') img1_path = os.path.join(data_path, 'img1.png') img2_path = os.path.join(data_path, 'img2.jpg') img3_path = os.path.join(data_path, 'img3.jpg') - # img4.png is a duplicate sample of img1.png - img4_path = os.path.join(data_path, 'img4.png') - # img5.jpg is a duplicate sample of img2.jpg - img5_path = os.path.join(data_path, 'img5.jpg') - # img6.jpg is a duplicate sample of img3.jpg - img6_path = os.path.join(data_path, 'img6.jpg') - # img7.jpg is a duplicate sample of img6.jpg - img7_path = os.path.join(data_path, 'img7.jpg') - - - def _run_image_deduplicator(self, - dataset: Dataset, target_list, - op): + # img1_dup.png is a duplicate sample of img1.png + img4_path = os.path.join(data_path, 'img1_dup.png') + os.symlink(img1_path, img4_path) + # img2_dup.jpg is a duplicate sample of img2.jpg + img5_path = os.path.join(data_path, 'img2_dup.jpg') + os.symlink(img2_path, img5_path) + # img3_dup.jpg is a duplicate sample of img3.jpg + img6_path = os.path.join(data_path, 'img3_dup.jpg') + os.symlink(img3_path, img6_path) + # img3_dup_dup.jpg is a duplicate sample of img6.jpg + img7_path = os.path.join(data_path, 'img3_dup_dup.jpg') + os.symlink(img6_path, img7_path) + + def _run_image_deduplicator(self, dataset: Dataset, target_list, op): dataset = dataset.map(op.compute_hash) dataset, _ = op.process(dataset) @@ -63,11 +64,7 @@ def test_2(self): }, { 'images': [self.img2_path] }] - tgt_list = [{ - 'images': [self.img1_path] - }, { - 'images': [self.img2_path] - }] + tgt_list = [{'images': [self.img1_path]}, {'images': [self.img2_path]}] dataset = Dataset.from_list(ds_list) op = ImageDeduplicator() self._run_image_deduplicator(dataset, tgt_list, op) @@ -216,5 +213,6 @@ def test_8(self): op = ImageDeduplicator(method='ahash') self._run_image_deduplicator(dataset, tgt_list, op) + if __name__ == '__main__': unittest.main() diff --git a/tests/ops/deduplicator/test_video_deduplicator.py b/tests/ops/deduplicator/test_video_deduplicator.py new file mode 100644 index 000000000..951ed6bf0 --- /dev/null +++ b/tests/ops/deduplicator/test_video_deduplicator.py @@ -0,0 +1,150 @@ +import os +import unittest + +from datasets import Dataset + +from data_juicer.ops.deduplicator.video_deduplicator import VideoDeduplicator +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class VideoDeduplicatorTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + video1_path = os.path.join(data_path, 'video1.mp4') + video2_path = os.path.join(data_path, 'video2.mp4') + video3_path = os.path.join(data_path, 'video3.mp4') + # video1_dup.mp4 is a duplicate sample of video1.mp4 + video4_path = os.path.join(data_path, 'video1_dup.mp4') + os.symlink(video1_path, video4_path) + # video2_dup.mp4 is a duplicate sample of video2.mp4 + video5_path = os.path.join(data_path, 'video2_dup.mp4') + os.symlink(video2_path, video5_path) + # video3_dup.mp4 is a duplicate sample of video3.mp4 + video6_path = os.path.join(data_path, 'video3_dup.mp4') + os.symlink(video3_path, video6_path) + # video3_dup_dup.mp4 is a duplicate sample of video6.mp4 + video7_path = os.path.join(data_path, 'video3_dup_dup.mp4') + os.symlink(video6_path, video7_path) + + def _run_video_deduplicator(self, dataset: Dataset, target_list, op): + + dataset = dataset.map(op.compute_hash) + dataset, _ = op.process(dataset) + dataset = dataset.select_columns(column_names=[op.video_key]) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_1(self): + + ds_list = [{ + 'videos': [self.video1_path] + }, { + 'videos': [self.video2_path] + }, { + 'videos': [self.video3_path] + }] + tgt_list = [{ + 'videos': [self.video1_path] + }, { + 'videos': [self.video2_path] + }, { + 'videos': [self.video3_path] + }] + dataset = Dataset.from_list(ds_list) + op = VideoDeduplicator() + self._run_video_deduplicator(dataset, tgt_list, op) + + def test_2(self): + + ds_list = [{ + 'videos': [self.video1_path] + }, { + 'videos': [self.video2_path] + }, { + 'videos': [self.video2_path] + }] + tgt_list = [{ + 'videos': [self.video1_path] + }, { + 'videos': [self.video2_path] + }] + dataset = Dataset.from_list(ds_list) + op = VideoDeduplicator() + self._run_video_deduplicator(dataset, tgt_list, op) + + def test_3(self): + + ds_list = [{ + 'videos': [self.video1_path] + }, { + 'videos': [self.video2_path] + }, { + 'videos': [self.video3_path] + }, { + 'videos': [self.video4_path] + }, { + 'videos': [self.video5_path] + }, { + 'videos': [self.video6_path] + }, { + 'videos': [self.video7_path] + }] + tgt_list = [{ + 'videos': [self.video1_path] + }, { + 'videos': [self.video2_path] + }, { + 'videos': [self.video3_path] + }] + dataset = Dataset.from_list(ds_list) + op = VideoDeduplicator() + self._run_video_deduplicator(dataset, tgt_list, op) + + def test_4(self): + + ds_list = [{ + 'videos': [self.video1_path, self.video2_path, self.video3_path] + }, { + 'videos': [self.video4_path, self.video5_path, self.video6_path] + }, { + 'videos': [self.video7_path, self.video5_path] + }, { + 'videos': [self.video6_path, self.video5_path] + }] + tgt_list = [{ + 'videos': [self.video1_path, self.video2_path, self.video3_path] + }, { + 'videos': [self.video7_path, self.video5_path] + }] + dataset = Dataset.from_list(ds_list) + op = VideoDeduplicator() + self._run_video_deduplicator(dataset, tgt_list, op) + + def test_5(self): + + ds_list = [{ + 'videos': [self.video1_path, self.video2_path] + }, { + 'videos': [self.video2_path, self.video1_path] + }, { + 'videos': [self.video4_path, self.video5_path] + }, { + 'videos': [self.video7_path, self.video7_path] + }, { + 'videos': [self.video6_path, self.video6_path] + }] + tgt_list = [{ + 'videos': [self.video1_path, self.video2_path] + }, { + 'videos': [self.video2_path, self.video1_path] + }, { + 'videos': [self.video7_path, self.video7_path] + }] + dataset = Dataset.from_list(ds_list) + op = VideoDeduplicator() + self._run_video_deduplicator(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/test_alphanumeric_filter.py b/tests/ops/filter/test_alphanumeric_filter.py index a8558cf06..efca696c2 100644 --- a/tests/ops/filter/test_alphanumeric_filter.py +++ b/tests/ops/filter/test_alphanumeric_filter.py @@ -4,9 +4,10 @@ from data_juicer.ops.filter.alphanumeric_filter import AlphanumericFilter from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class AlphanumericFilterTest(unittest.TestCase): +class AlphanumericFilterTest(DataJuicerTestCaseBase): def _run_alphanumeric_filter(self, dataset: Dataset, target_list, op): if Fields.stats not in dataset.features: diff --git a/tests/ops/filter/test_audio_duration_filter.py b/tests/ops/filter/test_audio_duration_filter.py index f12b45bb7..91a39bfd8 100644 --- a/tests/ops/filter/test_audio_duration_filter.py +++ b/tests/ops/filter/test_audio_duration_filter.py @@ -5,12 +5,13 @@ from data_juicer.ops.filter.audio_duration_filter import AudioDurationFilter from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class AudioDurationFilterTest(unittest.TestCase): +class AudioDurationFilterTest(DataJuicerTestCaseBase): - data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), - '..', 'data') + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') aud1_path = os.path.join(data_path, 'audio1.wav') # about 6s aud2_path = os.path.join(data_path, 'audio2.wav') # about 14s aud3_path = os.path.join(data_path, 'audio3.ogg') # about 1min59s @@ -49,7 +50,7 @@ def test_default_filter(self): op = AudioDurationFilter() self._run_audio_duration_filter(dataset, tgt_list, op) - def test_filter_short_audios(self): + def test_filter_long_audios(self): ds_list = [{ 'audios': [self.aud1_path] @@ -58,14 +59,12 @@ def test_filter_short_audios(self): }, { 'audios': [self.aud3_path] }] - tgt_list = [{ - 'audios': [self.aud1_path] - }] + tgt_list = [{'audios': [self.aud1_path]}] dataset = Dataset.from_list(ds_list) op = AudioDurationFilter(max_duration=10) self._run_audio_duration_filter(dataset, tgt_list, op) - def test_filter_long_audios(self): + def test_filter_short_audios(self): ds_list = [{ 'audios': [self.aud1_path] @@ -74,9 +73,7 @@ def test_filter_long_audios(self): }, { 'audios': [self.aud3_path] }] - tgt_list = [{ - 'audios': [self.aud3_path] - }] + tgt_list = [{'audios': [self.aud3_path]}] dataset = Dataset.from_list(ds_list) op = AudioDurationFilter(min_duration=60) self._run_audio_duration_filter(dataset, tgt_list, op) @@ -90,12 +87,9 @@ def test_filter_audios_within_range(self): }, { 'audios': [self.aud3_path] }] - tgt_list = [{ - 'audios': [self.aud2_path] - }] + tgt_list = [{'audios': [self.aud2_path]}] dataset = Dataset.from_list(ds_list) - op = AudioDurationFilter(min_duration=10, - max_duration=20) + op = AudioDurationFilter(min_duration=10, max_duration=20) self._run_audio_duration_filter(dataset, tgt_list, op) def test_any(self): @@ -143,12 +137,9 @@ def test_filter_in_parallel(self): }, { 'audios': [self.aud3_path] }] - tgt_list = [{ - 'audios': [self.aud2_path] - }] + tgt_list = [{'audios': [self.aud2_path]}] dataset = Dataset.from_list(ds_list) - op = AudioDurationFilter(min_duration=10, - max_duration=20) + op = AudioDurationFilter(min_duration=10, max_duration=20) self._run_audio_duration_filter(dataset, tgt_list, op, np=2) diff --git a/tests/ops/filter/test_audio_nmf_snr_filter.py b/tests/ops/filter/test_audio_nmf_snr_filter.py index 84b73d2c8..728c43f39 100644 --- a/tests/ops/filter/test_audio_nmf_snr_filter.py +++ b/tests/ops/filter/test_audio_nmf_snr_filter.py @@ -5,12 +5,13 @@ from data_juicer.ops.filter.audio_nmf_snr_filter import AudioNMFSNRFilter from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class AudioNMFSNRFilterTest(unittest.TestCase): +class AudioNMFSNRFilterTest(DataJuicerTestCaseBase): - data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), - '..', 'data') + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') aud1_path = os.path.join(data_path, 'audio1.wav') # about -7dB aud2_path = os.path.join(data_path, 'audio2.wav') # about 6dB aud3_path = os.path.join(data_path, 'audio3.ogg') # about 3dB @@ -58,11 +59,7 @@ def test_filter_low_snr_audios(self): }, { 'audios': [self.aud3_path] }] - tgt_list = [{ - 'audios': [self.aud2_path] - }, { - 'audios': [self.aud3_path] - }] + tgt_list = [{'audios': [self.aud2_path]}, {'audios': [self.aud3_path]}] dataset = Dataset.from_list(ds_list) op = AudioNMFSNRFilter(min_snr=0) self._run_audio_nmf_snr_filter(dataset, tgt_list, op) @@ -76,11 +73,7 @@ def test_filter_high_snr_audios(self): }, { 'audios': [self.aud3_path] }] - tgt_list = [{ - 'audios': [self.aud1_path] - }, { - 'audios': [self.aud3_path] - }] + tgt_list = [{'audios': [self.aud1_path]}, {'audios': [self.aud3_path]}] dataset = Dataset.from_list(ds_list) op = AudioNMFSNRFilter(min_snr=-1000, max_snr=5) self._run_audio_nmf_snr_filter(dataset, tgt_list, op) @@ -94,9 +87,7 @@ def test_filter_audios_within_range(self): }, { 'audios': [self.aud3_path] }] - tgt_list = [{ - 'audios': [self.aud3_path] - }] + tgt_list = [{'audios': [self.aud3_path]}] dataset = Dataset.from_list(ds_list) op = AudioNMFSNRFilter(min_snr=0, max_snr=5) self._run_audio_nmf_snr_filter(dataset, tgt_list, op) @@ -142,9 +133,7 @@ def test_filter_in_parallel(self): }, { 'audios': [self.aud3_path] }] - tgt_list = [{ - 'audios': [self.aud3_path] - }] + tgt_list = [{'audios': [self.aud3_path]}] dataset = Dataset.from_list(ds_list) op = AudioNMFSNRFilter(min_snr=0, max_snr=5, any_or_all='any') self._run_audio_nmf_snr_filter(dataset, tgt_list, op, np=2) diff --git a/tests/ops/filter/test_audio_size_filter.py b/tests/ops/filter/test_audio_size_filter.py index c47241965..00b4158d7 100644 --- a/tests/ops/filter/test_audio_size_filter.py +++ b/tests/ops/filter/test_audio_size_filter.py @@ -5,17 +5,18 @@ from data_juicer.ops.filter.audio_size_filter import AudioSizeFilter from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class AudioSizeFilterTest(unittest.TestCase): +class AudioSizeFilterTest(DataJuicerTestCaseBase): - data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), - '..', 'data') + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') aud1_path = os.path.join(data_path, 'audio1.wav') # 970574 / 948K aud2_path = os.path.join(data_path, 'audio2.wav') # 2494872 / 2.4M aud3_path = os.path.join(data_path, 'audio3.ogg') # 597254 / 583K - def _run_audio_size_filter(self,dataset: Dataset, target_list, op, np=1): + def _run_audio_size_filter(self, dataset: Dataset, target_list, op, np=1): if Fields.stats not in dataset.features: dataset = dataset.add_column(name=Fields.stats, column=[{}] * dataset.num_rows) @@ -34,11 +35,9 @@ def test_min_max(self): }, { 'audios': [self.aud3_path] }] - tgt_list = [{ - 'audios': [self.aud1_path] - }] + tgt_list = [{'audios': [self.aud1_path]}] dataset = Dataset.from_list(ds_list) - op = AudioSizeFilter(min_size="800kb", max_size="1MB") + op = AudioSizeFilter(min_size='800kb', max_size='1MB') self._run_audio_size_filter(dataset, tgt_list, op) def test_min(self): @@ -50,13 +49,9 @@ def test_min(self): }, { 'audios': [self.aud3_path] }] - tgt_list = [{ - 'audios': [self.aud1_path] - }, { - 'audios': [self.aud2_path] - }] + tgt_list = [{'audios': [self.aud1_path]}, {'audios': [self.aud2_path]}] dataset = Dataset.from_list(ds_list) - op = AudioSizeFilter(min_size="900kib") + op = AudioSizeFilter(min_size='900kib') self._run_audio_size_filter(dataset, tgt_list, op) def test_max(self): @@ -68,13 +63,9 @@ def test_max(self): }, { 'audios': [self.aud3_path] }] - tgt_list = [{ - 'audios': [self.aud1_path] - }, { - 'audios': [self.aud3_path] - }] + tgt_list = [{'audios': [self.aud1_path]}, {'audios': [self.aud3_path]}] dataset = Dataset.from_list(ds_list) - op = AudioSizeFilter(max_size="2MiB") + op = AudioSizeFilter(max_size='2MiB') self._run_audio_size_filter(dataset, tgt_list, op) def test_any(self): @@ -92,8 +83,9 @@ def test_any(self): 'audios': [self.aud1_path, self.aud3_path] }] dataset = Dataset.from_list(ds_list) - op = AudioSizeFilter(min_size="800kb", max_size="1MB", - any_or_all='any') + op = AudioSizeFilter(min_size='800kb', + max_size='1MB', + any_or_all='any') self._run_audio_size_filter(dataset, tgt_list, op) def test_all(self): @@ -107,8 +99,9 @@ def test_all(self): }] tgt_list = [] dataset = Dataset.from_list(ds_list) - op = AudioSizeFilter(min_size="800kb", max_size="1MB", - any_or_all='all') + op = AudioSizeFilter(min_size='800kb', + max_size='1MB', + any_or_all='all') self._run_audio_size_filter(dataset, tgt_list, op) def test_filter_in_parallel(self): @@ -120,11 +113,9 @@ def test_filter_in_parallel(self): }, { 'audios': [self.aud3_path] }] - tgt_list = [{ - 'audios': [self.aud1_path] - }] + tgt_list = [{'audios': [self.aud1_path]}] dataset = Dataset.from_list(ds_list) - op = AudioSizeFilter(min_size="800kb", max_size="1MB") + op = AudioSizeFilter(min_size='800kb', max_size='1MB') self._run_audio_size_filter(dataset, tgt_list, op, np=2) diff --git a/tests/ops/filter/test_average_line_length_filter.py b/tests/ops/filter/test_average_line_length_filter.py index 740d5f3c4..a1c39e702 100644 --- a/tests/ops/filter/test_average_line_length_filter.py +++ b/tests/ops/filter/test_average_line_length_filter.py @@ -5,9 +5,10 @@ from data_juicer.ops.filter.average_line_length_filter import \ AverageLineLengthFilter from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class AverageLineLengthFilterTest(unittest.TestCase): +class AverageLineLengthFilterTest(DataJuicerTestCaseBase): def _run_average_line_length_filter(self, dataset: Dataset, target_list, op): diff --git a/tests/ops/filter/test_character_repetition_filter.py b/tests/ops/filter/test_character_repetition_filter.py index b54d76a71..85133c133 100644 --- a/tests/ops/filter/test_character_repetition_filter.py +++ b/tests/ops/filter/test_character_repetition_filter.py @@ -5,9 +5,10 @@ from data_juicer.ops.filter.character_repetition_filter import \ CharacterRepetitionFilter from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class CharacterRepetitionFilterTest(unittest.TestCase): +class CharacterRepetitionFilterTest(DataJuicerTestCaseBase): def _run_character_repetition_filter(self, dataset: Dataset, target_list, op): diff --git a/tests/ops/filter/test_face_area_filter.py b/tests/ops/filter/test_face_area_filter.py index 0008c9377..1e747ec59 100644 --- a/tests/ops/filter/test_face_area_filter.py +++ b/tests/ops/filter/test_face_area_filter.py @@ -5,20 +5,22 @@ from data_juicer.ops.filter.face_area_filter import FaceAreaFilter from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class FaceAreaFilterTest(unittest.TestCase): +class FaceAreaFilterTest(DataJuicerTestCaseBase): maxDiff = None - data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), - '..', 'data') + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') img1_path = os.path.join(data_path, 'cat.jpg') img2_path = os.path.join(data_path, 'lena.jpg') img3_path = os.path.join(data_path, 'lena-face.jpg') def _run_face_area_filter(self, - dataset: Dataset, target_list, + dataset: Dataset, + target_list, op, num_proc=1): if Fields.stats not in dataset.features: @@ -39,9 +41,7 @@ def test_filter_small(self): }, { 'images': [self.img3_path] }] - tgt_list = [{ - 'images': [self.img3_path] - }] + tgt_list = [{'images': [self.img3_path]}] dataset = Dataset.from_list(ds_list) op = FaceAreaFilter(min_ratio=0.4, max_ratio=1.0) self._run_face_area_filter(dataset, tgt_list, op) @@ -55,11 +55,7 @@ def test_filter_large(self): }, { 'images': [self.img3_path] }] - tgt_list = [{ - 'images': [self.img1_path] - }, { - 'images': [self.img2_path] - }] + tgt_list = [{'images': [self.img1_path]}, {'images': [self.img2_path]}] dataset = Dataset.from_list(ds_list) op = FaceAreaFilter(min_ratio=0.0, max_ratio=0.4) self._run_face_area_filter(dataset, tgt_list, op) @@ -67,20 +63,27 @@ def test_filter_large(self): def test_filter_multimodal(self): ds_list = [{ - 'text': 'a test sentence', 'images': [] + 'text': 'a test sentence', + 'images': [] }, { - 'text': 'a test sentence', 'images': [self.img1_path] + 'text': 'a test sentence', + 'images': [self.img1_path] }, { - 'text': 'a test sentence', 'images': [self.img2_path] + 'text': 'a test sentence', + 'images': [self.img2_path] }, { - 'text': 'a test sentence', 'images': [self.img3_path] + 'text': 'a test sentence', + 'images': [self.img3_path] }] tgt_list = [{ - 'text': 'a test sentence', 'images': [] + 'text': 'a test sentence', + 'images': [] }, { - 'text': 'a test sentence', 'images': [self.img1_path] + 'text': 'a test sentence', + 'images': [self.img1_path] }, { - 'text': 'a test sentence', 'images': [self.img2_path] + 'text': 'a test sentence', + 'images': [self.img2_path] }] dataset = Dataset.from_list(ds_list) op = FaceAreaFilter() @@ -103,9 +106,7 @@ def test_any(self): 'images': [self.img1_path, self.img3_path] }] dataset = Dataset.from_list(ds_list) - op = FaceAreaFilter(min_ratio=0.0, - max_ratio=0.4, - any_or_all='any') + op = FaceAreaFilter(min_ratio=0.0, max_ratio=0.4, any_or_all='any') self._run_face_area_filter(dataset, tgt_list, op) def test_all(self): @@ -117,13 +118,9 @@ def test_all(self): }, { 'images': [self.img1_path, self.img3_path] }] - tgt_list = [{ - 'images': [self.img1_path, self.img2_path] - }] + tgt_list = [{'images': [self.img1_path, self.img2_path]}] dataset = Dataset.from_list(ds_list) - op = FaceAreaFilter(min_ratio=0.0, - max_ratio=0.4, - any_or_all='all') + op = FaceAreaFilter(min_ratio=0.0, max_ratio=0.4, any_or_all='all') self._run_face_area_filter(dataset, tgt_list, op) def test_filter_multi_process(self): @@ -135,11 +132,7 @@ def test_filter_multi_process(self): }, { 'images': [self.img3_path] }] - tgt_list = [{ - 'images': [self.img1_path] - }, { - 'images': [self.img2_path] - }] + tgt_list = [{'images': [self.img1_path]}, {'images': [self.img2_path]}] dataset = Dataset.from_list(ds_list) op = FaceAreaFilter() self._run_face_area_filter(dataset, tgt_list, op, num_proc=3) diff --git a/tests/ops/filter/test_flagged_words_filter.py b/tests/ops/filter/test_flagged_words_filter.py index af7ddf233..e346eb0f5 100644 --- a/tests/ops/filter/test_flagged_words_filter.py +++ b/tests/ops/filter/test_flagged_words_filter.py @@ -4,9 +4,10 @@ from data_juicer.ops.filter.flagged_words_filter import FlaggedWordFilter from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class FlaggedWordFilterTest(unittest.TestCase): +class FlaggedWordFilterTest(DataJuicerTestCaseBase): def _run_flagged_words_filter(self, dataset: Dataset, target_list, op): if Fields.stats not in dataset.features: diff --git a/tests/ops/filter/test_image_aesthetics_filter.py b/tests/ops/filter/test_image_aesthetics_filter.py new file mode 100644 index 000000000..ef221bf08 --- /dev/null +++ b/tests/ops/filter/test_image_aesthetics_filter.py @@ -0,0 +1,155 @@ +import os +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.image_aesthetics_filter import \ + ImageAestheticsFilter +from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class ImageAestheticsFilterTest(DataJuicerTestCaseBase): + + maxDiff = None + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + img1_path = os.path.join(data_path, 'cat.jpg') + img2_path = os.path.join(data_path, 'blip.jpg') + img3_path = os.path.join(data_path, 'lena-face.jpg') + + model_id = \ + 'shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE' + + # with shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE + # the img1, img2, img3 gets scores 0.4382, 0.5973, 0.5216 respectively + + def _run_image_aesthetics_filter(self, + dataset: Dataset, + target_list, + op, + num_proc=1): + if Fields.stats not in dataset.features: + dataset = dataset.add_column(name=Fields.stats, + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats, num_proc=num_proc) + dataset = dataset.filter(op.process, num_proc=num_proc) + dataset = dataset.remove_columns(Fields.stats) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_filter_small(self): + + ds_list = [{ + 'images': [self.img1_path] + }, { + 'images': [self.img2_path] + }, { + 'images': [self.img3_path] + }] + tgt_list = [{'images': [self.img2_path]}] + dataset = Dataset.from_list(ds_list) + op = ImageAestheticsFilter(hf_scorer_model=self.model_id, + min_score=0.55, + max_score=1.0) + self._run_image_aesthetics_filter(dataset, tgt_list, op) + + def test_filter_large(self): + + ds_list = [{ + 'images': [self.img1_path] + }, { + 'images': [self.img2_path] + }, { + 'images': [self.img3_path] + }] + tgt_list = [{'images': [self.img1_path]}, {'images': [self.img3_path]}] + dataset = Dataset.from_list(ds_list) + op = ImageAestheticsFilter(hf_scorer_model=self.model_id, + min_score=0.4, + max_score=0.55) + self._run_image_aesthetics_filter(dataset, tgt_list, op) + + def test_filter_multimodal(self): + + ds_list = [{ + 'text': 'a test sentence', + 'images': [] + }, { + 'text': 'a test sentence', + 'images': [self.img1_path] + }, { + 'text': 'a test sentence', + 'images': [self.img2_path] + }, { + 'text': 'a test sentence', + 'images': [self.img3_path] + }] + tgt_list = [{ + 'text': 'a test sentence', + 'images': [] + }, { + 'text': 'a test sentence', + 'images': [self.img2_path] + }, { + 'text': 'a test sentence', + 'images': [self.img3_path] + }] + dataset = Dataset.from_list(ds_list) + op = ImageAestheticsFilter(hf_scorer_model=self.model_id, ) + self._run_image_aesthetics_filter(dataset, tgt_list, op) + + def test_any(self): + + ds_list = [{ + 'images': [self.img1_path, self.img2_path] + }, { + 'images': [self.img2_path, self.img3_path] + }, { + 'images': [self.img1_path, self.img3_path] + }] + tgt_list = [{ + 'images': [self.img1_path, self.img2_path] + }, { + 'images': [self.img2_path, self.img3_path] + }, { + 'images': [self.img1_path, self.img3_path] + }] + dataset = Dataset.from_list(ds_list) + op = ImageAestheticsFilter(hf_scorer_model=self.model_id, + any_or_all='any') + self._run_image_aesthetics_filter(dataset, tgt_list, op) + + def test_all(self): + + ds_list = [{ + 'images': [self.img1_path, self.img2_path] + }, { + 'images': [self.img2_path, self.img3_path] + }, { + 'images': [self.img1_path, self.img3_path] + }] + tgt_list = [{'images': [self.img2_path, self.img3_path]}] + dataset = Dataset.from_list(ds_list) + op = ImageAestheticsFilter(hf_scorer_model=self.model_id, + any_or_all='all') + self._run_image_aesthetics_filter(dataset, tgt_list, op) + + def test_filter_multi_process(self): + + ds_list = [{ + 'images': [self.img1_path] + }, { + 'images': [self.img2_path] + }, { + 'images': [self.img3_path] + }] + tgt_list = [{'images': [self.img2_path]}, {'images': [self.img3_path]}] + dataset = Dataset.from_list(ds_list) + op = ImageAestheticsFilter(hf_scorer_model=self.model_id, ) + self._run_image_aesthetics_filter(dataset, tgt_list, op, num_proc=3) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/test_image_aspect_ratio_filter.py b/tests/ops/filter/test_image_aspect_ratio_filter.py index a328d934a..d8d3df0ea 100644 --- a/tests/ops/filter/test_image_aspect_ratio_filter.py +++ b/tests/ops/filter/test_image_aspect_ratio_filter.py @@ -6,18 +6,18 @@ from data_juicer.ops.filter.image_aspect_ratio_filter import \ ImageAspectRatioFilter from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class ImageAspectRatioFilterTest(unittest.TestCase): +class ImageAspectRatioFilterTest(DataJuicerTestCaseBase): - data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), - '..', 'data') + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') img1_path = os.path.join(data_path, 'img1.png') img2_path = os.path.join(data_path, 'img2.jpg') img3_path = os.path.join(data_path, 'img3.jpg') - def _run_image_aspect_ratio_filter(self, - dataset: Dataset, target_list, + def _run_image_aspect_ratio_filter(self, dataset: Dataset, target_list, op): if Fields.stats not in dataset.features: dataset = dataset.add_column(name=Fields.stats, @@ -37,9 +37,7 @@ def test_filter1(self): }, { 'images': [self.img3_path] }] - tgt_list = [{ - 'images': [self.img1_path] - }] + tgt_list = [{'images': [self.img1_path]}] dataset = Dataset.from_list(ds_list) op = ImageAspectRatioFilter(min_ratio=0.8, max_ratio=1.2) self._run_image_aspect_ratio_filter(dataset, tgt_list, op) @@ -53,11 +51,7 @@ def test_filter2(self): }, { 'images': [self.img3_path] }] - tgt_list = [{ - 'images': [self.img1_path] - }, { - 'images': [self.img2_path] - }] + tgt_list = [{'images': [self.img1_path]}, {'images': [self.img2_path]}] dataset = Dataset.from_list(ds_list) op = ImageAspectRatioFilter(min_ratio=0.8) self._run_image_aspect_ratio_filter(dataset, tgt_list, op) @@ -71,11 +65,7 @@ def test_filter3(self): }, { 'images': [self.img3_path] }] - tgt_list = [{ - 'images': [self.img1_path] - }, { - 'images': [self.img3_path] - }] + tgt_list = [{'images': [self.img1_path]}, {'images': [self.img3_path]}] dataset = Dataset.from_list(ds_list) op = ImageAspectRatioFilter(max_ratio=1.2) self._run_image_aspect_ratio_filter(dataset, tgt_list, op) diff --git a/tests/ops/filter/test_image_shape_filter.py b/tests/ops/filter/test_image_shape_filter.py index 3cc73406c..e7e5deaed 100644 --- a/tests/ops/filter/test_image_shape_filter.py +++ b/tests/ops/filter/test_image_shape_filter.py @@ -5,20 +5,18 @@ from data_juicer.ops.filter.image_shape_filter import ImageShapeFilter from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class ImageShapeFilterTest(unittest.TestCase): +class ImageShapeFilterTest(DataJuicerTestCaseBase): - data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), - '..', 'data') + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') img1_path = os.path.join(data_path, 'img1.png') img2_path = os.path.join(data_path, 'img2.jpg') img3_path = os.path.join(data_path, 'img3.jpg') - def _run_image_shape_filter(self, - dataset: Dataset, - target_list, - op): + def _run_image_shape_filter(self, dataset: Dataset, target_list, op): if Fields.stats not in dataset.features: dataset = dataset.add_column(name=Fields.stats, column=[{}] * dataset.num_rows) @@ -37,12 +35,9 @@ def test_filter1(self): }, { 'images': [self.img3_path] }] - tgt_list = [{ - 'images': [self.img2_path] - }] + tgt_list = [{'images': [self.img2_path]}] dataset = Dataset.from_list(ds_list) - op = ImageShapeFilter(min_width=400, - min_height=400) + op = ImageShapeFilter(min_width=400, min_height=400) self._run_image_shape_filter(dataset, tgt_list, op) def test_filter2(self): @@ -54,14 +49,9 @@ def test_filter2(self): }, { 'images': [self.img3_path] }] - tgt_list = [{ - 'images': [self.img1_path] - }, { - 'images': [self.img3_path] - }] + tgt_list = [{'images': [self.img1_path]}, {'images': [self.img3_path]}] dataset = Dataset.from_list(ds_list) - op = ImageShapeFilter(max_width=500, - max_height=500) + op = ImageShapeFilter(max_width=500, max_height=500) self._run_image_shape_filter(dataset, tgt_list, op) def test_filter3(self): @@ -99,9 +89,7 @@ def test_any(self): 'images': [self.img2_path, self.img3_path] }] dataset = Dataset.from_list(ds_list) - op = ImageShapeFilter(min_width=400, - min_height=400, - any_or_all='any') + op = ImageShapeFilter(min_width=400, min_height=400, any_or_all='any') self._run_image_shape_filter(dataset, tgt_list, op) def test_all(self): @@ -115,9 +103,7 @@ def test_all(self): }] tgt_list = [] dataset = Dataset.from_list(ds_list) - op = ImageShapeFilter(min_width=400, - min_height=400, - any_or_all='all') + op = ImageShapeFilter(min_width=400, min_height=400, any_or_all='all') self._run_image_shape_filter(dataset, tgt_list, op) diff --git a/tests/ops/filter/test_image_size_filter.py b/tests/ops/filter/test_image_size_filter.py index 46cfff62f..fcc5e3e76 100644 --- a/tests/ops/filter/test_image_size_filter.py +++ b/tests/ops/filter/test_image_size_filter.py @@ -5,19 +5,18 @@ from data_juicer.ops.filter.image_size_filter import ImageSizeFilter from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class ImageSizeFilterTest(unittest.TestCase): +class ImageSizeFilterTest(DataJuicerTestCaseBase): - data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), - '..', 'data') + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') img1_path = os.path.join(data_path, 'img1.png') img2_path = os.path.join(data_path, 'img2.jpg') img3_path = os.path.join(data_path, 'img3.jpg') - def _run_image_size_filter(self, - dataset: Dataset, target_list, - op): + def _run_image_size_filter(self, dataset: Dataset, target_list, op): if Fields.stats not in dataset.features: dataset = dataset.add_column(name=Fields.stats, column=[{}] * dataset.num_rows) @@ -29,54 +28,56 @@ def _run_image_size_filter(self, def test_min_max(self): - ds_list = [{ - 'images': [self.img1_path] # 171KB - }, { - 'images': [self.img2_path] # 189KB - }, { - 'images': [self.img3_path] # 114KB - }] - tgt_list = [{ - 'images': [self.img1_path] - }] + ds_list = [ + { + 'images': [self.img1_path] # 171KB + }, + { + 'images': [self.img2_path] # 189KB + }, + { + 'images': [self.img3_path] # 114KB + } + ] + tgt_list = [{'images': [self.img1_path]}] dataset = Dataset.from_list(ds_list) - op = ImageSizeFilter(min_size="120kb", max_size="180KB") + op = ImageSizeFilter(min_size='120kb', max_size='180KB') self._run_image_size_filter(dataset, tgt_list, op) def test_min(self): - ds_list = [{ - 'images': [self.img1_path] # 171KB - }, { - 'images': [self.img2_path] # 189KB - }, { - 'images': [self.img3_path] # 114KB - }] - tgt_list = [{ - 'images': [self.img1_path] - }, { - 'images': [self.img2_path] - }] + ds_list = [ + { + 'images': [self.img1_path] # 171KB + }, + { + 'images': [self.img2_path] # 189KB + }, + { + 'images': [self.img3_path] # 114KB + } + ] + tgt_list = [{'images': [self.img1_path]}, {'images': [self.img2_path]}] dataset = Dataset.from_list(ds_list) - op = ImageSizeFilter(min_size="120kib") + op = ImageSizeFilter(min_size='120kib') self._run_image_size_filter(dataset, tgt_list, op) def test_max(self): - ds_list = [{ - 'images': [self.img1_path] # 171KB - }, { - 'images': [self.img2_path] # 189KB - }, { - 'images': [self.img3_path] # 114KB - }] - tgt_list = [{ - 'images': [self.img1_path] - }, { - 'images': [self.img3_path] - }] + ds_list = [ + { + 'images': [self.img1_path] # 171KB + }, + { + 'images': [self.img2_path] # 189KB + }, + { + 'images': [self.img3_path] # 114KB + } + ] + tgt_list = [{'images': [self.img1_path]}, {'images': [self.img3_path]}] dataset = Dataset.from_list(ds_list) - op = ImageSizeFilter(max_size="180KiB") + op = ImageSizeFilter(max_size='180KiB') self._run_image_size_filter(dataset, tgt_list, op) def test_any(self): @@ -94,8 +95,9 @@ def test_any(self): 'images': [self.img1_path, self.img3_path] }] dataset = Dataset.from_list(ds_list) - op = ImageSizeFilter(min_size="120kb", max_size="180KB", - any_or_all='any') + op = ImageSizeFilter(min_size='120kb', + max_size='180KB', + any_or_all='any') self._run_image_size_filter(dataset, tgt_list, op) def test_all(self): @@ -109,7 +111,9 @@ def test_all(self): }] tgt_list = [] dataset = Dataset.from_list(ds_list) - op = ImageSizeFilter(min_size="120kb", max_size="180KB", any_or_all='all') + op = ImageSizeFilter(min_size='120kb', + max_size='180KB', + any_or_all='all') self._run_image_size_filter(dataset, tgt_list, op) diff --git a/tests/ops/filter/test_image_text_matching_filter.py b/tests/ops/filter/test_image_text_matching_filter.py index 15adfb5d4..7620b84a8 100644 --- a/tests/ops/filter/test_image_text_matching_filter.py +++ b/tests/ops/filter/test_image_text_matching_filter.py @@ -1,13 +1,17 @@ +# flake8: noqa: E501 + import os import unittest from datasets import Dataset -from data_juicer.ops.filter.image_text_matching_filter import ImageTextMatchingFilter +from data_juicer.ops.filter.image_text_matching_filter import \ + ImageTextMatchingFilter from data_juicer.utils.constant import Fields from data_juicer.utils.mm_utils import SpecialTokens from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + class ImageTextMatchingFilterTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', @@ -20,7 +24,7 @@ class ImageTextMatchingFilterTest(DataJuicerTestCaseBase): @classmethod def tearDownClass(cls) -> None: super().tearDownClass(cls.hf_blip) - + def _run_filter(self, dataset: Dataset, target_list, op, num_proc=1): if Fields.stats not in dataset.features: @@ -30,7 +34,9 @@ def _run_filter(self, dataset: Dataset, target_list, op, num_proc=1): dataset = dataset.add_column(name=Fields.stats, column=[{}] * dataset.num_rows) - dataset = dataset.map(op.compute_stats, num_proc=num_proc, with_rank=True) + dataset = dataset.map(op.compute_stats, + num_proc=num_proc, + with_rank=True) dataset = dataset.filter(op.process, num_proc=num_proc) dataset = dataset.select_columns(column_names=['text', 'images']) res_list = dataset.to_list() @@ -39,23 +45,26 @@ def _run_filter(self, dataset: Dataset, target_list, op, num_proc=1): def test_no_eoc_special_token(self): ds_list = [{ - 'text': f'{SpecialTokens.image}a woman sitting on the beach with a dog', + 'text': + f'{SpecialTokens.image}a woman sitting on the beach with a dog', 'images': [self.demo_path] }, { - 'text': f'{SpecialTokens.image}a man sitting on the grass with a cat', + 'text': + f'{SpecialTokens.image}a man sitting on the grass with a cat', 'images': [self.demo_path] }] tgt_list = [{ - 'text': f'{SpecialTokens.image}a woman sitting on the beach with a dog', + 'text': + f'{SpecialTokens.image}a woman sitting on the beach with a dog', 'images': [self.demo_path] }] dataset = Dataset.from_list(ds_list) op = ImageTextMatchingFilter(hf_blip=self.hf_blip, - reduce_mode='avg', - any_or_all='any', - min_score=0.003, - max_score=1.0) + reduce_mode='avg', + any_or_all='any', + min_score=0.003, + max_score=1.0) self._run_filter(dataset, tgt_list, op) def test_eoc_special_token(self): @@ -65,7 +74,8 @@ def test_eoc_special_token(self): f'{SpecialTokens.image}a woman sitting on the beach with a dog{SpecialTokens.eoc}', 'images': [self.demo_path] }, { - 'text': f'{SpecialTokens.image}a man sitting on the grass with a cat', + 'text': + f'{SpecialTokens.image}a man sitting on the grass with a cat', 'images': [self.demo_path] }] tgt_list = [{ @@ -76,10 +86,10 @@ def test_eoc_special_token(self): dataset = Dataset.from_list(ds_list) op = ImageTextMatchingFilter(hf_blip=self.hf_blip, - reduce_mode='avg', - any_or_all='any', - min_score=0.003, - max_score=1.0) + reduce_mode='avg', + any_or_all='any', + min_score=0.003, + max_score=1.0) self._run_filter(dataset, tgt_list, op) def test_horizontal_flip(self): @@ -89,7 +99,8 @@ def test_horizontal_flip(self): f'{SpecialTokens.image}a woman sitting on the beach with a dog{SpecialTokens.eoc}', 'images': [self.demo_path] }, { - 'text': f'{SpecialTokens.image}a man sitting on the grass with a cat', + 'text': + f'{SpecialTokens.image}a man sitting on the grass with a cat', 'images': [self.demo_path] }] tgt_list = [{ @@ -100,12 +111,12 @@ def test_horizontal_flip(self): dataset = Dataset.from_list(ds_list) op = ImageTextMatchingFilter(hf_blip=self.hf_blip, - reduce_mode='avg', - any_or_all='any', - horizontal_flip=True, - vertical_flip=False, - min_score=0.003, - max_score=1.0) + reduce_mode='avg', + any_or_all='any', + horizontal_flip=True, + vertical_flip=False, + min_score=0.003, + max_score=1.0) self._run_filter(dataset, tgt_list, op) def test_vertical_flip(self): @@ -115,7 +126,8 @@ def test_vertical_flip(self): f'{SpecialTokens.image}a woman sitting on the beach with a dog{SpecialTokens.eoc}', 'images': [self.demo_path] }, { - 'text': f'{SpecialTokens.image}a man sitting on the grass with a cat', + 'text': + f'{SpecialTokens.image}a man sitting on the grass with a cat', 'images': [self.demo_path] }] tgt_list = [{ @@ -126,12 +138,12 @@ def test_vertical_flip(self): dataset = Dataset.from_list(ds_list) op = ImageTextMatchingFilter(hf_blip=self.hf_blip, - reduce_mode='avg', - any_or_all='any', - horizontal_flip=False, - vertical_flip=True, - min_score=0.003, - max_score=1.0) + reduce_mode='avg', + any_or_all='any', + horizontal_flip=False, + vertical_flip=True, + min_score=0.003, + max_score=1.0) self._run_filter(dataset, tgt_list, op) def test_keep_any(self): @@ -150,10 +162,10 @@ def test_keep_any(self): }] dataset = Dataset.from_list(ds_list) op = ImageTextMatchingFilter(hf_blip=self.hf_blip, - reduce_mode='avg', - any_or_all='any', - min_score=0.003, - max_score=1.0) + reduce_mode='avg', + any_or_all='any', + min_score=0.003, + max_score=1.0) self._run_filter(dataset, tgt_list, op) def test_keep_all(self): @@ -167,66 +179,71 @@ def test_keep_all(self): tgt_list = [] dataset = Dataset.from_list(ds_list) op = ImageTextMatchingFilter(hf_blip=self.hf_blip, - reduce_mode='avg', - any_or_all='all', - min_score=0.003, - max_score=1.0) + reduce_mode='avg', + any_or_all='all', + min_score=0.003, + max_score=1.0) self._run_filter(dataset, tgt_list, op) def test_reduce_avg(self): ds_list = [{ - 'text': f'{SpecialTokens.image}a woman sitting on the beach with a dog ' + 'text': + f'{SpecialTokens.image}a woman sitting on the beach with a dog ' f'{SpecialTokens.image} {SpecialTokens.eoc}', 'images': [self.demo_path, self.img3_path] }] tgt_list = [{ - 'text': f'{SpecialTokens.image}a woman sitting on the beach with a dog ' + 'text': + f'{SpecialTokens.image}a woman sitting on the beach with a dog ' f'{SpecialTokens.image} {SpecialTokens.eoc}', 'images': [self.demo_path, self.img3_path] }] dataset = Dataset.from_list(ds_list) op = ImageTextMatchingFilter(hf_blip=self.hf_blip, - reduce_mode='avg', - any_or_all='any', - min_score=0.003, - max_score=1.0) + reduce_mode='avg', + any_or_all='any', + min_score=0.003, + max_score=1.0) self._run_filter(dataset, tgt_list, op) def test_reduce_max(self): ds_list = [{ - 'text': f'{SpecialTokens.image}a woman sitting on the beach with a dog ' + 'text': + f'{SpecialTokens.image}a woman sitting on the beach with a dog ' f'{SpecialTokens.image} {SpecialTokens.eoc}', 'images': [self.demo_path, self.img3_path] }] tgt_list = [{ - 'text': f'{SpecialTokens.image}a woman sitting on the beach with a dog ' + 'text': + f'{SpecialTokens.image}a woman sitting on the beach with a dog ' f'{SpecialTokens.image} {SpecialTokens.eoc}', 'images': [self.demo_path, self.img3_path] }] dataset = Dataset.from_list(ds_list) op = ImageTextMatchingFilter(hf_blip=self.hf_blip, - reduce_mode='max', - any_or_all='any', - min_score=0.003, - max_score=1.0) + reduce_mode='max', + any_or_all='any', + min_score=0.003, + max_score=1.0) self._run_filter(dataset, tgt_list, op) def test_reduce_min(self): ds_list = [{ - 'text': f'{SpecialTokens.image}a woman sitting on the beach with a dog ' + 'text': + f'{SpecialTokens.image}a woman sitting on the beach with a dog ' f'{SpecialTokens.image} {SpecialTokens.eoc}', 'images': [self.demo_path, self.img3_path] }] dataset = Dataset.from_list(ds_list) op = ImageTextMatchingFilter(hf_blip=self.hf_blip, - reduce_mode='min', - any_or_all='any', - min_score=0.1, - max_score=0.9) + reduce_mode='min', + any_or_all='any', + min_score=0.1, + max_score=0.9) self._run_filter(dataset, [], op) def test_multi_process(self): @@ -245,10 +262,10 @@ def test_multi_process(self): }] * 10 dataset = Dataset.from_list(ds_list) op = ImageTextMatchingFilter(hf_blip=self.hf_blip, - reduce_mode='avg', - any_or_all='any', - min_score=0.003, - max_score=1.0) + reduce_mode='avg', + any_or_all='any', + min_score=0.003, + max_score=1.0) self._run_filter(dataset, tgt_list, op, num_proc=4) diff --git a/tests/ops/filter/test_image_text_similarity_filter.py b/tests/ops/filter/test_image_text_similarity_filter.py index f50637561..549ee3137 100644 --- a/tests/ops/filter/test_image_text_similarity_filter.py +++ b/tests/ops/filter/test_image_text_similarity_filter.py @@ -3,11 +3,13 @@ from datasets import Dataset -from data_juicer.ops.filter.image_text_similarity_filter import ImageTextSimilarityFilter +from data_juicer.ops.filter.image_text_similarity_filter import \ + ImageTextSimilarityFilter from data_juicer.utils.constant import Fields from data_juicer.utils.mm_utils import SpecialTokens from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + class ImageTextSimilarityFilterTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', @@ -30,7 +32,9 @@ def _run_filter(self, dataset: Dataset, target_list, op, num_proc=1): dataset = dataset.add_column(name=Fields.stats, column=[{}] * dataset.num_rows) - dataset = dataset.map(op.compute_stats, num_proc=num_proc, with_rank=True) + dataset = dataset.map(op.compute_stats, + num_proc=num_proc, + with_rank=True) dataset = dataset.filter(op.process, num_proc=num_proc) dataset = dataset.select_columns(column_names=['text', 'images']) res_list = dataset.to_list() @@ -52,12 +56,12 @@ def test_no_eoc_special_token(self): dataset = Dataset.from_list(ds_list) op = ImageTextSimilarityFilter(hf_clip=self.hf_clip, - reduce_mode='avg', - any_or_all='any', - horizontal_flip=False, - vertical_flip=False, - min_score=0.2, - max_score=0.9) + reduce_mode='avg', + any_or_all='any', + horizontal_flip=False, + vertical_flip=False, + min_score=0.2, + max_score=0.9) self._run_filter(dataset, tgt_list, op) def test_eoc_special_token(self): @@ -67,7 +71,8 @@ def test_eoc_special_token(self): f'{SpecialTokens.image}a photo of a cat{SpecialTokens.eoc}', 'images': [self.cat_path] }, { - 'text': f'{SpecialTokens.image}a photo of a dog{SpecialTokens.eoc}', + 'text': + f'{SpecialTokens.image}a photo of a dog{SpecialTokens.eoc}', 'images': [self.cat_path] }] tgt_list = [{ @@ -78,12 +83,12 @@ def test_eoc_special_token(self): dataset = Dataset.from_list(ds_list) op = ImageTextSimilarityFilter(hf_clip=self.hf_clip, - reduce_mode='avg', - any_or_all='any', - horizontal_flip=False, - vertical_flip=False, - min_score=0.2, - max_score=0.9) + reduce_mode='avg', + any_or_all='any', + horizontal_flip=False, + vertical_flip=False, + min_score=0.2, + max_score=0.9) self._run_filter(dataset, tgt_list, op) def test_horizontal_flip(self): @@ -104,12 +109,12 @@ def test_horizontal_flip(self): dataset = Dataset.from_list(ds_list) op = ImageTextSimilarityFilter(hf_clip=self.hf_clip, - reduce_mode='avg', - any_or_all='any', - horizontal_flip=True, - vertical_flip=False, - min_score=0.24, - max_score=0.9) + reduce_mode='avg', + any_or_all='any', + horizontal_flip=True, + vertical_flip=False, + min_score=0.24, + max_score=0.9) self._run_filter(dataset, tgt_list, op) def test_vertical_flip(self): @@ -130,12 +135,12 @@ def test_vertical_flip(self): dataset = Dataset.from_list(ds_list) op = ImageTextSimilarityFilter(hf_clip=self.hf_clip, - reduce_mode='avg', - any_or_all='any', - horizontal_flip=False, - vertical_flip=True, - min_score=0.2, - max_score=0.9) + reduce_mode='avg', + any_or_all='any', + horizontal_flip=False, + vertical_flip=True, + min_score=0.2, + max_score=0.9) self._run_filter(dataset, tgt_list, op) def test_keep_any(self): @@ -154,12 +159,12 @@ def test_keep_any(self): }] dataset = Dataset.from_list(ds_list) op = ImageTextSimilarityFilter(hf_clip=self.hf_clip, - reduce_mode='avg', - any_or_all='any', - horizontal_flip=False, - vertical_flip=False, - min_score=0.2, - max_score=0.9) + reduce_mode='avg', + any_or_all='any', + horizontal_flip=False, + vertical_flip=False, + min_score=0.2, + max_score=0.9) self._run_filter(dataset, tgt_list, op) def test_keep_all(self): @@ -173,12 +178,12 @@ def test_keep_all(self): tgt_list = [] dataset = Dataset.from_list(ds_list) op = ImageTextSimilarityFilter(hf_clip=self.hf_clip, - reduce_mode='avg', - any_or_all='all', - horizontal_flip=False, - vertical_flip=False, - min_score=0.2, - max_score=0.9) + reduce_mode='avg', + any_or_all='all', + horizontal_flip=False, + vertical_flip=False, + min_score=0.2, + max_score=0.9) self._run_filter(dataset, tgt_list, op) def test_reduce_avg(self): @@ -195,12 +200,12 @@ def test_reduce_avg(self): }] dataset = Dataset.from_list(ds_list) op = ImageTextSimilarityFilter(hf_clip=self.hf_clip, - reduce_mode='avg', - any_or_all='any', - horizontal_flip=False, - vertical_flip=False, - min_score=0.2, - max_score=0.9) + reduce_mode='avg', + any_or_all='any', + horizontal_flip=False, + vertical_flip=False, + min_score=0.2, + max_score=0.9) self._run_filter(dataset, tgt_list, op) def test_reduce_max(self): @@ -217,12 +222,12 @@ def test_reduce_max(self): }] dataset = Dataset.from_list(ds_list) op = ImageTextSimilarityFilter(hf_clip=self.hf_clip, - reduce_mode='max', - any_or_all='any', - horizontal_flip=False, - vertical_flip=False, - min_score=0.2, - max_score=0.9) + reduce_mode='max', + any_or_all='any', + horizontal_flip=False, + vertical_flip=False, + min_score=0.2, + max_score=0.9) self._run_filter(dataset, tgt_list, op) def test_reduce_min(self): @@ -240,12 +245,12 @@ def test_reduce_min(self): dataset = Dataset.from_list(ds_list) op = ImageTextSimilarityFilter(hf_clip=self.hf_clip, - reduce_mode='min', - any_or_all='any', - horizontal_flip=False, - vertical_flip=False, - min_score=0.1, - max_score=0.9) + reduce_mode='min', + any_or_all='any', + horizontal_flip=False, + vertical_flip=False, + min_score=0.1, + max_score=0.9) self._run_filter(dataset, tgt_list, op) @@ -268,12 +273,12 @@ def test_multi_process(self): }] * 10 dataset = Dataset.from_list(ds_list) op = ImageTextSimilarityFilter(hf_clip=self.hf_clip, - reduce_mode='avg', - any_or_all='any', - horizontal_flip=False, - vertical_flip=False, - min_score=0.2, - max_score=0.9) + reduce_mode='avg', + any_or_all='any', + horizontal_flip=False, + vertical_flip=False, + min_score=0.2, + max_score=0.9) self._run_filter(dataset, tgt_list, op, num_proc=4) diff --git a/tests/ops/filter/test_language_id_score_filter.py b/tests/ops/filter/test_language_id_score_filter.py index 0b6e50daa..21d71ceb5 100644 --- a/tests/ops/filter/test_language_id_score_filter.py +++ b/tests/ops/filter/test_language_id_score_filter.py @@ -5,9 +5,10 @@ from data_juicer.ops.filter.language_id_score_filter import \ LanguageIDScoreFilter from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class LanguageIDScoreFilterTest(unittest.TestCase): +class LanguageIDScoreFilterTest(DataJuicerTestCaseBase): def _run_language_id_score_filter(self, dataset: Dataset, target_list, op): if Fields.stats not in dataset.features: diff --git a/tests/ops/filter/test_maximum_line_length_filter.py b/tests/ops/filter/test_maximum_line_length_filter.py index 8bcf6aa83..ef8a6d33e 100644 --- a/tests/ops/filter/test_maximum_line_length_filter.py +++ b/tests/ops/filter/test_maximum_line_length_filter.py @@ -5,9 +5,10 @@ from data_juicer.ops.filter.maximum_line_length_filter import \ MaximumLineLengthFilter from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class MaximumLineLengthFilterTest(unittest.TestCase): +class MaximumLineLengthFilterTest(DataJuicerTestCaseBase): def _run_maximum_line_length_filter(self, dataset: Dataset, target_list, op): diff --git a/tests/ops/filter/test_perplexity_filter.py b/tests/ops/filter/test_perplexity_filter.py index 4b45598dd..114bdb307 100644 --- a/tests/ops/filter/test_perplexity_filter.py +++ b/tests/ops/filter/test_perplexity_filter.py @@ -4,9 +4,10 @@ from data_juicer.ops.filter.perplexity_filter import PerplexityFilter from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class PerplexityFilterTest(unittest.TestCase): +class PerplexityFilterTest(DataJuicerTestCaseBase): def _run_perplexity_filter(self, dataset: Dataset, target_list, op): if Fields.stats not in dataset.features: diff --git a/tests/ops/filter/test_phrase_grounding_recall_filter.py b/tests/ops/filter/test_phrase_grounding_recall_filter.py index c5510014d..16d689e7d 100644 --- a/tests/ops/filter/test_phrase_grounding_recall_filter.py +++ b/tests/ops/filter/test_phrase_grounding_recall_filter.py @@ -1,13 +1,17 @@ +# flake8: noqa: E501 + import os import unittest from datasets import Dataset -from data_juicer.ops.filter.phrase_grounding_recall_filter import PhraseGroundingRecallFilter +from data_juicer.ops.filter.phrase_grounding_recall_filter import \ + PhraseGroundingRecallFilter from data_juicer.utils.constant import Fields from data_juicer.utils.mm_utils import SpecialTokens from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + class PhraseGroundingRecallFilterTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', @@ -24,7 +28,7 @@ class PhraseGroundingRecallFilterTest(DataJuicerTestCaseBase): @classmethod def tearDownClass(cls) -> None: super().tearDownClass(cls.hf_owlvit) - + def _run_filter(self, dataset: Dataset, target_list, op, num_proc=1): if Fields.stats not in dataset.features: @@ -34,7 +38,9 @@ def _run_filter(self, dataset: Dataset, target_list, op, num_proc=1): dataset = dataset.add_column(name=Fields.stats, column=[{}] * dataset.num_rows) - dataset = dataset.map(op.compute_stats, num_proc=num_proc, with_rank=True) + dataset = dataset.map(op.compute_stats, + num_proc=num_proc, + with_rank=True) dataset = dataset.filter(op.process, num_proc=num_proc) dataset = dataset.select_columns(column_names=['text', 'images']) res_list = dataset.to_list() @@ -43,35 +49,45 @@ def _run_filter(self, dataset: Dataset, target_list, op, num_proc=1): def test_general(self): ds_list = [{ - 'text': f'{SpecialTokens.image}a woman sitting on the beach with a dog', + 'text': + f'{SpecialTokens.image}a woman sitting on the beach with a dog', 'images': [self.demo_path] }, { - 'text': f'Two cats are sleeping on the couch with two remote controls{SpecialTokens.image}', + 'text': + f'Two cats are sleeping on the couch with two remote controls{SpecialTokens.image}', 'images': [self.cat_path] }, { - 'text': f'{SpecialTokens.image} select luxury furniture 3 - inch gel memory foam mattress topper {SpecialTokens.eoc}', + 'text': + f'{SpecialTokens.image} select luxury furniture 3 - inch gel memory foam mattress topper {SpecialTokens.eoc}', 'images': [self.img1_path] }, { - 'text': f'{SpecialTokens.image} A bus with red advertisements is running on the street. {SpecialTokens.eoc}', + 'text': + f'{SpecialTokens.image} A bus with red advertisements is running on the street. {SpecialTokens.eoc}', 'images': [self.img2_path] }, { - 'text': f'{SpecialTokens.image} A woman carrying a bag is walking in a rainy alley holding an umbrella {SpecialTokens.eoc}', + 'text': + f'{SpecialTokens.image} A woman carrying a bag is walking in a rainy alley holding an umbrella {SpecialTokens.eoc}', 'images': [self.img3_path] }] tgt_list = [{ - 'text': f'{SpecialTokens.image}a woman sitting on the beach with a dog', + 'text': + f'{SpecialTokens.image}a woman sitting on the beach with a dog', 'images': [self.demo_path] }, { - 'text': f'Two cats are sleeping on the couch with two remote controls{SpecialTokens.image}', + 'text': + f'Two cats are sleeping on the couch with two remote controls{SpecialTokens.image}', 'images': [self.cat_path] }, { - 'text': f'{SpecialTokens.image} select luxury furniture 3 - inch gel memory foam mattress topper {SpecialTokens.eoc}', + 'text': + f'{SpecialTokens.image} select luxury furniture 3 - inch gel memory foam mattress topper {SpecialTokens.eoc}', 'images': [self.img1_path] }, { - 'text': f'{SpecialTokens.image} A bus with red advertisements is running on the street. {SpecialTokens.eoc}', + 'text': + f'{SpecialTokens.image} A bus with red advertisements is running on the street. {SpecialTokens.eoc}', 'images': [self.img2_path] }, { - 'text': f'{SpecialTokens.image} A woman carrying a bag is walking in a rainy alley holding an umbrella {SpecialTokens.eoc}', + 'text': + f'{SpecialTokens.image} A woman carrying a bag is walking in a rainy alley holding an umbrella {SpecialTokens.eoc}', 'images': [self.img3_path] }] @@ -88,29 +104,37 @@ def test_general(self): def test_high_recall(self): ds_list = [{ - 'text': f'{SpecialTokens.image}a woman sitting on the beach with a dog', + 'text': + f'{SpecialTokens.image}a woman sitting on the beach with a dog', 'images': [self.demo_path] }, { - 'text': f'Two cats are sleeping on the couch with two remote controls{SpecialTokens.image}', + 'text': + f'Two cats are sleeping on the couch with two remote controls{SpecialTokens.image}', 'images': [self.cat_path] }, { - 'text': f'{SpecialTokens.image} select luxury furniture 3 - inch gel memory foam mattress topper {SpecialTokens.eoc}', + 'text': + f'{SpecialTokens.image} select luxury furniture 3 - inch gel memory foam mattress topper {SpecialTokens.eoc}', 'images': [self.img1_path] }, { - 'text': f'{SpecialTokens.image} A bus with red advertisements is running on the street. {SpecialTokens.eoc}', + 'text': + f'{SpecialTokens.image} A bus with red advertisements is running on the street. {SpecialTokens.eoc}', 'images': [self.img2_path] }, { - 'text': f'{SpecialTokens.image} A woman carrying a bag is walking in a rainy alley holding an umbrella {SpecialTokens.eoc}', + 'text': + f'{SpecialTokens.image} A woman carrying a bag is walking in a rainy alley holding an umbrella {SpecialTokens.eoc}', 'images': [self.img3_path] }] tgt_list = [{ - 'text': f'{SpecialTokens.image}a woman sitting on the beach with a dog', + 'text': + f'{SpecialTokens.image}a woman sitting on the beach with a dog', 'images': [self.demo_path] }, { - 'text': f'Two cats are sleeping on the couch with two remote controls{SpecialTokens.image}', + 'text': + f'Two cats are sleeping on the couch with two remote controls{SpecialTokens.image}', 'images': [self.cat_path] }, { - 'text': f'{SpecialTokens.image} A woman carrying a bag is walking in a rainy alley holding an umbrella {SpecialTokens.eoc}', + 'text': + f'{SpecialTokens.image} A woman carrying a bag is walking in a rainy alley holding an umbrella {SpecialTokens.eoc}', 'images': [self.img3_path] }] @@ -127,14 +151,17 @@ def test_high_recall(self): def test_high_conf_thr(self): ds_list = [{ - 'text': f'{SpecialTokens.image}a woman sitting on the beach with a dog', + 'text': + f'{SpecialTokens.image}a woman sitting on the beach with a dog', 'images': [self.demo_path] }, { - 'text': f'{SpecialTokens.image}a man sitting on the grass with a cat', + 'text': + f'{SpecialTokens.image}a man sitting on the grass with a cat', 'images': [self.demo_path] }] tgt_list = [{ - 'text': f'{SpecialTokens.image}a woman sitting on the beach with a dog', + 'text': + f'{SpecialTokens.image}a woman sitting on the beach with a dog', 'images': [self.demo_path] }] @@ -152,17 +179,21 @@ def test_high_conf_thr(self): def test_low_conf_thr(self): # some similar but different objects might be detected incorrectly ds_list = [{ - 'text': f'{SpecialTokens.image}a woman sitting on the beach with a dog', + 'text': + f'{SpecialTokens.image}a woman sitting on the beach with a dog', 'images': [self.demo_path] }, { - 'text': f'{SpecialTokens.image}a man sitting on the grass with a cat', + 'text': + f'{SpecialTokens.image}a man sitting on the grass with a cat', 'images': [self.demo_path] }] tgt_list = [{ - 'text': f'{SpecialTokens.image}a woman sitting on the beach with a dog', + 'text': + f'{SpecialTokens.image}a woman sitting on the beach with a dog', 'images': [self.demo_path] }, { - 'text': f'{SpecialTokens.image}a man sitting on the grass with a cat', + 'text': + f'{SpecialTokens.image}a man sitting on the grass with a cat', 'images': [self.demo_path] }] @@ -183,7 +214,8 @@ def test_low_area_ratio(self): 'text': f'{SpecialTokens.image} a photo of a woman\'s face', 'images': [self.face_path] }, { - 'text': f'{SpecialTokens.image}A bus with red advertisements is running on the street.', + 'text': + f'{SpecialTokens.image}A bus with red advertisements is running on the street.', 'images': [self.img2_path] }] tgt_list = [] @@ -205,11 +237,13 @@ def test_high_area_ratio(self): 'text': f'{SpecialTokens.image} a photo of a woman\'s face', 'images': [self.face_path] }, { - 'text': f'{SpecialTokens.image}A bus with red advertisements is running on the street.', + 'text': + f'{SpecialTokens.image}A bus with red advertisements is running on the street.', 'images': [self.img2_path] }] tgt_list = [{ - 'text': f'{SpecialTokens.image}A bus with red advertisements is running on the street.', + 'text': + f'{SpecialTokens.image}A bus with red advertisements is running on the street.', 'images': [self.img2_path] }] @@ -227,17 +261,21 @@ def test_high_area_ratio(self): def test_reduce_avg(self): ds_list = [{ - 'text': f'{SpecialTokens.image}a woman sitting on the beach with a dog{SpecialTokens.image}', + 'text': + f'{SpecialTokens.image}a woman sitting on the beach with a dog{SpecialTokens.image}', 'images': [self.demo_path, self.cat_path] }, { - 'text': f'{SpecialTokens.image} {SpecialTokens.image} select luxury furniture 3 - inch gel memory foam mattress topper {SpecialTokens.eoc}', + 'text': + f'{SpecialTokens.image} {SpecialTokens.image} select luxury furniture 3 - inch gel memory foam mattress topper {SpecialTokens.eoc}', 'images': [self.img1_path, self.img2_path] }] tgt_list = [{ - 'text': f'{SpecialTokens.image}a woman sitting on the beach with a dog{SpecialTokens.image}', + 'text': + f'{SpecialTokens.image}a woman sitting on the beach with a dog{SpecialTokens.image}', 'images': [self.demo_path, self.cat_path] }, { - 'text': f'{SpecialTokens.image} {SpecialTokens.image} select luxury furniture 3 - inch gel memory foam mattress topper {SpecialTokens.eoc}', + 'text': + f'{SpecialTokens.image} {SpecialTokens.image} select luxury furniture 3 - inch gel memory foam mattress topper {SpecialTokens.eoc}', 'images': [self.img1_path, self.img2_path] }] @@ -254,14 +292,17 @@ def test_reduce_avg(self): def test_reduce_max(self): ds_list = [{ - 'text': f'{SpecialTokens.image}a woman sitting on the beach with a dog{SpecialTokens.image}', + 'text': + f'{SpecialTokens.image}a woman sitting on the beach with a dog{SpecialTokens.image}', 'images': [self.demo_path, self.cat_path] }, { - 'text': f'{SpecialTokens.image} {SpecialTokens.image} select luxury furniture 3 - inch gel memory foam mattress topper {SpecialTokens.eoc}', + 'text': + f'{SpecialTokens.image} {SpecialTokens.image} select luxury furniture 3 - inch gel memory foam mattress topper {SpecialTokens.eoc}', 'images': [self.img1_path, self.img2_path] }] tgt_list = [{ - 'text': f'{SpecialTokens.image}a woman sitting on the beach with a dog{SpecialTokens.image}', + 'text': + f'{SpecialTokens.image}a woman sitting on the beach with a dog{SpecialTokens.image}', 'images': [self.demo_path, self.cat_path] }] @@ -278,10 +319,12 @@ def test_reduce_max(self): def test_reduce_min(self): ds_list = [{ - 'text': f'{SpecialTokens.image}a woman sitting on the beach with a dog{SpecialTokens.image}', + 'text': + f'{SpecialTokens.image}a woman sitting on the beach with a dog{SpecialTokens.image}', 'images': [self.demo_path, self.cat_path] }, { - 'text': f'{SpecialTokens.image} {SpecialTokens.image} select luxury furniture 3 - inch gel memory foam mattress topper {SpecialTokens.eoc}', + 'text': + f'{SpecialTokens.image} {SpecialTokens.image} select luxury furniture 3 - inch gel memory foam mattress topper {SpecialTokens.eoc}', 'images': [self.img1_path, self.img2_path] }] tgt_list = [] @@ -300,8 +343,8 @@ def test_keep_all(self): ds_list = [{ 'text': - f'{SpecialTokens.image} {SpecialTokens.image} select luxury furniture 3 - inch gel memory foam mattress topper {SpecialTokens.eoc}' - f'{SpecialTokens.image} a woman sitting on the beach with a dog', + f'{SpecialTokens.image} {SpecialTokens.image} select luxury furniture 3 - inch gel memory foam mattress topper {SpecialTokens.eoc}' + f'{SpecialTokens.image} a woman sitting on the beach with a dog', 'images': [self.img1_path, self.cat_path, self.demo_path] }] tgt_list = [] @@ -320,14 +363,14 @@ def test_keep_any(self): ds_list = [{ 'text': - f'{SpecialTokens.image} {SpecialTokens.image} select luxury furniture 3 - inch gel memory foam mattress topper {SpecialTokens.eoc}' - f'{SpecialTokens.image} a woman sitting on the beach with a dog', + f'{SpecialTokens.image} {SpecialTokens.image} select luxury furniture 3 - inch gel memory foam mattress topper {SpecialTokens.eoc}' + f'{SpecialTokens.image} a woman sitting on the beach with a dog', 'images': [self.img1_path, self.cat_path, self.demo_path] }] tgt_list = [{ 'text': - f'{SpecialTokens.image} {SpecialTokens.image} select luxury furniture 3 - inch gel memory foam mattress topper {SpecialTokens.eoc}' - f'{SpecialTokens.image} a woman sitting on the beach with a dog', + f'{SpecialTokens.image} {SpecialTokens.image} select luxury furniture 3 - inch gel memory foam mattress topper {SpecialTokens.eoc}' + f'{SpecialTokens.image} a woman sitting on the beach with a dog', 'images': [self.img1_path, self.cat_path, self.demo_path] }] @@ -344,29 +387,37 @@ def test_keep_any(self): def test_process_in_parallel(self): ds_list = [{ - 'text': f'{SpecialTokens.image}a woman sitting on the beach with a dog', + 'text': + f'{SpecialTokens.image}a woman sitting on the beach with a dog', 'images': [self.demo_path] }, { - 'text': f'Two cats are sleeping on the couch with two remote controls{SpecialTokens.image}', + 'text': + f'Two cats are sleeping on the couch with two remote controls{SpecialTokens.image}', 'images': [self.cat_path] }, { - 'text': f'{SpecialTokens.image} select luxury furniture 3 - inch gel memory foam mattress topper {SpecialTokens.eoc}', + 'text': + f'{SpecialTokens.image} select luxury furniture 3 - inch gel memory foam mattress topper {SpecialTokens.eoc}', 'images': [self.img1_path] }, { - 'text': f'{SpecialTokens.image} A bus with red advertisements is running on the street. {SpecialTokens.eoc}', + 'text': + f'{SpecialTokens.image} A bus with red advertisements is running on the street. {SpecialTokens.eoc}', 'images': [self.img2_path] }, { - 'text': f'{SpecialTokens.image} A woman carrying a bag is walking in a rainy alley holding an umbrella {SpecialTokens.eoc}', + 'text': + f'{SpecialTokens.image} A woman carrying a bag is walking in a rainy alley holding an umbrella {SpecialTokens.eoc}', 'images': [self.img3_path] }] tgt_list = [{ - 'text': f'{SpecialTokens.image}a woman sitting on the beach with a dog', + 'text': + f'{SpecialTokens.image}a woman sitting on the beach with a dog', 'images': [self.demo_path] }, { - 'text': f'Two cats are sleeping on the couch with two remote controls{SpecialTokens.image}', + 'text': + f'Two cats are sleeping on the couch with two remote controls{SpecialTokens.image}', 'images': [self.cat_path] }, { - 'text': f'{SpecialTokens.image} A woman carrying a bag is walking in a rainy alley holding an umbrella {SpecialTokens.eoc}', + 'text': + f'{SpecialTokens.image} A woman carrying a bag is walking in a rainy alley holding an umbrella {SpecialTokens.eoc}', 'images': [self.img3_path] }] diff --git a/tests/ops/filter/test_special_characters_filter.py b/tests/ops/filter/test_special_characters_filter.py index 301291bc4..4ea505968 100644 --- a/tests/ops/filter/test_special_characters_filter.py +++ b/tests/ops/filter/test_special_characters_filter.py @@ -5,9 +5,10 @@ from data_juicer.ops.filter.special_characters_filter import \ SpecialCharactersFilter from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class SpecialCharactersFilterTest(unittest.TestCase): +class SpecialCharactersFilterTest(DataJuicerTestCaseBase): def _run_special_characters_filter(self, dataset: Dataset, target_list, op): diff --git a/tests/ops/filter/test_specified_field_filter.py b/tests/ops/filter/test_specified_field_filter.py index a3bd51020..3086e2b00 100644 --- a/tests/ops/filter/test_specified_field_filter.py +++ b/tests/ops/filter/test_specified_field_filter.py @@ -3,9 +3,10 @@ from datasets import Dataset from data_juicer.ops.filter.specified_field_filter import SpecifiedFieldFilter +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class SpecifiedFieldFilterTest(unittest.TestCase): +class SpecifiedFieldFilterTest(DataJuicerTestCaseBase): def _run_specified_field_filter(self, dataset: Dataset, target_list, op): dataset = dataset.map(op.compute_stats) diff --git a/tests/ops/filter/test_specified_numeric_field_filter.py b/tests/ops/filter/test_specified_numeric_field_filter.py index f82fd4617..c580f6905 100644 --- a/tests/ops/filter/test_specified_numeric_field_filter.py +++ b/tests/ops/filter/test_specified_numeric_field_filter.py @@ -4,9 +4,10 @@ from data_juicer.ops.filter.specified_numeric_field_filter import \ SpecifiedNumericFieldFilter +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class SpecifiedNumericFieldFilterTest(unittest.TestCase): +class SpecifiedNumericFieldFilterTest(DataJuicerTestCaseBase): def _run_specified_numeric_field_filter(self, dataset: Dataset, target_list, op): diff --git a/tests/ops/filter/test_stop_words_filter.py b/tests/ops/filter/test_stop_words_filter.py index 60219c1c5..8772b6960 100644 --- a/tests/ops/filter/test_stop_words_filter.py +++ b/tests/ops/filter/test_stop_words_filter.py @@ -4,9 +4,10 @@ from data_juicer.ops.filter.stopwords_filter import StopWordsFilter from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class StopWordsFilterTest(unittest.TestCase): +class StopWordsFilterTest(DataJuicerTestCaseBase): def _run_stopwords_filter(self, dataset: Dataset, target_list, op): if Fields.stats not in dataset.features: diff --git a/tests/ops/filter/test_suffix_filter.py b/tests/ops/filter/test_suffix_filter.py index ea2407245..48980c120 100644 --- a/tests/ops/filter/test_suffix_filter.py +++ b/tests/ops/filter/test_suffix_filter.py @@ -4,9 +4,10 @@ from data_juicer.ops.filter.suffix_filter import SuffixFilter from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class SuffixFilterTest(unittest.TestCase): +class SuffixFilterTest(DataJuicerTestCaseBase): def _run_suffix_filter(self, dataset: Dataset, target_list, op): dataset = dataset.map(op.compute_stats) diff --git a/tests/ops/filter/test_text_action_filter.py b/tests/ops/filter/test_text_action_filter.py index 9a146ea33..78b40dfad 100644 --- a/tests/ops/filter/test_text_action_filter.py +++ b/tests/ops/filter/test_text_action_filter.py @@ -1,14 +1,15 @@ -import unittest import os +import unittest from datasets import Dataset from data_juicer.ops.filter.text_action_filter import TextActionFilter from data_juicer.utils.constant import Fields from data_juicer.utils.mm_utils import SpecialTokens +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class TextActionFilterTest(unittest.TestCase): +class TextActionFilterTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data') @@ -16,7 +17,8 @@ class TextActionFilterTest(unittest.TestCase): cat_path = os.path.join(data_path, 'cat.jpg') img3_path = os.path.join(data_path, 'img3.jpg') - def _run_text_action_filter(self, dataset: Dataset, target_list, op, column_names): + def _run_text_action_filter(self, dataset: Dataset, target_list, op, + column_names): if Fields.stats not in dataset.features: dataset = dataset.add_column(name=Fields.stats, column=[{}] * dataset.num_rows) @@ -34,7 +36,7 @@ def test_en_text_case(self): 'text': 'Tom plays piano.' }, { 'text': 'Tom played piano.' - },{ + }, { 'text': 'I play piano.' }, { 'text': 'to play piano.' @@ -53,7 +55,7 @@ def test_en_text_case(self): 'text': 'Tom plays piano.' }, { 'text': 'Tom played piano.' - },{ + }, { 'text': 'I play piano.' }, { 'text': 'to play piano.' @@ -75,11 +77,7 @@ def test_zh_text_case(self): }, { 'text': '我有一只猫,它是一只猫' }] - tgt_list = [{ - 'text': '小明在 弹奏钢琴' - }, { - 'text': 'Tom在打篮球' - }] + tgt_list = [{'text': '小明在 弹奏钢琴'}, {'text': 'Tom在打篮球'}] dataset = Dataset.from_list(ds_list) op = TextActionFilter(lang='zh') self._run_text_action_filter(dataset, tgt_list, op, ['text']) @@ -95,14 +93,14 @@ def test_image_text_case(self): 'text': f'{SpecialTokens.image}背影{SpecialTokens.eoc}', 'images': [self.img3_path] }, { - 'text': f'雨中行走的女人背影', + 'text': '雨中行走的女人背影', 'images': [self.img3_path] }] tgt_list = [{ 'text': f'{SpecialTokens.image}小猫咪正在睡觉。{SpecialTokens.eoc}', 'images': [self.cat_path] }, { - 'text': f'雨中行走的女人背影', + 'text': '雨中行走的女人背影', 'images': [self.img3_path] }] @@ -110,5 +108,6 @@ def test_image_text_case(self): op = TextActionFilter(lang='zh') self._run_text_action_filter(dataset, tgt_list, op, ['text', 'images']) + if __name__ == '__main__': unittest.main() diff --git a/tests/ops/filter/test_text_entity_dependency_filter.py b/tests/ops/filter/test_text_entity_dependency_filter.py index a9daef1c1..6247318f7 100644 --- a/tests/ops/filter/test_text_entity_dependency_filter.py +++ b/tests/ops/filter/test_text_entity_dependency_filter.py @@ -1,14 +1,16 @@ -import unittest import os +import unittest from datasets import Dataset -from data_juicer.ops.filter.text_entity_dependency_filter import TextEntityDependencyFilter +from data_juicer.ops.filter.text_entity_dependency_filter import \ + TextEntityDependencyFilter from data_juicer.utils.constant import Fields from data_juicer.utils.mm_utils import SpecialTokens +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class TextEntityDependencyFilterTest(unittest.TestCase): +class TextEntityDependencyFilterTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data') @@ -16,7 +18,8 @@ class TextEntityDependencyFilterTest(unittest.TestCase): cat_path = os.path.join(data_path, 'cat.jpg') img3_path = os.path.join(data_path, 'img3.jpg') - def _run_text_entity_denpendency_filter(self, dataset: Dataset, target_list, op, column_names): + def _run_text_entity_denpendency_filter(self, dataset: Dataset, + target_list, op, column_names): if Fields.stats not in dataset.features: dataset = dataset.add_column(name=Fields.stats, column=[{}] * dataset.num_rows) @@ -34,7 +37,7 @@ def test_en_text_case(self): 'text': 'Tom is playing piano.' }, { 'text': 'piano.' - },{ + }, { 'text': 'a green tree', }, { 'text': 'tree', @@ -50,7 +53,8 @@ def test_en_text_case(self): }] dataset = Dataset.from_list(ds_list) op = TextEntityDependencyFilter(lang='en', any_or_all='any') - self._run_text_entity_denpendency_filter(dataset, tgt_list, op, ['text']) + self._run_text_entity_denpendency_filter(dataset, tgt_list, op, + ['text']) def test_zh_text_case(self): @@ -67,16 +71,11 @@ def test_zh_text_case(self): }, { 'text': '书。山。星星。土豆。' }] - tgt_list = [{ - 'text': '她在笑' - }, { - 'text': '枯藤老树昏鸦' - }, { - 'text': '一只会上树的猫' - }] + tgt_list = [{'text': '她在笑'}, {'text': '枯藤老树昏鸦'}, {'text': '一只会上树的猫'}] dataset = Dataset.from_list(ds_list) op = TextEntityDependencyFilter(lang='zh', any_or_all='all') - self._run_text_entity_denpendency_filter(dataset, tgt_list, op, ['text']) + self._run_text_entity_denpendency_filter(dataset, tgt_list, op, + ['text']) def test_image_text_case(self): ds_list = [{ @@ -89,20 +88,22 @@ def test_image_text_case(self): 'text': f'{SpecialTokens.image}背影{SpecialTokens.eoc}', 'images': [self.img3_path] }, { - 'text': f'撑着伞的女人背影', + 'text': '撑着伞的女人背影', 'images': [self.img3_path] }] tgt_list = [{ 'text': f'{SpecialTokens.image}三只缩成一团的小猫咪。{SpecialTokens.eoc}', 'images': [self.cat_path] }, { - 'text': f'撑着伞的女人背影', + 'text': '撑着伞的女人背影', 'images': [self.img3_path] }] dataset = Dataset.from_list(ds_list) op = TextEntityDependencyFilter(lang='zh', any_or_all='any') - self._run_text_entity_denpendency_filter(dataset, tgt_list, op, ['text', 'images']) + self._run_text_entity_denpendency_filter(dataset, tgt_list, op, + ['text', 'images']) + if __name__ == '__main__': unittest.main() diff --git a/tests/ops/filter/test_text_length_filter.py b/tests/ops/filter/test_text_length_filter.py index 1ff93e422..cb5df982b 100644 --- a/tests/ops/filter/test_text_length_filter.py +++ b/tests/ops/filter/test_text_length_filter.py @@ -4,9 +4,10 @@ from data_juicer.ops.filter.text_length_filter import TextLengthFilter from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class TextLengthFilterTest(unittest.TestCase): +class TextLengthFilterTest(DataJuicerTestCaseBase): def _run_text_length_filter(self, dataset: Dataset, target_list, op): if Fields.stats not in dataset.features: diff --git a/tests/ops/filter/test_token_num_filter.py b/tests/ops/filter/test_token_num_filter.py index a830e91fe..514ce21c3 100644 --- a/tests/ops/filter/test_token_num_filter.py +++ b/tests/ops/filter/test_token_num_filter.py @@ -4,9 +4,10 @@ from data_juicer.ops.filter.token_num_filter import TokenNumFilter from data_juicer.utils.constant import Fields, StatsKeys +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class WordNumFilterTest(unittest.TestCase): +class WordNumFilterTest(DataJuicerTestCaseBase): def test_token_num(self): src = [ diff --git a/tests/ops/filter/test_video_aesthetics_filter.py b/tests/ops/filter/test_video_aesthetics_filter.py new file mode 100644 index 000000000..afa6a3f0e --- /dev/null +++ b/tests/ops/filter/test_video_aesthetics_filter.py @@ -0,0 +1,244 @@ +import os +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.video_aesthetics_filter import \ + VideoAestheticsFilter +from data_juicer.utils.constant import Fields +from data_juicer.utils.mm_utils import SpecialTokens +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class VideoAestheticsFilterTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + # vid-low: keyframes -- 0.410, uniform-3 -- 0.410, uniform-5 -- 0.406 + # vid-mid: keyframes -- 0.448, uniform-3 -- 0.419, uniform-5 -- 0.449 + # vid-high: keyframes -- 0.468, uniform-3 -- 0.474, uniform-5 -- 0.480 + vid_low_path = os.path.join(data_path, 'video4.mp4') + vid_mid_path = os.path.join(data_path, 'video1.mp4') + vid_high_path = os.path.join(data_path, 'video3.mp4') + vid_low_text = ( + f'{SpecialTokens.video} [[q]]: Can you summarize what the girls ' + f'are doing in the video?\n", "[[a]]: Sure. The video shows a girl' + f' brushing the hair of another girl who keeps moving her face ' + f'around while the first girl keeps brushing the hair.' + f'{SpecialTokens.eoc}') + vid_mid_text = (f'{SpecialTokens.video} 白色的小羊站在一旁讲话。' + f'旁边还有两只灰色猫咪和一只拉着灰狼的猫咪' + f'{SpecialTokens.eoc}') + vid_high_text = (f'两个长头发的女子正坐在一张圆桌前讲话互动。 ' + f'{SpecialTokens.video} {SpecialTokens.eoc}') + + hf_aesthetics_scorer = \ + 'shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE' + + @classmethod + def tearDownClass(cls) -> None: + super().tearDownClass(cls.hf_aesthetics_scorer) + + def _run_video_aesthetics_filter(self, + dataset: Dataset, + target_list, + op, + np=1): + if Fields.stats not in dataset.features: + dataset = dataset.add_column(name=Fields.stats, + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats, num_proc=np) + dataset = dataset.filter(op.process, num_proc=np) + dataset = dataset.select_columns(column_names=[op.video_key]) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_default_filter(self): + ds_list = [{ + 'videos': [self.vid_low_path], + 'text': self.vid_low_text, + }, { + 'videos': [self.vid_mid_path], + 'text': self.vid_mid_text, + }, { + 'videos': [self.vid_high_path], + 'text': self.vid_high_text, + }] + tgt_list = [{ + 'videos': [self.vid_low_path] + }, { + 'videos': [self.vid_mid_path] + }, { + 'videos': [self.vid_high_path] + }] + dataset = Dataset.from_list(ds_list) + op = VideoAestheticsFilter(self.hf_aesthetics_scorer) + self._run_video_aesthetics_filter(dataset, tgt_list, op) + + def test_filter_large_score_videos(self): + ds_list = [{ + 'videos': [self.vid_low_path], + 'text': self.vid_low_text, + }, { + 'videos': [self.vid_mid_path], + 'text': self.vid_mid_text, + }, { + 'videos': [self.vid_high_path], + 'text': self.vid_high_text, + }] + tgt_list = [{ + 'videos': [self.vid_low_path] + }, { + 'videos': [self.vid_mid_path] + }] + dataset = Dataset.from_list(ds_list) + op = VideoAestheticsFilter(self.hf_aesthetics_scorer, max_score=0.45) + self._run_video_aesthetics_filter(dataset, tgt_list, op) + + def test_filter_small_score_videos(self): + ds_list = [{ + 'videos': [self.vid_low_path], + 'text': self.vid_low_text, + }, { + 'videos': [self.vid_mid_path], + 'text': self.vid_mid_text, + }, { + 'videos': [self.vid_high_path], + 'text': self.vid_high_text, + }] + tgt_list = [{ + 'videos': [self.vid_mid_path] + }, { + 'videos': [self.vid_high_path] + }] + dataset = Dataset.from_list(ds_list) + op = VideoAestheticsFilter(self.hf_aesthetics_scorer, min_score=0.415) + self._run_video_aesthetics_filter(dataset, tgt_list, op) + + def test_filter_videos_within_range_keyframes(self): + ds_list = [{ + 'videos': [self.vid_low_path], + 'text': self.vid_low_text, + }, { + 'videos': [self.vid_mid_path], + 'text': self.vid_mid_text, + }, { + 'videos': [self.vid_high_path], + 'text': self.vid_high_text, + }] + tgt_list = [{'videos': [self.vid_mid_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoAestheticsFilter(self.hf_aesthetics_scorer, + min_score=0.415, + max_score=0.47) + self._run_video_aesthetics_filter(dataset, tgt_list, op) + + def test_filter_keyframes(self): + ds_list = [{ + 'videos': [self.vid_low_path], + 'text': self.vid_low_text, + }, { + 'videos': [self.vid_mid_path], + 'text': self.vid_mid_text, + }, { + 'videos': [self.vid_high_path], + 'text': self.vid_high_text, + }] + tgt_list = [ + { + 'videos': [self.vid_mid_path] + }, + ] + dataset = Dataset.from_list(ds_list) + op = VideoAestheticsFilter(self.hf_aesthetics_scorer, + min_score=0.411, + max_score=0.45, + frame_sampling_method='keyframe') + self._run_video_aesthetics_filter(dataset, tgt_list, op) + + def test_filter_uniform_frames_with_different_frame_num(self): + ds_list = [{ + 'videos': [self.vid_low_path], + 'text': self.vid_low_text, + }, { + 'videos': [self.vid_mid_path], + 'text': self.vid_mid_text, + }, { + 'videos': [self.vid_high_path], + 'text': self.vid_high_text, + }] + tgt_list = [{'videos': [self.vid_mid_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoAestheticsFilter(self.hf_aesthetics_scorer, + min_score=0.41, + max_score=0.48, + frame_sampling_method='uniform', + frame_num=5) + self._run_video_aesthetics_filter(dataset, tgt_list, op) + + def test_any(self): + ds_list = [{ + 'videos': [self.vid_low_path, self.vid_mid_path], + 'text': self.vid_low_text + self.vid_mid_text, + }, { + 'videos': [self.vid_mid_path, self.vid_high_path], + 'text': self.vid_mid_text + self.vid_high_text, + }, { + 'videos': [self.vid_low_path, self.vid_high_path], + 'text': self.vid_low_text + self.vid_high_text, + }] + tgt_list = [{ + 'videos': [self.vid_low_path, self.vid_mid_path] + }, { + 'videos': [self.vid_mid_path, self.vid_high_path] + }] + dataset = Dataset.from_list(ds_list) + op = VideoAestheticsFilter(self.hf_aesthetics_scorer, + min_score=0.415, + max_score=0.45, + any_or_all='any') + self._run_video_aesthetics_filter(dataset, tgt_list, op) + + def test_all(self): + ds_list = [{ + 'videos': [self.vid_low_path, self.vid_mid_path], + 'text': self.vid_low_text + self.vid_mid_text, + }, { + 'videos': [self.vid_mid_path, self.vid_high_path], + 'text': self.vid_mid_text + self.vid_high_text, + }, { + 'videos': [self.vid_low_path, self.vid_high_path], + 'text': self.vid_low_text + self.vid_high_text, + }] + tgt_list = [] + dataset = Dataset.from_list(ds_list) + op = VideoAestheticsFilter(self.hf_aesthetics_scorer, + min_score=0.415, + max_score=0.45, + any_or_all='all') + self._run_video_aesthetics_filter(dataset, tgt_list, op) + + def test_filter_in_parallel(self): + + ds_list = [{ + 'videos': [self.vid_low_path], + 'text': self.vid_low_text, + }, { + 'videos': [self.vid_mid_path], + 'text': self.vid_mid_text, + }, { + 'videos': [self.vid_high_path], + 'text': self.vid_high_text, + }] + tgt_list = [{'videos': [self.vid_mid_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoAestheticsFilter( + self.hf_aesthetics_scorer, + min_score=0.415, + max_score=0.45, + ) + self._run_video_aesthetics_filter(dataset, tgt_list, op, np=2) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/test_video_aspect_ratio_filter.py b/tests/ops/filter/test_video_aspect_ratio_filter.py new file mode 100644 index 000000000..b07844097 --- /dev/null +++ b/tests/ops/filter/test_video_aspect_ratio_filter.py @@ -0,0 +1,106 @@ +import os +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.video_aspect_ratio_filter import \ + VideoAspectRatioFilter +from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class VideoAspectRatioFilterTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + vid1_path = os.path.join(data_path, 'video1.mp4') # 640x360, 16:9 + vid2_path = os.path.join(data_path, 'video2.mp4') # 480x640, 3:4 + vid3_path = os.path.join(data_path, 'video3.mp4') # 362x640, 181:320 + + def _run_op(self, dataset: Dataset, target_list, op, np=1): + if Fields.stats not in dataset.features: + dataset = dataset.add_column(name=Fields.stats, + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats, num_proc=np) + dataset = dataset.filter(op.process, num_proc=np) + dataset = dataset.select_columns(column_names=[op.video_key]) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_default_params(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + dataset = Dataset.from_list(ds_list) + op = VideoAspectRatioFilter() + self._run_op(dataset, tgt_list, op) + + def test_any(self): + + ds_list = [{ + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'videos': [self.vid2_path, self.vid3_path] + }, { + 'videos': [self.vid1_path, self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'videos': [self.vid2_path, self.vid3_path] + }, { + 'videos': [self.vid1_path, self.vid3_path] + }] + + dataset = Dataset.from_list(ds_list) + op = VideoAspectRatioFilter(min_ratio='3/4', + max_ratio='16/9', + any_or_all='any') + self._run_op(dataset, tgt_list, op) + + def test_all(self): + + ds_list = [{ + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'videos': [self.vid2_path, self.vid3_path] + }, { + 'videos': [self.vid1_path, self.vid3_path] + }] + tgt_list = [{'videos': [self.vid1_path, self.vid2_path]}] + + dataset = Dataset.from_list(ds_list) + op = VideoAspectRatioFilter(min_ratio='3/4', + max_ratio='16/9', + any_or_all='all') + self._run_op(dataset, tgt_list, op) + + def test_parallel(self): + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{'videos': [self.vid1_path]}, {'videos': [self.vid2_path]}] + + dataset = Dataset.from_list(ds_list) + op = VideoAspectRatioFilter(min_ratio='3/4', max_ratio='16/9') + self._run_op(dataset, tgt_list, op, np=2) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/test_video_duration_filter.py b/tests/ops/filter/test_video_duration_filter.py new file mode 100644 index 000000000..2954836bf --- /dev/null +++ b/tests/ops/filter/test_video_duration_filter.py @@ -0,0 +1,147 @@ +import os +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.video_duration_filter import VideoDurationFilter +from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class VideoDurationFilterTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + vid1_path = os.path.join(data_path, 'video1.mp4') # about 12s + vid2_path = os.path.join(data_path, 'video2.mp4') # about 23s + vid3_path = os.path.join(data_path, 'video3.mp4') # about 50s + + def _run_video_duration_filter(self, + dataset: Dataset, + target_list, + op, + np=1): + if Fields.stats not in dataset.features: + dataset = dataset.add_column(name=Fields.stats, + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats, num_proc=np) + dataset = dataset.filter(op.process, num_proc=np) + dataset = dataset.select_columns(column_names=[op.video_key]) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_default_filter(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + dataset = Dataset.from_list(ds_list) + op = VideoDurationFilter() + self._run_video_duration_filter(dataset, tgt_list, op) + + def test_filter_long_videos(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{'videos': [self.vid1_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoDurationFilter(max_duration=15) + self._run_video_duration_filter(dataset, tgt_list, op) + + def test_filter_short_videos(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{'videos': [self.vid3_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoDurationFilter(min_duration=30) + self._run_video_duration_filter(dataset, tgt_list, op) + + def test_filter_videos_within_range(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{'videos': [self.vid2_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoDurationFilter(min_duration=16, max_duration=42) + self._run_video_duration_filter(dataset, tgt_list, op) + + def test_any(self): + + ds_list = [{ + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'videos': [self.vid2_path, self.vid3_path] + }, { + 'videos': [self.vid1_path, self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'videos': [self.vid2_path, self.vid3_path] + }] + dataset = Dataset.from_list(ds_list) + op = VideoDurationFilter(min_duration=15, + max_duration=30, + any_or_all='any') + self._run_video_duration_filter(dataset, tgt_list, op) + + def test_all(self): + + ds_list = [{ + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'videos': [self.vid2_path, self.vid3_path] + }, { + 'videos': [self.vid1_path, self.vid3_path] + }] + tgt_list = [] + dataset = Dataset.from_list(ds_list) + op = VideoDurationFilter(min_duration=15, + max_duration=30, + any_or_all='all') + self._run_video_duration_filter(dataset, tgt_list, op) + + def test_filter_in_parallel(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{'videos': [self.vid2_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoDurationFilter(min_duration=15, max_duration=30) + self._run_video_duration_filter(dataset, tgt_list, op, np=2) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/test_video_frames_text_similarity_filter.py b/tests/ops/filter/test_video_frames_text_similarity_filter.py new file mode 100644 index 000000000..04e7355e5 --- /dev/null +++ b/tests/ops/filter/test_video_frames_text_similarity_filter.py @@ -0,0 +1,274 @@ +import os +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.video_frames_text_similarity_filter import \ + VideoFramesTextSimilarityFilter +from data_juicer.utils.constant import Fields +from data_juicer.utils.mm_utils import SpecialTokens +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class VideoFramesTextSimilarityFilterTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + # vid1: keyframes -- 0.2515, uniform-2 -- 0.2378, uniform-3 -- 0.2342 + # vid2: keyframes -- 0.2686, uniform-2 -- 0.2741, uniform-3 -- 0.2697 + # vid3: keyframes -- 0.3020, uniform-2 -- 0.3044, uniform-3 -- 0.2998 + vid1_path = os.path.join(data_path, 'video1.mp4') + vid2_path = os.path.join(data_path, 'video2.mp4') + vid3_path = os.path.join(data_path, 'video3.mp4') + + hf_clip = 'openai/clip-vit-base-patch32' + + @classmethod + def tearDownClass(cls) -> None: + super().tearDownClass(cls.hf_clip) + + def _run_video_frames_text_similarity_filter(self, + dataset: Dataset, + target_list, + op, + np=1): + if Fields.stats not in dataset.features: + dataset = dataset.add_column(name=Fields.stats, + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats, num_proc=np) + dataset = dataset.filter(op.process, num_proc=np) + dataset = dataset.select_columns(column_names=[op.video_key]) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_default_filter(self): + ds_list = [{ + 'videos': [self.vid1_path], + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + }, { + 'videos': [self.vid2_path], + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。' + f'{SpecialTokens.eoc}', + }, { + 'videos': [self.vid3_path], + 'text': + f'两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.video} {SpecialTokens.eoc}', + }] + tgt_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + dataset = Dataset.from_list(ds_list) + op = VideoFramesTextSimilarityFilter(self.hf_clip) + self._run_video_frames_text_similarity_filter(dataset, tgt_list, op) + + def test_filter_large_score_videos(self): + ds_list = [{ + 'videos': [self.vid1_path], + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + }, { + 'videos': [self.vid2_path], + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。' + f'{SpecialTokens.eoc}', + }, { + 'videos': [self.vid3_path], + 'text': + f'两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.video} {SpecialTokens.eoc}', + }] + tgt_list = [{'videos': [self.vid1_path]}, {'videos': [self.vid2_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoFramesTextSimilarityFilter(self.hf_clip, max_score=0.3) + self._run_video_frames_text_similarity_filter(dataset, tgt_list, op) + + def test_filter_small_score_videos(self): + ds_list = [{ + 'videos': [self.vid1_path], + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + }, { + 'videos': [self.vid2_path], + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。' + f'{SpecialTokens.eoc}', + }, { + 'videos': [self.vid3_path], + 'text': + f'两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.video} {SpecialTokens.eoc}', + }] + tgt_list = [{'videos': [self.vid2_path]}, {'videos': [self.vid3_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoFramesTextSimilarityFilter(self.hf_clip, min_score=0.26) + self._run_video_frames_text_similarity_filter(dataset, tgt_list, op) + + def test_filter_videos_within_range_keyframes(self): + ds_list = [{ + 'videos': [self.vid1_path], + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + }, { + 'videos': [self.vid2_path], + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。' + f'{SpecialTokens.eoc}', + }, { + 'videos': [self.vid3_path], + 'text': + f'两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.video} {SpecialTokens.eoc}', + }] + tgt_list = [{'videos': [self.vid2_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoFramesTextSimilarityFilter(self.hf_clip, + min_score=0.26, + max_score=0.3) + self._run_video_frames_text_similarity_filter(dataset, tgt_list, op) + + def test_filter_uniform_frames(self): + ds_list = [{ + 'videos': [self.vid1_path], + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + }, { + 'videos': [self.vid2_path], + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。' + f'{SpecialTokens.eoc}', + }, { + 'videos': [self.vid3_path], + 'text': + f'两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.video} {SpecialTokens.eoc}', + }] + tgt_list = [{'videos': [self.vid2_path]}, {'videos': [self.vid3_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoFramesTextSimilarityFilter(self.hf_clip, + min_score=0.26, + max_score=0.3, + frame_sampling_method='uniform') + self._run_video_frames_text_similarity_filter(dataset, tgt_list, op) + + def test_filter_uniform_frames_with_different_frame_num(self): + ds_list = [{ + 'videos': [self.vid1_path], + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + }, { + 'videos': [self.vid2_path], + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。' + f'{SpecialTokens.eoc}', + }, { + 'videos': [self.vid3_path], + 'text': + f'两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.video} {SpecialTokens.eoc}', + }] + tgt_list = [{'videos': [self.vid2_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoFramesTextSimilarityFilter(self.hf_clip, + min_score=0.26, + max_score=0.3, + frame_sampling_method='uniform', + frame_num=2) + self._run_video_frames_text_similarity_filter(dataset, tgt_list, op) + + def test_any(self): + + ds_list = [{ + 'videos': [self.vid1_path, self.vid2_path], + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。' + f'{SpecialTokens.eoc} {SpecialTokens.video} 身穿白色上衣的男子,' + f'拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + }, { + 'videos': [self.vid2_path, self.vid3_path], + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。' + f'{SpecialTokens.eoc} 两个长头发的女子正坐在一张圆桌前讲话互动。 ' + f'{SpecialTokens.video} {SpecialTokens.eoc}', + }, { + 'videos': [self.vid1_path, self.vid3_path], + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。' + f'{SpecialTokens.eoc} 两个长头发的女子正坐在一张圆桌前讲话互动。 ' + f'{SpecialTokens.video} {SpecialTokens.eoc}', + }] + tgt_list = [{ + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'videos': [self.vid2_path, self.vid3_path] + }] + dataset = Dataset.from_list(ds_list) + op = VideoFramesTextSimilarityFilter(self.hf_clip, + min_score=0.26, + max_score=0.3, + frame_sampling_method='uniform', + frame_num=2, + any_or_all='any') + self._run_video_frames_text_similarity_filter(dataset, tgt_list, op) + + def test_all(self): + ds_list = [{ + 'videos': [self.vid1_path, self.vid2_path], + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。' + f'{SpecialTokens.eoc} {SpecialTokens.video} 身穿白色上衣的男子,' + f'拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + }, { + 'videos': [self.vid2_path, self.vid3_path], + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。' + f'{SpecialTokens.eoc} 两个长头发的女子正坐在一张圆桌前讲话互动。 ' + f'{SpecialTokens.video} {SpecialTokens.eoc}', + }, { + 'videos': [self.vid1_path, self.vid3_path], + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。' + f'{SpecialTokens.eoc} 两个长头发的女子正坐在一张圆桌前讲话互动。 ' + f'{SpecialTokens.video} {SpecialTokens.eoc}', + }] + tgt_list = [] + dataset = Dataset.from_list(ds_list) + op = VideoFramesTextSimilarityFilter(self.hf_clip, + min_score=0.26, + max_score=0.3, + frame_sampling_method='uniform', + frame_num=2, + any_or_all='all') + self._run_video_frames_text_similarity_filter(dataset, tgt_list, op) + + def test_filter_in_parallel(self): + + ds_list = [{ + 'videos': [self.vid1_path], + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + }, { + 'videos': [self.vid2_path], + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。' + f'{SpecialTokens.eoc}', + }, { + 'videos': [self.vid3_path], + 'text': + f'两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.video} {SpecialTokens.eoc}', + }] + tgt_list = [{'videos': [self.vid2_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoFramesTextSimilarityFilter( + self.hf_clip, + min_score=0.26, + max_score=0.3, + ) + self._run_video_frames_text_similarity_filter(dataset, + tgt_list, + op, + np=2) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/test_video_motion_score_filter.py b/tests/ops/filter/test_video_motion_score_filter.py new file mode 100644 index 000000000..0c7ce3f5d --- /dev/null +++ b/tests/ops/filter/test_video_motion_score_filter.py @@ -0,0 +1,140 @@ +import os +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.video_motion_score_filter import \ + VideoMotionScoreFilter +from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class VideoMotionScoreFilterTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + vid1_path = os.path.join(data_path, 'video1.mp4') # 1.8210126 + vid2_path = os.path.join(data_path, 'video2.mp4') # 3.600746 + vid3_path = os.path.join(data_path, 'video3.mp4') # 1.1822891 + + def _run_helper(self, op, source_list, target_list, np=1): + dataset = Dataset.from_list(source_list) + if Fields.stats not in dataset.features: + dataset = dataset.add_column(name=Fields.stats, + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats, num_proc=np) + dataset = dataset.filter(op.process, num_proc=np) + dataset = dataset.select_columns(column_names=[op.video_key]) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_default(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + op = VideoMotionScoreFilter() + self._run_helper(op, ds_list, tgt_list) + + def test_high(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{'videos': [self.vid2_path]}] + op = VideoMotionScoreFilter(min_score=3.0) + self._run_helper(op, ds_list, tgt_list) + + def test_low(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{'videos': [self.vid3_path]}] + op = VideoMotionScoreFilter(min_score=0.0, max_score=1.50) + self._run_helper(op, ds_list, tgt_list) + + def test_middle(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{'videos': [self.vid1_path]}] + op = VideoMotionScoreFilter(min_score=1.5, max_score=3.0) + self._run_helper(op, ds_list, tgt_list) + + def test_any(self): + + ds_list = [{ + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'videos': [self.vid2_path, self.vid3_path] + }, { + 'videos': [self.vid1_path, self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'videos': [self.vid1_path, self.vid3_path] + }] + op = VideoMotionScoreFilter(min_score=1.5, + max_score=3.0, + any_or_all='any') + self._run_helper(op, ds_list, tgt_list) + + def test_all(self): + + ds_list = [{ + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'videos': [self.vid2_path, self.vid3_path] + }, { + 'videos': [self.vid1_path, self.vid3_path] + }] + tgt_list = [] + op = VideoMotionScoreFilter(min_score=1.5, + max_score=3.0, + any_or_all='all') + self._run_helper(op, ds_list, tgt_list) + + def test_parallel(self): + import multiprocess as mp + mp.set_start_method('forkserver', force=True) + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{'videos': [self.vid1_path]}] + op = VideoMotionScoreFilter(min_score=1.5, max_score=3.0) + self._run_helper(op, ds_list, tgt_list, np=2) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/test_video_ocr_area_ratio_filter.py b/tests/ops/filter/test_video_ocr_area_ratio_filter.py new file mode 100644 index 000000000..420094d2b --- /dev/null +++ b/tests/ops/filter/test_video_ocr_area_ratio_filter.py @@ -0,0 +1,157 @@ +import os +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.video_ocr_area_ratio_filter import \ + VideoOcrAreaRatioFilter +from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class VideoOcrAreaRatioFilterTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + vid1_path = os.path.join(data_path, 'video1.mp4') # about 0.067 + vid2_path = os.path.join(data_path, 'video2.mp4') # about 0.288 + vid3_path = os.path.join(data_path, 'video3.mp4') # about 0.075 + + def _run_video_ocr_area_ratio_filter(self, + dataset: Dataset, + target_list, + op, + np=1): + if Fields.stats not in dataset.features: + dataset = dataset.add_column(name=Fields.stats, + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats, num_proc=np) + dataset = dataset.filter(op.process, num_proc=np) + dataset = dataset.select_columns(column_names=[op.video_key]) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_default_filter(self): + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + dataset = Dataset.from_list(ds_list) + op = VideoOcrAreaRatioFilter() + self._run_video_ocr_area_ratio_filter(dataset, tgt_list, op) + + def test_filter_large_ratio_videos(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{'videos': [self.vid1_path]}, {'videos': [self.vid3_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoOcrAreaRatioFilter(max_area_ratio=0.1) + self._run_video_ocr_area_ratio_filter(dataset, tgt_list, op) + + def test_filter_small_ratio_videos(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{'videos': [self.vid2_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoOcrAreaRatioFilter(min_area_ratio=0.2) + self._run_video_ocr_area_ratio_filter(dataset, tgt_list, op) + + def test_filter_videos_within_range(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{'videos': [self.vid3_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoOcrAreaRatioFilter(min_area_ratio=0.07, max_area_ratio=0.1) + self._run_video_ocr_area_ratio_filter(dataset, tgt_list, op) + + def test_any(self): + + ds_list = [{ + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'videos': [self.vid2_path, self.vid3_path] + }, { + 'videos': [self.vid1_path, self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid2_path, self.vid3_path] + }, { + 'videos': [self.vid1_path, self.vid3_path] + }] + dataset = Dataset.from_list(ds_list) + op = VideoOcrAreaRatioFilter(min_area_ratio=0.07, + max_area_ratio=0.1, + any_or_all='any') + self._run_video_ocr_area_ratio_filter(dataset, tgt_list, op) + + def test_all(self): + + ds_list = [{ + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'videos': [self.vid2_path, self.vid3_path] + }, { + 'videos': [self.vid1_path, self.vid3_path] + }] + tgt_list = [] + dataset = Dataset.from_list(ds_list) + op = VideoOcrAreaRatioFilter(min_area_ratio=0.07, + max_area_ratio=0.1, + any_or_all='all') + self._run_video_ocr_area_ratio_filter(dataset, tgt_list, op) + + def test_filter_in_parallel(self): + + # WARNING: current parallel tests only work in spawn method + import multiprocess + original_method = multiprocess.get_start_method() + multiprocess.set_start_method('spawn', force=True) + # WARNING: current parallel tests only work in spawn method + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{'videos': [self.vid3_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoOcrAreaRatioFilter(min_area_ratio=0.07, max_area_ratio=0.1) + self._run_video_ocr_area_ratio_filter(dataset, tgt_list, op, np=2) + + # WARNING: current parallel tests only work in spawn method + multiprocess.set_start_method(original_method, force=True) + # WARNING: current parallel tests only work in spawn method + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/test_video_resolution_filter.py b/tests/ops/filter/test_video_resolution_filter.py new file mode 100644 index 000000000..210662a3e --- /dev/null +++ b/tests/ops/filter/test_video_resolution_filter.py @@ -0,0 +1,151 @@ +import os +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.video_resolution_filter import \ + VideoResolutionFilter +from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class VideoResolutionFilterTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + # video1: horizontal resolution 640p, vertical resolution 360p + # video2: horizontal resolution 480p, vertical resolution 640p + # video3: horizontal resolution 362p, vertical resolution 640p + vid1_path = os.path.join(data_path, 'video1.mp4') + vid2_path = os.path.join(data_path, 'video2.mp4') + vid3_path = os.path.join(data_path, 'video3.mp4') + + def _run_video_resolution_filter(self, + dataset: Dataset, + target_list, + op, + np=1): + if Fields.stats not in dataset.features: + dataset = dataset.add_column(name=Fields.stats, + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats, num_proc=np) + dataset = dataset.filter(op.process, num_proc=np) + dataset = dataset.select_columns(column_names=[op.video_key]) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_default_filter(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + dataset = Dataset.from_list(ds_list) + op = VideoResolutionFilter() + self._run_video_resolution_filter(dataset, tgt_list, op) + + def test_filter_low_resolution_videos(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{'videos': [self.vid2_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoResolutionFilter(min_width=480, min_height=480) + self._run_video_resolution_filter(dataset, tgt_list, op) + + def test_filter_high_resolution_videos(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{'videos': [self.vid1_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoResolutionFilter(max_width=640, max_height=480) + self._run_video_resolution_filter(dataset, tgt_list, op) + + def test_filter_videos_within_range(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{'videos': [self.vid2_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoResolutionFilter(min_width=400, max_width=500) + self._run_video_resolution_filter(dataset, tgt_list, op) + + def test_any(self): + + ds_list = [{ + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'videos': [self.vid2_path, self.vid3_path] + }, { + 'videos': [self.vid1_path, self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'videos': [self.vid2_path, self.vid3_path] + }] + dataset = Dataset.from_list(ds_list) + op = VideoResolutionFilter(min_width=400, + max_width=500, + any_or_all='any') + self._run_video_resolution_filter(dataset, tgt_list, op) + + def test_all(self): + + ds_list = [{ + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'videos': [self.vid2_path, self.vid3_path] + }, { + 'videos': [self.vid1_path, self.vid3_path] + }] + tgt_list = [] + dataset = Dataset.from_list(ds_list) + op = VideoResolutionFilter(min_width=400, + max_width=500, + any_or_all='all') + self._run_video_resolution_filter(dataset, tgt_list, op) + + def test_filter_in_parallel(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{'videos': [self.vid2_path]}] + dataset = Dataset.from_list(ds_list) + op = VideoResolutionFilter(min_width=400, max_width=500) + self._run_video_resolution_filter(dataset, tgt_list, op, np=2) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/test_word_num_filter.py b/tests/ops/filter/test_word_num_filter.py index d4ee8b239..6a4967f97 100644 --- a/tests/ops/filter/test_word_num_filter.py +++ b/tests/ops/filter/test_word_num_filter.py @@ -4,9 +4,10 @@ from data_juicer.ops.filter.word_num_filter import WordNumFilter from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class WordNumFilterTest(unittest.TestCase): +class WordNumFilterTest(DataJuicerTestCaseBase): def _run_word_num_filter(self, dataset: Dataset, target_list, op): if Fields.stats not in dataset.features: diff --git a/tests/ops/filter/test_word_repetition_filter.py b/tests/ops/filter/test_word_repetition_filter.py index 53435fd70..cf5f02330 100644 --- a/tests/ops/filter/test_word_repetition_filter.py +++ b/tests/ops/filter/test_word_repetition_filter.py @@ -4,9 +4,10 @@ from data_juicer.ops.filter.word_repetition_filter import WordRepetitionFilter from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class WordRepetitionFilterTest(unittest.TestCase): +class WordRepetitionFilterTest(DataJuicerTestCaseBase): def _run_word_repetition_filter(self, dataset: Dataset, target_list, op): if Fields.stats not in dataset.features: diff --git a/tests/ops/mapper/test_audio_ffmpeg_wrapped_mapper.py b/tests/ops/mapper/test_audio_ffmpeg_wrapped_mapper.py new file mode 100644 index 000000000..4ee4fdd61 --- /dev/null +++ b/tests/ops/mapper/test_audio_ffmpeg_wrapped_mapper.py @@ -0,0 +1,60 @@ +import os +import unittest + +import librosa +from datasets import Dataset + +from data_juicer.ops.mapper.audio_ffmpeg_wrapped_mapper import \ + AudioFFmpegWrappedMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class AudioFFmpegWrappedMapperTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + aud1_path = os.path.join(data_path, 'audio1.wav') # 5.501678004535147 + aud2_path = os.path.join(data_path, 'audio2.wav') # 14.142426303854876 + aud3_path = os.path.join(data_path, 'audio3.ogg') # 119.87591836734694 + + def _run_op(self, ds_list, target_list, op, np=1): + dataset = Dataset.from_list(ds_list) + dataset = dataset.map(op.process, num_proc=np) + + def get_size(dataset): + durations = [] + res_list = dataset.to_list() + for sample in res_list: + sample_durations = [] + for aud_path in sample['audios']: + sample_durations.append( + librosa.get_duration(path=aud_path)) + durations.append(sample_durations) + return durations + + sizes = get_size(dataset) + self.assertEqual(sizes, target_list) + + def test_resize(self): + ds_list = [{ + 'audios': [self.aud1_path, self.aud2_path, self.aud3_path] + }] + tgt_list = [[5.501678004535147, 6.0, 6.0]] + op = AudioFFmpegWrappedMapper('atrim', + filter_kwargs={'end': 6}, + capture_stderr=False) + self._run_op(ds_list, tgt_list, op) + + def test_resize_parallel(self): + ds_list = [{ + 'audios': [self.aud1_path, self.aud2_path, self.aud3_path] + }] + tgt_list = [[5.501678004535147, 6.0, 6.0]] + op = AudioFFmpegWrappedMapper('atrim', + filter_kwargs={'end': 6}, + capture_stderr=False) + self._run_op(ds_list, tgt_list, op, np=2) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_chinese_convert_mapper.py b/tests/ops/mapper/test_chinese_convert_mapper.py index fd35bbbb1..9bbe8e8df 100644 --- a/tests/ops/mapper/test_chinese_convert_mapper.py +++ b/tests/ops/mapper/test_chinese_convert_mapper.py @@ -1,9 +1,10 @@ import unittest from data_juicer.ops.mapper.chinese_convert_mapper import ChineseConvertMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class ChineseConvertMapperTest(unittest.TestCase): +class ChineseConvertMapperTest(DataJuicerTestCaseBase): def setUp(self, mode='s2t'): self.op = ChineseConvertMapper(mode) diff --git a/tests/ops/mapper/test_clean_copyright_mapper.py b/tests/ops/mapper/test_clean_copyright_mapper.py index 302942d26..726d829f7 100644 --- a/tests/ops/mapper/test_clean_copyright_mapper.py +++ b/tests/ops/mapper/test_clean_copyright_mapper.py @@ -1,9 +1,10 @@ import unittest from data_juicer.ops.mapper.clean_copyright_mapper import CleanCopyrightMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class CleanCopyrightMapperTest(unittest.TestCase): +class CleanCopyrightMapperTest(DataJuicerTestCaseBase): def setUp(self): self.op = CleanCopyrightMapper() diff --git a/tests/ops/mapper/test_clean_email_mapper.py b/tests/ops/mapper/test_clean_email_mapper.py index 9e20aede9..b3f0e5e9a 100644 --- a/tests/ops/mapper/test_clean_email_mapper.py +++ b/tests/ops/mapper/test_clean_email_mapper.py @@ -1,9 +1,10 @@ import unittest from data_juicer.ops.mapper.clean_email_mapper import CleanEmailMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class CleanEmailMapperTest(unittest.TestCase): +class CleanEmailMapperTest(DataJuicerTestCaseBase): def _run_clean_email(self, op, samples): for sample in samples: @@ -45,6 +46,7 @@ def test_replace_email(self): }] op = CleanEmailMapper(repl='') self._run_clean_email(op, samples) - + + if __name__ == '__main__': unittest.main() diff --git a/tests/ops/mapper/test_clean_html_mapper.py b/tests/ops/mapper/test_clean_html_mapper.py index ecab4114d..69249b60a 100644 --- a/tests/ops/mapper/test_clean_html_mapper.py +++ b/tests/ops/mapper/test_clean_html_mapper.py @@ -1,9 +1,10 @@ import unittest from data_juicer.ops.mapper.clean_html_mapper import CleanHtmlMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class CleanHtmlMapperTest(unittest.TestCase): +class CleanHtmlMapperTest(DataJuicerTestCaseBase): def setUp(self): self.op = CleanHtmlMapper() diff --git a/tests/ops/mapper/test_clean_ip_mapper.py b/tests/ops/mapper/test_clean_ip_mapper.py index 85d61c569..ccbaf52b7 100644 --- a/tests/ops/mapper/test_clean_ip_mapper.py +++ b/tests/ops/mapper/test_clean_ip_mapper.py @@ -1,9 +1,10 @@ import unittest from data_juicer.ops.mapper.clean_ip_mapper import CleanIpMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class CleanIpMapperTest(unittest.TestCase): +class CleanIpMapperTest(DataJuicerTestCaseBase): def _run_clean_ip(self, op, samples): for sample in samples: @@ -63,5 +64,7 @@ def test_replace_ipv4(self): }] op = CleanIpMapper(repl='') self._run_clean_ip(op, samples) + + if __name__ == '__main__': unittest.main() diff --git a/tests/ops/mapper/test_clean_links_mapper.py b/tests/ops/mapper/test_clean_links_mapper.py index 5c22e7ccd..28e14b2d9 100644 --- a/tests/ops/mapper/test_clean_links_mapper.py +++ b/tests/ops/mapper/test_clean_links_mapper.py @@ -1,9 +1,10 @@ import unittest from data_juicer.ops.mapper.clean_links_mapper import CleanLinksMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class CleanLinksMapperTest(unittest.TestCase): +class CleanLinksMapperTest(DataJuicerTestCaseBase): def setUp(self): self.op = CleanLinksMapper() @@ -216,22 +217,28 @@ def test_no_link_text(self): def test_replace_links_text(self): - samples = [{ - 'text': 'ftp://user:password@ftp.example.com:21/', - 'target': '' - }, { - 'text': 'This is a sample for test', - 'target': 'This is a sample for test', - }, { - 'text': 'abcd://ef is a sample for test', - 'target': ' is a sample for test', - }, { + samples = [ + { + 'text': 'ftp://user:password@ftp.example.com:21/', + 'target': '' + }, + { + 'text': 'This is a sample for test', + 'target': 'This is a sample for test', + }, + { + 'text': 'abcd://ef is a sample for test', + 'target': ' is a sample for test', + }, + { 'text': 'HTTP://example.com/my-page.html?param1=value1¶m2=value2', 'target': '' - },] + }, + ] op = CleanLinksMapper(repl='') self._run_clean_links(op, samples) + if __name__ == '__main__': unittest.main() diff --git a/tests/ops/mapper/test_exapnd_macro_mapper.py b/tests/ops/mapper/test_exapnd_macro_mapper.py index 3cdc8a0c1..68dbf047b 100644 --- a/tests/ops/mapper/test_exapnd_macro_mapper.py +++ b/tests/ops/mapper/test_exapnd_macro_mapper.py @@ -1,9 +1,10 @@ import unittest from data_juicer.ops.mapper.expand_macro_mapper import ExpandMacroMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class ExpandMacroMapperTest(unittest.TestCase): +class ExpandMacroMapperTest(DataJuicerTestCaseBase): def setUp(self): self.op = ExpandMacroMapper() diff --git a/tests/ops/mapper/test_fix_unicode_mapper.py b/tests/ops/mapper/test_fix_unicode_mapper.py index f77e53eb7..547020b51 100644 --- a/tests/ops/mapper/test_fix_unicode_mapper.py +++ b/tests/ops/mapper/test_fix_unicode_mapper.py @@ -1,9 +1,10 @@ import unittest from data_juicer.ops.mapper.fix_unicode_mapper import FixUnicodeMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class FixUnicodeMapperTest(unittest.TestCase): +class FixUnicodeMapperTest(DataJuicerTestCaseBase): def setUp(self): self.op = FixUnicodeMapper() diff --git a/tests/ops/mapper/test_image_blur_mapper.py b/tests/ops/mapper/test_image_blur_mapper.py index c0885e295..632c1978b 100644 --- a/tests/ops/mapper/test_image_blur_mapper.py +++ b/tests/ops/mapper/test_image_blur_mapper.py @@ -1,25 +1,23 @@ import os import unittest -import numpy as np +import numpy as np from datasets import Dataset -from data_juicer.utils.mm_utils import load_image from data_juicer.ops.mapper.image_blur_mapper import ImageBlurMapper +from data_juicer.utils.mm_utils import load_image +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class ImageBlurMapperTest(unittest.TestCase): +class ImageBlurMapperTest(DataJuicerTestCaseBase): - data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), - '..', 'data') + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') img1_path = os.path.join(data_path, 'img1.png') img2_path = os.path.join(data_path, 'img2.jpg') img3_path = os.path.join(data_path, 'img3.jpg') - def _get_blured_img_path(self, path): - return os.path.join(os.path.dirname(path), '_blured.'.join(os.path.basename(path).split('.'))) - - def _get_blur_kernel(self, blur_type = 'gaussian', radius = 2): + def _get_blur_kernel(self, blur_type='gaussian', radius=2): from PIL import ImageFilter if blur_type == 'mean': return ImageFilter.BLUR @@ -28,11 +26,10 @@ def _get_blur_kernel(self, blur_type = 'gaussian', radius = 2): else: return ImageFilter.GaussianBlur(radius) - def _run_image_blur_mapper(self, op, source_list, target_list, blur_kernel): + def _run_image_blur_mapper(self, op, source_list, blur_kernel): dataset = Dataset.from_list(source_list) dataset = dataset.map(op.process) res_list = dataset.to_list() - self.assertEqual(res_list, target_list) for source, res in zip(source_list, res_list): for s_path, r_path in zip(source[op.image_key], res[op.image_key]): s_img = load_image(s_path).convert('RGB').filter(blur_kernel) @@ -51,16 +48,9 @@ def test(self): }, { 'images': [self.img3_path] }] - tgt_list = [{ - 'images': [self._get_blured_img_path(self.img1_path)] - }, { - 'images': [self._get_blured_img_path(self.img2_path)] - }, { - 'images': [self._get_blured_img_path(self.img3_path)] - }] - op = ImageBlurMapper(p = 1, blur_type = 'gaussian', radius = 2) + op = ImageBlurMapper(p=1, blur_type='gaussian', radius=2) blur_kernel = self._get_blur_kernel('gaussian', 2) - self._run_image_blur_mapper(op, ds_list, tgt_list, blur_kernel) + self._run_image_blur_mapper(op, ds_list, blur_kernel) def test_blur_type(self): ds_list = [{ @@ -70,16 +60,9 @@ def test_blur_type(self): }, { 'images': [self.img1_path] }] - tgt_list = [{ - 'images': [self._get_blured_img_path(self.img2_path)] - }, { - 'images': [self._get_blured_img_path(self.img3_path)] - }, { - 'images': [self._get_blured_img_path(self.img1_path)] - }] - op = ImageBlurMapper(p = 1, blur_type = 'box', radius = 2) + op = ImageBlurMapper(p=1, blur_type='box', radius=2) blur_kernel = self._get_blur_kernel('box', 2) - self._run_image_blur_mapper(op, ds_list, tgt_list, blur_kernel) + self._run_image_blur_mapper(op, ds_list, blur_kernel) def test_radius(self): ds_list = [{ @@ -89,16 +72,9 @@ def test_radius(self): }, { 'images': [self.img1_path] }] - tgt_list = [{ - 'images': [self._get_blured_img_path(self.img3_path)] - }, { - 'images': [self._get_blured_img_path(self.img2_path)] - }, { - 'images': [self._get_blured_img_path(self.img1_path)] - }] - op = ImageBlurMapper(p = 1, blur_type = 'gaussian', radius = 5) + op = ImageBlurMapper(p=1, blur_type='gaussian', radius=5) blur_kernel = self._get_blur_kernel('gaussian', 5) - self._run_image_blur_mapper(op, ds_list, tgt_list, blur_kernel) + self._run_image_blur_mapper(op, ds_list, blur_kernel) def test_multi_img(self): ds_list = [{ @@ -108,16 +84,9 @@ def test_multi_img(self): }, { 'images': [self.img3_path, self.img1_path] }] - tgt_list = [{ - 'images': [self._get_blured_img_path(self.img1_path), self._get_blured_img_path(self.img2_path), self._get_blured_img_path(self.img3_path)] - }, { - 'images': [self._get_blured_img_path(self.img2_path)] - }, { - 'images': [self._get_blured_img_path(self.img3_path), self._get_blured_img_path(self.img1_path)] - }] - op = ImageBlurMapper(p = 1, blur_type = 'gaussian', radius = 2) + op = ImageBlurMapper(p=1, blur_type='gaussian', radius=2) blur_kernel = self._get_blur_kernel('gaussian', 2) - self._run_image_blur_mapper(op, ds_list, tgt_list, blur_kernel) + self._run_image_blur_mapper(op, ds_list, blur_kernel) if __name__ == '__main__': diff --git a/tests/ops/mapper/test_image_captioning_mapper.py b/tests/ops/mapper/test_image_captioning_mapper.py new file mode 100644 index 000000000..56d48621f --- /dev/null +++ b/tests/ops/mapper/test_image_captioning_mapper.py @@ -0,0 +1,243 @@ +import os +import unittest + +from data_juicer.core.data import NestedDataset +from data_juicer.ops.mapper.image_captioning_mapper import \ + ImageCaptioningMapper +from data_juicer.utils.mm_utils import SpecialTokens +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) + + +# Skip tests for this OP in the GitHub actions due to disk space limitation. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class ImageCaptioningMapperTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + + cat_path = os.path.join(data_path, 'cat.jpg') + img3_path = os.path.join(data_path, 'img3.jpg') + + hf_img2seq = 'Salesforce/blip2-opt-2.7b' + + @classmethod + def tearDownClass(cls) -> None: + super().tearDownClass(cls.hf_img2seq) + + def _run_mapper(self, + dataset: NestedDataset, + op, + num_proc=1, + caption_num=0): + + dataset = dataset.map(op.process, num_proc=num_proc, with_rank=True) + dataset_list = dataset.select_columns(column_names=['text']).to_list() + # assert the caption is generated successfully in terms of not_none + # as the generated content is not deterministic + self.assertEqual(len(dataset_list), caption_num) + + def test_no_eoc_special_token(self): + + ds_list = [{ + 'text': f'{SpecialTokens.image}a photo of a cat', + 'images': [self.cat_path] + }, { + 'text': f'{SpecialTokens.image}a photo, a women with an umbrella', + 'images': [self.img3_path] + }] + caption_num = 1 + dataset = NestedDataset.from_list(ds_list) + op = ImageCaptioningMapper(hf_img2seq=self.hf_img2seq, + caption_num=caption_num, + keep_candidate_mode='random_any') + self._run_mapper(dataset, op, caption_num=len(dataset) * 2) + + def test_eoc_special_token(self): + + ds_list = [ + { + 'text': + f'{SpecialTokens.image}a photo of a cat{SpecialTokens.eoc}', + 'images': [self.cat_path] + }, + { + 'text': + f'{SpecialTokens.image}a photo, a women with an umbrella{SpecialTokens.eoc}', # noqa: E501 + 'images': [self.img3_path] + } + ] + caption_num = 1 + dataset = NestedDataset.from_list(ds_list) + op = ImageCaptioningMapper(hf_img2seq=self.hf_img2seq, + caption_num=caption_num, + keep_candidate_mode='random_any') + self._run_mapper(dataset, op, caption_num=len(dataset) * 2) + + def test_multi_candidate_keep_random_any(self): + + ds_list = [{ + 'text': f'{SpecialTokens.image}a photo of a cat', + 'images': [self.cat_path] + }, { + 'text': f'{SpecialTokens.image}a photo, a women with an umbrella', + 'images': [self.img3_path] + }] + caption_num = 4 + dataset = NestedDataset.from_list(ds_list) + op = ImageCaptioningMapper(hf_img2seq=self.hf_img2seq, + caption_num=caption_num, + keep_candidate_mode='random_any') + self._run_mapper(dataset, op, caption_num=len(dataset) * 2) + + def test_multi_candidate_keep_all(self): + + ds_list = [{ + 'text': f'{SpecialTokens.image}a photo of a cat', + 'images': [self.cat_path] + }, { + 'text': f'{SpecialTokens.image}a photo, a women with an umbrella', + 'images': [self.img3_path] + }] + caption_num = 4 + dataset = NestedDataset.from_list(ds_list) + op = ImageCaptioningMapper(hf_img2seq=self.hf_img2seq, + caption_num=caption_num, + keep_candidate_mode='all') + self._run_mapper(dataset, + op, + caption_num=(1 + caption_num) * len(dataset)) + + def test_multi_candidate_keep_similar_one(self): + ds_list = [{ + 'text': f'{SpecialTokens.image}a photo of a cat', + 'images': [self.cat_path] + }, { + 'text': f'{SpecialTokens.image}a photo, a women with an umbrella', + 'images': [self.img3_path] + }] + caption_num = 4 + dataset = NestedDataset.from_list(ds_list) + op = ImageCaptioningMapper(hf_img2seq=self.hf_img2seq, + caption_num=caption_num, + keep_candidate_mode='similar_one_simhash') + self._run_mapper(dataset, op, caption_num=len(dataset) * 2) + + def test_multi_process(self): + ds_list = [{ + 'text': f'{SpecialTokens.image}a photo of a cat', + 'images': [self.cat_path] + }] * 10 + caption_num = 1 + dataset = NestedDataset.from_list(ds_list) + op = ImageCaptioningMapper(hf_img2seq=self.hf_img2seq, + caption_num=caption_num, + keep_candidate_mode='random_any') + self._run_mapper(dataset, op, num_proc=4, caption_num=len(dataset) * 2) + + def test_no_eoc_special_token_remove_original_sample(self): + + ds_list = [{ + 'text': f'{SpecialTokens.image}a photo of a cat', + 'images': [self.cat_path] + }, { + 'text': f'{SpecialTokens.image}a photo, a women with an umbrella', + 'images': [self.img3_path] + }] + caption_num = 1 + dataset = NestedDataset.from_list(ds_list) + op = ImageCaptioningMapper(hf_img2seq=self.hf_img2seq, + caption_num=caption_num, + keep_candidate_mode='random_any', + keep_original_sample=False) + self._run_mapper(dataset, op, caption_num=len(dataset)) + + def test_eoc_special_token_remove_original_sample(self): + + ds_list = [ + { + 'text': + f'{SpecialTokens.image}a photo of a cat{SpecialTokens.eoc}', + 'images': [self.cat_path] + }, + { + 'text': + f'{SpecialTokens.image}a photo, a women with an umbrella{SpecialTokens.eoc}', # noqa: E501 + 'images': [self.img3_path] + } + ] + caption_num = 1 + dataset = NestedDataset.from_list(ds_list) + op = ImageCaptioningMapper(hf_img2seq=self.hf_img2seq, + caption_num=caption_num, + keep_candidate_mode='random_any', + keep_original_sample=False) + self._run_mapper(dataset, op, caption_num=len(dataset)) + + def test_multi_candidate_keep_random_any_remove_original_sample(self): + + ds_list = [{ + 'text': f'{SpecialTokens.image}a photo of a cat', + 'images': [self.cat_path] + }, { + 'text': f'{SpecialTokens.image}a photo, a women with an umbrella', + 'images': [self.img3_path] + }] + caption_num = 4 + dataset = NestedDataset.from_list(ds_list) + op = ImageCaptioningMapper(hf_img2seq=self.hf_img2seq, + caption_num=caption_num, + keep_candidate_mode='random_any', + keep_original_sample=False) + self._run_mapper(dataset, op, caption_num=len(dataset)) + + def test_multi_candidate_keep_all_remove_original_sample(self): + + ds_list = [{ + 'text': f'{SpecialTokens.image}a photo of a cat', + 'images': [self.cat_path] + }, { + 'text': f'{SpecialTokens.image}a photo, a women with an umbrella', + 'images': [self.img3_path] + }] + caption_num = 4 + dataset = NestedDataset.from_list(ds_list) + op = ImageCaptioningMapper(hf_img2seq=self.hf_img2seq, + caption_num=caption_num, + keep_candidate_mode='all', + keep_original_sample=False) + self._run_mapper(dataset, op, caption_num=caption_num * len(dataset)) + + def test_multi_candidate_keep_similar_one_remove_original_sample(self): + ds_list = [{ + 'text': f'{SpecialTokens.image}a photo of a cat', + 'images': [self.cat_path] + }, { + 'text': f'{SpecialTokens.image}a photo, a women with an umbrella', + 'images': [self.img3_path] + }] + caption_num = 4 + dataset = NestedDataset.from_list(ds_list) + op = ImageCaptioningMapper(hf_img2seq=self.hf_img2seq, + caption_num=caption_num, + keep_candidate_mode='similar_one_simhash', + keep_original_sample=False) + self._run_mapper(dataset, op, caption_num=len(dataset)) + + def test_multi_process_remove_original_sample(self): + ds_list = [{ + 'text': f'{SpecialTokens.image}a photo of a cat', + 'images': [self.cat_path] + }] * 10 + caption_num = 1 + dataset = NestedDataset.from_list(ds_list) + op = ImageCaptioningMapper(hf_img2seq=self.hf_img2seq, + caption_num=caption_num, + keep_candidate_mode='random_any', + keep_original_sample=False) + self._run_mapper(dataset, op, num_proc=4, caption_num=len(dataset)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_image_diffusion_mapper.py b/tests/ops/mapper/test_image_diffusion_mapper.py index 30bf3b7d3..bdc0d0ea4 100644 --- a/tests/ops/mapper/test_image_diffusion_mapper.py +++ b/tests/ops/mapper/test_image_diffusion_mapper.py @@ -2,11 +2,13 @@ import shutil import unittest +from data_juicer import _cuda_device_count from data_juicer.core.data import NestedDataset -from data_juicer.ops.mapper.image_diffusion_mapper import \ - ImageDiffusionMapper +from data_juicer.ops.mapper.image_diffusion_mapper import ImageDiffusionMapper from data_juicer.utils.mm_utils import SpecialTokens -from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) + # Skip tests for this OP in the GitHub actions due to disk space limitation. # These tests have been tested locally. @@ -20,7 +22,7 @@ class ImageDiffusionMapperTest(DataJuicerTestCaseBase): img3_path = os.path.join(data_path, 'img3.jpg') hf_diffusion = 'CompVis/stable-diffusion-v1-4' - hf_blip2 = 'Salesforce/blip2-opt-2.7b' + hf_img2seq = 'Salesforce/blip2-opt-2.7b' # dir to save the images produced in the tests output_dir = '../diffusion_output/' @@ -28,12 +30,18 @@ class ImageDiffusionMapperTest(DataJuicerTestCaseBase): @classmethod def tearDownClass(cls) -> None: super().tearDownClass(cls.hf_diffusion) - super().tearDownClass(cls.hf_blip2) + super().tearDownClass(cls.hf_img2seq) - def _run_mapper(self, dataset: NestedDataset, op, move_to_dir, num_proc=1, total_num=1): + def _run_mapper(self, + dataset: NestedDataset, + op, + move_to_dir, + num_proc=1, + total_num=1): dataset = dataset.map(op.process, num_proc=num_proc, with_rank=True) - dataset_list = dataset.select_columns(column_names=['images']).to_list() + dataset_list = dataset.select_columns( + column_names=['images']).to_list() self.assertEqual(len(dataset_list), total_num) if not os.path.exists(move_to_dir): @@ -42,29 +50,27 @@ def _run_mapper(self, dataset: NestedDataset, op, move_to_dir, num_proc=1, total for image_path in data['images']: if str(image_path) != str(self.cat_path) \ and str(image_path) != str(self.img3_path): - move_to_path = os.path.join(move_to_dir, os.path.basename(image_path)) - shutil.move(image_path, move_to_path) + cp_to_path = os.path.join(move_to_dir, + os.path.basename(image_path)) + shutil.copyfile(image_path, cp_to_path) def test_for_strength(self): ds_list = [{ 'text': f'{SpecialTokens.image}a photo of a cat', - 'caption': f'a women with an umbrella', + 'caption': 'a women with an umbrella', 'images': [self.cat_path] }] aug_num = 3 dataset = NestedDataset.from_list(ds_list) - op = ImageDiffusionMapper( - hf_diffusion=self.hf_diffusion, - strength=1.0, - aug_num=aug_num, - keep_original_sample=True, - caption_key='caption' - ) - self._run_mapper( - dataset, op, - os.path.join(self.output_dir, 'test_for_strength'), - total_num=(aug_num+1)*len(ds_list)) - + op = ImageDiffusionMapper(hf_diffusion=self.hf_diffusion, + strength=1.0, + aug_num=aug_num, + keep_original_sample=True, + caption_key='caption') + self._run_mapper(dataset, + op, + os.path.join(self.output_dir, 'test_for_strength'), + total_num=(aug_num + 1) * len(ds_list)) def test_for_given_caption_list(self): @@ -76,16 +82,15 @@ def test_for_given_caption_list(self): aug_num = 2 dataset = NestedDataset.from_list(ds_list) - op = ImageDiffusionMapper( - hf_diffusion=self.hf_diffusion, - aug_num=aug_num, - keep_original_sample=False, - caption_key='captions' - ) - self._run_mapper( - dataset, op, - os.path.join(self.output_dir, 'test_for_given_caption_list'), - total_num=aug_num*len(ds_list)) + op = ImageDiffusionMapper(hf_diffusion=self.hf_diffusion, + aug_num=aug_num, + keep_original_sample=False, + caption_key='captions') + self._run_mapper(dataset, + op, + os.path.join(self.output_dir, + 'test_for_given_caption_list'), + total_num=aug_num * len(ds_list)) def test_for_given_caption_string(self): @@ -99,16 +104,15 @@ def test_for_given_caption_string(self): aug_num = 1 dataset = NestedDataset.from_list(ds_list) - op = ImageDiffusionMapper( - hf_diffusion=self.hf_diffusion, - aug_num=aug_num, - keep_original_sample=False, - caption_key='text' - ) - self._run_mapper( - dataset, op, - os.path.join(self.output_dir, 'test_for_given_caption_string'), - total_num=aug_num*len(ds_list)) + op = ImageDiffusionMapper(hf_diffusion=self.hf_diffusion, + aug_num=aug_num, + keep_original_sample=False, + caption_key='text') + self._run_mapper(dataset, + op, + os.path.join(self.output_dir, + 'test_for_given_caption_string'), + total_num=aug_num * len(ds_list)) def test_for_no_given_caption(self): @@ -122,16 +126,15 @@ def test_for_no_given_caption(self): aug_num = 2 dataset = NestedDataset.from_list(ds_list) - op = ImageDiffusionMapper( - hf_diffusion=self.hf_diffusion, - aug_num=aug_num, - keep_original_sample=False, - hf_blip2=self.hf_blip2 - ) - self._run_mapper( - dataset, op, - os.path.join(self.output_dir, 'test_for_no_given_caption'), - total_num=aug_num*len(ds_list)) + op = ImageDiffusionMapper(hf_diffusion=self.hf_diffusion, + aug_num=aug_num, + keep_original_sample=False, + hf_img2seq=self.hf_img2seq) + self._run_mapper(dataset, + op, + os.path.join(self.output_dir, + 'test_for_no_given_caption'), + total_num=aug_num * len(ds_list)) def test_for_fp16_given_caption_string(self): @@ -145,17 +148,16 @@ def test_for_fp16_given_caption_string(self): aug_num = 1 dataset = NestedDataset.from_list(ds_list) - op = ImageDiffusionMapper( - hf_diffusion=self.hf_diffusion, - floating_point='fp16', - aug_num=aug_num, - keep_original_sample=False, - caption_key='text' - ) - self._run_mapper( - dataset, op, - os.path.join(self.output_dir, 'test_for_fp16_given_caption_string'), - total_num=aug_num*len(ds_list)) + op = ImageDiffusionMapper(hf_diffusion=self.hf_diffusion, + floating_point='fp16', + aug_num=aug_num, + keep_original_sample=False, + caption_key='text') + self._run_mapper(dataset, + op, + os.path.join(self.output_dir, + 'test_for_fp16_given_caption_string'), + total_num=aug_num * len(ds_list)) def test_for_multi_process_given_caption_string(self): @@ -169,17 +171,23 @@ def test_for_multi_process_given_caption_string(self): aug_num = 1 dataset = NestedDataset.from_list(ds_list) - op = ImageDiffusionMapper( - hf_diffusion=self.hf_diffusion, - aug_num=aug_num, - keep_original_sample=False, - caption_key='text' - ) - self._run_mapper( - dataset, op, - os.path.join(self.output_dir, 'test_for_given_caption_string'), - num_proc=2, - total_num=aug_num*len(ds_list)) + op = ImageDiffusionMapper(hf_diffusion=self.hf_diffusion, + aug_num=aug_num, + keep_original_sample=False, + caption_key='text') + + # set num_proc <= the number of CUDA if it is available + num_proc = 2 + if _cuda_device_count() == 1: + num_proc = 1 + + self._run_mapper(dataset, + op, + os.path.join(self.output_dir, + 'test_for_given_caption_string'), + num_proc=num_proc, + total_num=aug_num * len(ds_list)) + if __name__ == '__main__': unittest.main() diff --git a/tests/ops/mapper/test_nlpaug_en_mapper.py b/tests/ops/mapper/test_nlpaug_en_mapper.py index fa93e9273..5451ffd7c 100644 --- a/tests/ops/mapper/test_nlpaug_en_mapper.py +++ b/tests/ops/mapper/test_nlpaug_en_mapper.py @@ -1,10 +1,13 @@ +# flake8: noqa: E501 + import unittest from data_juicer.core import NestedDataset from data_juicer.ops.mapper.nlpaug_en_mapper import NlpaugEnMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class NlpaugEnMapperTest(unittest.TestCase): +class NlpaugEnMapperTest(DataJuicerTestCaseBase): def setUp(self): self.samples = NestedDataset.from_dict({ @@ -119,7 +122,8 @@ def test_all_aug_methods_with_sequential_off(self): (aug_num * aug_method_num + 1) * len(self.samples)) self.assertEqual(len(result['meta']), len(result['text'])) - def test_number_of_generated_samples_with_sequential_on_remove_original_sample(self): + def test_number_of_generated_samples_with_sequential_on_remove_original_sample( + self): aug_num = 3 aug_method_num = 3 op = NlpaugEnMapper( @@ -132,11 +136,11 @@ def test_number_of_generated_samples_with_sequential_on_remove_original_sample(s ) self.assertEqual(len(op.aug), aug_method_num) result = self.samples.map(op.process) - self.assertEqual(len(result['text']), - aug_num * len(self.samples)) + self.assertEqual(len(result['text']), aug_num * len(self.samples)) self.assertEqual(len(result['meta']), len(result['text'])) - def test_number_of_generated_samples_with_sequential_off_remove_original_sample(self): + def test_number_of_generated_samples_with_sequential_off_remove_original_sample( + self): aug_num = 3 aug_method_num = 3 op = NlpaugEnMapper( @@ -201,8 +205,7 @@ def test_all_aug_methods_with_sequential_on_remove_original_sample(self): ) self.assertEqual(len(op.aug), aug_method_num) result = self.samples.map(op.process) - self.assertEqual(len(result['text']), - aug_num * len(self.samples)) + self.assertEqual(len(result['text']), aug_num * len(self.samples)) self.assertEqual(len(result['meta']), len(result['text'])) def test_all_aug_methods_with_sequential_off_remove_original_sample(self): diff --git a/tests/ops/mapper/test_nlpcda_zh_mapper.py b/tests/ops/mapper/test_nlpcda_zh_mapper.py index 6110f0130..80aa2bf84 100644 --- a/tests/ops/mapper/test_nlpcda_zh_mapper.py +++ b/tests/ops/mapper/test_nlpcda_zh_mapper.py @@ -1,10 +1,13 @@ +# flake8: noqa: E501 + import unittest from data_juicer.core import NestedDataset from data_juicer.ops.mapper.nlpcda_zh_mapper import NlpcdaZhMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class NlpaugEnMapperTest(unittest.TestCase): +class NlpaugEnMapperTest(DataJuicerTestCaseBase): def setUp(self): self.samples = NestedDataset.from_dict({ @@ -142,7 +145,8 @@ def test_all_aug_methods_with_sequential_off(self): self.assertGreaterEqual(len(result['text']), len(self.samples['text'])) self.assertEqual(len(result['meta']), len(result['text'])) - def test_number_of_generated_samples_with_sequential_on_remove_original_sample(self): + def test_number_of_generated_samples_with_sequential_on_remove_original_sample( + self): aug_num = 3 aug_method_num = 3 op = NlpcdaZhMapper( @@ -160,7 +164,8 @@ def test_number_of_generated_samples_with_sequential_on_remove_original_sample(s self.assertGreaterEqual(len(result['text']), len(self.samples['text'])) self.assertEqual(len(result['meta']), len(result['text'])) - def test_number_of_generated_samples_with_sequential_off_remove_original_sample(self): + def test_number_of_generated_samples_with_sequential_off_remove_original_sample( + self): aug_num = 3 aug_method_num = 3 op = NlpcdaZhMapper( @@ -173,9 +178,9 @@ def test_number_of_generated_samples_with_sequential_off_remove_original_sample( ) self.assertEqual(len(op.aug_pipeline), aug_method_num) result = self.samples.map(op.process) - self.assertLessEqual(len(result['text']), - aug_num * aug_method_num * - len(self.samples['text'])) + self.assertLessEqual( + len(result['text']), + aug_num * aug_method_num * len(self.samples['text'])) self.assertGreaterEqual(len(result['text']), len(self.samples['text'])) self.assertEqual(len(result['meta']), len(result['text'])) @@ -244,9 +249,9 @@ def test_all_aug_methods_with_sequential_off_remove_original_sample(self): ) self.assertEqual(len(op.aug_pipeline), aug_method_num) result = self.samples.map(op.process) - self.assertLessEqual(len(result['text']), - aug_num * aug_method_num * - len(self.samples['text'])) + self.assertLessEqual( + len(result['text']), + aug_num * aug_method_num * len(self.samples['text'])) self.assertGreaterEqual(len(result['text']), len(self.samples['text'])) self.assertEqual(len(result['meta']), len(result['text'])) diff --git a/tests/ops/mapper/test_punctuation_normalization_mapper.py b/tests/ops/mapper/test_punctuation_normalization_mapper.py index a114b83b1..a69d4040e 100644 --- a/tests/ops/mapper/test_punctuation_normalization_mapper.py +++ b/tests/ops/mapper/test_punctuation_normalization_mapper.py @@ -2,9 +2,10 @@ from data_juicer.ops.mapper.punctuation_normalization_mapper import \ PunctuationNormalizationMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class PunctuationNormalizationMapperTest(unittest.TestCase): +class PunctuationNormalizationMapperTest(DataJuicerTestCaseBase): def setUp(self): self.op = PunctuationNormalizationMapper() diff --git a/tests/ops/mapper/test_remove_bibliography_mapper.py b/tests/ops/mapper/test_remove_bibliography_mapper.py index 449cb59c7..76096fe93 100644 --- a/tests/ops/mapper/test_remove_bibliography_mapper.py +++ b/tests/ops/mapper/test_remove_bibliography_mapper.py @@ -2,9 +2,10 @@ from data_juicer.ops.mapper.remove_bibliography_mapper import \ RemoveBibliographyMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class RemoveBibliographyMapperTest(unittest.TestCase): +class RemoveBibliographyMapperTest(DataJuicerTestCaseBase): def setUp(self): self.op = RemoveBibliographyMapper() diff --git a/tests/ops/mapper/test_remove_comments_mapper.py b/tests/ops/mapper/test_remove_comments_mapper.py index d61494c14..81a0df5de 100644 --- a/tests/ops/mapper/test_remove_comments_mapper.py +++ b/tests/ops/mapper/test_remove_comments_mapper.py @@ -1,9 +1,10 @@ import unittest from data_juicer.ops.mapper.remove_comments_mapper import RemoveCommentsMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class RemoveCommentsMapperTest(unittest.TestCase): +class RemoveCommentsMapperTest(DataJuicerTestCaseBase): def _run_remove_comments(self, samples, op): for sample in samples: diff --git a/tests/ops/mapper/test_remove_header_mapper.py b/tests/ops/mapper/test_remove_header_mapper.py index ea7170fad..c91bfe790 100644 --- a/tests/ops/mapper/test_remove_header_mapper.py +++ b/tests/ops/mapper/test_remove_header_mapper.py @@ -1,9 +1,10 @@ import unittest from data_juicer.ops.mapper.remove_header_mapper import RemoveHeaderMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class RemoveHeaderMapperTest(unittest.TestCase): +class RemoveHeaderMapperTest(DataJuicerTestCaseBase): def setUp(self): self.op = RemoveHeaderMapper() diff --git a/tests/ops/mapper/test_remove_long_words_mapper.py b/tests/ops/mapper/test_remove_long_words_mapper.py index 01962e508..533d7a717 100644 --- a/tests/ops/mapper/test_remove_long_words_mapper.py +++ b/tests/ops/mapper/test_remove_long_words_mapper.py @@ -2,9 +2,10 @@ from data_juicer.ops.mapper.remove_long_words_mapper import \ RemoveLongWordsMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class RemoveLongWordsMapperTest(unittest.TestCase): +class RemoveLongWordsMapperTest(DataJuicerTestCaseBase): def _run_remove_long_words(self, samples, op): for sample in samples: diff --git a/tests/ops/mapper/test_remove_non_chinese_character_mapper.py b/tests/ops/mapper/test_remove_non_chinese_character_mapper.py index d7c1953c8..283a75ab0 100644 --- a/tests/ops/mapper/test_remove_non_chinese_character_mapper.py +++ b/tests/ops/mapper/test_remove_non_chinese_character_mapper.py @@ -2,9 +2,10 @@ from data_juicer.ops.mapper.remove_non_chinese_character_mapper import \ RemoveNonChineseCharacterlMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class RemoveNonChineseCharacterlMapperrTest(unittest.TestCase): +class RemoveNonChineseCharacterlMapperrTest(DataJuicerTestCaseBase): def setUp(self, keep_alphabet=True, keep_number=True, keep_punc=True): self.op = RemoveNonChineseCharacterlMapper(keep_alphabet, keep_number, diff --git a/tests/ops/mapper/test_remove_repeat_sentences_mapper.py b/tests/ops/mapper/test_remove_repeat_sentences_mapper.py index 923ac5824..a7fe347fe 100644 --- a/tests/ops/mapper/test_remove_repeat_sentences_mapper.py +++ b/tests/ops/mapper/test_remove_repeat_sentences_mapper.py @@ -1,9 +1,13 @@ +# flake8: noqa: E501 + import unittest -from data_juicer.ops.mapper.remove_repeat_sentences_mapper import RemoveRepeatSentencesMapper +from data_juicer.ops.mapper.remove_repeat_sentences_mapper import \ + RemoveRepeatSentencesMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class RemoveRepeatSentencesMapperTest(unittest.TestCase): +class RemoveRepeatSentencesMapperTest(DataJuicerTestCaseBase): def _run_helper(self, samples, op): for sample in samples: @@ -12,44 +16,52 @@ def _run_helper(self, samples, op): def test_text(self): - samples = [ - { - 'text': '今天天气真不错,阳光明媚,适合出去散步。小明说:“今天天气真不错,我们去海边吧。” 小红回答说:“好主意!” 但是,小李觉得:“今天天气真不错,我们去爬山吧。” 今天天气真不错,阳光明媚,适合出去散步。昨天下了一整天的雨,今天终于放晴了。昨天下了一整天的雨,今天终于放晴了。', - 'target': '今天天气真不错,阳光明媚,适合出去散步。小明说:“今天天气真不错,我们去海边吧。” 小红回答说:“好主意!” 但是,小李觉得:“今天天气真不错,我们去爬山吧。”昨天下了一整天的雨,今天终于放晴了。', - }, { - 'text': 'The quick brown fox jumps over the lazy dog. Isn\'t it amazing how a simple sentence can contain every letter of the alphabet? The quick brown fox jumps over the lazy dog. Speaking of weather, yesterday was quite dreary; however, today is absolutely delightful. Isn\'t it amazing how a simple sentence can contain every letter of the alphabet? "Let\'s seize the day," Tom exclaimed, full of enthusiasm. "Let\'s seize the day," Tom exclaimed, full of enthusiasm.', - 'target': 'The quick brown fox jumps over the lazy dog. Isn\'t it amazing how a simple sentence can contain every letter of the alphabet? Speaking of weather, yesterday was quite dreary; however, today is absolutely delightful. "Let\'s seize the day," Tom exclaimed, full of enthusiasm.' - }, { - 'text': '''我很开心 。但是你不开心 。我很开心 。\n你好呀!我很开心 。我好的。你好呀!''', - 'target': '''我很开心 。但是你不开心 。\n你好呀!我好的。''' - }, { - 'text': '默认配置下,长度低于2的句子不会被去重。去重?去重。去重!重。重...... 重! 1234?3215. 1234. 3. 3. 3', - 'target': '默认配置下,长度低于2的句子不会被去重。去重?重。重...... 重! 1234?3215. 3. 3. 3' - } - ] + samples = [{ + 'text': + '今天天气真不错,阳光明媚,适合出去散步。小明说:“今天天气真不错,我们去海边吧。” 小红回答说:“好主意!” 但是,小李觉得:“今天天气真不错,我们去爬山吧。” 今天天气真不错,阳光明媚,适合出去散步。昨天下了一整天的雨,今天终于放晴了。昨天下了一整天的雨,今天终于放晴了。', + 'target': + '今天天气真不错,阳光明媚,适合出去散步。小明说:“今天天气真不错,我们去海边吧。” 小红回答说:“好主意!” 但是,小李觉得:“今天天气真不错,我们去爬山吧。”昨天下了一整天的雨,今天终于放晴了。', + }, { + 'text': + 'The quick brown fox jumps over the lazy dog. Isn\'t it amazing how a simple sentence can contain every letter of the alphabet? The quick brown fox jumps over the lazy dog. Speaking of weather, yesterday was quite dreary; however, today is absolutely delightful. Isn\'t it amazing how a simple sentence can contain every letter of the alphabet? "Let\'s seize the day," Tom exclaimed, full of enthusiasm. "Let\'s seize the day," Tom exclaimed, full of enthusiasm.', + 'target': + 'The quick brown fox jumps over the lazy dog. Isn\'t it amazing how a simple sentence can contain every letter of the alphabet? Speaking of weather, yesterday was quite dreary; however, today is absolutely delightful. "Let\'s seize the day," Tom exclaimed, full of enthusiasm.' + }, { + 'text': '''我很开心 。但是你不开心 。我很开心 。\n你好呀!我很开心 。我好的。你好呀!''', + 'target': '''我很开心 。但是你不开心 。\n你好呀!我好的。''' + }, { + 'text': + '默认配置下,长度低于2的句子不会被去重。去重?去重。去重!重。重...... 重! 1234?3215. 1234. 3. 3. 3', + 'target': + '默认配置下,长度低于2的句子不会被去重。去重?重。重...... 重! 1234?3215. 3. 3. 3' + }] op = RemoveRepeatSentencesMapper() self._run_helper(samples, op) def test_text2(self): - samples = [ - { - 'text': 'Life is what happens when you\'re busy making other plans. John Lennon once said. Life is what happens when you\'re busy making other plans. This phrase has resonated with many people over the years. 人生就是当你忙于制定其他计划时发生的事情。对很多人来说,这句话引起了共鸣。', - 'target': 'Life is what happens when you\'re busy making other plans. John Lennon once said. This phrase has resonated with many people over the years. 人生就是当你忙于制定其他计划时发生的事情。对很多人来说,这句话引起了共鸣。', - }, { - 'text': 'The quick brown fox jumps over the lazy dog. Isn\'t it amazing how a simple sentence can contain every letter of the alphabet? The quick brown fox jumps over the lazy dog. Speaking of weather, yesterday was quite dreary; however, today is absolutely delightful. Isn\'t it amazing how a simple sentence can contain every letter of the alphabet? "Let\'s seize the day," Tom exclaimed, full of enthusiasm. "Let\'s seize the day," Tom exclaimed, full of enthusiasm.', - 'target': 'The quick brown fox jumps over the lazy dog. Isn\'t it amazing how a simple sentence can contain every letter of the alphabet? Speaking of weather, yesterday was quite dreary; however, today is absolutely delightful. "Let\'s seize the day," Tom exclaimed, full of enthusiasm.' - }, { - 'text': '''我很开心 。但是你不开心 。我很开心 。\n你好呀!我很开心 。我好的。你好呀!''', - 'target': '''我很开心 。但是你不开心 。\n你好呀!我好的。你好呀!''' - }, { - 'text': '去重?去重。去重!重。重...... 重! 1234?3215. 1234. 3. 3. 3', - 'target': '去重?去重。去重!重。重...... 重! 1234?3215. 1234. 3. 3. 3' - } - ] - - op = RemoveRepeatSentencesMapper(lowercase=True, ignore_special_character=False, min_repeat_sentence_length=5) + samples = [{ + 'text': + 'Life is what happens when you\'re busy making other plans. John Lennon once said. Life is what happens when you\'re busy making other plans. This phrase has resonated with many people over the years. 人生就是当你忙于制定其他计划时发生的事情。对很多人来说,这句话引起了共鸣。', + 'target': + 'Life is what happens when you\'re busy making other plans. John Lennon once said. This phrase has resonated with many people over the years. 人生就是当你忙于制定其他计划时发生的事情。对很多人来说,这句话引起了共鸣。', + }, { + 'text': + 'The quick brown fox jumps over the lazy dog. Isn\'t it amazing how a simple sentence can contain every letter of the alphabet? The quick brown fox jumps over the lazy dog. Speaking of weather, yesterday was quite dreary; however, today is absolutely delightful. Isn\'t it amazing how a simple sentence can contain every letter of the alphabet? "Let\'s seize the day," Tom exclaimed, full of enthusiasm. "Let\'s seize the day," Tom exclaimed, full of enthusiasm.', + 'target': + 'The quick brown fox jumps over the lazy dog. Isn\'t it amazing how a simple sentence can contain every letter of the alphabet? Speaking of weather, yesterday was quite dreary; however, today is absolutely delightful. "Let\'s seize the day," Tom exclaimed, full of enthusiasm.' + }, { + 'text': '''我很开心 。但是你不开心 。我很开心 。\n你好呀!我很开心 。我好的。你好呀!''', + 'target': '''我很开心 。但是你不开心 。\n你好呀!我好的。你好呀!''' + }, { + 'text': '去重?去重。去重!重。重...... 重! 1234?3215. 1234. 3. 3. 3', + 'target': '去重?去重。去重!重。重...... 重! 1234?3215. 1234. 3. 3. 3' + }] + + op = RemoveRepeatSentencesMapper(lowercase=True, + ignore_special_character=False, + min_repeat_sentence_length=5) self._run_helper(samples, op) diff --git a/tests/ops/mapper/test_remove_specific_chars_mapper.py b/tests/ops/mapper/test_remove_specific_chars_mapper.py index 4073d45df..f61a3f6fc 100644 --- a/tests/ops/mapper/test_remove_specific_chars_mapper.py +++ b/tests/ops/mapper/test_remove_specific_chars_mapper.py @@ -2,9 +2,10 @@ from data_juicer.ops.mapper.remove_specific_chars_mapper import \ RemoveSpecificCharsMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class RemoveSpecificCharsMapperTest(unittest.TestCase): +class RemoveSpecificCharsMapperTest(DataJuicerTestCaseBase): def setUp(self): self.op = RemoveSpecificCharsMapper() diff --git a/tests/ops/mapper/test_remove_table_text_mapper.py b/tests/ops/mapper/test_remove_table_text_mapper.py index d08585d3e..2be4a2453 100644 --- a/tests/ops/mapper/test_remove_table_text_mapper.py +++ b/tests/ops/mapper/test_remove_table_text_mapper.py @@ -2,9 +2,10 @@ from data_juicer.ops.mapper.remove_table_text_mapper import \ RemoveTableTextMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class RemoveTableTextMapperTest(unittest.TestCase): +class RemoveTableTextMapperTest(DataJuicerTestCaseBase): def setUp(self): self.op = RemoveTableTextMapper() diff --git a/tests/ops/mapper/test_remove_words_with_incorrect_substrings_mapper.py b/tests/ops/mapper/test_remove_words_with_incorrect_substrings_mapper.py index ad1fbe183..02157ad52 100644 --- a/tests/ops/mapper/test_remove_words_with_incorrect_substrings_mapper.py +++ b/tests/ops/mapper/test_remove_words_with_incorrect_substrings_mapper.py @@ -2,9 +2,10 @@ from data_juicer.ops.mapper.remove_words_with_incorrect_substrings_mapper import \ RemoveWordsWithIncorrectSubstringsMapper # noqa: E501 +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class RemoveWordsWithIncorrectSubstringsMapperTest(unittest.TestCase): +class RemoveWordsWithIncorrectSubstringsMapperTest(DataJuicerTestCaseBase): def _run_remove_words_with_incorrect_sbstrings(self, samples, op): for sample in samples: diff --git a/tests/ops/mapper/test_replace_content_mapper.py b/tests/ops/mapper/test_replace_content_mapper.py index ec6ae512e..64f88c888 100644 --- a/tests/ops/mapper/test_replace_content_mapper.py +++ b/tests/ops/mapper/test_replace_content_mapper.py @@ -1,12 +1,12 @@ import unittest -from data_juicer.ops.mapper.replace_content_mapper import \ - ReplaceContentMapper +from data_juicer.ops.mapper.replace_content_mapper import ReplaceContentMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class ReplaceContentMapperTest(unittest.TestCase): +class ReplaceContentMapperTest(DataJuicerTestCaseBase): - def _run_helper(self,op, samples): + def _run_helper(self, op, samples): for sample in samples: result = op.process(sample) self.assertEqual(result['text'], result['target']) @@ -34,7 +34,6 @@ def test_special_char_pattern_text(self): op = ReplaceContentMapper(pattern='●■', repl='') self._run_helper(op, samples) - def test_raw_digit_pattern_text(self): samples = [ @@ -45,7 +44,7 @@ def test_raw_digit_pattern_text(self): ] op = ReplaceContentMapper(pattern=r'\d+(?:,\d+)*', repl='') self._run_helper(op, samples) - + def test_regular_digit_pattern_text(self): samples = [ @@ -57,5 +56,6 @@ def test_regular_digit_pattern_text(self): op = ReplaceContentMapper(pattern='\\d+(?:,\\d+)*', repl='') self._run_helper(op, samples) + if __name__ == '__main__': unittest.main() diff --git a/tests/ops/mapper/test_sentence_split_mapper.py b/tests/ops/mapper/test_sentence_split_mapper.py index abd914bda..3cdf3a977 100644 --- a/tests/ops/mapper/test_sentence_split_mapper.py +++ b/tests/ops/mapper/test_sentence_split_mapper.py @@ -1,9 +1,10 @@ import unittest from data_juicer.ops.mapper.sentence_split_mapper import SentenceSplitMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class SentenceSplitMapperTest(unittest.TestCase): +class SentenceSplitMapperTest(DataJuicerTestCaseBase): def _run_helper(self, op, samples): for sample in samples: diff --git a/tests/ops/mapper/test_video_captioning_from_audio_mapper.py b/tests/ops/mapper/test_video_captioning_from_audio_mapper.py new file mode 100644 index 000000000..3a842bab8 --- /dev/null +++ b/tests/ops/mapper/test_video_captioning_from_audio_mapper.py @@ -0,0 +1,160 @@ +import os +import unittest + +from data_juicer.core.data import NestedDataset +from data_juicer.ops.mapper.video_captioning_from_audio_mapper import \ + VideoCaptioningFromAudioMapper +from data_juicer.utils.mm_utils import SpecialTokens +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) + + +# Skip tests for this OP in the GitHub actions due to disk space limitation. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class VideoCaptioningFromAudioMapperTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + vid1_path = os.path.join(data_path, 'video1.mp4') + vid2_path = os.path.join(data_path, 'video2.mp4') + vid3_path = os.path.join(data_path, 'video3.mp4') + + @staticmethod + def _count_generated_caption_num(text): + chunks = text.split(SpecialTokens.eoc) + vid_num = 0 + cap_num = 0 + for chunk in chunks: + if chunk.strip() == '': + continue + vid_num += chunk.count(SpecialTokens.video) + caps = [ + cap for cap in chunk.split(SpecialTokens.video) if cap.strip() + ] + cap_num += len(caps) + return vid_num, cap_num + + def _run_op(self, dataset: NestedDataset, caption_num, op, np=1): + dataset = dataset.map(op.process, num_proc=np) + text_list = dataset.select_columns(column_names=['text']).to_list() + for txt in text_list: + vid_num, cap_num = self._count_generated_caption_num(txt['text']) + self.assertEqual(vid_num, cap_num) + self.assertEqual(len(dataset), caption_num) + + def test_default_params(self): + + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。', + 'videos': [self.vid2_path] + }, { + 'text': f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。', + 'videos': [self.vid3_path] + }] + dataset = NestedDataset.from_list(ds_list) + op = VideoCaptioningFromAudioMapper() + self._run_op(dataset, len(dataset) * 2, op) + + def test_with_eoc(self): + + ds_list = [{ + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。' + f'{SpecialTokens.eoc}', + 'videos': [self.vid1_path] + }, { + 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。 ' + f'{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': + f'两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.video} {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }] + dataset = NestedDataset.from_list(ds_list) + op = VideoCaptioningFromAudioMapper() + self._run_op(dataset, len(dataset) * 2, op) + + def test_no_original_samples(self): + + ds_list = [{ + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。' + f'{SpecialTokens.eoc}', + 'videos': [self.vid1_path] + }, { + 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。 ' + f'{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': + f'两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.video} {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }] + dataset = NestedDataset.from_list(ds_list) + op = VideoCaptioningFromAudioMapper(keep_original_sample=False) + self._run_op(dataset, len(dataset), op) + + def test_multi_chunk_samples(self): + + ds_list = [{ + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。' + f'{SpecialTokens.eoc} {SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,' + f'拍打自己的胃部。 {SpecialTokens.eoc}', + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'text': + f'两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.video} {SpecialTokens.eoc} ' + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid3_path, self.vid1_path] + }] + dataset = NestedDataset.from_list(ds_list) + op = VideoCaptioningFromAudioMapper() + self._run_op(dataset, len(dataset) * 2, op) + + def test_multi_video_samples(self): + + ds_list = [{ + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。 ' + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。 ' + f'{SpecialTokens.eoc}', + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'text': + f'两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.video} {SpecialTokens.eoc} ' + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。 ' + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。 ' + f'{SpecialTokens.eoc} {SpecialTokens.video} 白色的小羊站在一旁讲话。' + f'旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': + [self.vid3_path, self.vid1_path, self.vid2_path, self.vid1_path] + }] + dataset = NestedDataset.from_list(ds_list) + op = VideoCaptioningFromAudioMapper() + self._run_op(dataset, len(dataset) * 2, op) + + def test_parallel(self): + + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。', + 'videos': [self.vid2_path] + }, { + 'text': f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。', + 'videos': [self.vid3_path] + }] + dataset = NestedDataset.from_list(ds_list) + op = VideoCaptioningFromAudioMapper() + self._run_op(dataset, len(dataset) * 2, op, np=2) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_video_captioning_from_video_mapper.py b/tests/ops/mapper/test_video_captioning_from_video_mapper.py new file mode 100644 index 000000000..012761af5 --- /dev/null +++ b/tests/ops/mapper/test_video_captioning_from_video_mapper.py @@ -0,0 +1,232 @@ +import os +import unittest + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.video_captioning_from_video_mapper import \ + VideoCaptioningFromVideoMapper +from data_juicer.utils.mm_utils import SpecialTokens +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) + + +# Skip tests for this OP in the GitHub actions due to disk space limitation. +# These tests have been tested locally. +@SKIPPED_TESTS.register_module() +class VideoCaptioningFromVideoMapperTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + vid1_path = os.path.join(data_path, 'video1.mp4') + vid2_path = os.path.join(data_path, 'video2.mp4') + hf_video_blip = 'kpyu/video-blip-opt-2.7b-ego4d' + + @classmethod + def tearDownClass(cls) -> None: + super().tearDownClass(cls.hf_video_blip) + + def _run_mapper(self, ds_list, op, num_proc=1, caption_num=0): + + dataset = Dataset.from_list(ds_list) + dataset = dataset.map(op.process, num_proc=num_proc, with_rank=True) + dataset_list = dataset.select_columns(column_names=['text']).to_list() + # assert the caption is generated successfully in terms of not_none + # as the generated content is not deterministic + self.assertEqual(len(dataset_list), caption_num) + + def test_default_params_no_eoc(self): + + ds_list = [{ + 'text': f'{SpecialTokens.video}白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪', + 'videos': [self.vid1_path] + }, { + 'text': f'{SpecialTokens.video}身穿白色上衣的男子,拿着一个东西,拍打自己的胃', + 'videos': [self.vid2_path] + }] + op = VideoCaptioningFromVideoMapper(hf_video_blip=self.hf_video_blip) + self._run_mapper(ds_list, op, caption_num=len(ds_list) * 2) + + def test_default_params_with_eoc(self): + + ds_list = [ + { + 'text': + f'{SpecialTokens.video}白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪' + f'{SpecialTokens.eoc}', + 'videos': [self.vid1_path] + }, + { + 'text': + f'{SpecialTokens.video}身穿白色上衣的男子,拿着一个东西,拍打自己的胃{SpecialTokens.eoc}', # noqa: E501 + 'videos': [self.vid2_path] + } + ] + op = VideoCaptioningFromVideoMapper(hf_video_blip=self.hf_video_blip) + self._run_mapper(ds_list, op, caption_num=len(ds_list) * 2) + + def test_multi_candidate_keep_random_any(self): + + ds_list = [{ + 'text': f'{SpecialTokens.video}白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪', + 'videos': [self.vid1_path] + }, { + 'text': f'{SpecialTokens.video}身穿白色上衣的男子,拿着一个东西,拍打自己的胃', + 'videos': [self.vid2_path] + }] + caption_num = 4 + op = VideoCaptioningFromVideoMapper(hf_video_blip=self.hf_video_blip, + caption_num=caption_num, + keep_candidate_mode='random_any') + self._run_mapper(ds_list, op, caption_num=len(ds_list) * 2) + + def test_multi_candidate_keep_all(self): + + ds_list = [{ + 'text': f'{SpecialTokens.video}白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪', + 'videos': [self.vid1_path] + }, { + 'text': f'{SpecialTokens.video}身穿白色上衣的男子,拿着一个东西,拍打自己的胃', + 'videos': [self.vid2_path] + }] + caption_num = 4 + op = VideoCaptioningFromVideoMapper(hf_video_blip=self.hf_video_blip, + caption_num=caption_num, + keep_candidate_mode='all') + self._run_mapper(ds_list, + op, + caption_num=(1 + caption_num) * len(ds_list)) + + def test_multi_candidate_keep_similar_one(self): + ds_list = [{ + 'text': f'{SpecialTokens.video}白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪', + 'videos': [self.vid1_path] + }, { + 'text': f'{SpecialTokens.video}身穿白色上衣的男子,拿着一个东西,拍打自己的胃', + 'videos': [self.vid2_path] + }] + caption_num = 4 + op = VideoCaptioningFromVideoMapper( + hf_video_blip=self.hf_video_blip, + caption_num=caption_num, + keep_candidate_mode='similar_one_simhash') + self._run_mapper(ds_list, op, caption_num=len(ds_list) * 2) + + def test_remove_original_sample(self): + + ds_list = [ + { + 'text': + f'{SpecialTokens.video}白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪', + 'videos': [self.vid1_path] + }, + { + 'text': + f'{SpecialTokens.video}身穿白色上衣的男子,拿着一个东西,拍打自己的胃', # noqa: E501 + 'videos': [self.vid2_path] + } + ] + op = VideoCaptioningFromVideoMapper(hf_video_blip=self.hf_video_blip, + keep_original_sample=False) + self._run_mapper(ds_list, op, caption_num=len(ds_list)) + + def test_multi_candidate_remove_original_sample(self): + + ds_list = [{ + 'text': f'{SpecialTokens.video}白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪', + 'videos': [self.vid1_path] + }, { + 'text': f'{SpecialTokens.video}身穿白色上衣的男子,拿着一个东西,拍打自己的胃', + 'videos': [self.vid2_path] + }] + caption_num = 4 + op = VideoCaptioningFromVideoMapper(hf_video_blip=self.hf_video_blip, + caption_num=caption_num, + keep_original_sample=False) + self._run_mapper(ds_list, op, caption_num=len(ds_list)) + + def test_multi_process(self): + ds_list = [{ + 'text': f'{SpecialTokens.video}白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪', + 'videos': [self.vid1_path] + }] * 10 + op = VideoCaptioningFromVideoMapper(hf_video_blip=self.hf_video_blip) + self._run_mapper(ds_list, op, num_proc=4, caption_num=len(ds_list) * 2) + + def test_multi_process_remove_original_sample(self): + ds_list = [{ + 'text': f'{SpecialTokens.video}白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪', + 'videos': [self.vid1_path] + }] * 10 + + op = VideoCaptioningFromVideoMapper(hf_video_blip=self.hf_video_blip, + keep_original_sample=False) + self._run_mapper(ds_list, op, num_proc=4, caption_num=len(ds_list)) + + def test_frame_sampling_method(self): + + ds_list = [{ + 'text': f'{SpecialTokens.video}白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪', + 'videos': [self.vid1_path] + }, { + 'text': f'{SpecialTokens.video}身穿白色上衣的男子,拿着一个东西,拍打自己的胃', + 'videos': [self.vid2_path] + }] + op = VideoCaptioningFromVideoMapper(hf_video_blip=self.hf_video_blip, + frame_sampling_method='uniform') + self._run_mapper(ds_list, op, caption_num=len(ds_list) * 2) + + def test_frame_num(self): + + ds_list = [{ + 'text': f'{SpecialTokens.video}白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪', + 'videos': [self.vid1_path] + }, { + 'text': f'{SpecialTokens.video}身穿白色上衣的男子,拿着一个东西,拍打自己的胃', + 'videos': [self.vid2_path] + }] + op = VideoCaptioningFromVideoMapper(hf_video_blip=self.hf_video_blip, + frame_sampling_method='uniform', + frame_num=5) + self._run_mapper(ds_list, op, caption_num=len(ds_list) * 2) + + def test_horizontal_flip(self): + + ds_list = [{ + 'text': f'{SpecialTokens.video}白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪', + 'videos': [self.vid1_path] + }, { + 'text': f'{SpecialTokens.video}身穿白色上衣的男子,拿着一个东西,拍打自己的胃', + 'videos': [self.vid2_path] + }] + op = VideoCaptioningFromVideoMapper(hf_video_blip=self.hf_video_blip, + horizontal_flip=True) + self._run_mapper(ds_list, op, caption_num=len(ds_list) * 2) + + def test_vertical_flip(self): + + ds_list = [{ + 'text': f'{SpecialTokens.video}白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪', + 'videos': [self.vid1_path] + }, { + 'text': f'{SpecialTokens.video}身穿白色上衣的男子,拿着一个东西,拍打自己的胃', + 'videos': [self.vid2_path] + }] + op = VideoCaptioningFromVideoMapper(hf_video_blip=self.hf_video_blip, + vertical_flip=True) + self._run_mapper(ds_list, op, caption_num=len(ds_list) * 2) + + def test_multi_tag(self): + + ds_list = [{ + 'text': f'{SpecialTokens.video}{SpecialTokens.video}' + '白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪', + 'videos': [ + self.vid1_path, + self.vid1_path, + ] + }] + op = VideoCaptioningFromVideoMapper(hf_video_blip=self.hf_video_blip) + self._run_mapper(ds_list, op, caption_num=len(ds_list) * 2) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_video_ffmpeg_wrapped_mapper.py b/tests/ops/mapper/test_video_ffmpeg_wrapped_mapper.py new file mode 100644 index 000000000..1071bd864 --- /dev/null +++ b/tests/ops/mapper/test_video_ffmpeg_wrapped_mapper.py @@ -0,0 +1,69 @@ +import os +import unittest + +from datasets import Dataset + +from data_juicer.ops.mapper.video_ffmpeg_wrapped_mapper import \ + VideoFFmpegWrappedMapper +from data_juicer.utils.mm_utils import load_video +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class VideoFFmpegWrappedMapperTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + vid1_path = os.path.join(data_path, 'video1.mp4') # 640x360, 16:9 + vid2_path = os.path.join(data_path, 'video2.mp4') # 480x640, 3:4 + vid3_path = os.path.join(data_path, 'video3.mp4') # 362x640, 181:320 + + def _run_op(self, ds_list, target_list, op, np=1): + dataset = Dataset.from_list(ds_list) + dataset = dataset.map(op.process, num_proc=np) + + def get_size(dataset): + sizes = [] + res_list = dataset.to_list() + for sample in res_list: + sample_list = [] + for value in sample['videos']: + video = load_video(value) + width = video.streams.video[0].codec_context.width + height = video.streams.video[0].codec_context.height + sample_list.append((width, height)) + video.close() + sizes.append(sample_list) + return sizes + + sizes = get_size(dataset) + self.assertEqual(sizes, target_list) + + def test_resize(self): + ds_list = [{ + 'videos': [self.vid1_path, self.vid2_path, self.vid3_path] + }] + tgt_list = [[(400, 480), (400, 480), (400, 480)]] + op = VideoFFmpegWrappedMapper('scale', + filter_kwargs={ + 'width': 400, + 'height': 480 + }, + capture_stderr=False) + self._run_op(ds_list, tgt_list, op) + + def test_resize_parallel(self): + ds_list = [{ + 'videos': [self.vid1_path, self.vid2_path, self.vid3_path] + }] + tgt_list = [[(400, 480), (400, 480), (400, 480)]] + op = VideoFFmpegWrappedMapper('scale', + filter_kwargs={ + 'width': 400, + 'height': 480 + }, + capture_stderr=False) + self._run_op(ds_list, tgt_list, op, np=2) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_video_resize_aspect_ratio_mapper.py b/tests/ops/mapper/test_video_resize_aspect_ratio_mapper.py new file mode 100644 index 000000000..3db841646 --- /dev/null +++ b/tests/ops/mapper/test_video_resize_aspect_ratio_mapper.py @@ -0,0 +1,150 @@ +import os +import unittest + +from datasets import Dataset + +from data_juicer.ops.mapper.video_resize_aspect_ratio_mapper import \ + VideoResizeAspectRatioMapper +from data_juicer.utils.mm_utils import load_video +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class VideoResizeAspectRatioMapperTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + vid1_path = os.path.join(data_path, 'video1.mp4') # 640x360, 16:9 + vid2_path = os.path.join(data_path, 'video2.mp4') # 480x640, 3:4 + vid3_path = os.path.join(data_path, 'video3.mp4') # 362x640, 181:320 + + def _run_op(self, dataset: Dataset, target_list, op, np=1): + dataset = dataset.map(op.process, num_proc=np) + + def get_size(dataset): + sizes = [] + res_list = dataset.to_list() + for sample in res_list: + sample_list = [] + for value in sample['videos']: + video = load_video(value) + width = video.streams.video[0].codec_context.width + height = video.streams.video[0].codec_context.height + sample_list.append((width, height)) + video.close() + sizes.append(sample_list) + return sizes + + sizes = get_size(dataset) + self.assertEqual(sizes, target_list) + + def test_default_params(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [ + [(640, 360)], # no change + [(480, 640)], # no change + [(362, 640)] # no change + ] + dataset = Dataset.from_list(ds_list) + op = VideoResizeAspectRatioMapper() + self._run_op(dataset, tgt_list, op) + + def test_min_ratio_increase(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [ + [(640, 360)], # no change + [(480, 640)], # no change + [(480, 640)] # 181:320 to 3:4 + ] + dataset = Dataset.from_list(ds_list) + op = VideoResizeAspectRatioMapper(min_ratio='3/4', strategy='increase') + self._run_op(dataset, tgt_list, op) + + def test_min_ratio_decrease(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [ + [(640, 360)], # no change + [(480, 640)], # no change + [(362, 482)] # ratio 181:320 to 3:4 + ] + dataset = Dataset.from_list(ds_list) + op = VideoResizeAspectRatioMapper(min_ratio='3/4', strategy='decrease') + self._run_op(dataset, tgt_list, op) + + def test_max_ratio_increase(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [ + [(640, 480)], # 16:9 to 4:3 + [(480, 640)], # no change + [(362, 640)] # no change + ] + dataset = Dataset.from_list(ds_list) + op = VideoResizeAspectRatioMapper(max_ratio='4/3', strategy='increase') + self._run_op(dataset, tgt_list, op) + + def test_max_ratio_decrease(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [ + [(480, 360)], # 16:9 to 4:3 + [(480, 640)], # no change + [(362, 640)] # no change + ] + dataset = Dataset.from_list(ds_list) + op = VideoResizeAspectRatioMapper(max_ratio='4/3', strategy='decrease') + self._run_op(dataset, tgt_list, op) + + def test_parallel(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [ + [(480, 360)], # 16:9 to 4:3 + [(480, 640)], # no change + [(362, 640)] # no change + ] + dataset = Dataset.from_list(ds_list) + op = VideoResizeAspectRatioMapper(max_ratio='4/3', strategy='decrease') + self._run_op(dataset, tgt_list, op, np=2) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_video_resize_resolution_mapper.py b/tests/ops/mapper/test_video_resize_resolution_mapper.py new file mode 100644 index 000000000..8f8b2cafa --- /dev/null +++ b/tests/ops/mapper/test_video_resize_resolution_mapper.py @@ -0,0 +1,187 @@ +import os +import unittest + +import ffmpeg +from datasets import Dataset + +from data_juicer.ops.mapper.video_resize_resolution_mapper import \ + VideoResizeResolutionMapper +from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class VideoResizeResolutionMapperTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + # video1: horizontal resolution 640p, vertical resolution 360p + # video2: horizontal resolution 480p, vertical resolution 640p + # video3: horizontal resolution 362p, vertical resolution 640p + vid1_path = os.path.join(data_path, 'video1.mp4') + vid2_path = os.path.join(data_path, 'video2.mp4') + vid3_path = os.path.join(data_path, 'video3.mp4') + + def _get_size_list(self, dataset: Dataset): + res_list = [] + for sample in dataset.to_list(): + cur_list = [] + for value in sample['videos']: + print(value) + probe = ffmpeg.probe(value) + video_stream = next((stream for stream in probe['streams'] + if stream['codec_type'] == 'video'), None) + width = int(video_stream['width']) + height = int(video_stream['height']) + cur_list.append((width, height)) + res_list.append(cur_list) + return res_list + + def _run_video_resize_resolution_mapper(self, + dataset: Dataset, + target_list, + op, + test_name, + np=1): + if Fields.stats not in dataset.features: + dataset = dataset.add_column(name=Fields.stats, + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.process, num_proc=np) + dataset = dataset.select_columns(column_names=[op.video_key]) + + # check each video personally + # output_dir = '../video_resize_resolution_mapper' + # move_to_dir = os.path.join(output_dir, test_name) + # if not os.path.exists(move_to_dir): + # os.makedirs(move_to_dir) + # for sample in dataset.to_list(): + # for value in sample['videos']: + # move_to_path = os.path.join(move_to_dir, + # os.path.basename(value)) + # shutil.copyfile(value, move_to_path) + + res_list = self._get_size_list(dataset) + self.assertEqual(res_list, target_list) + + def test_default_mapper(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [[(640, 360)], [(480, 640)], [(362, 640)]] + dataset = Dataset.from_list(ds_list) + op = VideoResizeResolutionMapper() + self._run_video_resize_resolution_mapper(dataset, tgt_list, op, + 'test_default_mapper') + + def test_width_mapper(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [[(480, 270)], [(480, 640)], [(400, 708)]] + dataset = Dataset.from_list(ds_list) + op = VideoResizeResolutionMapper(min_width=400, max_width=480) + self._run_video_resize_resolution_mapper(dataset, tgt_list, op, + 'test_width_mapper') + + def test_height_mapper(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [[(854, 480)], [(360, 480)], [(272, 480)]] + dataset = Dataset.from_list(ds_list) + op = VideoResizeResolutionMapper(min_height=480, max_height=480) + self._run_video_resize_resolution_mapper(dataset, tgt_list, op, + 'test_width_mapper') + + def test_width_and_height_mapper(self): + + ds_list = [{ + 'videos': [self.vid1_path, self.vid2_path, self.vid3_path] + }] + tgt_list = [[(480, 480), (400, 480), (400, 480)]] + dataset = Dataset.from_list(ds_list) + op = VideoResizeResolutionMapper(min_width=400, + max_width=480, + min_height=480, + max_height=480) + self._run_video_resize_resolution_mapper( + dataset, tgt_list, op, 'test_width_and_height_mapper') + + def test_keep_aspect_ratio_decrease_mapper(self): + + ds_list = [{'videos': [self.vid1_path]}] + tgt_list = [[(480, 270)]] + dataset = Dataset.from_list(ds_list) + op = VideoResizeResolutionMapper( + min_width=400, + max_width=480, + min_height=480, + max_height=480, + force_original_aspect_ratio='decrease') + self._run_video_resize_resolution_mapper( + dataset, tgt_list, op, 'test_keep_aspect_ratio_decrease_mapper') + + def test_keep_aspect_ratio_increase_mapper(self): + + ds_list = [{'videos': [self.vid1_path]}] + tgt_list = [[(854, 480)]] + dataset = Dataset.from_list(ds_list) + op = VideoResizeResolutionMapper( + min_width=400, + max_width=480, + min_height=480, + max_height=480, + force_original_aspect_ratio='increase') + self._run_video_resize_resolution_mapper( + dataset, tgt_list, op, 'test_keep_aspect_ratio_increase_mapper') + + def test_force_divisible_by(self): + + ds_list = [{'videos': [self.vid1_path]}] + tgt_list = [[(480, 272)]] + dataset = Dataset.from_list(ds_list) + op = VideoResizeResolutionMapper( + min_width=400, + max_width=480, + min_height=480, + max_height=480, + force_original_aspect_ratio='decrease', + force_divisible_by=4) + self._run_video_resize_resolution_mapper(dataset, tgt_list, op, + 'test_force_divisible_by') + + def test_filter_in_parallel(self): + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [[(480, 270)], [(480, 640)], [(400, 708)]] + dataset = Dataset.from_list(ds_list) + op = VideoResizeResolutionMapper(min_width=400, max_width=480) + self._run_video_resize_resolution_mapper(dataset, + tgt_list, + op, + 'test_filter_in_parallel', + np=2) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_video_split_by_duration_mapper.py b/tests/ops/mapper/test_video_split_by_duration_mapper.py new file mode 100644 index 000000000..43089dfa7 --- /dev/null +++ b/tests/ops/mapper/test_video_split_by_duration_mapper.py @@ -0,0 +1,232 @@ +# flake8: noqa: E501 + +import os +import unittest + +from data_juicer.core.data import NestedDataset +from data_juicer.ops.mapper.video_split_by_duration_mapper import \ + VideoSplitByDurationMapper +from data_juicer.utils.file_utils import add_suffix_to_filename +from data_juicer.utils.mm_utils import SpecialTokens +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class VideoSplitByDurationMapperTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + vid1_path = os.path.join(data_path, 'video1.mp4') + vid2_path = os.path.join(data_path, 'video2.mp4') + vid3_path = os.path.join(data_path, 'video3.mp4') + + def _get_res_list(self, dataset, source_list): + res_list = [] + origin_paths = [self.vid1_path, self.vid2_path, self.vid3_path] + idx = 0 + for sample in dataset.to_list(): + output_paths = sample['videos'] + + # for keep_original_sample=True + if set(output_paths) <= set(origin_paths): + res_list.append(sample) + continue + + source = source_list[idx] + idx += 1 + + output_file_names = [ + os.path.splitext(os.path.basename(p))[0] for p in output_paths + ] + split_frames_nums = [] + for origin_path in source['videos']: + origin_file_name = os.path.splitext( + os.path.basename(origin_path))[0] + cnt = 0 + for output_file_name in output_file_names: + if origin_file_name in output_file_name: + cnt += 1 + split_frames_nums.append(cnt) + + res_list.append({ + 'text': sample['text'], + 'split_frames_num': split_frames_nums + }) + + return res_list + + def _run_video_split_by_duration_mapper(self, + op, + source_list, + target_list, + num_proc=1): + dataset = NestedDataset.from_list(source_list) + dataset = dataset.map(op.process, num_proc=num_proc) + res_list = self._get_res_list(dataset, source_list) + self.assertEqual(res_list, target_list) + + def test(self): + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'text': + f'{SpecialTokens.video}{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}', + 'split_frames_num': [2] + }, { + 'text': + f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'split_frames_num': [3] + }, { + 'text': + f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'split_frames_num': [5] + }] + op = VideoSplitByDurationMapper(split_duration=10, + keep_original_sample=False) + self._run_video_split_by_duration_mapper(op, ds_list, tgt_list) + + def test_keep_ori_sample(self): + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': + f'{SpecialTokens.video}{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}', + 'split_frames_num': [2] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'split_frames_num': [3] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }, { + 'text': + f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'split_frames_num': [5] + }] + op = VideoSplitByDurationMapper() + self._run_video_split_by_duration_mapper(op, ds_list, tgt_list) + + def test_multi_process(self): + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'text': + f'{SpecialTokens.video}{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}', + 'split_frames_num': [2] + }, { + 'text': + f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'split_frames_num': [3] + }, { + 'text': + f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'split_frames_num': [5] + }] + op = VideoSplitByDurationMapper(keep_original_sample=False) + self._run_video_split_by_duration_mapper(op, + ds_list, + tgt_list, + num_proc=2) + + def test_multi_chunk(self): + ds_list = [{ + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。', + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid2_path, self.vid3_path] + }, { + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid1_path, self.vid3_path] + }] + tgt_list = [{ + 'text': + f'{SpecialTokens.video}{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'split_frames_num': [2, 3] + }, { + 'text': + f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'split_frames_num': [3, 5] + }, { + 'text': + f'{SpecialTokens.video}{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'split_frames_num': [2, 5] + }] + op = VideoSplitByDurationMapper(keep_original_sample=False) + self._run_video_split_by_duration_mapper(op, ds_list, tgt_list) + + def test_min_last_split_duration(self): + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}', + 'split_frames_num': [1] + }, { + 'text': + f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'split_frames_num': [3] + }, { + 'text': + f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'split_frames_num': [5] + }] + op = VideoSplitByDurationMapper(split_duration=10, + min_last_split_duration=3, + keep_original_sample=False) + self._run_video_split_by_duration_mapper(op, ds_list, tgt_list) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_video_split_by_key_frame_mapper.py b/tests/ops/mapper/test_video_split_by_key_frame_mapper.py new file mode 100644 index 000000000..997ae9ed8 --- /dev/null +++ b/tests/ops/mapper/test_video_split_by_key_frame_mapper.py @@ -0,0 +1,200 @@ +# flake8: noqa: E501 + +import os +import unittest + +from data_juicer.core.data import NestedDataset +from data_juicer.ops.mapper.video_split_by_key_frame_mapper import \ + VideoSplitByKeyFrameMapper +from data_juicer.utils.file_utils import add_suffix_to_filename +from data_juicer.utils.mm_utils import SpecialTokens +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class VideoSplitByKeyFrameMapperTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + vid1_path = os.path.join(data_path, 'video1.mp4') + vid2_path = os.path.join(data_path, 'video2.mp4') + vid3_path = os.path.join(data_path, 'video3.mp4') + + def _get_res_list(self, dataset, source_list): + res_list = [] + origin_paths = [self.vid1_path, self.vid2_path, self.vid3_path] + idx = 0 + for sample in dataset.to_list(): + output_paths = sample['videos'] + + # for keep_original_sample=True + if set(output_paths) <= set(origin_paths): + res_list.append(sample) + continue + + source = source_list[idx] + idx += 1 + + output_file_names = [ + os.path.splitext(os.path.basename(p))[0] for p in output_paths + ] + split_frames_nums = [] + for origin_path in source['videos']: + origin_file_name = os.path.splitext( + os.path.basename(origin_path))[0] + cnt = 0 + for output_file_name in output_file_names: + if origin_file_name in output_file_name: + cnt += 1 + split_frames_nums.append(cnt) + + res_list.append({ + 'text': sample['text'], + 'split_frames_num': split_frames_nums + }) + + return res_list + + def _run_video_split_by_key_frame_mapper(self, + op, + source_list, + target_list, + num_proc=1): + dataset = NestedDataset.from_list(source_list) + dataset = dataset.map(op.process, num_proc=num_proc) + res_list = self._get_res_list(dataset, source_list) + self.assertEqual(res_list, target_list) + + def test(self): + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'text': + f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}', + 'split_frames_num': [3] + }, { + 'text': + f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'split_frames_num': [3] + }, { + 'text': + f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'split_frames_num': [6] + }] + op = VideoSplitByKeyFrameMapper(keep_original_sample=False) + self._run_video_split_by_key_frame_mapper(op, ds_list, tgt_list) + + def test_keep_ori_sample(self): + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': + f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}', + 'split_frames_num': [3] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'split_frames_num': [3] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }, { + 'text': + f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'split_frames_num': [6] + }] + op = VideoSplitByKeyFrameMapper() + self._run_video_split_by_key_frame_mapper(op, ds_list, tgt_list) + + def test_multi_process(self): + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'text': + f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}', + 'split_frames_num': [3] + }, { + 'text': + f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'split_frames_num': [3] + }, { + 'text': + f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'split_frames_num': [6] + }] + op = VideoSplitByKeyFrameMapper(keep_original_sample=False) + self._run_video_split_by_key_frame_mapper(op, + ds_list, + tgt_list, + num_proc=2) + + def test_multi_chunk(self): + ds_list = [{ + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。', + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid2_path, self.vid3_path] + }, { + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid1_path, self.vid3_path] + }] + tgt_list = [{ + 'text': + f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'split_frames_num': [3, 3] + }, { + 'text': + f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'split_frames_num': [3, 6] + }, { + 'text': + f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'split_frames_num': [3, 6] + }] + op = VideoSplitByKeyFrameMapper(keep_original_sample=False) + self._run_video_split_by_key_frame_mapper(op, ds_list, tgt_list) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_video_split_by_scene_mapper.py b/tests/ops/mapper/test_video_split_by_scene_mapper.py new file mode 100644 index 000000000..f4b3263aa --- /dev/null +++ b/tests/ops/mapper/test_video_split_by_scene_mapper.py @@ -0,0 +1,171 @@ +import os +import unittest + +from datasets import Dataset + +from data_juicer.ops.mapper.video_split_by_scene_mapper import \ + VideoSplitBySceneMapper +from data_juicer.utils.mm_utils import SpecialTokens +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class VideoSplitBySceneMapperTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + vid1_path = os.path.join(data_path, 'video1.mp4') # about 12s + vid2_path = os.path.join(data_path, 'video2.mp4') # about 23s + vid3_path = os.path.join(data_path, 'video3.mp4') # about 50s + + vid1_base, vid1_ext = os.path.splitext(os.path.basename(vid1_path)) + vid2_base, vid2_ext = os.path.splitext(os.path.basename(vid2_path)) + vid3_base, vid3_ext = os.path.splitext(os.path.basename(vid3_path)) + + op_name = 'video_split_by_scene_mapper' + + def get_res_list(self, dataset: Dataset): + res_list = [] + for sample in dataset.to_list(): + scene_num = len(sample['videos']) + if 'text' in sample: + res_list.append({ + 'scene_num': scene_num, + 'text': sample['text'] + }) + else: + res_list.append({'scene_num': scene_num}) + return res_list + + def _run_helper(self, op, source_list, target_list): + dataset = Dataset.from_list(source_list) + dataset = dataset.map(op.process) + res_list = self.get_res_list(dataset) + self.assertEqual(res_list, target_list) + + def test_ContentDetector(self): + ds_list = [ + { + 'videos': [self.vid1_path] # 3 scenes + }, + { + 'videos': [self.vid2_path] # 1 scene + }, + { + 'videos': [self.vid3_path] # 2 scenes + } + ] + tgt_list = [{'scene_num': 3}, {'scene_num': 1}, {'scene_num': 2}] + op = VideoSplitBySceneMapper(detector='ContentDetector', + threshold=27.0, + min_scene_len=15) + self._run_helper(op, ds_list, tgt_list) + + def test_AdaptiveDetector(self): + ds_list = [ + { + 'videos': [self.vid1_path] # 3 scenes + }, + { + 'videos': [self.vid2_path] # 1 scene + }, + { + 'videos': [self.vid3_path] # 8 scenes + } + ] + tgt_list = [{'scene_num': 3}, {'scene_num': 1}, {'scene_num': 8}] + op = VideoSplitBySceneMapper(detector='AdaptiveDetector', + threshold=3.0, + min_scene_len=15) + self._run_helper(op, ds_list, tgt_list) + + def test_ThresholdDetector(self): + ds_list = [ + { + 'videos': [self.vid1_path] # 1 scene + }, + { + 'videos': [self.vid2_path] # 1 scene + }, + { + 'videos': [self.vid3_path] # 1 scene + } + ] + tgt_list = [{'scene_num': 1}, {'scene_num': 1}, {'scene_num': 1}] + op = VideoSplitBySceneMapper(detector='ThresholdDetector', + threshold=12.0, + min_scene_len=15) + self._run_helper(op, ds_list, tgt_list) + + def test_default_progress(self): + ds_list = [ + { + 'videos': [self.vid1_path] # 3 scenes + }, + { + 'videos': [self.vid2_path] # 1 scene + }, + { + 'videos': [self.vid3_path] # 2 scenes + } + ] + tgt_list = [{'scene_num': 3}, {'scene_num': 1}, {'scene_num': 2}] + op = VideoSplitBySceneMapper(show_progress=True) + self._run_helper(op, ds_list, tgt_list) + + def test_default_kwargs(self): + ds_list = [ + { + 'videos': [self.vid1_path] # 2 scenes + }, + { + 'videos': [self.vid2_path] # 1 scene + }, + { + 'videos': [self.vid3_path] # 2 scenes + } + ] + tgt_list = [{'scene_num': 2}, {'scene_num': 1}, {'scene_num': 2}] + op = VideoSplitBySceneMapper(luma_only=True, kernel_size=5) + self._run_helper(op, ds_list, tgt_list) + + def test_default_with_text(self): + ds_list = [ + { + 'text': + f'{SpecialTokens.video} this is video1 {SpecialTokens.eoc}', + 'videos': [self.vid1_path] # 3 scenes + }, + { + 'text': + f'{SpecialTokens.video} this is video2 {SpecialTokens.eoc}', + 'videos': [self.vid2_path] # 1 scene + }, + { + 'text': + f'{SpecialTokens.video} this is video3 {SpecialTokens.eoc}', + 'videos': [self.vid3_path] # 2 scenes + } + ] + tgt_list = [ + { + 'text': + f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} this is video1 {SpecialTokens.eoc}', # noqa: E501 + 'scene_num': 3 + }, + { + 'text': + f'{SpecialTokens.video} this is video2 {SpecialTokens.eoc}', + 'scene_num': 1 + }, + { + 'text': + f'{SpecialTokens.video}{SpecialTokens.video} this is video3 {SpecialTokens.eoc}', # noqa: E501 + 'scene_num': 2 + } + ] + op = VideoSplitBySceneMapper() + self._run_helper(op, ds_list, tgt_list) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_video_tagging_from_audio_mapper.py b/tests/ops/mapper/test_video_tagging_from_audio_mapper.py new file mode 100644 index 000000000..042109d86 --- /dev/null +++ b/tests/ops/mapper/test_video_tagging_from_audio_mapper.py @@ -0,0 +1,155 @@ +import os +import unittest + +from data_juicer.core.data import NestedDataset +from data_juicer.ops.mapper.video_tagging_from_audio_mapper import \ + VideoTaggingFromAudioMapper +from data_juicer.utils.constant import Fields +from data_juicer.utils.mm_utils import SpecialTokens +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class VideoTaggingFromAudioMapperTest(DataJuicerTestCaseBase): + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + vid1_path = os.path.join(data_path, 'video1.mp4') # Music + vid2_path = os.path.join(data_path, 'video2.mp4') # Music + vid3_path = os.path.join(data_path, 'video3.mp4') # Music + vid4_path = os.path.join(data_path, 'video4.mp4') # Speech + vid5_path = os.path.join(data_path, 'video5.mp4') # Speech + vid3_no_aud_path = os.path.join(data_path, 'video3-no-audio.mp4') # EMPTY + + hf_ast = 'MIT/ast-finetuned-audioset-10-10-0.4593' + + @classmethod + def tearDownClass(cls) -> None: + super().tearDownClass(cls.hf_ast) + + def _run_video_tagging_from_audio_mapper(self, + op, + source_list, + target_list, + num_proc=1): + dataset = NestedDataset.from_list(source_list) + dataset = dataset.map(op.process, num_proc=num_proc) + res_list = dataset.select_columns([Fields.video_audio_tags + ])[Fields.video_audio_tags] + self.assertEqual(res_list, target_list) + + def test(self): + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。' + f'{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': f'{SpecialTokens.video} 一个人在帮另一个人梳头发。 {SpecialTokens.eoc}', + 'videos': [self.vid4_path] + }, { + 'text': + f'{SpecialTokens.video} 一个穿着红色连衣裙的女人在试衣服。 {SpecialTokens.eoc}', + 'videos': [self.vid5_path] + }] + tgt_list = [['Music'], ['Music'], ['Speech'], ['Speech']] + op = VideoTaggingFromAudioMapper(self.hf_ast) + self._run_video_tagging_from_audio_mapper(op, ds_list, tgt_list) + + def test_multi_chunk(self): + ds_list = [{ + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。' + f'{SpecialTokens.eoc}{SpecialTokens.video} ' + f'身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。', + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。' + f'{SpecialTokens.eoc}{SpecialTokens.video} 一个人在帮另一个人梳头发。 ' + f'{SpecialTokens.eoc}', + 'videos': [self.vid2_path, self.vid4_path] + }, { + 'text': + f'一个穿着红色连衣裙的女人在试衣服。 {SpecialTokens.video} {SpecialTokens.eoc} ' + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid5_path, self.vid1_path] + }] + tgt_list = [['Music', 'Music'], ['Music', 'Speech'], + ['Speech', 'Music']] + op = VideoTaggingFromAudioMapper(self.hf_ast) + self._run_video_tagging_from_audio_mapper(op, ds_list, tgt_list) + + def test_multi_video(self): + ds_list = [{ + 'text': + f'{SpecialTokens.video} {SpecialTokens.video} 白色的小羊站在一旁讲话。' + f'旁边还有两只灰色猫咪和一只拉着灰狼的猫咪; 一个人在帮另一个人梳头发 {SpecialTokens.eoc}' + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。', + 'videos': [self.vid1_path, self.vid4_path, self.vid2_path] + }, { + 'text': + f'一个穿着红色连衣裙的女人在试衣服。 {SpecialTokens.video} {SpecialTokens.video} ' + f'白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid5_path, self.vid1_path] + }] + tgt_list = [['Music', 'Speech', 'Music'], ['Speech', 'Music']] + op = VideoTaggingFromAudioMapper(self.hf_ast) + self._run_video_tagging_from_audio_mapper(op, ds_list, tgt_list) + + def test_no_video(self): + ds_list = [{ + 'text': '白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [] + }, { + 'text': f'身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}' + f'{SpecialTokens.video} 一个人在帮另一个人梳头发。 {SpecialTokens.eoc}', + 'videos': [self.vid4_path] + }] + tgt_list = [[], ['Speech']] + op = VideoTaggingFromAudioMapper(self.hf_ast) + self._run_video_tagging_from_audio_mapper(op, ds_list, tgt_list) + + def test_no_audio(self): + ds_list = [{ + 'text': + f'{SpecialTokens.video} {SpecialTokens.video} 白色的小羊站在一旁讲话。' + f'旁边还有两只灰色猫咪和一只拉着灰狼的猫咪; 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}' + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。', + 'videos': [self.vid1_path, self.vid3_no_aud_path, self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video} {SpecialTokens.video} ' + f'两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.video} 一个人在帮另一个人梳头发。', + 'videos': [self.vid3_path, self.vid3_no_aud_path, self.vid4_path] + }] + tgt_list = [['Music', 'EMPTY', 'Music'], ['Music', 'EMPTY', 'Speech']] + op = VideoTaggingFromAudioMapper(self.hf_ast) + self._run_video_tagging_from_audio_mapper(op, ds_list, tgt_list) + + def test_multi_process(self): + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。' + f'{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': f'{SpecialTokens.video} 一个人在帮另一个人梳头发。 {SpecialTokens.eoc}', + 'videos': [self.vid4_path] + }, { + 'text': + f'{SpecialTokens.video} 一个穿着红色连衣裙的女人在试衣服。 {SpecialTokens.eoc}', + 'videos': [self.vid5_path] + }] + tgt_list = [['Music'], ['Music'], ['Speech'], ['Speech']] + op = VideoTaggingFromAudioMapper(self.hf_ast) + self._run_video_tagging_from_audio_mapper(op, + ds_list, + tgt_list, + num_proc=2) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_video_tagging_from_frames_mapper.py b/tests/ops/mapper/test_video_tagging_from_frames_mapper.py new file mode 100644 index 000000000..ea13109fc --- /dev/null +++ b/tests/ops/mapper/test_video_tagging_from_frames_mapper.py @@ -0,0 +1,244 @@ +# flake8: noqa: E501 +import os +import unittest + +from data_juicer.core.data import NestedDataset +from data_juicer.ops.mapper.video_tagging_from_frames_mapper import \ + VideoTaggingFromFramesMapper +from data_juicer.utils.constant import Fields +from data_juicer.utils.mm_utils import SpecialTokens +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class VideoTaggingFromFramesMapperTest(DataJuicerTestCaseBase): + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + vid1_path = os.path.join(data_path, 'video1.mp4') + vid2_path = os.path.join(data_path, 'video2.mp4') + vid3_path = os.path.join(data_path, 'video3.mp4') + + def _run_video_tagging_from_frames_mapper(self, + op, + source_list, + target_list, + num_proc=1): + dataset = NestedDataset.from_list(source_list) + dataset = dataset.map(op.process, num_proc=num_proc) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test(self): + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path], + Fields.video_frame_tags: [[ + 'animal', 'ray', 'text', 'writing', 'yellow', 'game', + 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', + 'sky' + ]] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path], + Fields.video_frame_tags: [[ + 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', + 'catch', 'hand', 'blind', 'cotton candy', 'ball', 'person' + ]] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path], + Fields.video_frame_tags: [[ + 'woman', 'table', 'girl', 'sit', 'person', 'laptop', + 'bookshelf', 'conversation', 'round table', 'computer', 'man', + 'closet', 'stool', 'computer screen', 'laugh', 'cabinet', + 'hand', 'selfie', 'stand' + ]] + }] + op = VideoTaggingFromFramesMapper() + self._run_video_tagging_from_frames_mapper(op, ds_list, tgt_list) + + def test_uniform(self): + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path], + Fields.video_frame_tags: [[ + 'animal', 'cartoon', 'anime', 'game', 'screenshot', + 'video game', 'robe', 'ray', 'text', 'writing', 'yellow', + 'doll', 'tail', 'cartoon character', 'sky', 'person' + ]] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path], + Fields.video_frame_tags: [[ + 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', + 'hand', 'catch', 'bulletin board', 'blind', 'play', 'Wii', + 'cotton candy', 'tennis racket', 'game controller', 'remote', + 'stand', 'video game', 'Wii controller', 'racket', + 'baseball uniform', 'toy', 'green' + ]] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path], + Fields.video_frame_tags: [[ + 'table', 'sit', 'woman', 'bookshelf', 'conversation', 'person', + 'round table', 'computer', 'girl', 'laptop', 'man', 'closet', + 'stand', 'computer screen', 'talk', 'room', 'stool', 'hand', + 'point' + ]] + }] + op = VideoTaggingFromFramesMapper(frame_sampling_method='uniform', + frame_num=10) + self._run_video_tagging_from_frames_mapper(op, ds_list, tgt_list) + + def test_multi_process(self): + # WARNING: current parallel tests only work in spawn method + import multiprocess + original_method = multiprocess.get_start_method() + multiprocess.set_start_method('spawn', force=True) + # WARNING: current parallel tests only work in spawn method + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path], + Fields.video_frame_tags: [[ + 'animal', 'ray', 'text', 'writing', 'yellow', 'game', + 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', + 'sky' + ]] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path], + Fields.video_frame_tags: [[ + 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', + 'catch', 'hand', 'blind', 'cotton candy', 'ball', 'person' + ]] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path], + Fields.video_frame_tags: [[ + 'woman', 'table', 'girl', 'sit', 'person', 'laptop', + 'bookshelf', 'conversation', 'round table', 'computer', 'man', + 'closet', 'stool', 'computer screen', 'laugh', 'cabinet', + 'hand', 'selfie', 'stand' + ]] + }] + op = VideoTaggingFromFramesMapper() + self._run_video_tagging_from_frames_mapper(op, + ds_list, + tgt_list, + num_proc=2) + # WARNING: current parallel tests only work in spawn method + multiprocess.set_start_method(original_method, force=True) + # WARNING: current parallel tests only work in spawn method + + def test_multi_chunk(self): + ds_list = [{ + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。', + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid2_path, self.vid3_path] + }, { + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid1_path, self.vid3_path] + }] + tgt_list = [{ + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。', + 'videos': [self.vid1_path, self.vid2_path], + Fields.video_frame_tags: + [[ + 'animal', 'ray', 'text', 'writing', 'yellow', 'game', + 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', + 'sky' + ], + [ + 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', + 'catch', 'hand', 'blind', 'cotton candy', 'ball', 'person' + ]] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid2_path, self.vid3_path], + Fields.video_frame_tags: [[ + 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', + 'catch', 'hand', 'blind', 'cotton candy', 'ball', 'person' + ], + [ + 'woman', 'table', 'girl', 'sit', + 'person', 'laptop', 'bookshelf', + 'conversation', 'round table', + 'computer', 'man', 'closet', 'stool', + 'computer screen', 'laugh', + 'cabinet', 'hand', 'selfie', 'stand' + ]] + }, { + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid1_path, self.vid3_path], + Fields.video_frame_tags: [[ + 'animal', 'ray', 'text', 'writing', 'yellow', 'game', + 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', + 'sky' + ], + [ + 'woman', 'table', 'girl', 'sit', + 'person', 'laptop', 'bookshelf', + 'conversation', 'round table', + 'computer', 'man', 'closet', 'stool', + 'computer screen', 'laugh', + 'cabinet', 'hand', 'selfie', 'stand' + ]] + }] + op = VideoTaggingFromFramesMapper() + self._run_video_tagging_from_frames_mapper(op, ds_list, tgt_list) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_whitespace_normalization_mapper.py b/tests/ops/mapper/test_whitespace_normalization_mapper.py index 0bffdf60c..985cc7076 100644 --- a/tests/ops/mapper/test_whitespace_normalization_mapper.py +++ b/tests/ops/mapper/test_whitespace_normalization_mapper.py @@ -2,9 +2,10 @@ from data_juicer.ops.mapper.whitespace_normalization_mapper import \ WhitespaceNormalizationMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class WhitespaceNormalizationMapperTest(unittest.TestCase): +class WhitespaceNormalizationMapperTest(DataJuicerTestCaseBase): def setUp(self): self.op = WhitespaceNormalizationMapper() diff --git a/tests/ops/selector/test_frequency_specified_field_selector.py b/tests/ops/selector/test_frequency_specified_field_selector.py index 8e6e32440..4593e83ef 100644 --- a/tests/ops/selector/test_frequency_specified_field_selector.py +++ b/tests/ops/selector/test_frequency_specified_field_selector.py @@ -4,9 +4,10 @@ from data_juicer.ops.selector.frequency_specified_field_selector import \ FrequencySpecifiedFieldSelector +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class FrequencySpecifiedFieldSelectorTest(unittest.TestCase): +class FrequencySpecifiedFieldSelectorTest(DataJuicerTestCaseBase): def _run_frequency_selector(self, dataset: Dataset, target_list, op): dataset = op.process(dataset) diff --git a/tests/ops/selector/test_topk_specified_field_selector.py b/tests/ops/selector/test_topk_specified_field_selector.py index 0f386a1e2..f10129ded 100644 --- a/tests/ops/selector/test_topk_specified_field_selector.py +++ b/tests/ops/selector/test_topk_specified_field_selector.py @@ -4,9 +4,10 @@ from data_juicer.ops.selector.topk_specified_field_selector import \ TopkSpecifiedFieldSelector +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -class TopkSpecifiedFieldSelectorTest(unittest.TestCase): +class TopkSpecifiedFieldSelectorTest(DataJuicerTestCaseBase): def _run_topk_selector(self, dataset: Dataset, target_list, op): dataset = op.process(dataset) diff --git a/tests/ops/test_op_fusion.py b/tests/ops/test_op_fusion.py index 7f13ad431..ad50ba472 100644 --- a/tests/ops/test_op_fusion.py +++ b/tests/ops/test_op_fusion.py @@ -1,11 +1,12 @@ import unittest from data_juicer.ops.load import load_ops -from data_juicer.utils.unittest_utils import SKIPPED_TESTS +from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, + DataJuicerTestCaseBase) @SKIPPED_TESTS.register_module() -class OpFusionTest(unittest.TestCase): +class OpFusionTest(DataJuicerTestCaseBase): def _run_op_fusion(self, original_process_list, target_process_list): new_process_list, _ = load_ops(original_process_list, op_fusion=True) @@ -165,9 +166,9 @@ def test_regular_config(self): } }, { - 'OpFusion:(words_num_filter,word_repetition_filter,stopwords_filter,flagged_words_filter,perplexity_filter)': + 'OpFusion:(words_num_filter,word_repetition_filter,stopwords_filter,flagged_words_filter,perplexity_filter)': # noqa: E501 [ - { # noqa: E501 + { 'words_num_filter': { 'lang': 'en', 'max_num': 100000, @@ -622,9 +623,9 @@ def test_multiple_groups(self): } }, { - 'OpFusion:(words_num_filter,word_repetition_filter,perplexity_filter)': + 'OpFusion:(words_num_filter,word_repetition_filter,perplexity_filter)': # noqa: E501 [ - { # noqa: E501 + { 'words_num_filter': { 'lang': 'en', 'max_num': 100000, @@ -713,9 +714,9 @@ def test_only_fusible_ops(self): } }] target_process = [{ - 'OpFusion:(words_num_filter,word_repetition_filter,stopwords_filter,flagged_words_filter,perplexity_filter)': + 'OpFusion:(words_num_filter,word_repetition_filter,stopwords_filter,flagged_words_filter,perplexity_filter)': # noqa: E501 [ - { # noqa: E501 + { 'words_num_filter': { 'lang': 'en', 'max_num': 100000, @@ -931,9 +932,9 @@ def test_different_intermediate_vars(self): } }, { - 'OpFusion:(average_line_length_filter,maximum_line_length_filter)': + 'OpFusion:(average_line_length_filter,maximum_line_length_filter)': # noqa: E501 [ - { # noqa: E501 + { 'average_line_length_filter': { 'min_len': 10, 'text_key': 'text', @@ -948,9 +949,9 @@ def test_different_intermediate_vars(self): ] }, { - 'OpFusion:(words_num_filter,word_repetition_filter,stopwords_filter,flagged_words_filter,perplexity_filter)': + 'OpFusion:(words_num_filter,word_repetition_filter,stopwords_filter,flagged_words_filter,perplexity_filter)': # noqa: E501 [ - { # noqa: E501 + { 'words_num_filter': { 'lang': 'en', 'max_num': 100000, diff --git a/tools/multimodal/README.md b/tools/multimodal/README.md index 33f2ddcb4..7b1426c2d 100644 --- a/tools/multimodal/README.md +++ b/tools/multimodal/README.md @@ -69,11 +69,15 @@ These tools consist of two types: For now, dataset formats that are supported by Data-Juicer are listed in the following table. -| Format | Type | source_format_to_data_juicer_format | data_juicer_format_to_target_format | Ref. | -|------------|------------|-------------------------------------|-------------------------------------|------------------------------------------------------------------------------------------------------------------| -| LLaVA-like | image-text | `llava_to_dj.py` | `dj_to_llava.py` | [Format Description](https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md#dataset-format) | -| MMC4-like | image-text | `mmc4_to_dj.py` | `dj_to_mmc4.py` | [Format Description](https://github.com/allenai/mmc4#documents) | -| WavCaps-like | audio-text | `wavcaps_to_dj.py` | `dj_to_wavcaps.py` | [Format Description](https://github.com/XinhaoMei/WavCaps#table-of-contents) | +| Format | Type | source_format_to_data_juicer_format | data_juicer_format_to_target_format | Ref. | +|--------------------|------------|-------------------------------------|-------------------------------------|------------------------------------------------------------------------------------------------------------------| +| LLaVA-like | image-text | `llava_to_dj.py` | `dj_to_llava.py` | [Format Description](https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md#dataset-format) | +| MMC4-like | image-text | `mmc4_to_dj.py` | `dj_to_mmc4.py` | [Format Description](https://github.com/allenai/mmc4#documents) | +| WavCaps-like | audio-text | `wavcaps_to_dj.py` | `dj_to_wavcaps.py` | [Format Description](https://github.com/XinhaoMei/WavCaps#table-of-contents) | +| Video-ChatGPT-like | video-text | `video_chatgpt_to_dj.py` | `dj_to_video_chatgpt.py` | [Format Description]( https://github.com/mbzuai-oryx/Video-ChatGPT/tree/main/data) | | +| Youku-mPLUG-like | video-text | `youku_to_dj.py` | `dj_to_youku.py` | [Format Description](https://modelscope.cn/datasets/modelscope/Youku-AliceMind/summary) | | +| InternVid-like | video-text | `internvid_to_dj.py` | `dj_to_internvid.py` | [Format Description](https://huggingface.co/datasets/OpenGVLab/InternVid) | | + For all tools, you can run the following command to find out the usage of them: @@ -161,7 +165,7 @@ Users should be cautious about this point if you need this matrix in later usage Despite these extra fields, tools for MMC4 can perfectly convert MMC4-like datasets to Data-Juicer-format datasets and convert them back~ -### WavCaps-like +#### WavCaps-like The [WavCaps](https://github.com/XinhaoMei/WavCaps#dataset) is composed of four sub-datasets: [FreeSound](https://freesound.org/), [BBC Sound Effects](https://sound-effects.bbcrewind.co.uk/),[SoundBible](https://soundbible.com/) and [AudioSet Strongly-labelled Subset](https://research.google.com/audioset/download_strong.html). Each sub-dataset has different fields. For example, the 'description' field is included in SoundBible, but does not exist in AudioSet. To ensure that the different sub-datasets can be properly merged after conversion, the union of all fields from the sub-datasets is used during the wavcaps_to_dj stage, and all fields are fully retained during the dj_to_wavcaps stage. @@ -196,3 +200,35 @@ The [WavCaps](https://github.com/XinhaoMei/WavCaps#dataset) is composed of four "tags": "" }] } ``` + +#### Video-ChatGPT-like + +The Video-ChatGPT dataset contains 3 types of data with unified format: +- Topics for Video summarization +- Description-based question-answers (exploring spatial, temporal, relationships, and reasoning concepts); +- and Creative/generative question-answers. +They all obey the `` format, where the `video_id` is in the form "v_youtube_id". We suppose that users have downloaded these videos already, and they need to specify the corresponding storage directory when using the converter tool. + + + +#### Youku-mPLUG-like + +The Youku-mPLUG dataset contains 4 types of format: pretrain, classification, retrieval, captioning. +They are slightly different from each other in field name or other attributes, but all of them obey the `` format. + +#### InternVid-like + +The InternVid dataset contains 4 fields: +- `YoutubeID`: the Youtube ID of the video used in the sample. +We suppose that users have downloaded these videos already +and this field is replaced with its storage path. +- `Start_timestamp`: the start timestamp in string of the video clip for the +corresponding caption. +- `End_timestamp`: the end timestamp in string of the video clip for the +corresponding caption. +- `Caption`: the corresponding caption for the video clip. + +As we can see, the caption in this dataset corresponds to the video clip +specified by the start/end timestamps instead of the whole video. So the +conversion tool will cut the specified video clip for you if the argument +`cut_videos` is set to True. You can cut before conversion by yourself as well. diff --git a/tools/multimodal/README_ZH.md b/tools/multimodal/README_ZH.md index 996bdbb54..8e29df36f 100644 --- a/tools/multimodal/README_ZH.md +++ b/tools/multimodal/README_ZH.md @@ -62,11 +62,14 @@ 目前,Data-Juicer 支持的数据集格式在下面表格中列出。 -| 格式 | 类型 | source_format_to_data_juicer_format | data_juicer_format_to_target_format | 格式参考 | -|----------|-------|-------------------------------------|-------------------------------------|----------------------------------------------------------------------------------------------------| -| 类LLaVA格式 | 图像-文本 | `llava_to_dj.py` | `dj_to_llava.py` | [格式描述](https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md#dataset-format) | -| 类MMC4格式 | 图像-文本 | `mmc4_to_dj.py` | `dj_to_mmc4.py` | [格式描述](https://github.com/allenai/mmc4#documents) | -| 类WavCaps格式 | 音频-文本 | `wavcaps_to_dj.py` | `dj_to_wavcaps.py` | [格式描述](https://github.com/XinhaoMei/WavCaps#table-of-contents) | +| 格式 | 类型 | source_format_to_data_juicer_format | data_juicer_format_to_target_format | 格式参考 | +|------------------|-------|-------------------------------------|-------------------------------------|----------------------------------------------------------------------------------------------------| +| 类LLaVA格式 | 图像-文本 | `llava_to_dj.py` | `dj_to_llava.py` | [格式描述](https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md#dataset-format) | +| 类MMC4格式 | 图像-文本 | `mmc4_to_dj.py` | `dj_to_mmc4.py` | [格式描述](https://github.com/allenai/mmc4#documents) | +| 类WavCaps格式 | 音频-文本 | `wavcaps_to_dj.py` | `dj_to_wavcaps.py` | [格式描述](https://github.com/XinhaoMei/WavCaps#table-of-contents) | +| 类Video-ChatGPT格式 |视频-文本 | `video_chatgpt_to_dj.py` | `dj_to_video_chatgpt.py` | [格式描述]( https://github.com/mbzuai-oryx/Video-ChatGPT/tree/main/data) | | +| 类Youku-mPLUG格式 | 视频-文本 | `youku_to_dj.py` | `dj_to_youku.py` | [格式描述](https://modelscope.cn/datasets/modelscope/Youku-AliceMind/summary) | | +| 类InternVid格式 | 视频-文本 | `internvid_to_dj.py` | `dj_to_internvid.py` | [格式描述](https://huggingface.co/datasets/OpenGVLab/InternVid) | | 对于所有工具,您可以运行以下命令来了解它们的详细用法: @@ -168,3 +171,27 @@ python tools/multimodal/source_format_to_data_juicer_format/llava_to_dj.py --hel "tags": "" }] } ``` + +#### 类Video-ChatGPT格式 +Video-ChatGPT数据集包含3种统一格式的数据: +- 视频摘要主题 +- 基于描述的问题答案(探索空间、时间、关系和推理概念); +- 以及创意/生成性问题解答。 +它们都遵循“”格式,其中“video_id”表示为YouTube视频的id:“v_youtube_id”。 我们假设用户已经下载了这些视频,在使用转换工具时需要指定相应的存储目录。 +#### 类Youku-mPLUG格式 + +Youku-mPLUG数据集中一共有4种类型的格式:pretrain,classification, +retrieval,captioning。它们在字段名称或者其他属性上会有轻微的差异,但是所有类型都遵从 `` 的格式。 + +#### 类InternVid格式 + +InternVid数据集包括4个字段: +- `YoutubeID`: 样本中使用的视频的Youtube ID。我们假设用户已经下载了这些视频, +并且这个字段已经被替换为了视频的存储路径。 +- `Start_timestamp`: 与caption对应的视频片段的开始时间戳字符串。 +- `End_timestamp`: 与caption对应的视频片段的结束时间戳字符串 +- `Caption`: 与视频片段对应的caption。 + +正如我们看到,该数据集中的caption对应到了一段由开始/结束时间戳指定的视频片段,而非整段视频。 +因此,如果 `cut_videos` 参数设置为 True,针对该数据集的转换工具会为您剪辑出指定的视频片段。 +您也可以在转换前自行对下载的视频进行剪辑。 diff --git a/tools/multimodal/__init__.py b/tools/multimodal/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tools/multimodal/data_juicer_format_to_target_format/dj_to_internvid.py b/tools/multimodal/data_juicer_format_to_target_format/dj_to_internvid.py new file mode 100644 index 000000000..89d1e8a38 --- /dev/null +++ b/tools/multimodal/data_juicer_format_to_target_format/dj_to_internvid.py @@ -0,0 +1,154 @@ +# This tool is used to convert multimodal dataset in Data-Juicer format to a +# target dataset in InternVid format. +# +# Data-Juicer format: +# - two extra fields: +# - text: a chunk of text with the video special token. +# - videos: video paths list, including cut videos according to their timestamps # noqa: E501 +# - other fields in the original format can be kept or not +# - in jsonl +# {'videos': ['videos/qJrOyggIB-w-cut.mp4'], +# 'text': 'a screen shot of heroes of the storm with people in action', +# 'Start_timestamp': '00:07:33.689', +# 'End_timestamp': '00:07:51.085', +# 'Aesthetic_Score': 4.29296875, +# 'UMT_Score': 0.4501953125} +# +# Corresponding InternVid format: +# - in jsonl +# - restore to Caption and YoutubeID +# {'YoutubeID': 'videos/qJrOyggIB-w.mp4', +# 'Start_timestamp': '00:07:33.689', +# 'End_timestamp': '00:07:51.085', +# 'Caption': 'a screen shot of heroes of the storm with people in action', +# 'Aesthetic_Score': 4.29296875, +# 'UMT_Score': 0.4501953125} +# +# Reference: +# https://huggingface.co/datasets/OpenGVLab/InternVid + +import os + +import fire +import jsonlines as jl +from loguru import logger +from tqdm import tqdm + +from data_juicer.utils.mm_utils import SpecialTokens +from tools.multimodal.utils import remove_dj_special_tokens + + +def main( + dj_ds_path: str, + target_internvid_ds_path: str, + eoc_special_token: str = SpecialTokens.eoc, + video_special_token: str = SpecialTokens.video, + sent_seperator: str = ' ', + convert_to_relative_paths: bool = False, + original_internvid_ds_path: str = None, +): + """ + Convert a Data-Juicer-format dataset to a InternVid-like dataset. + + :param dj_ds_path: path to the input dataset in Data-Juicer format. + :param target_internvid_ds_path: path to store the converted dataset in + InternVid format. + :param eoc_special_token: the special token for "end of a chunk". It's used + to split sentence chunks explicitly. Default: <|__dj__eoc|> (from + Data-Juicer). + :param video_special_token: the special token for videos. It's used to + locate the videos in the text. In typical InternVide-like datasets, + this special token is not specified. So we simply use the default video + special token from our Data-Juicer. Default: <__dj__video> (from + Data-Juicer). + :param sent_seperator: seperator to split different sentences. Default: " " + :param convert_to_relative_paths: whether convert the video paths in this + dataset to relative paths to the original dataset. If it's True, an + extra argument original_internvid_ds_path is required. When the + processed and converted dataset will be used in another machine, it's + better to set this argument to True. Default: False. + :param original_internvid_ds_path: path to the original unprocessed + InternVid dataset, which is used to help to recover the relative video + paths for better migration. Default: None. + """ + # ----- Constant settings. Better not to change them. ----- + text_key = 'text' # default key of field to store the sample text + video_key = 'videos' # default key of field to store the video list + tgt_text_key = 'Caption' # default target key of field to store texts + tgt_video_key = 'YoutubeID' # default target field to store videos + # ----- Constant settings. Better not to change them. ----- + + # check arguments + # check paths + if not os.path.exists(dj_ds_path): + raise FileNotFoundError( + f'Input dataset [{dj_ds_path}] can not be found.') + if not target_internvid_ds_path.endswith('.jsonl'): + raise ValueError( + 'Only support "jsonl" target dataset file for InternVid now.') + if os.path.dirname(target_internvid_ds_path) \ + and not os.path.exists(os.path.dirname(target_internvid_ds_path)): + logger.info( + f'Create directory [{os.path.dirname(target_internvid_ds_path)}] ' + f'for the target dataset.') + os.makedirs(os.path.dirname(target_internvid_ds_path)) + # if convert_to_relative_paths is True, check if the + # original_internvid_ds_path is provided as well. + if convert_to_relative_paths: + if not original_internvid_ds_path: + raise ValueError('When convert_to_relative_paths is set to True, ' + 'the original_internvid_ds_path must be provided ' + 'for recovering the relative paths. Please ' + 'check and retry.') + original_internvid_ds_path = os.path.abspath( + original_internvid_ds_path) + # if provided original_internvid_ds_path is the dataset file path, only + # keep the directory path. + if os.path.isfile(original_internvid_ds_path): + original_internvid_ds_path = os.path.dirname( + original_internvid_ds_path) + + # save InternVid dataset from Data-Juicer format + logger.info('Start converting the original dataset to InternVid format...') + with jl.open(dj_ds_path) as reader: + with jl.open(target_internvid_ds_path, mode='w') as writer: + for line_num, s in enumerate(tqdm(reader)): + video = s.pop(video_key)[0] + text = s.pop(text_key) + + new_sample = {} + # add other fields + for key in s: + new_sample[key] = s[key] + + # add video + if convert_to_relative_paths: + if video.startswith(original_internvid_ds_path): + video = os.path.relpath(video, + original_internvid_ds_path) + else: + raise ValueError( + f'The original_internvid_ds_path ' + f'[{original_internvid_ds_path}] is not the ' + f'directory that contains the video ' + f'[{video}] in the sample of line number ' + f'[{line_num}]. Please check if the correct ' + f'original_internvid_ds_path is provided or ' + f'something wrong with this sample, and try ' + f'again later.') + new_sample[tgt_video_key] = video + + # add caption + text = remove_dj_special_tokens(text.strip(), + eoc_special_token, + sent_seperator, + video_special_token) + + new_sample[tgt_text_key] = text + + writer.write(new_sample) + logger.info(f'Store the target dataset into [{target_internvid_ds_path}].') + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/tools/multimodal/data_juicer_format_to_target_format/dj_to_mmc4.py b/tools/multimodal/data_juicer_format_to_target_format/dj_to_mmc4.py index c3ccf8faa..8c7c4dc0d 100644 --- a/tools/multimodal/data_juicer_format_to_target_format/dj_to_mmc4.py +++ b/tools/multimodal/data_juicer_format_to_target_format/dj_to_mmc4.py @@ -128,7 +128,7 @@ def main( Data-Juicer, such as "images", "text", ... Default: False. :param convert_to_relative_paths: whether convert the image paths in this dataset to relative paths to the original dataset. If it's True, an - extra argument original_llava_ds_path is required. When the processed + extra argument original_mmc4_ds_path is required. When the processed and converted dataset will be used in another machine, it's better to set this argument to True. Default: False. :param original_mmc4_ds_path: path to the original unprocessed MMC4 @@ -157,7 +157,7 @@ def main( f'the target dataset.') os.makedirs(os.path.dirname(target_mmc4_ds_path)) - # if convert_to_relative_paths is True, check if the original_llava_ds_path + # if convert_to_relative_paths is True, check if the original_mmc4_ds_path # is provided as well. if convert_to_relative_paths: if not original_mmc4_ds_path: diff --git a/tools/multimodal/data_juicer_format_to_target_format/dj_to_video_chatgpt.py b/tools/multimodal/data_juicer_format_to_target_format/dj_to_video_chatgpt.py new file mode 100644 index 000000000..27b1ddb9f --- /dev/null +++ b/tools/multimodal/data_juicer_format_to_target_format/dj_to_video_chatgpt.py @@ -0,0 +1,167 @@ +# This tool is used to convert multimodal dataset in Data-Juicer format to a +# target dataset in Video-ChatGPT format. +# +# Reference: +# https://github.com/mbzuai-oryx/Video-ChatGPT/tree/main/data +# +# # Corresponding Data-Juicer format: +# - two new fields to store the main data: 'youtube_id' and 'text', +# the 'videos' is actual path to the video files +# {'youtube_id': 'k_ZXmr8pmrs', +# 'videos': ['youtube_video_dir/v_k_ZXmr8pmrs.mp4'], +# 'text': +# '<__dj__video>' +# '[[q]]: What are the main activities that take place in the video? \n' +# '[[a]]: The main activities that take place in the video are the +# preparation of camera equipment by a man.... <|__dj__eoc|>' +# } +# +# Video-ChatGPT format: +# - Topics for Video summarization; Description-based question-answers +# (exploring spatial, temporal, relationships, and reasoning concepts); +# and Creative/generative question-answers +# - in json file, a single line storing text-video tuples in a list, +# below is an example +# +# [{'q': 'What are the main activities that take place in the video?', +# 'a': 'The main activities that take place in the video are the preparation of +# camera equipment by a man, a group of men riding a helicopter, and a man +# sailing a boat through the water.', +# 'video_id': 'v_k_ZXmr8pmrs'}, ...] +# +import json +import os + +import fire +import jsonlines as jl +from loguru import logger +from tqdm import tqdm + +from data_juicer.utils.mm_utils import SpecialTokens +from tools.multimodal.utils import remove_dj_special_tokens + + +def main( + dj_ds_path: str, + target_video_chatgpt_ds_path: str, + eoc_special_token: str = SpecialTokens.eoc, + video_special_token: str = SpecialTokens.video, + sent_seperator: str = ' ', + convert_to_relative_paths: bool = False, + original_video_chatgpt_ds_path: str = None, +): + """ + Convert a Data-Juicer-format dataset to a Video-ChatGPT-like dataset. + + :param dj_ds_path: path to the input dataset in Data-Juicer format. + :param target_video_chatgpt_ds_path: path to store the converted dataset in + Video-ChatGPT format. + :param eoc_special_token: the special token for "end of a chunk". It's used + to split sentence chunks explicitly. Default: <|__dj__eoc|> (from + Data-Juicer). + :param video_special_token: the special token for videos. It's used to + locate the videos in the text. In typical Video-ChatGPT-like datasets, + this special token is not specified. So we simply use the default video + special token from our Data-Juicer. Default: <__dj__video> (from + Data-Juicer). + :param sent_seperator: seperator to split different sentences. Default: " " + :param convert_to_relative_paths: whether convert the video paths in this + dataset to relative paths to the original dataset. If it's True, an + extra argument original_video_chatgpt_ds_path is required. When the + processed and converted dataset will be used in another machine, it's + better to set this argument to True. Default: False. + :param original_video_chatgpt_ds_path: path to the original unprocessed + Video-ChatGPT dataset, which is used to help to recover the relative video + paths for better migration. Default: None. + """ + # ----- Constant settings. Better not to change them. ----- + text_key = 'text' # default key of field to store the sample text + video_key = 'videos' # default key of field to store video files path + video_id_key = 'youtube_id' # default key of field to store youtube id + tgt_q_key = 'q' # default original key of field to store texts + tgt_a_key = 'a' + tgt_video_key = 'video_id' # default original field to store videos + # ----- Constant settings. Better not to change them. ----- + + # check arguments + if not os.path.exists(dj_ds_path): + raise FileNotFoundError(f'Input Video_ChatGPT dataset in dj format, ' + f'[{dj_ds_path}], can not be found.') + if not target_video_chatgpt_ds_path.endswith('.json'): + raise ValueError( + 'Only support "json" target dataset file for Video_ChatGPT now.') + if (os.path.dirname(target_video_chatgpt_ds_path) and + not os.path.exists(os.path.dirname(target_video_chatgpt_ds_path))): + logger.info(f'Create directory ' + f'[{os.path.dirname(target_video_chatgpt_ds_path)}] ' + f'for the target dataset.') + os.makedirs(os.path.dirname(target_video_chatgpt_ds_path)) + + # if convert_to_relative_paths is True, check if the + # original_video_chatgpt_ds_path is provided as well. + if convert_to_relative_paths: + if not original_video_chatgpt_ds_path: + raise ValueError( + 'When convert_to_relative_paths is set to True, ' + 'the original_video_chatgpt_ds_path must be provided ' + 'for recovering the relative paths. Please ' + 'check and retry.') + original_video_chatgpt_ds_path = os.path.abspath( + original_video_chatgpt_ds_path) + # if provided original_video_chatgpt_ds_path is the dataset file path, + # only keep the directory path. + if os.path.isfile(original_video_chatgpt_ds_path): + original_video_chatgpt_ds_path = os.path.dirname( + original_video_chatgpt_ds_path) + + # save Video-ChatGPT dataset from Data-Juicer format + logger.info('Start converting the DJ dataset to Video-ChatGPT format...') + all_samples = [] + with jl.open(dj_ds_path) as reader: + for line_num, s in enumerate(tqdm(reader)): + video_path = s.pop(video_key)[0] + new_sample = {} + + video_id = s.pop(video_id_key) + new_sample[tgt_video_key] = 'v_' + video_id + # add video + if convert_to_relative_paths: + if video_path.startswith(original_video_chatgpt_ds_path): + video_path = os.path.relpath( + video_path, original_video_chatgpt_ds_path) + else: + raise ValueError( + f'The original_video_chatgpt_ds_path ' + f'[{original_video_chatgpt_ds_path}] is not the ' + f'directory that contains the video ' + f'[{video_path}] in the sample of line number ' + f'[{line_num}]. Please check if the correct ' + f'original_video_chatgpt_ds_path is provided or ' + f'something wrong with this sample, and try ' + f'again later.') + new_sample[video_key] = video_path + + # add question and answer + text = s.pop(text_key).strip() + text = remove_dj_special_tokens(text, eoc_special_token, + sent_seperator, + video_special_token) + # get the question and answer + parts = text.split(f'[[{tgt_q_key}]]:')[1] + q, a = parts.split(f'[[{tgt_a_key}]]:') + new_sample[tgt_q_key] = q.strip() + new_sample[tgt_a_key] = a.strip() + + # add other fields + for key in s: + if key not in [tgt_q_key, tgt_a_key]: + new_sample[key] = s[key] + + all_samples.append(new_sample) + with open(target_video_chatgpt_ds_path, 'w') as file: + json.dump(all_samples, file) + logger.info(f'Store the dataset into [{target_video_chatgpt_ds_path}].') + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/tools/multimodal/data_juicer_format_to_target_format/dj_to_youku.py b/tools/multimodal/data_juicer_format_to_target_format/dj_to_youku.py new file mode 100644 index 000000000..923576b70 --- /dev/null +++ b/tools/multimodal/data_juicer_format_to_target_format/dj_to_youku.py @@ -0,0 +1,214 @@ +# This tool is used to convert multimodal dataset in Data-Juicer format to a +# target dataset in Youku-mPLUG-like format. +# +# Corresponding Data-Juicer format: +# - two new fields to store the main data: +# - text: a chunk of text with the video special token. +# - videos: video paths list +# - other fields in the original format can be kept or not +# - in jsonl +# +# Youku-mPLUG-pretrain Data-Juicer format: +# {'videos': ['videos/pretrain/14111Y1211b-1134b18bAE55bFE7Jbb7135YE3aY54EaB14ba7CbAa1AbACB24527A.flv'], # noqa: E501 +# 'text': '<__dj__video> 妈妈给宝宝听胎心,看看宝宝是怎么做的,太调皮了 <|__dj__eoc|>'} +# +# Youku-mPLUG-classification Data-Juicer format: +# {'videos': ['videos/classification/14111B1211bFBCBCJYF48B55b7523C51F3-8b3a5YbCa5817aBb38a5YAC-241F71J.mp4'], # noqa: E501 +# 'text': '<__dj__video> 兔宝宝刚出生,为什么兔妈妈要把它们吃掉?看完涨见识了 <|__dj__eoc|>', +# 'label': '宠物'} +# +# Youku-mPLUG-retrieval Data-Juicer format: +# {'videos': ['videos/retrieval/14111B1211bA1F31-E-2YB57--C518FCEBC553abJYFa5541a31C8a57522AYJbYF4aTdfofa112.mp4'], # noqa: E501 +# 'text': '<__dj__video> 身穿黑色上衣戴着头盔的女子在路上骑着摩托车四周还停放了一些车 <|__dj__eoc|>'} +# +# Youku-mPLUG-captioning Data-Juicer format: +# {'videos': ['videos/caption/14111B1211bEJB-1b3b-J3b7b8Y213BJ32-521a1EA8a53-3aBA72aA-4-2-CF1EJ8aTdfofa114.mp4'], # noqa: E501 +# 'text': '<__dj__video> 穿白色球服的女生高高跳起,接住了球。 <|__dj__eoc|> <__dj__video> 一排穿红色短袖的女生正在接受颁奖。 <|__dj__eoc|>']} # noqa: E501 +# +# Youku-mPLUG format: +# - 4 types: pretrain, classification, retrieval, captioning +# - in csv +# - text-video pair with other fields (label, ...) +# +# Youku-mPLUG-pretrain format: +# {'video_id:FILE': 'videos/pretrain/14111Y1211b-1134b18bAE55bFE7Jbb7135YE3aY54EaB14ba7CbAa1AbACB24527A.flv', # noqa: E501 +# 'title': '妈妈给宝宝听胎心,看看宝宝是怎么做的,太调皮了'} +# +# Youku-mPLUG-classification format: +# {'video_id:FILE': 'videos/classification/14111B1211bFBCBCJYF48B55b7523C51F3-8b3a5YbCa5817aBb38a5YAC-241F71J.mp4', # noqa: E501 +# 'title': '兔宝宝刚出生,为什么兔妈妈要把它们吃掉?看完涨见识了', +# 'label': '宠物'} +# +# Youku-mPLUG-retrieval format: +# {'clip_name:FILE': 'videos/retrieval/14111B1211bA1F31-E-2YB57--C518FCEBC553abJYFa5541a31C8a57522AYJbYF4aTdfofa112.mp4', # noqa: E501 +# 'caption': '身穿黑色上衣戴着头盔的女子在路上骑着摩托车四周还停放了一些车'} +# +# Youku-mPLUG-captioning format: +# {'video_id:FILE': 'videos/caption/14111B1211bEJB-1b3b-J3b7b8Y213BJ32-521a1EA8a53-3aBA72aA-4-2-CF1EJ8aTdfofa114.mp4', # noqa: E501 +# 'golden_caption': '穿白色球服的女生高高跳起,接住了球。']} +# +# Reference: +# https://modelscope.cn/datasets/modelscope/Youku-AliceMind/summary + +import csv +import os + +import fire +import jsonlines as jl +from loguru import logger +from tqdm import tqdm + +from data_juicer.utils.mm_utils import SpecialTokens +from tools.multimodal.utils import remove_dj_special_tokens + + +def main( + dj_ds_path: str, + target_youku_ds_path: str, + eoc_special_token: str = SpecialTokens.eoc, + video_special_token: str = SpecialTokens.video, + sent_seperator: str = ' ', + subset_type: str = 'classification', + convert_to_relative_paths: bool = False, + original_youku_ds_path: str = None, +): + """ + Convert a Data-Juicer-format dataset to a Youku-mPLUG-like dataset. + + :param dj_ds_path: path to the input dataset in Data-Juicer format. + :param target_youku_ds_path: path to store the converted dataset in + Youku-mPLUG format. + :param eoc_special_token: the special token for "end of a chunk". It's used + to split sentence chunks explicitly. Default: <|__dj__eoc|> (from + Data-Juicer). + :param video_special_token: the special token for videos. It's used to + locate the videos in the text. In typical Youku-mPLUG-like datasets, + this special token is not specified. So we simply use the default video + special token from our Data-Juicer. Default: <__dj__video> (from + Data-Juicer). + :param sent_seperator: seperator to split different sentences. Default: " " + :param subset_type: the subset type of the input dataset. Should be one of + ["pretrain", "classification", "retrieval", "captioning"]. Default: + "classification". + :param convert_to_relative_paths: whether convert the video paths in this + dataset to relative paths to the original dataset. If it's True, an + extra argument original_youku_ds_path is required. When the processed + and converted dataset will be used in another machine, it's better to + set this argument to True. Default: False. + :param original_youku_ds_path: path to the original unprocessed Youku-mPLUG + dataset, which is used to help to recover the relative video paths for + better migration. Default: None. + """ + # ----- Constant settings. Better not to change them. ----- + text_key = 'text' # default key of field to store the sample text + video_key = 'videos' # default key of field to store the video list + fields_infos = { + 'pretrain': { + 'video_key': 'video_id:FILE', + 'text_key': 'title', + 'other_required_keys': [], + }, + 'classification': { + 'video_key': 'video_id:FILE', + 'text_key': 'title', + 'other_required_keys': ['label'], + }, + 'retrieval': { + 'video_key': 'clip_name:FILE', + 'text_key': 'caption', + 'other_required_keys': [], + }, + 'captioning': { + 'video_key': 'video_id:FILE', + 'text_key': 'golden_caption', + 'other_required_keys': [], + } + } + # ----- Constant settings. Better not to change them. ----- + + # check arguments + # check paths + if not os.path.exists(dj_ds_path): + raise FileNotFoundError( + f'Input dataset [{dj_ds_path}] can not be found.') + if not target_youku_ds_path.endswith('.csv'): + raise ValueError( + 'Only support "csv" target dataset file for Youku-mPLUG now.') + if os.path.dirname(target_youku_ds_path) \ + and not os.path.exists(os.path.dirname(target_youku_ds_path)): + logger.info( + f'Create directory [{os.path.dirname(target_youku_ds_path)}] for ' + f'the target dataset.') + os.makedirs(os.path.dirname(target_youku_ds_path)) + # check subset type + if subset_type not in fields_infos: + logger.error(f'Arg subset_type should be one of ["pretrain", ' + f'"classification", "retrieval", "captioning"], but ' + f'given [{subset_type}].') + tgt_video_key = fields_infos[subset_type]['video_key'] + tgt_text_key = fields_infos[subset_type]['text_key'] + tgt_required_keys = fields_infos[subset_type]['other_required_keys'] + + # if convert_to_relative_paths is True, check if the original_youku_ds_path + # is provided as well. + if convert_to_relative_paths: + if not original_youku_ds_path: + raise ValueError('When convert_to_relative_paths is set to True, ' + 'the original_youku_ds_path must be provided ' + 'for recovering the relative paths. Please ' + 'check and retry.') + original_youku_ds_path = os.path.abspath(original_youku_ds_path) + # if provided original_youku_ds_path is the dataset file path, only + # keep the directory path. + if os.path.isfile(original_youku_ds_path): + original_youku_ds_path = os.path.dirname(original_youku_ds_path) + + # save Youku-mPLUG dataset from Data-Juicer format + logger.info( + 'Start converting the original dataset to Youku-mPLUG format...') + with jl.open(dj_ds_path) as reader: + with open(target_youku_ds_path, 'w') as csvfile: + writer = csv.DictWriter(csvfile, + fieldnames=[tgt_video_key, tgt_text_key] + + tgt_required_keys) + # write headers first + writer.writeheader() + for line_num, s in enumerate(tqdm(reader)): + new_sample = {} + # add required fields + for key in tgt_required_keys: + if key not in s: + raise ValueError(f'Required key [{key}] is not in the ' + f'original Data-Juicer dataset.') + new_sample[key] = s[key] + + # add video, only keep the first one + video = s[video_key][0] + if convert_to_relative_paths: + if video.startswith(original_youku_ds_path): + video = os.path.relpath(video, original_youku_ds_path) + else: + raise ValueError( + f'The original_youku_ds_path ' + f'[{original_youku_ds_path}] is not the ' + f'directory that contains the video ' + f'[{video}] in the sample of line number ' + f'[{line_num}]. Please check if the correct ' + f'original_youku_ds_path is provided or ' + f'something wrong with this sample, and try ' + f'again later.') + new_sample[tgt_video_key] = video + + # add text, remove extra special tokens + text = s[text_key].strip() + text = remove_dj_special_tokens(text, eoc_special_token, + sent_seperator, + video_special_token) + new_sample[tgt_text_key] = text + + writer.writerow(new_sample) + logger.info(f'Store the target dataset into [{target_youku_ds_path}].') + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/tools/multimodal/source_format_to_data_juicer_format/internvid_to_dj.py b/tools/multimodal/source_format_to_data_juicer_format/internvid_to_dj.py new file mode 100644 index 000000000..93b51ccb8 --- /dev/null +++ b/tools/multimodal/source_format_to_data_juicer_format/internvid_to_dj.py @@ -0,0 +1,159 @@ +# This tool is used to convert multimodal dataset in InternVid format to a +# target dataset in Data-Juicer format. +# +# InternVid format: +# - in jsonl +# - caption-video pair with other fields (CLIP_Score, ...) +# - videos are from Youtube, and start/end timestamps are given +# - **Notice**: only YoutubeIDs are provided in the original dataset. +# Here we suppose that users have downloaded these videos already, +# and the YoutubeIDs are replaced with their video paths. +# {'YoutubeID': 'videos/qJrOyggIB-w.mp4', +# 'Start_timestamp': '00:07:33.689', +# 'End_timestamp': '00:07:51.085', +# 'Caption': 'a screen shot of heroes of the storm with people in action', +# 'Aesthetic_Score': 4.29296875, +# 'UMT_Score': 0.4501953125} +# +# Corresponding Data-Juicer format: +# - two new fields are added: +# - text: a chunk of text with the video special token. +# - videos: video paths list, including cut videos according to their timestamps # noqa: E501 +# - other fields in the original format can be kept or not +# - in jsonl +# {'videos': ['videos/qJrOyggIB-w-cut.mp4'], +# 'text': 'a screen shot of heroes of the storm with people in action', +# 'Start_timestamp': '00:07:33.689', +# 'End_timestamp': '00:07:51.085', +# 'Aesthetic_Score': 4.29296875, +# 'UMT_Score': 0.4501953125} +# +# +# Reference: +# https://huggingface.co/datasets/OpenGVLab/InternVid + +import os + +import fire +import jsonlines as jl +from loguru import logger +from tqdm import tqdm + +from data_juicer.utils.file_utils import add_suffix_to_filename +from data_juicer.utils.mm_utils import (SpecialTokens, cut_video_by_seconds, + timecode_string_to_seconds) +from tools.multimodal.utils import (check_args_load_to_dj_data, + convert_text_to_dj) + + +def main( + internvid_ds_path: str, + target_ds_path: str, + eoc_special_token: str = SpecialTokens.eoc, + video_special_token: str = SpecialTokens.video, + add_eoc_at_last: bool = True, + sent_seperator: str = ' ', + video_special_token_insert_pos: str = 'before', + cut_videos: bool = True, + cut_video_store_path: str = None, + keep_other_fields: bool = True, +): + """ + Convert an InternVid-like dataset to the Data-Juicer format. + + :param internvid_ds_path: path to the input InternVid-like dataset. + :param target_ds_path: path to store the converted dataset in Data-Juicer + format. + :param eoc_special_token: the special token for "end of a chunk". It's used + to split sentence chunks explicitly. Default: <|__dj__eoc|> (from + Data-Juicer). + :param video_special_token: the special token for videos. It's used to + locate the videos in the text. In typical InternVide-like datasets, + this special token is not specified. So we simply use the default video + special token from our Data-Juicer. Default: <__dj__video> (from + Data-Juicer). + :param add_eoc_at_last: whether to add an extra eoc_special_token at the + end of text. Default: False. + :param sent_seperator: seperator to split different sentences or tokens. + Default: " " + :param video_special_token_insert_pos: the position in the sentence to + insert the corresponding video special token. Should be one of: [ + "before", "after", "random"]. Default: "before". + :param cut_videos: whether to cut the videos into smaller ones according to + their start/end timestamps. Default: True. If you did this process + before converting, please set it to False. + :param cut_video_store_path: a path to store the cut videos. If cut_videos + is True and this path is None, store the cut videos into the same + directory as the original videos. + :param keep_other_fields: whether to keep other fields in the original + datasets. Default: False. + """ + # ----- Constant settings. Better not to change them. ----- + text_key = 'text' # default key of field to store the sample text + video_key = 'videos' # default key of field to store the video list + ori_text_key = 'Caption' # default original key of field to store texts + ori_video_key = 'YoutubeID' # default original field to store videos + # ----- Constant settings. Better not to change them. ----- + + input_ds_dir = os.path.dirname(internvid_ds_path) + + # check arguments + check_args_load_to_dj_data(add_eoc_at_last, keep_other_fields, + target_ds_path, internvid_ds_path, + video_special_token_insert_pos, '.jsonl') + if cut_videos: + logger.warning('You set the cut_videos arg to True. This tool will ' + 'take a video cut from the input video according to ' + 'the start/end timestamps.') + + # start conversion + logger.info('Start converting the original InternVid dataset...') + with jl.open(internvid_ds_path) as reader: + with jl.open(target_ds_path, mode='w') as writer: + for s in tqdm(reader): + video = s.pop(ori_video_key) + text = s.pop(ori_text_key) + + # convert text to data-juicer format + # add video special token + new_sample, text = convert_text_to_dj( + text, s, add_eoc_at_last, eoc_special_token, + keep_other_fields, sent_seperator, video_special_token, + video_special_token_insert_pos) + + # cut videos if needed + if cut_videos: + video = os.path.join(input_ds_dir, video) + cut_video_path = None + if cut_video_store_path is None: + # set it to the directory stores the original videos + cut_video_path = os.path.dirname( + os.path.abspath(video)) + else: + cut_video_path = cut_video_store_path + # cut the video and store in a new path + video_basename = os.path.basename(video) + new_video = os.path.join( + cut_video_path, + add_suffix_to_filename( + video_basename, + f'_{s["Start_timestamp"]}_{s["End_timestamp"]}')) + start_pts = timecode_string_to_seconds( + s['Start_timestamp']) + end_pts = timecode_string_to_seconds(s['End_timestamp']) + cut_video_by_seconds(video, new_video, start_pts, end_pts) + video = new_video + + new_sample[video_key] = [video] + new_sample[text_key] = text + if cut_videos: + # add a meta field to record whether this video is cut + new_sample['is_cut'] = True + else: + new_sample['is_cut'] = False + writer.write(new_sample) + logger.info(f'Store the target dataset into [{target_ds_path}].') + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/tools/multimodal/source_format_to_data_juicer_format/video_chatgpt_to_dj.py b/tools/multimodal/source_format_to_data_juicer_format/video_chatgpt_to_dj.py new file mode 100644 index 000000000..6aa1db738 --- /dev/null +++ b/tools/multimodal/source_format_to_data_juicer_format/video_chatgpt_to_dj.py @@ -0,0 +1,134 @@ +# This tool is used to convert multimodal dataset in Video_Chatgpt format to a +# target dataset in Data-Juicer format. +# +# Video-ChatGPT format: +# - Topics for Video summarization; Description-based question-answers +# (exploring spatial, temporal, relationships, and reasoning concepts); +# and Creative/generative question-answers +# - in json file, a single line storing text-video pairs in a list, +# below is an example +# +# [{'q': 'What are the main activities that take place in the video?', +# 'a': 'The main activities that take place in the video are the preparation of +# camera equipment by a man, a group of men riding a helicopter, and a man +# sailing a boat through the water.', +# 'video_id': 'v_k_ZXmr8pmrs'}, ...] +# +# +# # Corresponding Data-Juicer format: +# - two new fields to store the main data: 'youtube_id' and 'text', +# the 'videos' is actual path to the video files +# {'youtube_id': 'k_ZXmr8pmrs', +# 'videos': ['youtube_video_dir/v_k_ZXmr8pmrs.mp4'], +# 'text': +# '<__dj__video>' +# '[[q]]: What are the main activities that take place in the video? \n' +# '[[a]]: The main activities that take place in the video are the +# preparation of camera equipment by a man.... <|__dj__eoc|>' +# } +# + +import json +import os + +import fire +import jsonlines as jl +from loguru import logger +from tqdm import tqdm + +from data_juicer.utils.mm_utils import SpecialTokens +from tools.multimodal.utils import (check_args_load_to_dj_data, + convert_text_to_dj) + + +@logger.catch +def main( + video_chatgpt_ds_path: str, + target_ds_dj_path: str, + eoc_special_token: str = SpecialTokens.eoc, + video_special_token: str = SpecialTokens.video, + add_eoc_at_last: bool = True, + sent_seperator: str = ' ', + video_special_token_insert_pos: str = 'before', + keep_other_fields: bool = True, +): + """ + Convert a Video_Chatgpt-like dataset to the Data-Juicer format. + + :param video_chatgpt_ds_path: path to the input Video_Chatgpt-like dataset. + :param target_ds_dj_path: path to store the converted dataset in + Data-Juicer format. + :param eoc_special_token: the special token for "end of a chunk". It's used + to split sentence chunks explicitly. Default: <|__dj__eoc|> (from + Data-Juicer). + :param video_special_token: the special token for videos. It's used to + locate the videos in the text. In typical Video_Chatgpt-like datasets, + this special token is not specified. So we simply use the default video + special token from our Data-Juicer. Default: <__dj__video> (from + Data-Juicer). + :param add_eoc_at_last: whether to add an extra eoc_special_token at the + end of text. Default: False. + :param sent_seperator: seperator to split different sentences or tokens. + Default: " " + :param video_special_token_insert_pos: the position in the sentence to + insert the corresponding video special token. Should be one of: [ + "before", "after", "random"]. Default: "before". + :param keep_other_fields: whether to keep other fields in the original + datasets. Default: False. + """ + # ----- Constant settings. Better not to change them. ----- + text_key = 'text' # default key of field to store the sample text + video_key = 'videos' # default key of field to store the video list + video_id_key = 'youtube_id' # default original field to store video id + ori_text_key_q = 'q' # default original key of field to store texts + ori_text_key_a = 'a' # default original key of field to store texts + ori_video_key = 'video_id' # default original field to store video ids + + def format_dj_text(text_q, text_a): + """ + This function returns a formatted string. + + :param text_q: Text for the question + :param text_a: Text for the answer + :return: Formatted string + """ + return f'[[{ori_text_key_q}]]:{text_q} \n[[{ori_text_key_a}]]:{text_a}' + + # ----- Constant settings. Better not to change them. ----- + + input_ds_dir = os.path.dirname(video_chatgpt_ds_path) + + # check arguments + check_args_load_to_dj_data(add_eoc_at_last, keep_other_fields, + target_ds_dj_path, video_chatgpt_ds_path, + video_special_token_insert_pos, '.jsonl') + + # start conversion + logger.info(f'Start converting the original Video_Chatgpt dataset ' + f'from {video_chatgpt_ds_path}...') + with open(video_chatgpt_ds_path, 'r') as json_file: + ori_data = json.load(json_file) + with jl.open(target_ds_dj_path, mode='w') as writer: + for s in tqdm(ori_data): + # v_k_ZXmr8pmrs --> k_ZXmr8pmrs + video_id = s.pop(ori_video_key)[2:] + text_q = s.pop(ori_text_key_q) + text_a = s.pop(ori_text_key_a) + + text = format_dj_text(text_q, text_a) + + new_sample, text = convert_text_to_dj( + text, s, add_eoc_at_last, eoc_special_token, keep_other_fields, + sent_seperator, video_special_token, + video_special_token_insert_pos) + + new_sample[video_key] = [os.path.join(input_ds_dir, video_id)] + new_sample[text_key] = text + new_sample[video_id_key] = video_id + + writer.write(new_sample) + logger.info(f'Store the target dataset into [{target_ds_dj_path}].') + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/tools/multimodal/source_format_to_data_juicer_format/youku_to_dj.py b/tools/multimodal/source_format_to_data_juicer_format/youku_to_dj.py new file mode 100644 index 000000000..f347dd090 --- /dev/null +++ b/tools/multimodal/source_format_to_data_juicer_format/youku_to_dj.py @@ -0,0 +1,162 @@ +# This tool is used to convert multimodal dataset in Youku-mPLUG format to a +# target dataset in Data-Juicer format. +# +# Youku-mPLUG format: +# - 4 types: pretrain, classification, retrieval, captioning +# - in csv +# - text-video pair with other fields (label, ...) +# +# Youku-mPLUG-pretrain format: +# {'video_id:FILE': 'videos/pretrain/14111Y1211b-1134b18bAE55bFE7Jbb7135YE3aY54EaB14ba7CbAa1AbACB24527A.flv', # noqa: E501 +# 'title': '妈妈给宝宝听胎心,看看宝宝是怎么做的,太调皮了'} +# +# Youku-mPLUG-classification format: +# {'video_id:FILE': 'videos/classification/14111B1211bFBCBCJYF48B55b7523C51F3-8b3a5YbCa5817aBb38a5YAC-241F71J.mp4', # noqa: E501 +# 'title': '兔宝宝刚出生,为什么兔妈妈要把它们吃掉?看完涨见识了', +# 'label': '宠物'} +# +# Youku-mPLUG-retrieval format: +# {'clip_name:FILE': 'videos/retrieval/14111B1211bA1F31-E-2YB57--C518FCEBC553abJYFa5541a31C8a57522AYJbYF4aTdfofa112.mp4', # noqa: E501 +# 'caption': '身穿黑色上衣戴着头盔的女子在路上骑着摩托车四周还停放了一些车'} +# +# Youku-mPLUG-captioning format: +# {'video_id:FILE': 'videos/caption/14111B1211bEJB-1b3b-J3b7b8Y213BJ32-521a1EA8a53-3aBA72aA-4-2-CF1EJ8aTdfofa114.mp4', # noqa: E501 +# 'golden_caption': '穿白色球服的女生高高跳起,接住了球。']} +# +# Corresponding Data-Juicer format: +# - two new fields are added: +# - text: a chunk of text with the video special token. +# - videos: video paths list +# - other fields in the original format can be kept or not +# - in jsonl +# +# Youku-mPLUG-pretrain Data-Juicer format: +# {'videos': ['videos/pretrain/14111Y1211b-1134b18bAE55bFE7Jbb7135YE3aY54EaB14ba7CbAa1AbACB24527A.flv'], # noqa: E501 +# 'text': '<__dj__video> 妈妈给宝宝听胎心,看看宝宝是怎么做的,太调皮了 <|__dj__eoc|>'} +# +# Youku-mPLUG-classification Data-Juicer format: +# {'videos': ['videos/classification/14111B1211bFBCBCJYF48B55b7523C51F3-8b3a5YbCa5817aBb38a5YAC-241F71J.mp4'], # noqa: E501 +# 'text': '<__dj__video> 兔宝宝刚出生,为什么兔妈妈要把它们吃掉?看完涨见识了 <|__dj__eoc|>', +# 'label': '宠物'} +# +# Youku-mPLUG-retrieval Data-Juicer format: +# {'videos': ['videos/retrieval/14111B1211bA1F31-E-2YB57--C518FCEBC553abJYFa5541a31C8a57522AYJbYF4aTdfofa112.mp4'], # noqa: E501 +# 'text': '<__dj__video> 身穿黑色上衣戴着头盔的女子在路上骑着摩托车四周还停放了一些车 <|__dj__eoc|>'} +# +# Youku-mPLUG-captioning Data-Juicer format: +# {'videos': ['videos/caption/14111B1211bEJB-1b3b-J3b7b8Y213BJ32-521a1EA8a53-3aBA72aA-4-2-CF1EJ8aTdfofa114.mp4'], # noqa: E501 +# 'text': '<__dj__video> 穿白色球服的女生高高跳起,接住了球。 <|__dj__eoc|> <__dj__video> 一排穿红色短袖的女生正在接受颁奖。 <|__dj__eoc|>']} # noqa: E501 +# +# Reference: +# https://modelscope.cn/datasets/modelscope/Youku-AliceMind/summary + +import csv + +import fire +import jsonlines as jl +from loguru import logger +from tqdm import tqdm + +from data_juicer.utils.mm_utils import SpecialTokens +from tools.multimodal.utils import (check_args_load_to_dj_data, + convert_text_to_dj) + + +@logger.catch +def main( + youku_ds_path: str, + target_ds_path: str, + eoc_special_token: str = SpecialTokens.eoc, + video_special_token: str = SpecialTokens.video, + add_eoc_at_last: bool = True, + sent_seperator: str = ' ', + video_special_token_insert_pos: str = 'before', + subset_type: str = 'classification', + keep_other_fields: bool = True, +): + """ + Convert a Youku-mPLUG-like dataset to the Data-Juicer format. + + :param youku_ds_path: path to the input Youku-mPLUG-like dataset. + :param target_ds_path: path to store the converted dataset in Data-Juicer + format. + :param eoc_special_token: the special token for "end of a chunk". It's used + to split sentence chunks explicitly. Default: <|__dj__eoc|> (from + Data-Juicer). + :param video_special_token: the special token for videos. It's used to + locate the videos in the text. In typical Youku-mPLUG-like datasets, + this special token is not specified. So we simply use the default video + special token from our Data-Juicer. Default: <__dj__video> (from + Data-Juicer). + :param add_eoc_at_last: whether to add an extra eoc_special_token at the + end of text. Default: False. + :param sent_seperator: seperator to split different sentences or tokens. + Default: " " + :param video_special_token_insert_pos: the position in the sentence to + insert the corresponding video special token. Should be one of: [ + "before", "after", "random"]. Default: "before". + :param subset_type: the subset type of the input dataset. Should be one of + ["pretrain", "classification", "retrieval", "captioning"]. Default: + "classification". + :param keep_other_fields: whether to keep other fields in the original + datasets. Default: False. + """ + # ----- Constant settings. Better not to change them. ----- + text_key = 'text' # default key of field to store the sample text + video_key = 'videos' # default key of field to store the video list + fields_infos = { + 'pretrain': { + 'video_key': 'video_id:FILE', + 'text_key': 'title', + }, + 'classification': { + 'video_key': 'video_id:FILE', + 'text_key': 'title', + }, + 'retrieval': { + 'video_key': 'clip_name:FILE', + 'text_key': 'caption', + }, + 'captioning': { + 'video_key': 'video_id:FILE', + 'text_key': 'golden_caption', + } + } + # ----- Constant settings. Better not to change them. ----- + + # check arguments + check_args_load_to_dj_data(add_eoc_at_last, keep_other_fields, + target_ds_path, youku_ds_path, + video_special_token_insert_pos, '.jsonl') + # check subset type + if subset_type not in fields_infos: + logger.error(f'Arg subset_type should be one of ["pretrain", ' + f'"classification", "retrieval", "captioning"], but ' + f'given [{subset_type}].') + ori_video_key = fields_infos[subset_type]['video_key'] + ori_text_key = fields_infos[subset_type]['text_key'] + + # load Youku-mPLUG dataset + logger.info('Start converting the original Youku-mPLUG dataset...') + with open(youku_ds_path) as csvfile: + reader = csv.DictReader(csvfile) + with jl.open(target_ds_path, mode='w') as writer: + for row in tqdm(reader): + video = row[ori_video_key] + text = row[ori_text_key] + + # convert text to data-juicer format + # add video special token + new_sample, text = convert_text_to_dj( + text, row, add_eoc_at_last, eoc_special_token, + keep_other_fields, sent_seperator, video_special_token, + video_special_token_insert_pos) + # all sentences correspond to the same video + new_sample[video_key] = [video] + new_sample[text_key] = text + writer.write(new_sample) + logger.info(f'Store the target dataset into [{target_ds_path}].') + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/tools/multimodal/utils.py b/tools/multimodal/utils.py new file mode 100644 index 000000000..595c52702 --- /dev/null +++ b/tools/multimodal/utils.py @@ -0,0 +1,82 @@ +import os +import random +from copy import deepcopy + +from loguru import logger + + +def remove_dj_special_tokens(text, eoc_special_token, sent_seperator, + video_special_token): + # remove possible sentence seperator + if text.startswith(sent_seperator): + text = text[len(sent_seperator):].strip() + if text.endswith(sent_seperator): + text = text[:-len(sent_seperator)].strip() + # remove eoc token + if text.endswith(eoc_special_token): + text = text[:-len(eoc_special_token)].strip() + # remove possible video special token + if text.startswith(video_special_token): + text = text[len(video_special_token):].strip() + if text.startswith(sent_seperator): + text = text[len(sent_seperator):].strip() + elif text.endswith(video_special_token): + text = text[:-len(video_special_token)].strip() + if text.endswith(sent_seperator): + text = text[:-len(sent_seperator)].strip() + return text + + +def check_args_load_to_dj_data(add_eoc_at_last, keep_other_fields, + target_ds_dj_path, video_ds_path, + video_special_token_insert_pos, + target_ds_path_suffix): + if not os.path.exists(video_ds_path): + raise FileNotFoundError(f'Input dataset ' + f'[{video_ds_path}] can not be found.') + if not target_ds_dj_path.endswith(target_ds_path_suffix): + raise ValueError( + f'Only support "{target_ds_path_suffix}" target dataset file now.') + if os.path.dirname(target_ds_dj_path) \ + and not os.path.exists(os.path.dirname(target_ds_dj_path)): + logger.info(f'Create directory [{os.path.dirname(target_ds_dj_path)}] ' + f'for the target dataset.') + os.makedirs(os.path.dirname(target_ds_dj_path)) + # check insert position + if video_special_token_insert_pos not in ['random', 'before', 'after']: + raise ValueError(f'Arg video_special_token_insert_pos should be one ' + f'of ["before", "after", "random"], but given ' + f'[{video_special_token_insert_pos}]') + # check whether to add the eoc special token at last + if not add_eoc_at_last: + logger.warning('You choose not to add special eoc token at the last, ' + 'which might cause some compatibility problems for ' + 'other type of datasets (e.g. OpenFlamingo).') + if not keep_other_fields: + logger.warning('You choose not to keep other fields in the original ' + 'dataset. Thus some information might be lost in the ' + 'processed anc converted-back dataset!') + + +def convert_text_to_dj(text, original_sample, add_eoc_at_last, + eoc_special_token, keep_other_fields, sent_seperator, + video_special_token, video_special_token_insert_pos): + if video_special_token_insert_pos == 'before': + text = video_special_token + sent_seperator + text + elif video_special_token_insert_pos == 'after': + text += sent_seperator + video_special_token + else: + if random.random() < 0.5: + # before + text = video_special_token + sent_seperator \ + + text + else: + # after + text += sent_seperator + video_special_token + if add_eoc_at_last: + text += f'{sent_seperator}{eoc_special_token}' + if keep_other_fields: + new_sample = deepcopy(original_sample) + else: + new_sample = {} + return new_sample, text