From 6951904b4d85081eed06b9d4da43912431b53454 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=81=93=E8=BE=95?= Date: Mon, 6 Nov 2023 20:23:05 +0800 Subject: [PATCH 1/6] added auto-HPO feature with WandB --- data_juicer/config/config.py | 200 ++++++++++++++-------- data_juicer/format/formatter.py | 9 +- tools/hpo/README.md | 51 ++++++ tools/hpo/configs/process.yaml | 19 ++ tools/hpo/configs/quality_score_hpo.yaml | 29 ++++ tools/hpo/demo-redpajama-c4-refined.jsonl | 10 ++ tools/hpo/execute_hpo.py | 46 +++++ tools/hpo/objects.py | 53 ++++++ tools/quality_classifier/__init__.py | 0 tools/quality_classifier/eval.py | 2 +- tools/quality_classifier/predict.py | 25 +-- tools/quality_classifier/qc_utils.py | 37 ++-- tools/quality_classifier/train.py | 2 +- 13 files changed, 382 insertions(+), 101 deletions(-) create mode 100644 tools/hpo/README.md create mode 100644 tools/hpo/configs/process.yaml create mode 100644 tools/hpo/configs/quality_score_hpo.yaml create mode 100644 tools/hpo/demo-redpajama-c4-refined.jsonl create mode 100644 tools/hpo/execute_hpo.py create mode 100644 tools/hpo/objects.py create mode 100644 tools/quality_classifier/__init__.py diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index 5e31c38c1..ed18ac470 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -7,10 +7,10 @@ from jsonargparse import (ActionConfigFile, ArgumentParser, dict_to_namespace, namespace_to_dict) from jsonargparse.typing import NonNegativeInt, PositiveInt -from loguru import logger from data_juicer.ops.base_op import OPERATORS from data_juicer.utils.logger_utils import setup_logger +from loguru import logger def init_configs(args=None): @@ -26,11 +26,16 @@ def init_configs(args=None): """ parser = ArgumentParser(default_env=True, default_config_files=None) + parser.add_argument('--config', + action=ActionConfigFile, + help='Path to a configuration file.', + required=True) + parser.add_argument( - '--config', - action=ActionConfigFile, - help='Path to a configuration file.', - required=True) + '--hpo_config', + type=str, + help='Path to a configuration file when using auto-HPO tool.', + required=False) # basic global paras with extended type hints # e.g., files can be mode include flags @@ -39,159 +44,154 @@ def init_configs(args=None): # "dw": "path to a directory that exists and is writeable") # "dc": "path to a directory that can be created if it does not exist") # "drw": "path to a directory that exists and is readable and writeable") - parser.add_argument( - '--project_name', - type=str, - default='hello_world', - help='Name of your data process project.') + parser.add_argument('--project_name', + type=str, + default='hello_world', + help='Name of your data process project.') parser.add_argument( '--executor_type', type=str, default='default', choices=['default', 'ray'], - help='Type of executor, support "default" or "ray" for now.' - ) + help='Type of executor, support "default" or "ray" for now.') parser.add_argument( '--dataset_path', type=str, help='Path to datasets with optional weights(0.0-1.0), 1.0 as ' - 'default. Accepted format: dataset1-path dataset2-path ' - ' dataset3-path ...') + 'default. Accepted format: dataset1-path dataset2-path ' + ' dataset3-path ...') parser.add_argument( '--export_path', type=str, default='./outputs/hello_world.jsonl', help='Path to export and save the output processed dataset. The ' - 'directory to store the processed dataset will be the work ' - 'directory of this process.') + 'directory to store the processed dataset will be the work ' + 'directory of this process.') parser.add_argument( '--export_shard_size', type=NonNegativeInt, default=0, help='Shard size of exported dataset in Byte. In default, it\'s 0, ' - 'which means export the whole dataset into only one file. If ' - 'it\'s set a positive number, the exported dataset will be split ' - 'into several sub-dataset shards, and the max size of each shard ' - 'won\'t larger than the export_shard_size') + 'which means export the whole dataset into only one file. If ' + 'it\'s set a positive number, the exported dataset will be split ' + 'into several sub-dataset shards, and the max size of each shard ' + 'won\'t larger than the export_shard_size') parser.add_argument( '--export_in_parallel', type=bool, default=False, help='Whether to export the result dataset in parallel to a single ' - 'file, which usually takes less time. It only works when ' - 'export_shard_size is 0, and its default number of processes is ' - 'the same as the argument np. **Notice**: If it\'s True, ' - 'sometimes exporting in parallel might require much more time ' - 'due to the IO blocking, especially for very large datasets. ' - 'When this happens, False is a better choice, although it takes ' - 'more time.') - parser.add_argument( - '--np', - type=PositiveInt, - default=4, - help='Number of processes to process dataset.') + 'file, which usually takes less time. It only works when ' + 'export_shard_size is 0, and its default number of processes is ' + 'the same as the argument np. **Notice**: If it\'s True, ' + 'sometimes exporting in parallel might require much more time ' + 'due to the IO blocking, especially for very large datasets. ' + 'When this happens, False is a better choice, although it takes ' + 'more time.') + parser.add_argument('--np', + type=PositiveInt, + default=4, + help='Number of processes to process dataset.') parser.add_argument( '--text_keys', type=Union[str, List[str]], default='text', help='Key name of field where the sample texts to be processed, e.g., ' - '`text`, `text.instruction`, `text.output`, ... Note: currently, ' - 'we support specify only ONE key for each op, for cases ' - 'requiring multiple keys, users can specify the op multiple ' - 'times. We will only use the first key of `text_keys` when you ' - 'set multiple keys.') + '`text`, `text.instruction`, `text.output`, ... Note: currently, ' + 'we support specify only ONE key for each op, for cases ' + 'requiring multiple keys, users can specify the op multiple ' + 'times. We will only use the first key of `text_keys` when you ' + 'set multiple keys.') parser.add_argument( '--suffixes', type=Union[str, List[str], Tuple[str]], default=[], help='Suffixes of files that will be find and loaded. If not set, we ' - 'will find all suffix files, and select a suitable formatter ' - 'with the most files as default.') + 'will find all suffix files, and select a suitable formatter ' + 'with the most files as default.') parser.add_argument( '--use_cache', type=bool, default=True, help='Whether to use the cache management of huggingface datasets. It ' - 'might take up lots of disk space when using cache') + 'might take up lots of disk space when using cache') parser.add_argument( '--ds_cache_dir', type=str, default=None, help='Cache dir for HuggingFace datasets. In default it\'s the same ' - 'as the environment variable `HF_DATASETS_CACHE`, whose default ' - 'value is usually "~/.cache/huggingface/datasets". If this ' - 'argument is set to a valid path by users, it will override the ' - 'default cache dir.') + 'as the environment variable `HF_DATASETS_CACHE`, whose default ' + 'value is usually "~/.cache/huggingface/datasets". If this ' + 'argument is set to a valid path by users, it will override the ' + 'default cache dir.') parser.add_argument( '--cache_compress', type=str, default=None, help='The compression method of the cache file, which can be' - 'specified in ["gzip", "zstd", "lz4"]. If this parameter is' - 'None, the cache file will not be compressed.') + 'specified in ["gzip", "zstd", "lz4"]. If this parameter is' + 'None, the cache file will not be compressed.') parser.add_argument( '--use_checkpoint', type=bool, default=False, help='Whether to use the checkpoint management to save the latest ' - 'version of dataset to work dir when processing. Rerun the same ' - 'config will reload the checkpoint and skip ops before it. Cache ' - 'will be disabled when it is true . If args of ops before the ' - 'checkpoint are changed, all ops will be rerun from the ' - 'beginning.') + 'version of dataset to work dir when processing. Rerun the same ' + 'config will reload the checkpoint and skip ops before it. Cache ' + 'will be disabled when it is true . If args of ops before the ' + 'checkpoint are changed, all ops will be rerun from the ' + 'beginning.') parser.add_argument( '--temp_dir', type=str, default=None, help='Path to the temp directory to store intermediate caches when ' - 'cache is disabled. In default it\'s None, so the temp dir will ' - 'be specified by system. NOTICE: you should be caution when ' - 'setting this argument because it might cause unexpected program ' - 'behaviors when this path is set to an unsafe directory.') + 'cache is disabled. In default it\'s None, so the temp dir will ' + 'be specified by system. NOTICE: you should be caution when ' + 'setting this argument because it might cause unexpected program ' + 'behaviors when this path is set to an unsafe directory.') parser.add_argument( '--open_tracer', type=bool, default=False, help='Whether to open the tracer to trace samples changed during ' - 'process. It might take more time when opening tracer.') + 'process. It might take more time when opening tracer.') parser.add_argument( '--op_list_to_trace', type=List[str], default=[], help='Which ops will be traced by tracer. If it\'s empty, all ops in ' - 'cfg.process will be traced. Only available when open_tracer is ' - 'true.') + 'cfg.process will be traced. Only available when open_tracer is ' + 'true.') parser.add_argument( '--trace_num', type=int, default=10, help='Number of samples extracted by tracer to show the dataset ' - 'difference before and after a op. Only available when ' - 'open_tracer is true.') + 'difference before and after a op. Only available when ' + 'open_tracer is true.') parser.add_argument( '--op_fusion', type=bool, default=False, help='Whether to fuse operators that share the same intermediate ' - 'variables automatically. Op fusion might reduce the memory ' - 'requirements slightly but speed up the whole process.') + 'variables automatically. Op fusion might reduce the memory ' + 'requirements slightly but speed up the whole process.') parser.add_argument( '--process', type=List[Dict], help='List of several operators with their arguments, these ops will ' - 'be applied to dataset in order') + 'be applied to dataset in order') parser.add_argument( '--save_stats_in_one_file', type=bool, default=False, help='Whether to save all stats to only one file. Only used in ' - 'Analysis.') - parser.add_argument( - '--ray_address', - type=str, - default='auto', - help='The address of the Ray cluster.' - ) + 'Analysis.') + parser.add_argument('--ray_address', + type=str, + default='auto', + help='The address of the Ray cluster.') # add all parameters of the registered ops class to the parser, # and these op parameters can be modified through the command line, @@ -285,7 +285,9 @@ def init_setup_from_cfg(cfg): timestamp = time.strftime('%Y%m%d%H%M%S', time.localtime(time.time())) cfg.timestamp = timestamp logfile_name = timestamp + '.txt' - setup_logger(save_dir=log_dir, filename=logfile_name, redirect=cfg.executor_type=='default') + setup_logger(save_dir=log_dir, + filename=logfile_name, + redirect=cfg.executor_type == 'default') # whether or not to use cache management # disabling the cache or using checkpoint explicitly will turn off the @@ -391,6 +393,7 @@ def sort_op_by_types_and_names(op_name_classes): deduplicator_ops) + sorted(selector_ops) return ops_sorted_by_types + def config_backup(cfg): cfg_path = cfg.config[0].absolute work_dir = cfg.work_dir @@ -399,6 +402,7 @@ def config_backup(cfg): f'work_dir [{work_dir}]') shutil.copyfile(cfg_path, target_path) + def display_config(cfg): from tabulate import tabulate import pprint @@ -416,3 +420,57 @@ def display_config(cfg): logger.info('Configuration table: ') print(table) + + +def merge_config(ori_cfg, new_cfg: Dict): + """ + Merge configuration from new_cfg into ori_cfg + + :param ori_cfg: the original configuration object, whose type is + expected as namespace from jsonargparse + :param new_cfg: the configuration object to be merged, whose type is + expected as dict or namespace from jsonargparse + + :return cfg_after_merge + """ + try: + ori_specified_op_names = set() + ori_specified_op_idx = {} # {op_name: op_order} + # format of ori_cfg.process + # ori_cfg.process[i] = { + # op_in_process_name: + # None if internal_op_para is None else + # namespace_to_dict(internal_op_para) + # } + for op_order, op_in_process in enumerate(ori_cfg.process): + op_name = list(op_in_process.keys())[0] + ori_specified_op_names.add(op_name) + ori_specified_op_idx[op_name] = op_order + + for new_k, new_v in new_cfg.items(): + # merge parameters other than `cfg.process` and DJ-OPs + if new_k in ori_cfg and new_k != 'process' and '.' not in new_k: + ori_cfg[new_k] = new_v + else: + # merge parameters of DJ-OPs into cfg.process + # for nested style, e.g., `remove_table_text_mapper.min_col: 2` + key_as_groups = new_k.split('.') + if len(key_as_groups) > 1 and \ + key_as_groups[0] in ori_specified_op_names: + op_name, para_name = key_as_groups[0], key_as_groups[1] + op_order = ori_specified_op_idx[op_name] + ori_cfg.process[op_order][op_name][para_name] = new_v + + ori_cfg = init_setup_from_cfg(ori_cfg) + + # copy the config file into the work directory + config_backup(ori_cfg) + + # show the final config tables before the process started + print('=' * 10, '\nAfter merging, the new cfg becomes:', '=' * 10) + display_config(ori_cfg) + + return ori_cfg + + except ArgumentError: + logger.error('Config merge failed') diff --git a/data_juicer/format/formatter.py b/data_juicer/format/formatter.py index 54acdc8a0..123f61a45 100644 --- a/data_juicer/format/formatter.py +++ b/data_juicer/format/formatter.py @@ -1,13 +1,12 @@ import os from typing import List, Tuple, Union -from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset -from loguru import logger - from data_juicer.utils.constant import Fields from data_juicer.utils.file_utils import (find_files_with_suffix, is_absolute_path) from data_juicer.utils.registry import Registry +from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset +from loguru import logger FORMATTERS = Registry('Formatters') @@ -262,4 +261,6 @@ def load_formatter(dataset_path, # no data else: - raise NotImplementedError + raise ValueError('Can not found local data or huggingface ' + 'dataset-hub for your given path: ' + f'{dataset_path} and suffixes: {suffixes}') diff --git a/tools/hpo/README.md b/tools/hpo/README.md new file mode 100644 index 000000000..8728c043e --- /dev/null +++ b/tools/hpo/README.md @@ -0,0 +1,51 @@ +# Hyper-parameter Optimization for Data Recipe + +## Auto-HPO + +We incorporate an automated HPO tool, WandB [Sweep](https://docs.wandb.ai/guides/sweeps), into Data-Juicer to streamline the finding of good data processing hyper-parameters. +With this tool, users can investigate correlations and importance scores of +specific hyper-parameters of data recipes from the HPO view. + +*Note*: this is an experimental feature. Auto-HPO for data recipes still has +a large room to explore. Feel free to provide more suggestions, discussion, +and contribution via new PRs! + + +## Prerequisite +You need to install data-juicer first. +Besides, the tool leverages WandB, install it via `pip install wandb`. +Before using this tool, you need to run ` +```wandb login``` and enter your WandB +API key. +If you have your own instance of WandB (e.g., [locally-hosted machine](https://docs.wandb.ai/guides/hosting/)), run the following script: + +```shell +wandb login --host +# enter your api key +``` + + + +## Usage and Customization + +Given a data recipe, characterized by specified configuration file +``, you can use `execute_hpo.py` to search the +hyper-parameter space defined by ``. +```shell +# cd tools/hpo +python execute_hpo.py --config --hpo_config + +# e.g., +python execute_hpo.py --config configs/process.yaml --hpo_config configs/quality_score_hpo.yaml +``` + +We provide an illustrative objective "quality_score" in `hpo/objects.py`, +which uses quality scorer to measure the processed data, and links the average scores to hyper-parameters of data recipes. + +You can implement your own HPO objective in `get_hpo_objective` function, e.g., that links the data +recipes to +- model_loss (by replacing the quality scorer into a training procedure), +- downstream_task (by eplacing the quality scorer into a training and an +evaluation procedure), or +- some synergy measures that combines metrics you are interested, such that + the trade-offs from different views can be explored. diff --git a/tools/hpo/configs/process.yaml b/tools/hpo/configs/process.yaml new file mode 100644 index 000000000..4f42662ef --- /dev/null +++ b/tools/hpo/configs/process.yaml @@ -0,0 +1,19 @@ +# Process config example for dataset + +# global parameters +project_name: 'demo-process-hpo' +dataset_path: 'demo-redpajama-c4-refined.jsonl' +np: 4 # number of subprocess to process your dataset + +export_path: './outputs/demo-hpo-process/demo-hpo-processed.jsonl' + +# process schedule +# a list of several process operators with their arguments +process: + - character_repetition_filter: # filter text with the character repetition ratio out of specific range + rep_len: 10 # repetition length for char-level n-gram + min_ratio: 0.0 # the min ratio of filter range + max_ratio: 0.5 + - text_length_filter: # filter text with length out of specific range + min_len: 10 # the min length of filter range + max_len: 10000 diff --git a/tools/hpo/configs/quality_score_hpo.yaml b/tools/hpo/configs/quality_score_hpo.yaml new file mode 100644 index 000000000..0f5ef41a8 --- /dev/null +++ b/tools/hpo/configs/quality_score_hpo.yaml @@ -0,0 +1,29 @@ + +sweep_name: hpo_for_data-juicer +# sweep_count: 10 + +# hpo configuration from original sweep, see more options and details in +# https://docs.wandb.ai/guides/sweeps/define-sweep-configuration + +method: bayes # ["random", "grid", "bayes"] + +metric: + name: quality_score # defined in hpo/objects.py + goal: maximize # ["maximize", "minimize", +"target"] + +parameters: +# can be [single value, multiple values, probabilities, distribution, nested] + character_repetition_filter.rep_len: + values: [2, 4, 8, 16] + character_repetition_filter.max_ratio: + values: [0.3, 0.5, 0.7] + text_length_filter.min_len: + distribution: q_log_uniform_values + min: 8 + max: 512 + + +#early_terminate: +# type: hyperband +# max_iter: 27 +# s: 2 diff --git a/tools/hpo/demo-redpajama-c4-refined.jsonl b/tools/hpo/demo-redpajama-c4-refined.jsonl new file mode 100644 index 000000000..03893f4a1 --- /dev/null +++ b/tools/hpo/demo-redpajama-c4-refined.jsonl @@ -0,0 +1,10 @@ +{"text":"Athlone Community Radio will be broadcasting live for 4 hours from Golden Island Shopping Centre in Athlone, the heart of Ireland, to Celebrate World Radio Day 2019. Our live broadcast will begin at 12pm and finish at 4pm. The Theme of World Radio Day is \"Dialogue, Tolerance and Peace\".\nThis hour will feature local Musicians from Athlone. We will discuss how music plays a key role in promoting Tolerance , Dialogue and Peace. This programme will be presented by presenter\/Sound Engineer with the station and also a musician Cathal McCormack.\nAthlone Today is our daily current affairs programme. We will be delving into the World Radio Theme 2019 of tolerance, dialogue and peace in our community. We will feature an interview with renowned expert in Media Literacy Irena Djak Cvetkovic on access to media, creation of Media and the critical evaluation of media and how media can play a part in this theme. With over 20 years of experience in training, Deridre Hunt shares her unique perspective into the world of training and radio. Nikki Dube who presented radio in South Africa to over 1 million listeners gives us an insight into the importance of Radio as tool for Community Development. Presenter and Producer of the Arts Programme,Philomena Barry, chats about Radio in Gambia. Writer Aidan Shortall has written a special poem entitled \"Dialogue, Tolerance and Peace\" to mark the theme."} +{"text":"\"Unemployment, Relative Price Dispersion and the Implicit Contract Model.\"\n\"The Federal Reserve's Preferences for Inflation and Unemployment: An Analysis of Fed Chairmen.\"\n\"Do Presidential Administrations Have Preferred Rates of Unemployment?\"\n\"The Baseball Writer and The National Baseball Hall of Fame: Are Election Outcomes Affected by Race or Ethnicity?,\" (with Cliff Reid).\n\"Race, Ethnicity and The National Baseball Hall of Fame: Is There Discrimination in the Nomination Process?,\" (with Cliff Reid and John Santos).\n\"Discrimination, Voting Behavior and the National Baseball Hall of Fame: The Case of Pitchers,\" (with Cliff Reid).\n\"A Comparison of Two Models to Forecast Voting for Membership into the National Baseball Hall of Fame,\" (with Cliff Reid).\n\"Jackie Robinson and The National Baseball Hall of Fame,\" (with Cliff Reid).\n\"Race, Ethnicity and the Market(s) for Baseball Cards,\" (with Jim Meehan and Cliff Reid).\n\"Further Analysis of the Determinants of Sovereign Credit Ratings,\" (with Shannon Landauer, '99).\n\"An Empirical Analysis of the Relationship between the Yield Spread and the Probability of Recessions,\" (with Sunil Thakor, '99)."} +{"text":"Volunteers are needed to help with the Assembly for Children at the 99th International Assembly in July. Below are the dates, times, and volunteers needed for each session. If you will be attending the Assembly and are interested in helping with the Assembly for Children, register online by clicking here. A background check and pastoral referral must be completed and can be found on the website.\n\"Hands-on\" laborers needed for take down of Assembly for Children props and décor."} +{"text":"It is possible to erase up to 4 HDD \/ SSD at the same time. IDE HDD connection is also possible with dedicated adapter.\nHDD \/ SSD is exchanged from the order in which erasing is completed, and asynchronous erase function which can erase newly is installed.Five types of erase algorithms are installed. The erasing method can be selected according to the application.\nHDD copy function installed. It is possible to copy one HDD data to up to three HDDs simultaneously.\nData delete contents can be printed with the attached dedicated printer . It is also possible to output text data of work log to USB memory.\nDedicated carrying case with waterproof \/ dustproof specification is included, which can contain the main body and all accessories.\nThis product is compatible with SATA 6 Gbps HDD \/ SSD, but the internal transfer speed is up to 130 MB \/ sec.\n2.5 \"HDD and 3.5\" HDD can not be connected at the same time."} +{"text":"The High brightness is being compared to 350 Watts conventional xenon light source that can be utilized for critical illness like laparoscopy without any problem and with utmost ease.\nIt has an extended service life, which means it can easily last for 60000 hours at a max, which is 120 times of xenon. This further emphasises that one doest not need to change the bulb for a number of years and can be used for a long time.\nWe have an extensive range of power from 100 to 240 V \/ AC; 50 \/ 60 Hz.\nThe perfect colour temperature ranges from 5000 K to 6500 K, and the colour rendering is over 70.\nIt does not gives out light in the Ultra Violet or Infrared Rays.\nThe customized connector, it accepts fibre light directs with dynamic and on the go areas varying from 3 mm to 10 mm in diameter or span.\nIt helps saves half the energy i.e. it will be more than 50%.\nIt is considered environment friendly."} +{"text":"The Arlington County Board plans to vote Saturday afternoon on giving Amazon $23 million and other incentives to build a headquarters campus in Crystal City, but only after hearing scores of northern Virginia residents and advocates testify for or against the project.\nThe five-member board is expected to support the plan, which was announced amid much hoopla on Nov. 13. The proposed county incentives are part of an agreement in which Amazon would occupy significant office space and bring at least 25,000 high-paying jobs to Arlington in coming years.\nOpponents hope to postpone the vote until after additional public hearings, where they want representatives of the online retail giant to answer questions directly from anyone in the community.\nThe Saturday hearing was scheduled to begin no earlier than 1 p.m. and last several hours before the vote. Ninety-one people signed up in advance to speak on the topic.\nIn the four months since Arlington won a much-publicized, nationwide contest to attract the facility known as HQ2, Arlington residents have been asking questions about its impact on their community.\nPeople have looked at the county's five online Q&A sessions 14,000 times, and about 400 attended community events to discuss the provisions in the Amazon agreement. Board members and county staff also met with scores of civic organizations, served on multiple panels and appeared on television, online and in news articles to discuss the deal.\nMost Arlingtonians, northern Virginians and residents of the Washington region support Amazon's arrival, several surveys have found. Business organizations, universities and nonprofit groups came out strongly for the deal.\nBut a small, vocal group of activists has sought to block the project, saying that the county and commonwealth should not give any incentives to one of the world's most valuable companies. They also have demanded housing and job protections for existing residents.\nThese opponents - including left-wing organizations and immigrants groups - felt empowered after Amazon canceled plans last month to build a headquarters facility in New York City, also with 25,000 jobs. The company withdrew after criticism of the plan from some elected leaders, unions and community activists.\nIn Virginia, however, such opposition did not appear to catch fire among the broader public.\nOfficials estimate that the Amazon project's net fiscal impact on Arlington could be worth additional revenue of $162 million over 12 years and $392.5 million over 16 years.\nThe incentives agreement promises the world's largest online retailer cash grants estimated at about $23 million if it occupies 6.05 million square feet of office space in Crystal City and Pentagon City through 2035.\nThe money would come from an expected increase in the hotel, motel and lodging tax paid by visitors; Amazon would get up to 15 percent of that increase, pegged to how much floor space is in active use by the company each year from 2020 to 2035.\nAmazon's offices will be located within an already-established special tax district where a portion of the property tax revenue goes toward infrastructure improvements such as parks and wider sidewalks.\nThe incentive agreement says that half of any new revenue from that district starting in 2021 will go specifically toward improvements around the Amazon buildings for the following 10 years. That grant is worth an estimated $28 million but the county says it's not a grant just for Amazon, because the improvements will benefit other companies in the immediate area. Amazon will have a chance to express its opinion on how the county uses the money, although the board will make the decision.\nThe county also offered Amazon the possibility of using its fast, fiber-optic network connection, which would be the subject of a separate agreement if the company chooses to use it.\nIt's not yet clear whether Amazon will pay the local business license tax because that tax is levied only on certain types of business, and Amazon has not yet announced which of its business units will be based in Arlington. If the company does pay the license tax, then some of its operations could be eligible for a discount of up to 72 percent under an existing program designed to attract technology companies.\nWhile Arlington pored over the details, the Virginia General Assembly passed, and Democratic Gov. Ralph Northam signed, an incentives package worth up to $750 million for Amazon."} +{"text":"It's kind of hard to tell from the picture, but Miss Fancy Nancy is sitting in a chair that is upholstered with children's b ook covers (well, not literally I hope, but with fabric that has children's book covers on it). From where I sit I can see The Giving Tree, The Story of Ferdinand, The Hobbit, and one of the Narnia books. It is much better if you go look at the real cover of Fancy Nancy and the Dazzling Book Report.\nSo, when I live in my house in the English countryside, and I am sitting in my personal library, this is what I am going to be sitting on."} +{"text":"One of the biggest games in the world is now available on smartphones around the world thanks to Tencent. The Chinese company's official port of PlayerUnknown's Battlegrounds (PUBG) is now available on iOS and Android worldwide.\nWe tried PUBG on iPhone X when it released in China last month and came away very impressed. It successfully squeezed much of the PC and Xbox game down to a much smaller screen -- and in places looked even better than the Xbox One version of the game.\nIt comes a week after rival Fortnite began inviting players to test its own mobile version. Unlike PUBG, Fortnite will allow players to play across platforms -- meaning iPhone players can kill or be killed by gamers on PC or PlayStation 4.\nFortnite recently surpassed pioneer PUBG in popularity. Both games feature 100 players landing on a deserted island filled with weapons, with everyone fighting to be the last person left standing. While PUBG adopts a realistic setting, Fortnite has a more cartoon-like look -- and an emphasis on building structures sets it apart from its rival.\nBut Tencent may find its main competition for PUBG on mobile could come from other Chinese companies. NetEase was quick to copy PUBG and released similar games for smartphone while the original was still on PC. NetEase says just one of its clones, Knives Out, has over 100 million registered players -- far more than the 40 million playing PUBG on PC and Xbox."} +{"text":"How many backlinks per day for new site?\nDiscussion in 'Black Hat SEO' started by Omoplata, Dec 3, 2010.\n1) for a newly created site, what's the max # backlinks per day I should do to be safe?\n2) how long do I have to let my site age before I can start making more blinks?\nI did about 6000 forum profiles every 24 hours for 10 days for one of my sites which had a brand new domain.\nThere is three backlinks for every of these forum profile so thats 18 000 backlinks every 24 hours and nothing happened in terms of being penalized or sandboxed. This is now maybe 3 months ago and the site is ranking on first page for a lot of my targeted keywords.\nbuild more you can in starting but do manual submission and not spammy type means manual + relevant to the post.. then after 1 month you can make a big blast..\nWow, dude, you built 18k backlinks a day on a brand new site? How quickly did you rank up? What kind of competition\/searches did those keywords have?"} +{"text":"Disturbia Clothing coupon code gift !\nDisturbia Clothing coupon code gift ! ! !\nDisturbia Clothing is always bringing something different and out of the mainstream, focusing on the dark side of popular culture, subversive iconography, childhood nostalgia, and angry slogans, all made with a strong sense of independence, and a quintessential British punk D.I.Y. ethic.\nIf you are a fan on Facebook, you've already noticed they are running a 20% discount, with the coupon code FACEBOOK20 until Monday the 18th.\nDisturbia Clothing coupon code! Don't forget to use it until Monday the 18th ! ! !\nJust log into your Facebook account, get your coupon and stay tune for our next special discounts ! Keep your eyes on the prize and, of course, on our Facebook page.\nInspiring graphics of the day - intense emotions graphics !"} diff --git a/tools/hpo/execute_hpo.py b/tools/hpo/execute_hpo.py new file mode 100644 index 000000000..66658564e --- /dev/null +++ b/tools/hpo/execute_hpo.py @@ -0,0 +1,46 @@ +import sys + +import yaml +from jsonargparse import namespace_to_dict + +import wandb +from data_juicer.config import init_configs, merge_config +from objects import get_hpo_objective + +# 1: load the defined search space +sweep_cfg_file_path = None +for i in range(len(sys.argv) - 1): + if sys.argv[i] == '--hpo_config': + sweep_cfg_file_path = sys.argv[i + 1] + break +if not sweep_cfg_file_path: + raise ValueError('Not found --hpo_config, you should specify your ' + 'hpo cfg file path following `--hpo_config`') +with open(sweep_cfg_file_path) as f: + sweep_configuration = yaml.safe_load(f) + + +def search(): + wandb.init(project=sweep_configuration['sweep_name']) + + # 2.1: Choose objective that links the hyper-parameters you want to search + object_func = get_hpo_objective(sweep_configuration['metric']['name']) + + dj_cfg = init_configs() + # merge the new hyper-parameters selected by HPO scheduler + dj_cfg = merge_config(dj_cfg, wandb.config) + wandb.config = namespace_to_dict(dj_cfg) # for configuration track + + # 2.2: calculate objective using new hyper-parameters, track the results + score = object_func(dj_cfg) + wandb.log({sweep_configuration['metric']['name']: score}) + + +# 3: Start the sweep, iteratively search hyper-parameters +sweep_id = wandb.sweep(sweep=sweep_configuration, + project=sweep_configuration['sweep_name']) + +wandb.agent(sweep_id, + function=search, + count=sweep_configuration['sweep_count'] + if 'sweep_count' in sweep_configuration else None) diff --git a/tools/hpo/objects.py b/tools/hpo/objects.py new file mode 100644 index 000000000..f6749bc58 --- /dev/null +++ b/tools/hpo/objects.py @@ -0,0 +1,53 @@ +from data_juicer.core import Executor +from tools.quality_classifier.predict import predict_score + + +def get_hpo_objective(obj_name): + if obj_name == 'quality_score': + return obj_quality_score + # elif obj_name == "model_loss": + # return obj_model_loss + # elif obj_name == "downstream_task": + # return obj_downstream_task + # elif obj_name == "synergy_metric": + # return obj_synergy_metric + else: + raise NotImplementedError( + f'unsupported objective type in HPO: {obj_name}. ' + f'Please implement it first.') + + +def obj_quality_score(dj_cfg): + """ + HPO loop: cfg --> data --> data score --> cfg --> data --> ... + + :param dj_cfg: specified data recipe (as a search point) + :return: a data score, after + 1. processing data according to the dj_cfg; + 2. applying a quality classifier + """ + + if dj_cfg.executor_type == 'default': + executor = Executor(dj_cfg) + elif dj_cfg.executor_type == 'ray': + from data_juicer.core.ray_executor import RayExecutor + executor = RayExecutor(dj_cfg) + else: + raise NotImplementedError( + f'unsupported executor_type: {dj_cfg.executor_type}, ' + f'expected in [`default`, `ray`]', ) + executor.run() + + # calculate and aggregate data score + + # feel free to customize the quality scorer, via the following args + # [--model ] \ + # [--tokenizer ] \ + # [--keep_method ] \ + # [--text_key ] \ + overall_quality_stats = predict_score(dj_cfg.export_path, + dj_cfg.export_path, + overall_stats=True) + + # by default, using the mean quality score of processed data as final score + return overall_quality_stats.loc['mean'] diff --git a/tools/quality_classifier/__init__.py b/tools/quality_classifier/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tools/quality_classifier/eval.py b/tools/quality_classifier/eval.py index 80e92ff67..47da8ef03 100644 --- a/tools/quality_classifier/eval.py +++ b/tools/quality_classifier/eval.py @@ -24,7 +24,7 @@ import fire from loguru import logger -from qc_utils import eval, init_spark, load_datasets +from .qc_utils import eval, init_spark, load_datasets @logger.catch diff --git a/tools/quality_classifier/predict.py b/tools/quality_classifier/predict.py index 80fda65fb..71a258ce7 100644 --- a/tools/quality_classifier/predict.py +++ b/tools/quality_classifier/predict.py @@ -60,18 +60,18 @@ import fire from loguru import logger -from qc_utils import (export_result, init_spark, load_dataset, predict, - prepare_model) +from .qc_utils import (export_result, init_spark, load_dataset, predict, + prepare_model) @logger.catch -def main(dataset_path, - result_path, - model='gpt3', - tokenizer=None, - keep_method='gpt3', - text_key='text', - overall_stats=False): +def predict_score(dataset_path, + result_path, + model='gpt3', + tokenizer=None, + keep_method='gpt3', + text_key='text', + overall_stats=False): """ Use specific quality classifier to predict document scores on your dataset :param dataset_path: the path to the dataset you want to predict for @@ -90,6 +90,7 @@ def main(dataset_path, :param overall_stats: whether to output an overall stats report on predicted document scores. It's False in default :return: + average quality score of the document """ # set default tokenizers for default models if model == 'chinese': @@ -114,12 +115,14 @@ def main(dataset_path, export_result(pred, result_path) # generate overall statistics on doc scores + overall = pred.select('doc_score').toPandas().describe(include='all') if overall_stats: - overall = pred.select('doc_score').toPandas().describe(include='all') # export to result report file overall.to_csv(os.path.join(result_path, 'overall.csv')) overall.to_markdown(os.path.join(result_path, 'overall.md')) + return overall + if __name__ == '__main__': - fire.Fire(main) + fire.Fire(predict_score) diff --git a/tools/quality_classifier/qc_utils.py b/tools/quality_classifier/qc_utils.py index af6b776cd..0ab2c42ed 100644 --- a/tools/quality_classifier/qc_utils.py +++ b/tools/quality_classifier/qc_utils.py @@ -2,8 +2,12 @@ import zipfile import numpy as np + import sentencepiece as spm import wget +from data_juicer.utils.cache_utils import DATA_JUICER_MODELS_CACHE +from data_juicer.utils.model_utils import (MODEL_LINKS, + prepare_sentencepiece_model) from loguru import logger from pyspark.ml import Pipeline, PipelineModel from pyspark.ml.classification import LogisticRegression @@ -12,25 +16,32 @@ from pyspark.sql.functions import col, rand, udf from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StringType -from data_juicer.utils.cache_utils import DATA_JUICER_MODELS_CACHE -from data_juicer.utils.model_utils import (MODEL_LINKS, - prepare_sentencepiece_model) - -def init_spark(): +def init_spark(spark_executor_memory=None, + spark_driver_memory=None, + spark_executor_memoryOverhead=None): """ Initialize a spark session. You can set parameters such as memory, number of partitions, timeout and so on here :return: A spark session instance. """ - spark = (SparkSession.builder.config('spark.driver.memory', '64g').config( - 'spark.executor.memory', - '64g').config('spark.sql.shuffle.partitions', '300').config( - 'spark.sql.execution.arrow.pyspark.enabled', - 'true').config('spark.executor.memoryOverhead', '20000').config( - 'spark.network.timeout', - '10000s').config('spark.executor.heartbeatInterval', - '3600s').getOrCreate()) + if not spark_executor_memory: + spark_executor_memory = '64g' + if not spark_driver_memory: + spark_driver_memory = '64g' + if not spark_executor_memoryOverhead: + spark_executor_memoryOverhead = '20000' + spark = (SparkSession.builder.config( + 'spark.driver.memory', spark_driver_memory).config( + 'spark.executor.memory', spark_executor_memory).config( + 'spark.sql.shuffle.partitions', '300').config( + 'spark.sql.execution.arrow.pyspark.enabled', + 'true').config('spark.executor.memoryOverhead', + spark_executor_memoryOverhead).config( + 'spark.network.timeout', + '10000s').config( + 'spark.executor.heartbeatInterval', + '3600s').getOrCreate()) logger.info('Spark initialization done.') return spark diff --git a/tools/quality_classifier/train.py b/tools/quality_classifier/train.py index 3774a9539..dab2b2759 100644 --- a/tools/quality_classifier/train.py +++ b/tools/quality_classifier/train.py @@ -30,7 +30,7 @@ import fire from loguru import logger -from qc_utils import eval, init_spark, load_datasets, shuffle, train +from .qc_utils import eval, init_spark, load_datasets, shuffle, train @logger.catch From 4101678bffd405e0ac9bf8847758cb231e63b8c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=81=93=E8=BE=95?= Date: Wed, 8 Nov 2023 12:02:40 +0800 Subject: [PATCH 2/6] added auto-HPO feature with WandB. - The core modifications are in tools/hpo, and data_juicer/config.py. - The others are from pre-commit run --- .github/ISSUE_TEMPLATE/bug_report.yml | 2 +- .github/ISSUE_TEMPLATE/custom.md | 2 - .github/ISSUE_TEMPLATE/feature_request.yml | 2 +- .github/ISSUE_TEMPLATE/question.yml | 2 +- LICENSE | 2 - README.md | 26 +- README_ZH.md | 6 +- app.py | 2 +- configs/README.md | 2 +- configs/README_ZH.md | 2 +- .../data_juicer_recipes/alpaca_cot/README.md | 6 +- .../alpaca_cot/README_ZH.md | 2 +- .../alpaca_cot/alpaca-cot-en-refine.yaml | 12 +- .../alpaca_cot/alpaca-cot-zh-refine.yaml | 4 +- .../redpajama-c4-refine.yaml | 2 +- data_juicer/analysis/overall_analysis.py | 2 + data_juicer/config/config.py | 20 +- data_juicer/core/data.py | 17 +- data_juicer/core/executor.py | 12 +- data_juicer/core/exporter.py | 3 +- data_juicer/core/ray_executor.py | 18 +- data_juicer/format/formatter.py | 5 +- data_juicer/ops/common/helper_func.py | 11 +- data_juicer/ops/filter/alphanumeric_filter.py | 2 +- .../ops/filter/average_line_length_filter.py | 2 +- .../ops/filter/character_repetition_filter.py | 3 +- .../ops/filter/flagged_words_filter.py | 11 +- .../ops/filter/language_id_score_filter.py | 6 +- .../ops/filter/maximum_line_length_filter.py | 2 +- data_juicer/ops/filter/perplexity_filter.py | 9 +- data_juicer/ops/filter/stopwords_filter.py | 11 +- data_juicer/ops/filter/token_num_filter.py | 5 +- data_juicer/ops/filter/word_num_filter.py | 11 +- .../ops/filter/word_repetition_filter.py | 18 +- data_juicer/ops/load.py | 1 + data_juicer/ops/mapper/nlpaug_en_mapper.py | 8 +- data_juicer/ops/mapper/nlpcda_zh_mapper.py | 6 +- ..._words_with_incorrect_substrings_mapper.py | 6 +- .../ops/mapper/sentence_split_mapper.py | 6 +- .../mapper/whitespace_normalization_mapper.py | 7 +- data_juicer/ops/op_fusion.py | 21 +- data_juicer/tools/__init__.py | 8 +- data_juicer/utils/compress.py | 15 +- data_juicer/utils/constant.py | 1 + data_juicer/utils/logger_utils.py | 8 +- data_juicer/utils/model_utils.py | 3 +- demos/README.md | 2 +- demos/README_ZH.md | 2 +- demos/auto_evaluation_helm/README_ZH.md | 2 +- demos/auto_evaluation_helm/app.py | 8 +- demos/data_process_hpo/app.py | 11 +- demos/data_process_loop/app.py | 4 +- .../data_process_loop/data/demo-dataset.jsonl | 2 +- .../data/demo-dataset.jsonl | 2 +- demos/data_visualization_op_effect/app.py | 4 +- demos/data_visualization_statistics/app.py | 2 +- demos/overview_scan/app.py | 9 +- demos/overview_scan/data/demo-dataset.jsonl | 2 +- .../process_cft_zh_data/data/alpaca-cot.jsonl | 2 +- demos/process_on_ray/data/demo-dataset.json | 2 +- .../data/demo-dataset.jsonl | 2 +- .../data/demo-dataset.jsonl | 2 +- .../quality_classifier/qc_utils.py | 1 + setup.py | 4 +- tests/format/test_unify_format.py | 3 +- .../test_character_repetition_filter.py | 4 +- tests/ops/filter/test_token_num_filter.py | 76 +- tests/ops/mapper/test_nlpaug_en_mapper.py | 5 +- tests/ops/mapper/test_nlpcda_zh_mapper.py | 13 +- tests/ops/test_op_fusion.py | 710 ++++++++++-------- .../converter/convert_gpt_to_transformers.py | 345 +++++---- tools/converter/modeling_megatron_llama.py | 469 +++++++----- tools/evaluator/config/evaluator_example.yaml | 4 +- .../evaluator/config/helm_spec_template.conf | 2 +- tools/evaluator/evaluator.py | 151 ++-- tools/evaluator/gpt_eval/README.md | 2 +- tools/evaluator/gpt_eval/answer_generator.py | 103 +-- tools/evaluator/gpt_eval/config/config.yaml | 2 +- .../evaluator/gpt_eval/config/reviewer.jsonl | 2 +- tools/evaluator/gpt_eval/gpt_evaluator.py | 132 ++-- tools/evaluator/recorder/README.md | 2 +- .../recorder/config/leaderboard_example.yaml | 2 +- .../recorder/config/llama_example.yaml | 2 +- .../recorder/config/mymodel_example.yaml | 2 +- tools/evaluator/recorder/wandb_writer.py | 202 ++--- tools/hpo/README.md | 59 +- tools/hpo/README_ZH.md | 49 ++ tools/hpo/configs/quality_score_hpo.yaml | 4 +- tools/hpo/execute_hpo.py | 4 +- tools/hpo/objects.py | 13 +- tools/postprocess/README.md | 4 +- tools/postprocess/README_ZH.md | 4 +- tools/postprocess/count_token.py | 15 +- tools/postprocess/deserialize_meta.py | 11 +- tools/preprocess/README.md | 2 +- tools/preprocess/README_ZH.md | 2 +- .../raw_alpaca_cot_merge_add_meta.py | 61 +- tools/quality_classifier/predict.py | 11 +- tools/quality_classifier/qc_utils.py | 10 +- 99 files changed, 1694 insertions(+), 1183 deletions(-) create mode 100644 tools/hpo/README_ZH.md diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index ddab3f93a..ad7c70fd9 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -113,4 +113,4 @@ body: - type: textarea attributes: label: Additional 额外信息 - description: Anything else you would like to share? 其他您想分享的信息。 \ No newline at end of file + description: Anything else you would like to share? 其他您想分享的信息。 diff --git a/.github/ISSUE_TEMPLATE/custom.md b/.github/ISSUE_TEMPLATE/custom.md index 48d5f81fa..b894315f4 100644 --- a/.github/ISSUE_TEMPLATE/custom.md +++ b/.github/ISSUE_TEMPLATE/custom.md @@ -6,5 +6,3 @@ labels: '' assignees: '' --- - - diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index b1353af0f..b3ec179bb 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -49,4 +49,4 @@ body: (Optional) We encourage you to submit a [Pull Request](https://github.com/alibaba/data-juicer/pulls) (PR) to help improve Data-Juicer for everyone, especially if you have a good understanding of how to implement a fix or feature. (可选项)我们鼓励您提交一个 [Pull Request (PR)]([Pull Request](https://github.com/alibaba/data-juicer/pulls)) 来为开源社区提升 Data-Juicer 的能力,尤其是如果您对如何实现或者修复一个功能有比较不错的理解的时候~ options: - - label: Yes I'd like to help by submitting a PR! 是的!我愿意提供帮助并提交一个PR! \ No newline at end of file + - label: Yes I'd like to help by submitting a PR! 是的!我愿意提供帮助并提交一个PR! diff --git a/.github/ISSUE_TEMPLATE/question.yml b/.github/ISSUE_TEMPLATE/question.yml index 50a8bdb7c..c99bccfc8 100644 --- a/.github/ISSUE_TEMPLATE/question.yml +++ b/.github/ISSUE_TEMPLATE/question.yml @@ -48,4 +48,4 @@ body: - type: textarea attributes: label: Additional 额外信息 - description: Anything else you would like to share? 其他您想分享的信息。 \ No newline at end of file + description: Anything else you would like to share? 其他您想分享的信息。 diff --git a/LICENSE b/LICENSE index f2c4ebb9e..bc0945c47 100644 --- a/LICENSE +++ b/LICENSE @@ -417,5 +417,3 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - - diff --git a/README.md b/README.md index 5398be759..43eab4c7f 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -English | [**中文**](README_ZH.md) +English | [**中文**](README_ZH.md) # Data-Juicer: A One-Stop Data Processing System for Large Language Models @@ -26,7 +26,7 @@ English | [**中文**](README_ZH.md) [![QualityClassifier](https://img.shields.io/badge/Tools-Quality_Classifier-saddlebrown?logo=Markdown)](tools/quality_classifier/README.md) [![AutoEvaluation](https://img.shields.io/badge/Tools-Auto_Evaluation-saddlebrown?logo=Markdown)](tools/evaluator/README.md) -Data-Juicer is a one-stop data processing system to make data higher-quality, +Data-Juicer is a one-stop data processing system to make data higher-quality, juicier, and more digestible for LLMs. This project is being actively updated and maintained, and we will periodically enhance and add more features and data recipes. We welcome you to join us in promoting LLM data development and research! @@ -68,22 +68,22 @@ Table of Contents ![Overview](https://img.alicdn.com/imgextra/i2/O1CN01IMPeD11xYRUYLmXKO_!!6000000006455-2-tps-3620-1604.png) -- **Systematic & Reusable**: - Empowering users with a systematic library of 20+ reusable [config recipes](configs), 50+ core [OPs](docs/Operators.md), and feature-rich - dedicated [toolkits](#documentation), designed to +- **Systematic & Reusable**: + Empowering users with a systematic library of 20+ reusable [config recipes](configs), 50+ core [OPs](docs/Operators.md), and feature-rich + dedicated [toolkits](#documentation), designed to function independently of specific LLM datasets and processing pipelines. -- **Data-in-the-loop**: Allowing detailed data analyses with an automated +- **Data-in-the-loop**: Allowing detailed data analyses with an automated report generation feature for a deeper understanding of your dataset. Coupled with multi-dimension automatic evaluation capabilities, it supports a timely feedback loop at multiple stages in the LLM development process. ![Data-in-the-loop](https://img.alicdn.com/imgextra/i1/O1CN011E99C01ndLZ55iCUS_!!6000000005112-0-tps-2701-1050.jpg) -- **Comprehensive Data Processing Recipes**: Offering tens of [pre-built data - processing recipes](configs/data_juicer_recipes/README.md) for pre-training, fine-tuning, en, zh, and more scenarios. Validated on - reference LLaMA models. +- **Comprehensive Data Processing Recipes**: Offering tens of [pre-built data + processing recipes](configs/data_juicer_recipes/README.md) for pre-training, fine-tuning, en, zh, and more scenarios. Validated on + reference LLaMA models. ![exp_llama](https://img.alicdn.com/imgextra/i2/O1CN019WtUPP1uhebnDlPR8_!!6000000006069-2-tps-2530-1005.png) -- **Enhanced Efficiency**: Providing a speedy data processing pipeline - requiring less memory and CPU usage, optimized for maximum productivity. +- **Enhanced Efficiency**: Providing a speedy data processing pipeline + requiring less memory and CPU usage, optimized for maximum productivity. ![sys-perf](https://img.alicdn.com/imgextra/i4/O1CN01Sk0q2U1hdRxbnQXFg_!!6000000004300-0-tps-2438-709.jpg) @@ -137,13 +137,13 @@ pip install py-data-juicer ### Using Docker -- You can +- You can - either pull our pre-built image from DockerHub: ```shell docker pull datajuicer/data-juicer: ``` - - or run the following command to build the docker image including the + - or run the following command to build the docker image including the latest `data-juicer` with provided [Dockerfile](Dockerfile): ```shell diff --git a/README_ZH.md b/README_ZH.md index 48fe9dfcc..437c70bcc 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -76,7 +76,7 @@ Data-Juicer 是一个一站式数据处理系统,旨在为大语言模型 (LLM * **效率增强**:提供高效的数据处理流水线,减少内存占用和CPU开销,提高生产力。 ![sys-perf](https://img.alicdn.com/imgextra/i4/O1CN01Sk0q2U1hdRxbnQXFg_!!6000000004300-0-tps-2438-709.jpg) * **用户友好**:设计简单易用,提供全面的[文档](#documentation)、简易[入门指南](#快速上手)和[演示配置](configs/README_ZH.md),并且可以轻松地添加/删除[现有配置](configs/config_all.yaml)中的算子。 - + * **灵活 & 易扩展**:支持大多数数据格式(如jsonl、parquet、csv等),并允许灵活组合算子。支持[自定义算子](docs/DeveloperGuide_ZH.md#构建自己的算子),以执行定制化的数据处理。 @@ -301,7 +301,7 @@ Data-Juicer 在 Apache License 2.0 协议下发布。 我们非常欢迎贡献新功能、修复漏洞以及讨论。请参考[开发者指南](docs/DeveloperGuide_ZH.md)。 -欢迎加入我们的[Slack channel](https://join.slack.com/t/data-juicer/shared_invite/zt-23zxltg9d-Z4d3EJuhZbCLGwtnLWWUDg?spm=a2c22.12281976.0.0.7a8275bc8g7ypp), 或[DingDing群](https://qr.dingtalk.com/action/joingroup?spm=a2c22.12281976.0.0.7a8275bc8g7ypp&code=v1,k1,C0DI7CwRFrg7gJP5aMC95FUmsNuwuKJboT62BqP5DAk=&_dt_no_comment=1&origin=11) 。 +欢迎加入我们的[Slack channel](https://join.slack.com/t/data-juicer/shared_invite/zt-23zxltg9d-Z4d3EJuhZbCLGwtnLWWUDg?spm=a2c22.12281976.0.0.7a8275bc8g7ypp), 或[DingDing群](https://qr.dingtalk.com/action/joingroup?spm=a2c22.12281976.0.0.7a8275bc8g7ypp&code=v1,k1,C0DI7CwRFrg7gJP5aMC95FUmsNuwuKJboT62BqP5DAk=&_dt_no_comment=1&origin=11) 。 ## 参考文献 如果您发现我们的工作对您的研发有帮助,请引用以下[论文](https://arxiv.org/abs/2309.02033) 。 @@ -315,4 +315,4 @@ eprint={2309.02033}, archivePrefix={arXiv}, primaryClass={cs.LG} } -``` \ No newline at end of file +``` diff --git a/app.py b/app.py index 30505ac05..9ccc8e255 100644 --- a/app.py +++ b/app.py @@ -230,7 +230,7 @@ class Visualize: @staticmethod def filter_dataset(dataset): if Fields.stats not in dataset.features: - return + return text_key = st.session_state.get('text_key', 'text') text = dataset[text_key] stats = pd.DataFrame(dataset[Fields.stats]) diff --git a/configs/README.md b/configs/README.md index 3f205c799..873c02dc5 100644 --- a/configs/README.md +++ b/configs/README.md @@ -29,4 +29,4 @@ We have reproduced the processing flow of some RedPajama datasets. Please refer We have reproduced the processing flow of some BLOOM datasets. please refer to the [reproduced_bloom](reproduced_bloom) folder for details. ### Data-Juicer Recipes -We have refined some open source datasets (including CFT datasets) by using Data-Juicer and have provided configuration files for the refined flow. please refer to the [data_juicer_recipes](data_juicer_recipes) folder for details. \ No newline at end of file +We have refined some open source datasets (including CFT datasets) by using Data-Juicer and have provided configuration files for the refined flow. please refer to the [data_juicer_recipes](data_juicer_recipes) folder for details. diff --git a/configs/README_ZH.md b/configs/README_ZH.md index 67a50eaaa..041f565a2 100644 --- a/configs/README_ZH.md +++ b/configs/README_ZH.md @@ -30,4 +30,4 @@ Demo 配置文件用于帮助用户快速熟悉 Data-Juicer 的基本功能, 我们已经重现了部分 BLOOM 数据集的处理流程,请参阅 [reproduced_bloom](reproduced_bloom) 文件夹以获取详细说明。 ### Data-Juicer 菜谱 -我们使用 Data-Juicer 更细致地处理了一些开源数据集(包含 CFT 数据集),并提供了处理流程的配置文件。请参阅 [data_juicer_recipes](data_juicer_recipes) 文件夹以获取详细说明。 \ No newline at end of file +我们使用 Data-Juicer 更细致地处理了一些开源数据集(包含 CFT 数据集),并提供了处理流程的配置文件。请参阅 [data_juicer_recipes](data_juicer_recipes) 文件夹以获取详细说明。 diff --git a/configs/data_juicer_recipes/alpaca_cot/README.md b/configs/data_juicer_recipes/alpaca_cot/README.md index 11dd8e887..787e52601 100644 --- a/configs/data_juicer_recipes/alpaca_cot/README.md +++ b/configs/data_juicer_recipes/alpaca_cot/README.md @@ -6,7 +6,7 @@ This folder contains some configuration files to allow users to easily and quick The raw data files can be downloaded from [Alpaca-CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) on HuggingFace. -### Convert raw Alpaca-CoT data to jsonl +### Convert raw Alpaca-CoT data to jsonl Use [raw_alpaca_cot_merge_add_meta.py](../../../tools/preprocess/raw_alpaca_cot_merge_add_meta.py) to select `instruction`, `input` and `output` columns and merge them to `text` field with a space, and add extra [ META ]( #meta_info) info to dataset: ```shell @@ -66,7 +66,7 @@ Each sample in refined data of Alpaca-CoT contains meta info listed as below: * `CFT-SR`: tagged as Single-round Dialog datasets * `CFT-MR`: tagged as Multi-round Dialog datasets - + * `CFT-P`: tagged as Preference datasets @@ -111,4 +111,4 @@ Each sample in refined data of Alpaca-CoT contains meta info listed as below: | StackExchange | MT | COL | EN | StackExchange | | ✅ | | ✅ | | ConvAI2 | TS | HG | EN | ConvAI2 | | ✅ | | | | FastChat | MT | SI | EN | FastChat | | ✅ | | | -| Tabular-LLM-Data | MT | COL | EN/CN | Tabular-LLM-Data | ✅ | | | | \ No newline at end of file +| Tabular-LLM-Data | MT | COL | EN/CN | Tabular-LLM-Data | ✅ | | | | diff --git a/configs/data_juicer_recipes/alpaca_cot/README_ZH.md b/configs/data_juicer_recipes/alpaca_cot/README_ZH.md index c482fdec2..62d751755 100644 --- a/configs/data_juicer_recipes/alpaca_cot/README_ZH.md +++ b/configs/data_juicer_recipes/alpaca_cot/README_ZH.md @@ -111,4 +111,4 @@ python tools/process_data.py --config configs/data_juicer_recipes/alpaca_cot/alp | StackExchange | MT | COL | EN | StackExchange | | ✅ | | ✅ | | ConvAI2 | TS | HG | EN | ConvAI2 | | ✅ | | | | FastChat | MT | SI | EN | FastChat | | ✅ | | | -| Tabular-LLM-Data | MT | COL | EN/CN | Tabular-LLM-Data | ✅ | | | | \ No newline at end of file +| Tabular-LLM-Data | MT | COL | EN/CN | Tabular-LLM-Data | ✅ | | | | diff --git a/configs/data_juicer_recipes/alpaca_cot/alpaca-cot-en-refine.yaml b/configs/data_juicer_recipes/alpaca_cot/alpaca-cot-en-refine.yaml index 85b338b92..e55626789 100644 --- a/configs/data_juicer_recipes/alpaca_cot/alpaca-cot-en-refine.yaml +++ b/configs/data_juicer_recipes/alpaca_cot/alpaca-cot-en-refine.yaml @@ -10,23 +10,23 @@ open_tracer: true # a list of several process operators with their arguments process: - document_deduplicator: # 104636705 - lowercase: true + lowercase: true ignore_non_character: true - + - alphanumeric_filter: # 104636381 tokenization: false - min_ratio: 0.1 + min_ratio: 0.1 - character_repetition_filter: # 104630030 rep_len: 10 - max_ratio: 0.6 + max_ratio: 0.6 - flagged_words_filter: # 104576967 lang: en tokenization: true - max_ratio: 0.017 + max_ratio: 0.017 - maximum_line_length_filter: # 104575811 min_len: 20 - text_length_filter: # 104573711 - min_len: 30 + min_len: 30 - document_simhash_deduplicator: # 72855345 tokenization: space diff --git a/configs/data_juicer_recipes/alpaca_cot/alpaca-cot-zh-refine.yaml b/configs/data_juicer_recipes/alpaca_cot/alpaca-cot-zh-refine.yaml index 9c20f78dc..9b252b949 100644 --- a/configs/data_juicer_recipes/alpaca_cot/alpaca-cot-zh-refine.yaml +++ b/configs/data_juicer_recipes/alpaca_cot/alpaca-cot-zh-refine.yaml @@ -15,10 +15,10 @@ process: - alphanumeric_filter: # 16957388 tokenization: false - min_ratio: 0.10 + min_ratio: 0.10 - character_repetition_filter: # 16956845 rep_len: 10 - max_ratio: 0.6 + max_ratio: 0.6 - flagged_words_filter: # 16954629 lang: zh tokenization: true diff --git a/configs/data_juicer_recipes/redpajama-c4-refine.yaml b/configs/data_juicer_recipes/redpajama-c4-refine.yaml index c31b55785..fca603059 100644 --- a/configs/data_juicer_recipes/redpajama-c4-refine.yaml +++ b/configs/data_juicer_recipes/redpajama-c4-refine.yaml @@ -49,4 +49,4 @@ process: lowercase: true ignore_pattern: '\p{P}' num_blocks: 6 - hamming_distance: 4 \ No newline at end of file + hamming_distance: 4 diff --git a/data_juicer/analysis/overall_analysis.py b/data_juicer/analysis/overall_analysis.py index bd4c92b7c..117d43d68 100644 --- a/data_juicer/analysis/overall_analysis.py +++ b/data_juicer/analysis/overall_analysis.py @@ -3,6 +3,8 @@ import pandas as pd from data_juicer.utils.constant import Fields + + class OverallAnalysis: """Apply analysis on the overall stats, including mean, std, quantiles, etc.""" diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index ed18ac470..0fbbb80f2 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -7,10 +7,10 @@ from jsonargparse import (ActionConfigFile, ArgumentParser, dict_to_namespace, namespace_to_dict) from jsonargparse.typing import NonNegativeInt, PositiveInt +from loguru import logger from data_juicer.ops.base_op import OPERATORS from data_juicer.utils.logger_utils import setup_logger -from loguru import logger def init_configs(args=None): @@ -450,7 +450,13 @@ def merge_config(ori_cfg, new_cfg: Dict): for new_k, new_v in new_cfg.items(): # merge parameters other than `cfg.process` and DJ-OPs if new_k in ori_cfg and new_k != 'process' and '.' not in new_k: + print( + '=' * 15, f'\nBefore merging, the cfg item is: ' + f'{new_k}: {ori_cfg[new_k]}') ori_cfg[new_k] = new_v + print( + f'After merging, the cfg item is: ' + f'{new_k}: {new_v}\n', '=' * 15, '\n') else: # merge parameters of DJ-OPs into cfg.process # for nested style, e.g., `remove_table_text_mapper.min_col: 2` @@ -459,17 +465,21 @@ def merge_config(ori_cfg, new_cfg: Dict): key_as_groups[0] in ori_specified_op_names: op_name, para_name = key_as_groups[0], key_as_groups[1] op_order = ori_specified_op_idx[op_name] + ori_cfg_val = ori_cfg.process[op_order][op_name][para_name] + print( + '=' * 15, f'\nBefore merging, the cfg item is: ' + f'{new_k}: {ori_cfg_val}' + ) ori_cfg.process[op_order][op_name][para_name] = new_v + print( + f'After merging, the cfg item is: ' + f'{new_k}: {new_v}\n', '=' * 15, '\n') ori_cfg = init_setup_from_cfg(ori_cfg) # copy the config file into the work directory config_backup(ori_cfg) - # show the final config tables before the process started - print('=' * 10, '\nAfter merging, the new cfg becomes:', '=' * 10) - display_config(ori_cfg) - return ori_cfg except ArgumentError: diff --git a/data_juicer/core/data.py b/data_juicer/core/data.py index 04ed6718d..6e43a4524 100644 --- a/data_juicer/core/data.py +++ b/data_juicer/core/data.py @@ -8,8 +8,9 @@ from loguru import logger from data_juicer.utils import cache_utils -from data_juicer.utils.compress import (cleanup_compressed_cache_files, - compress, decompress, CompressionOff) +from data_juicer.utils.compress import (CompressionOff, + cleanup_compressed_cache_files, + compress, decompress) from data_juicer.utils.fingerprint_utils import generate_fingerprint @@ -173,15 +174,13 @@ def map(self, *args, **kargs): kargs['new_fingerprint'] = new_fingerprint if cache_utils.CACHE_COMPRESS: - decompress(self, - kargs['new_fingerprint'], + decompress(self, kargs['new_fingerprint'], kargs['num_proc'] if 'num_proc' in kargs else 1) new_ds = NestedDataset(super().map(*args, **kargs)) if cache_utils.CACHE_COMPRESS: - compress(self, - new_ds, + compress(self, new_ds, kargs['num_proc'] if 'num_proc' in kargs else 1) if self.need_to_cleanup_caches: @@ -215,8 +214,7 @@ def filter(self, *args, **kargs): # after). So we need to decompress these two sets of compressed cache # files if cache_utils.CACHE_COMPRESS: - decompress(self, - [kargs['new_fingerprint'], self._fingerprint], + decompress(self, [kargs['new_fingerprint'], self._fingerprint], kargs['num_proc'] if 'num_proc' in kargs else 1) # Turn off the compression due to it invokes map actually in the filter @@ -231,8 +229,7 @@ def filter(self, *args, **kargs): self.need_to_cleanup_caches = prev_state if cache_utils.CACHE_COMPRESS: - compress(self, - new_ds, + compress(self, new_ds, kargs['num_proc'] if 'num_proc' in kargs else 1) if self.need_to_cleanup_caches: diff --git a/data_juicer/core/executor.py b/data_juicer/core/executor.py index 78e8066c9..eb32e306d 100644 --- a/data_juicer/core/executor.py +++ b/data_juicer/core/executor.py @@ -67,8 +67,7 @@ def __init__(self, cfg=None): logger.info('Preparing exporter...') self.exporter = Exporter(self.cfg.export_path, self.cfg.export_shard_size, - self.cfg.export_in_parallel, - self.cfg.np) + self.cfg.export_in_parallel, self.cfg.np) # setup tracer self.open_tracer = self.cfg.open_tracer @@ -120,14 +119,9 @@ def run(self, load_data_np=None): op_name in self.op_list_to_trace: if op.is_batched_op(): self.tracer.trace_batch_mapper( - op_name, - dataset, - tmp, - op.text_key) + op_name, dataset, tmp, op.text_key) else: - self.tracer.trace_mapper(op_name, - dataset, - tmp, + self.tracer.trace_mapper(op_name, dataset, tmp, op.text_key) elif isinstance(op, Filter): if Fields.stats not in dataset.features: diff --git a/data_juicer/core/exporter.py b/data_juicer/core/exporter.py index a77a12c93..9450c9482 100644 --- a/data_juicer/core/exporter.py +++ b/data_juicer/core/exporter.py @@ -161,8 +161,7 @@ def _export_impl(self, dataset, export_path, suffix, export_stats=True): Exporter.to_jsonl( ds_stats, stats_file, - num_proc=self.num_proc if self.export_in_parallel else 1 - ) + num_proc=self.num_proc if self.export_in_parallel else 1) def export(self, dataset): """ diff --git a/data_juicer/core/ray_executor.py b/data_juicer/core/ray_executor.py index 726949e52..d1290c206 100644 --- a/data_juicer/core/ray_executor.py +++ b/data_juicer/core/ray_executor.py @@ -1,13 +1,11 @@ -import os - +import ray +import ray.data as rd from loguru import logger + from data_juicer.config import init_configs -from data_juicer.ops import (Filter, Mapper, load_ops) +from data_juicer.ops import Filter, Mapper, load_ops from data_juicer.utils.constant import Fields -import ray -import ray.data as rd - class RayExecutor: """ @@ -35,7 +33,6 @@ def __init__(self, cfg=None): ray.init(self.cfg.ray_address) self.process_list = self.cfg.process - def run(self, load_data_np=None): """ Running the dataset process pipeline. @@ -57,7 +54,8 @@ def run(self, load_data_np=None): # - If checkpoint is open, clean the cache files after each process if Fields.stats not in dataset.columns(fetch_if_missing=False): logger.info(f'columns {dataset.columns(fetch_if_missing=False)}') - dataset = dataset.add_column(Fields.stats, lambda df: [{}] * len(df)) + dataset = dataset.add_column(Fields.stats, + lambda df: [{}] * len(df)) logger.info('Processing data...') for op_cfg, op in zip(self.process_list, self.ops): op_name, _ = list(op_cfg.items())[0] @@ -68,7 +66,9 @@ def run(self, load_data_np=None): dataset = dataset.map(op.compute_stats) dataset = dataset.filter(op.process) else: - logger.error('Ray executor only support Filter and Mapper OPs for now') + logger.error( + 'Ray executor only support Filter and Mapper OPs for now' + ) raise NotImplementedError except: # noqa: E722 logger.error(f'An error occurred during Op [{op_name}].') diff --git a/data_juicer/format/formatter.py b/data_juicer/format/formatter.py index 123f61a45..0a8629bfc 100644 --- a/data_juicer/format/formatter.py +++ b/data_juicer/format/formatter.py @@ -1,12 +1,13 @@ import os from typing import List, Tuple, Union +from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset +from loguru import logger + from data_juicer.utils.constant import Fields from data_juicer.utils.file_utils import (find_files_with_suffix, is_absolute_path) from data_juicer.utils.registry import Registry -from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset -from loguru import logger FORMATTERS = Registry('Formatters') diff --git a/data_juicer/ops/common/helper_func.py b/data_juicer/ops/common/helper_func.py index 1b622de74..fdf38d153 100644 --- a/data_juicer/ops/common/helper_func.py +++ b/data_juicer/ops/common/helper_func.py @@ -122,10 +122,12 @@ def words_augmentation(words, group_size, join_char): return augmentation -def get_words_from_document(document, - token_func=None, - new_line=True, - tab=True,): +def get_words_from_document( + document, + token_func=None, + new_line=True, + tab=True, +): """ Get words from a document. Useful to compute ratios, like the stopwords ratio. @@ -143,6 +145,7 @@ def get_words_from_document(document, words = split_on_whitespace(document, new_line, tab) return words + def words_refinement(words, lower_case=False, strip_chars=None, diff --git a/data_juicer/ops/filter/alphanumeric_filter.py b/data_juicer/ops/filter/alphanumeric_filter.py index d65921981..7dabf5552 100644 --- a/data_juicer/ops/filter/alphanumeric_filter.py +++ b/data_juicer/ops/filter/alphanumeric_filter.py @@ -3,7 +3,7 @@ from jsonargparse.typing import PositiveFloat from data_juicer.utils.constant import Fields, StatsKeys -from data_juicer.utils.model_utils import prepare_model, get_model +from data_juicer.utils.model_utils import get_model, prepare_model from ..base_op import OPERATORS, Filter from ..common import get_words_from_document diff --git a/data_juicer/ops/filter/average_line_length_filter.py b/data_juicer/ops/filter/average_line_length_filter.py index d89d0c2bd..0cacd84f0 100644 --- a/data_juicer/ops/filter/average_line_length_filter.py +++ b/data_juicer/ops/filter/average_line_length_filter.py @@ -2,7 +2,7 @@ from jsonargparse.typing import PositiveInt -from data_juicer.utils.constant import Fields, StatsKeys, InterVars +from data_juicer.utils.constant import Fields, InterVars, StatsKeys from ..base_op import OPERATORS, Filter from ..op_fusion import INTER_LINES diff --git a/data_juicer/ops/filter/character_repetition_filter.py b/data_juicer/ops/filter/character_repetition_filter.py index f6b65e35c..3e5ce7e34 100644 --- a/data_juicer/ops/filter/character_repetition_filter.py +++ b/data_juicer/ops/filter/character_repetition_filter.py @@ -59,7 +59,8 @@ def compute_stats(self, sample): freq_char_ngrams = sorted(list(freq_char_ngrams.values()), reverse=True) - num_no_rep_char_ngrams = len([el for el in freq_char_ngrams if el == 1]) + num_no_rep_char_ngrams = len( + [el for el in freq_char_ngrams if el == 1]) num_rep_char_ngrams = min( int(np.sqrt(len(freq_char_ngrams))), len(freq_char_ngrams) - num_no_rep_char_ngrams, diff --git a/data_juicer/ops/filter/flagged_words_filter.py b/data_juicer/ops/filter/flagged_words_filter.py index 140bc970b..fbc5e4eb8 100644 --- a/data_juicer/ops/filter/flagged_words_filter.py +++ b/data_juicer/ops/filter/flagged_words_filter.py @@ -4,14 +4,14 @@ from jsonargparse.typing import ClosedUnitInterval, List -from data_juicer.utils.constant import Fields, StatsKeys, InterVars -from data_juicer.utils.model_utils import prepare_model, get_model +from data_juicer.utils.constant import Fields, InterVars, StatsKeys +from data_juicer.utils.model_utils import get_model, prepare_model from ...utils.asset_utils import ASSET_DIR, load_words_asset from ..base_op import OPERATORS, Filter -from ..op_fusion import INTER_WORDS from ..common import (SPECIAL_CHARACTERS, get_words_from_document, words_refinement) +from ..op_fusion import INTER_WORDS @OPERATORS.register_module('flagged_words_filter') @@ -79,8 +79,9 @@ def compute_stats(self, sample, context=False): if context and words_key in sample[Fields.context]: words = sample[Fields.context][words_key] else: - tokenizer = get_model(self.model_key, lang=self.lang, - model_type='sentencepiece') + tokenizer = get_model(self.model_key, + lang=self.lang, + model_type='sentencepiece') words = get_words_from_document( sample[self.text_key], token_func=tokenizer.encode_as_pieces if tokenizer else None) diff --git a/data_juicer/ops/filter/language_id_score_filter.py b/data_juicer/ops/filter/language_id_score_filter.py index 60d066931..cc3520be5 100644 --- a/data_juicer/ops/filter/language_id_score_filter.py +++ b/data_juicer/ops/filter/language_id_score_filter.py @@ -2,7 +2,7 @@ from loguru import logger from data_juicer.utils.constant import Fields, StatsKeys -from data_juicer.utils.model_utils import prepare_model, get_model +from data_juicer.utils.model_utils import get_model, prepare_model from ..base_op import OPERATORS, Filter @@ -38,7 +38,9 @@ def compute_stats(self, sample): return sample text = sample[self.text_key].lower().replace('\n', ' ') - ft_model = get_model(self.model_key, lang=self.lang, model_type='fasttext') + ft_model = get_model(self.model_key, + lang=self.lang, + model_type='fasttext') if ft_model is None: err_msg = 'Model not loaded. Please retry later.' logger.error(err_msg) diff --git a/data_juicer/ops/filter/maximum_line_length_filter.py b/data_juicer/ops/filter/maximum_line_length_filter.py index edebd8180..f5dd2d0d5 100644 --- a/data_juicer/ops/filter/maximum_line_length_filter.py +++ b/data_juicer/ops/filter/maximum_line_length_filter.py @@ -2,7 +2,7 @@ from jsonargparse.typing import PositiveInt -from data_juicer.utils.constant import Fields, StatsKeys, InterVars +from data_juicer.utils.constant import Fields, InterVars, StatsKeys from ..base_op import OPERATORS, Filter from ..op_fusion import INTER_LINES diff --git a/data_juicer/ops/filter/perplexity_filter.py b/data_juicer/ops/filter/perplexity_filter.py index a182fba9a..975279b7c 100644 --- a/data_juicer/ops/filter/perplexity_filter.py +++ b/data_juicer/ops/filter/perplexity_filter.py @@ -4,12 +4,12 @@ from jsonargparse.typing import PositiveFloat -from data_juicer.utils.constant import Fields, StatsKeys, InterVars -from data_juicer.utils.model_utils import prepare_model, get_model +from data_juicer.utils.constant import Fields, InterVars, StatsKeys +from data_juicer.utils.model_utils import get_model, prepare_model from ..base_op import OPERATORS, Filter -from ..op_fusion import INTER_WORDS from ..common import get_words_from_document +from ..op_fusion import INTER_WORDS @OPERATORS.register_module('perplexity_filter') @@ -49,7 +49,8 @@ def compute_stats(self, sample, context=False): if context and words_key in sample[Fields.context]: words = sample[Fields.context][words_key] else: - tokenizer = get_model(self.sp_model_key, self.lang, 'sentencepiece') + tokenizer = get_model(self.sp_model_key, self.lang, + 'sentencepiece') words = get_words_from_document( sample[self.text_key], token_func=tokenizer.encode_as_pieces if tokenizer else None) diff --git a/data_juicer/ops/filter/stopwords_filter.py b/data_juicer/ops/filter/stopwords_filter.py index 37f6d42b1..03c2a5a15 100644 --- a/data_juicer/ops/filter/stopwords_filter.py +++ b/data_juicer/ops/filter/stopwords_filter.py @@ -5,13 +5,13 @@ from jsonargparse.typing import ClosedUnitInterval, List from data_juicer.utils.asset_utils import ASSET_DIR, load_words_asset -from data_juicer.utils.constant import Fields, StatsKeys, InterVars -from data_juicer.utils.model_utils import prepare_model, get_model +from data_juicer.utils.constant import Fields, InterVars, StatsKeys +from data_juicer.utils.model_utils import get_model, prepare_model from ..base_op import OPERATORS, Filter -from ..op_fusion import INTER_WORDS from ..common import (SPECIAL_CHARACTERS, get_words_from_document, words_refinement) +from ..op_fusion import INTER_WORDS @OPERATORS.register_module('stopwords_filter') @@ -77,8 +77,9 @@ def compute_stats(self, sample, context=False): if context and words_key in sample[Fields.context]: words = sample[Fields.context][words_key] else: - tokenizer = get_model(self.model_key, lang=self.lang, - model_type='sentencepiece') + tokenizer = get_model(self.model_key, + lang=self.lang, + model_type='sentencepiece') words = get_words_from_document( sample[self.text_key], token_func=tokenizer.encode_as_pieces if tokenizer else None) diff --git a/data_juicer/ops/filter/token_num_filter.py b/data_juicer/ops/filter/token_num_filter.py index 21066382a..5bdc58a3d 100644 --- a/data_juicer/ops/filter/token_num_filter.py +++ b/data_juicer/ops/filter/token_num_filter.py @@ -3,7 +3,7 @@ from jsonargparse.typing import PositiveInt from data_juicer.utils.constant import Fields, StatsKeys -from data_juicer.utils.model_utils import prepare_model, get_model +from data_juicer.utils.model_utils import get_model, prepare_model from ..base_op import OPERATORS, Filter from ..common import get_words_from_document @@ -48,8 +48,7 @@ def compute_stats(self, sample): tokenizer = get_model(self.model_key, model_type='huggingface') tokens = get_words_from_document( sample[self.text_key], - token_func=tokenizer.tokenize if tokenizer else None - ) + token_func=tokenizer.tokenize if tokenizer else None) sample[Fields.stats][StatsKeys.num_token] = len(tokens) return sample diff --git a/data_juicer/ops/filter/word_num_filter.py b/data_juicer/ops/filter/word_num_filter.py index 3d589287f..98e544b1c 100644 --- a/data_juicer/ops/filter/word_num_filter.py +++ b/data_juicer/ops/filter/word_num_filter.py @@ -2,13 +2,13 @@ from jsonargparse.typing import PositiveInt -from data_juicer.utils.constant import Fields, StatsKeys, InterVars -from data_juicer.utils.model_utils import prepare_model, get_model +from data_juicer.utils.constant import Fields, InterVars, StatsKeys +from data_juicer.utils.model_utils import get_model, prepare_model from ..base_op import OPERATORS, Filter -from ..op_fusion import INTER_WORDS from ..common import (SPECIAL_CHARACTERS, get_words_from_document, words_refinement) +from ..op_fusion import INTER_WORDS @OPERATORS.register_module('words_num_filter') @@ -57,8 +57,9 @@ def compute_stats(self, sample, context=False): if context and words_key in sample[Fields.context]: words = sample[Fields.context][words_key] else: - tokenizer = get_model(self.model_key, lang=self.lang, - model_type='sentencepiece') + tokenizer = get_model(self.model_key, + lang=self.lang, + model_type='sentencepiece') words = get_words_from_document( sample[self.text_key], token_func=tokenizer.encode_as_pieces if tokenizer else None) diff --git a/data_juicer/ops/filter/word_repetition_filter.py b/data_juicer/ops/filter/word_repetition_filter.py index f5bf8f311..1d8d232ce 100644 --- a/data_juicer/ops/filter/word_repetition_filter.py +++ b/data_juicer/ops/filter/word_repetition_filter.py @@ -4,13 +4,13 @@ from jsonargparse.typing import ClosedUnitInterval, PositiveInt -from data_juicer.utils.constant import Fields, StatsKeys, InterVars -from data_juicer.utils.model_utils import prepare_model, get_model +from data_juicer.utils.constant import Fields, InterVars, StatsKeys +from data_juicer.utils.model_utils import get_model, prepare_model from ..base_op import OPERATORS, Filter -from ..op_fusion import INTER_WORDS from ..common import (SPECIAL_CHARACTERS, get_words_from_document, words_refinement) +from ..op_fusion import INTER_WORDS @OPERATORS.register_module('word_repetition_filter') @@ -63,8 +63,9 @@ def compute_stats(self, sample, context=False): if context and words_key in sample[Fields.context]: words = sample[Fields.context][words_key] else: - tokenizer = get_model(self.model_key, lang=self.lang, - model_type='sentencepiece') + tokenizer = get_model(self.model_key, + lang=self.lang, + model_type='sentencepiece') words = get_words_from_document( sample[self.text_key], token_func=tokenizer.encode_as_pieces if tokenizer else None) @@ -77,10 +78,9 @@ def compute_stats(self, sample, context=False): if context and refined_words_key in sample[Fields.context]: words = sample[Fields.context][refined_words_key] else: - words = words_refinement( - words, - lower_case=True, - strip_chars=SPECIAL_CHARACTERS) + words = words_refinement(words, + lower_case=True, + strip_chars=SPECIAL_CHARACTERS) if context: sample[Fields.context][refined_words_key] = words word_ngrams = [ diff --git a/data_juicer/ops/load.py b/data_juicer/ops/load.py index a48f6cf44..e8d1ed65e 100644 --- a/data_juicer/ops/load.py +++ b/data_juicer/ops/load.py @@ -1,6 +1,7 @@ from .base_op import OPERATORS from .op_fusion import fuse_operators + def load_ops(process_list, op_fusion=False): """ Load op list according to the process list from config file. diff --git a/data_juicer/ops/mapper/nlpaug_en_mapper.py b/data_juicer/ops/mapper/nlpaug_en_mapper.py index ff2589ce5..6a5148c7b 100644 --- a/data_juicer/ops/mapper/nlpaug_en_mapper.py +++ b/data_juicer/ops/mapper/nlpaug_en_mapper.py @@ -1,11 +1,10 @@ +from copy import deepcopy + import nlpaug.augmenter.char as nac import nlpaug.augmenter.word as naw import nlpaug.flow as naf - -from nlpaug.util import Action - from loguru import logger -from copy import deepcopy +from nlpaug.util import Action from ..base_op import OPERATORS, Mapper @@ -138,4 +137,3 @@ def process(self, samples): res_samples[key] += res_samples[key] * self.aug_num \ * len(self.aug) return res_samples - diff --git a/data_juicer/ops/mapper/nlpcda_zh_mapper.py b/data_juicer/ops/mapper/nlpcda_zh_mapper.py index e1a9bf816..51cf50e49 100644 --- a/data_juicer/ops/mapper/nlpcda_zh_mapper.py +++ b/data_juicer/ops/mapper/nlpcda_zh_mapper.py @@ -1,10 +1,11 @@ +from copy import deepcopy from loguru import logger -from copy import deepcopy -from ..base_op import OPERATORS, Mapper from data_juicer.utils.logger_utils import HiddenPrints +from ..base_op import OPERATORS, Mapper + @OPERATORS.register_module('nlpcda_zh_mapper') class NlpcdaZhMapper(Mapper): @@ -152,4 +153,3 @@ def process(self, samples): res_samples[key] = res_samples[key] * \ len(res_samples[self.text_key]) return res_samples - diff --git a/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py b/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py index 9a8b94f17..c6f7c5e43 100644 --- a/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py +++ b/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py @@ -1,6 +1,6 @@ from jsonargparse.typing import List -from data_juicer.utils.model_utils import prepare_model, get_model +from data_juicer.utils.model_utils import get_model, prepare_model from ..base_op import OPERATORS, Mapper from ..common import (SPECIAL_CHARACTERS, get_words_from_document, @@ -44,7 +44,9 @@ def should_keep_word_with_incorrect_substrings(self, word, substrings): def process(self, sample): if self.tokenization: - tokenizer = get_model(self.model_key, lang=self.lang, model_type='sentencepiece') + tokenizer = get_model(self.model_key, + lang=self.lang, + model_type='sentencepiece') sentences = get_words_from_document( sample[self.text_key], token_func=tokenizer.encode_as_pieces if tokenizer else None) diff --git a/data_juicer/ops/mapper/sentence_split_mapper.py b/data_juicer/ops/mapper/sentence_split_mapper.py index 51317575c..65a02308e 100644 --- a/data_juicer/ops/mapper/sentence_split_mapper.py +++ b/data_juicer/ops/mapper/sentence_split_mapper.py @@ -1,4 +1,4 @@ -from data_juicer.utils.model_utils import prepare_model, get_model +from data_juicer.utils.model_utils import get_model, prepare_model from ..base_op import OPERATORS, Mapper from ..common import get_sentences_from_document @@ -22,7 +22,9 @@ def __init__(self, lang: str = 'en', *args, **kwargs): def process(self, sample): - nltk_model = get_model(self.model_key, lang=self.lang, model_type='nltk') + nltk_model = get_model(self.model_key, + lang=self.lang, + model_type='nltk') sample[self.text_key] = get_sentences_from_document( sample[self.text_key], model_func=nltk_model.tokenize if nltk_model else None) diff --git a/data_juicer/ops/mapper/whitespace_normalization_mapper.py b/data_juicer/ops/mapper/whitespace_normalization_mapper.py index 01fb204c0..6fa44b559 100644 --- a/data_juicer/ops/mapper/whitespace_normalization_mapper.py +++ b/data_juicer/ops/mapper/whitespace_normalization_mapper.py @@ -5,6 +5,7 @@ from ..base_op import OPERATORS, Mapper from ..common.special_characters import VARIOUS_WHITESPACES + @OPERATORS.register_module('whitespace_normalization_mapper') class WhitespaceNormalizationMapper(Mapper): """ @@ -29,8 +30,8 @@ def process(self, sample): text = sample[self.text_key].strip() # replace all kinds of whitespaces with ' ' - sample[self.text_key] = ''.join( - [char if char not in VARIOUS_WHITESPACES else ' ' - for char in text]) + sample[self.text_key] = ''.join([ + char if char not in VARIOUS_WHITESPACES else ' ' for char in text + ]) return sample diff --git a/data_juicer/ops/op_fusion.py b/data_juicer/ops/op_fusion.py index fdda23fed..099fc28d0 100644 --- a/data_juicer/ops/op_fusion.py +++ b/data_juicer/ops/op_fusion.py @@ -1,15 +1,17 @@ - from typing import List + from loguru import logger -from .base_op import Filter from data_juicer.utils.constant import Fields, InterVars from data_juicer.utils.registry import Registry +from .base_op import Filter + # Type of intermediate vars INTER_LINES = Registry(InterVars.lines) INTER_WORDS = Registry(InterVars.words) + def fuse_operators(process_list, ops): """ Fuse the input ops list and return the fused ops list. @@ -49,6 +51,7 @@ def fuse_operators(process_list, ops): fused_ops.extend(fused_group) return fused_op_def, fused_ops + def fuse_filter_group(original_filter_group): """ Fuse single filter group and return the fused filter group. @@ -60,8 +63,10 @@ def fuse_filter_group(original_filter_group): fused_group_def = [] fused_group = [] all_intermediate_vars = [INTER_LINES, INTER_WORDS] - all_fused_filters = {inter_vars: [] - for inter_vars in all_intermediate_vars} + all_fused_filters = { + inter_vars: [] + for inter_vars in all_intermediate_vars + } # group these filters by their intermediate vars for process, op in original_filter_group: op_name, op_args = list(process.items())[0] @@ -86,8 +91,10 @@ def fuse_filter_group(original_filter_group): defs, ops = zip(*inter_vars_filter) # new definition: new name and a definition list of fused op list fused_filter_def = { - 'OpFusion:(%s)' % ','.join([list(process.items())[0][0] - for process in defs]): list(defs) + 'OpFusion:(%s)' % ','.join([ + list(process.items())[0][0] for process in defs + ]): + list(defs) } logger.info(f'Ops are fused into one op ' f'{list(fused_filter_def.keys())[0]}.') @@ -104,8 +111,10 @@ def fuse_filter_group(original_filter_group): return fused_group_def, fused_group + class FusedFilter(Filter): """A fused operator for filters.""" + def __init__(self, fused_filters: List): """ Initialization method. diff --git a/data_juicer/tools/__init__.py b/data_juicer/tools/__init__.py index 1a17b4f33..b74f54887 100644 --- a/data_juicer/tools/__init__.py +++ b/data_juicer/tools/__init__.py @@ -9,7 +9,7 @@ from importlib import abc, util from pathlib import Path -_TOOLS_PATH = Path(__file__).resolve().parent.parent.parent / "tools" +_TOOLS_PATH = Path(__file__).resolve().parent.parent.parent / 'tools' if _TOOLS_PATH.is_dir(): # This is true only for in-place installation @@ -20,12 +20,12 @@ class _PathFinder(abc.MetaPathFinder): def find_spec(self, name, path, target=None): - if not name.startswith("data_juicer.tools."): + if not name.startswith('data_juicer.tools.'): return - project_name = name.split(".")[-1] + ".py" + project_name = name.split('.')[-1] + '.py' target_file = _TOOLS_PATH / project_name if not target_file.is_file(): return return util.spec_from_file_location(name, target_file) - sys.meta_path.append(_PathFinder()) \ No newline at end of file + sys.meta_path.append(_PathFinder()) diff --git a/data_juicer/utils/compress.py b/data_juicer/utils/compress.py index 7f2aabccf..3ffd9fada 100644 --- a/data_juicer/utils/compress.py +++ b/data_juicer/utils/compress.py @@ -2,9 +2,9 @@ import re import shutil from abc import ABC, abstractmethod +from multiprocessing import Pool from pathlib import Path from typing import Dict, List, Optional, Type, Union -from multiprocessing import Pool from datasets import Dataset from datasets.utils.extract import Extractor as HF_Extractor @@ -343,7 +343,10 @@ def compress(self, f'Compressing cache file to {formatted_cache_name}') if num_proc > 1: pool.apply_async(self.compress_manager.compress, - args=(full_name, compress_filename,)) + args=( + full_name, + compress_filename, + )) else: self.compress_manager.compress(full_name, compress_filename) @@ -399,7 +402,10 @@ def decompress(self, files_printed.add(formatted_cache_name) if num_proc > 1: pool.apply_async(self.compress_manager.decompress, - args=(full_name, raw_filename,)) + args=( + full_name, + raw_filename, + )) else: self.compress_manager.decompress(full_name, raw_filename) else: @@ -447,8 +453,10 @@ def cleanup_cache_files(self, ds): os.remove(full_name) return len(f_names) + class CompressionOff: """Define a range that turn off the cache compression temporarily.""" + def __enter__(self): """ Record the original cache compression method and turn it off. @@ -464,6 +472,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): from . import cache_utils cache_utils.CACHE_COMPRESS = self.original_cache_compress + def compress(prev_ds, this_ds=None, num_proc=1): if cache_utils.CACHE_COMPRESS: CacheCompressManager(cache_utils.CACHE_COMPRESS).compress( diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 31bc9f0f8..24a918999 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -31,6 +31,7 @@ class HashKeys(object): minhash = DEFAULT_PREFIX + 'minhash' simhash = DEFAULT_PREFIX + 'simhash' + class InterVars(object): lines = DEFAULT_PREFIX + 'lines' words = DEFAULT_PREFIX + 'words' diff --git a/data_juicer/utils/logger_utils.py b/data_juicer/utils/logger_utils.py index 5d4c6d970..ec99b1530 100644 --- a/data_juicer/utils/logger_utils.py +++ b/data_juicer/utils/logger_utils.py @@ -92,7 +92,11 @@ def get_log_file_path(): return handler._sink._file.name -def setup_logger(save_dir, distributed_rank=0, filename='log.txt', mode='o', redirect=True): +def setup_logger(save_dir, + distributed_rank=0, + filename='log.txt', + mode='o', + redirect=True): """ Setup logger for training and testing. @@ -133,8 +137,10 @@ def setup_logger(save_dir, distributed_rank=0, filename='log.txt', mode='o', red redirect_sys_output('INFO') LOGGER_SETUP = True + class HiddenPrints: """Define a range that hide the outputs within this range.""" + def __enter__(self): """ Store the original standard output and redirect the standard output to diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index 8d416d9f2..3ac0d7973 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -169,6 +169,7 @@ def prepare_huggingface_tokenizer(tokenizer_name): trust_remote_code=True) return tokenizer + def prepare_diversity_model(model_name, lang): """ Prepare diversity model for specific language. @@ -248,4 +249,4 @@ def get_model(model_key, lang='en', model_type='sentencepiece'): """ if model_key not in MODEL_ZOO: prepare_model(lang=lang, model_type=model_type, model_key=model_key) - return MODEL_ZOO.get(model_key, None) \ No newline at end of file + return MODEL_ZOO.get(model_key, None) diff --git a/demos/README.md b/demos/README.md index a45f4f510..000f78246 100644 --- a/demos/README.md +++ b/demos/README.md @@ -47,4 +47,4 @@ streamlit run app.py - This demo splits a dataset to different sub-datasets by language. - Data mixture (`data_mixture`) - - This demo selects and mixes samples from multiple datasets and exports them into a new dataset. \ No newline at end of file + - This demo selects and mixes samples from multiple datasets and exports them into a new dataset. diff --git a/demos/README_ZH.md b/demos/README_ZH.md index 939232e39..218fe1e64 100644 --- a/demos/README_ZH.md +++ b/demos/README_ZH.md @@ -47,4 +47,4 @@ streamlit run app.py - 该示例按照语言将数据集拆分为不同的子数据集。 - 数据混合 (`data_mixture`) - - 该示例从多份数据集中进行采样并混合为一个新的数据集。 \ No newline at end of file + - 该示例从多份数据集中进行采样并混合为一个新的数据集。 diff --git a/demos/auto_evaluation_helm/README_ZH.md b/demos/auto_evaluation_helm/README_ZH.md index b6b50d24e..2a30e1cbe 100644 --- a/demos/auto_evaluation_helm/README_ZH.md +++ b/demos/auto_evaluation_helm/README_ZH.md @@ -42,7 +42,7 @@ wandb login ```shell docker commit data-juicer-eval -``` +``` #### 2. 将数据集预处理为 Megatron-LM 可识别的格式 diff --git a/demos/auto_evaluation_helm/app.py b/demos/auto_evaluation_helm/app.py index a9f7edbd2..03b1d9e8b 100644 --- a/demos/auto_evaluation_helm/app.py +++ b/demos/auto_evaluation_helm/app.py @@ -1,7 +1,9 @@ import os import re + import streamlit as st + class Visualize: @staticmethod @@ -29,6 +31,7 @@ def setup(): def visualize(): Visualize.setup() + def main(): def make_image(line): @@ -38,10 +41,10 @@ def make_image(line): Visualize.visualize() buffer = [] - with open("README_ZH.md", 'r', encoding='utf-8') as f: + with open('README_ZH.md', 'r', encoding='utf-8') as f: lines = f.readlines() for line in lines: - if "imgs/" in line: + if 'imgs/' in line: st.markdown('\n'.join(buffer)) make_image(line) buffer.clear() @@ -50,5 +53,6 @@ def make_image(line): st.markdown('\n'.join(buffer)) # hello() + if __name__ == '__main__': main() diff --git a/demos/data_process_hpo/app.py b/demos/data_process_hpo/app.py index bc7a78baf..a5f62552c 100644 --- a/demos/data_process_hpo/app.py +++ b/demos/data_process_hpo/app.py @@ -1,6 +1,7 @@ -import os + import streamlit as st + class Visualize: @staticmethod @@ -29,18 +30,24 @@ def setup(): def visualize(): Visualize.setup() + def main(): + def hello(): - st.image('imgs/data-juicer.png', output_format='png', use_column_width = True) + st.image('imgs/data-juicer.png', + output_format='png', + use_column_width=True) demo = 'The demo is coming soon😊' st.markdown( f'
{demo} \
', unsafe_allow_html=True, ) + Visualize.visualize() hello() + if __name__ == '__main__': main() diff --git a/demos/data_process_loop/app.py b/demos/data_process_loop/app.py index 5a2bafd3b..f5307b57e 100644 --- a/demos/data_process_loop/app.py +++ b/demos/data_process_loop/app.py @@ -136,11 +136,11 @@ def process_and_show_res(): analyzer.analysis_path = os.path.dirname( cfg_for_processed_data.export_path) + '/analysis' analyzer.run() - + overall_file = os.path.join(analyzer.analysis_path, 'overall.csv') if os.path.exists(overall_file): analysis_res_processed = pd.read_csv(overall_file) - + if os.path.exists(analyzer.analysis_path): for f_path in os.listdir(analyzer.analysis_path): if '.png' in f_path and 'all-stats' in f_path: diff --git a/demos/data_process_loop/data/demo-dataset.jsonl b/demos/data_process_loop/data/demo-dataset.jsonl index 14aa71f9a..a212d42f4 100644 --- a/demos/data_process_loop/data/demo-dataset.jsonl +++ b/demos/data_process_loop/data/demo-dataset.jsonl @@ -8,4 +8,4 @@ {"text": "Sur la plateforme MT4, plusieurs manières d'accéder à ces fonctionnalités sont conçues simultanément."} {"text": "欢迎来到阿里巴巴!"} {"text": "This paper proposed a novel method on LLM pretraining."} -{"text":"世界十大网投平台_2022年卡塔尔世界杯官网\n177-8228-4819\n网站首页\n关于我们\n产品展示\n广告牌制作 广告灯箱制作 标识牌制作 楼宇亮化工程 门头店招制作 不锈钢金属字制作 LED发光字制作 形象墙Logo墙背景墙制作 LED显示屏制作 装饰装潢工程 铜字铜牌制作 户外广告 亚克力制品 各类广告设计 建筑工地广告制作 楼顶大字制作|楼顶发光字制作 霓虹灯制作 三维扣板|3D扣板|广告扣板 房地产广告制作设计 精神堡垒|立牌|指示牌制作 大型商业喷绘写真 展览展示 印刷服务\n合作伙伴\n新闻资讯\n公司新闻 行业新闻 制作知识 设计知识\n成功案例\n技术园地\n联系方式\n"} \ No newline at end of file +{"text":"世界十大网投平台_2022年卡塔尔世界杯官网\n177-8228-4819\n网站首页\n关于我们\n产品展示\n广告牌制作 广告灯箱制作 标识牌制作 楼宇亮化工程 门头店招制作 不锈钢金属字制作 LED发光字制作 形象墙Logo墙背景墙制作 LED显示屏制作 装饰装潢工程 铜字铜牌制作 户外广告 亚克力制品 各类广告设计 建筑工地广告制作 楼顶大字制作|楼顶发光字制作 霓虹灯制作 三维扣板|3D扣板|广告扣板 房地产广告制作设计 精神堡垒|立牌|指示牌制作 大型商业喷绘写真 展览展示 印刷服务\n合作伙伴\n新闻资讯\n公司新闻 行业新闻 制作知识 设计知识\n成功案例\n技术园地\n联系方式\n"} diff --git a/demos/data_visualization_diversity/data/demo-dataset.jsonl b/demos/data_visualization_diversity/data/demo-dataset.jsonl index 3f9e7640b..9df0c2513 100644 --- a/demos/data_visualization_diversity/data/demo-dataset.jsonl +++ b/demos/data_visualization_diversity/data/demo-dataset.jsonl @@ -144,4 +144,4 @@ "input": "Radius = 4", "output": "The area of a circle with a radius of 4 is equal to 12.5664 square units. This is calculated by using the formula A = \u03c0r2, where A is the area, \u03c0 is roughly equal to 3.1416 and r is the radius of the circle." } -] \ No newline at end of file +] diff --git a/demos/data_visualization_op_effect/app.py b/demos/data_visualization_op_effect/app.py index bb9b9ebb3..9e4ff4928 100644 --- a/demos/data_visualization_op_effect/app.py +++ b/demos/data_visualization_op_effect/app.py @@ -112,7 +112,7 @@ def analyze_and_show_res(dataset_file): analysis_res_ori = pd.DataFrame() if os.path.exists(overall_file): analysis_res_ori = pd.read_csv(overall_file) - + if os.path.exists(analyzer.analysis_path): for f_path in os.listdir(analyzer.analysis_path): if '.png' in f_path and 'all-stats' in f_path: @@ -321,7 +321,7 @@ def op_effect_analyze(): @staticmethod def filter_dataset(dataset): if Fields.stats not in dataset.features: - return + return text_key = st.session_state.get('text_key', 'text') text = dataset[text_key] stats = pd.DataFrame(dataset[Fields.stats]) diff --git a/demos/data_visualization_statistics/app.py b/demos/data_visualization_statistics/app.py index eecce2856..abae47096 100644 --- a/demos/data_visualization_statistics/app.py +++ b/demos/data_visualization_statistics/app.py @@ -97,7 +97,7 @@ def analyze_and_show_res(dataset_file): analysis_res_ori = pd.DataFrame() if os.path.exists(overall_file): analysis_res_ori = pd.read_csv(overall_file) - + if os.path.exists(analyzer.analysis_path): for f_path in os.listdir(analyzer.analysis_path): if '.png' in f_path and 'all-stats' in f_path: diff --git a/demos/overview_scan/app.py b/demos/overview_scan/app.py index 1050cc353..d1b109ac8 100644 --- a/demos/overview_scan/app.py +++ b/demos/overview_scan/app.py @@ -119,8 +119,8 @@ | clean_links_mapper | General, Code | en, zh | Removes links, such as those starting with http or ftp | | expand_macro_mapper | LaTeX | en, zh | Expands macros usually defined at the top of TeX documents | | fix_unicode_mapper | General | en, zh | Fixes broken Unicodes (by [ftfy](https://ftfy.readthedocs.io/)) | -| nlpaug_en_mapper | General | en | Simply augment texts in English based on the `nlpaug` library | -| nlpcda_zh_mapper | General | zh | Simply augment texts in Chinese based on the `nlpcda` library | +| nlpaug_en_mapper | General | en | Simply augment texts in English based on the `nlpaug` library | +| nlpcda_zh_mapper | General | zh | Simply augment texts in Chinese based on the `nlpcda` library | | punctuation_normalization_mapper | General | en, zh | Normalizes various Unicode punctuations to their ASCII equivalents | | remove_bibliography_mapper | LaTeX | en, zh | Removes the bibliography of TeX documents | | remove_comments_mapper | LaTeX | en, zh | Removes the comments of TeX documents | @@ -215,11 +215,12 @@ def run_demo(): analysis_res_ori = pd.DataFrame() if os.path.exists(overall_file): analysis_res_ori = pd.read_csv(overall_file) - + if os.path.exists(analyzer.analysis_path): for f_path in os.listdir(analyzer.analysis_path): if '.png' in f_path and 'all-stats' in f_path: - images_ori.append(os.path.join(analyzer.analysis_path, f_path)) + images_ori.append( + os.path.join(analyzer.analysis_path, f_path)) st.subheader('Statistics') st.dataframe(analysis_res_ori, use_container_width=True) diff --git a/demos/overview_scan/data/demo-dataset.jsonl b/demos/overview_scan/data/demo-dataset.jsonl index 14aa71f9a..a212d42f4 100644 --- a/demos/overview_scan/data/demo-dataset.jsonl +++ b/demos/overview_scan/data/demo-dataset.jsonl @@ -8,4 +8,4 @@ {"text": "Sur la plateforme MT4, plusieurs manières d'accéder à ces fonctionnalités sont conçues simultanément."} {"text": "欢迎来到阿里巴巴!"} {"text": "This paper proposed a novel method on LLM pretraining."} -{"text":"世界十大网投平台_2022年卡塔尔世界杯官网\n177-8228-4819\n网站首页\n关于我们\n产品展示\n广告牌制作 广告灯箱制作 标识牌制作 楼宇亮化工程 门头店招制作 不锈钢金属字制作 LED发光字制作 形象墙Logo墙背景墙制作 LED显示屏制作 装饰装潢工程 铜字铜牌制作 户外广告 亚克力制品 各类广告设计 建筑工地广告制作 楼顶大字制作|楼顶发光字制作 霓虹灯制作 三维扣板|3D扣板|广告扣板 房地产广告制作设计 精神堡垒|立牌|指示牌制作 大型商业喷绘写真 展览展示 印刷服务\n合作伙伴\n新闻资讯\n公司新闻 行业新闻 制作知识 设计知识\n成功案例\n技术园地\n联系方式\n"} \ No newline at end of file +{"text":"世界十大网投平台_2022年卡塔尔世界杯官网\n177-8228-4819\n网站首页\n关于我们\n产品展示\n广告牌制作 广告灯箱制作 标识牌制作 楼宇亮化工程 门头店招制作 不锈钢金属字制作 LED发光字制作 形象墙Logo墙背景墙制作 LED显示屏制作 装饰装潢工程 铜字铜牌制作 户外广告 亚克力制品 各类广告设计 建筑工地广告制作 楼顶大字制作|楼顶发光字制作 霓虹灯制作 三维扣板|3D扣板|广告扣板 房地产广告制作设计 精神堡垒|立牌|指示牌制作 大型商业喷绘写真 展览展示 印刷服务\n合作伙伴\n新闻资讯\n公司新闻 行业新闻 制作知识 设计知识\n成功案例\n技术园地\n联系方式\n"} diff --git a/demos/process_cft_zh_data/data/alpaca-cot.jsonl b/demos/process_cft_zh_data/data/alpaca-cot.jsonl index 16b2cc4c3..85a5912f6 100644 --- a/demos/process_cft_zh_data/data/alpaca-cot.jsonl +++ b/demos/process_cft_zh_data/data/alpaca-cot.jsonl @@ -8,4 +8,4 @@ {"instruction":"为一组电影列表添加与电影类型相关的标签。","input":"某个电影推荐网站的前100名电影列表,包括电影名称和类型。","output":"添加标签:\n\n1. 动作片:XXX、XXX、XXX...\n2. 喜剧片:XXX、XXX、XXX...\n3. 科幻片:XXX、XXX、XXX...\n4. 恐怖片:XXX、XXX、XXX...\n5. 爱情片:XXX、XXX、XXX...\n6. 剧情片:XXX、XXX、XXX...\n7. 纪录片:XXX、XXX、XXX...\n8. 动画片:XXX、XXX、XXX...\n9. 战争片:XXX、XXX、XXX...\n10. 悬疑片:XXX、XXX、XXX...","meta":{"Dataset":"belle_cn","Gen":"SI","IFT":false,"Lang":"CN", "CFT-MR":false, "CFT-P":false, "CFT-SR":true,"Task":"TS\/MT","origin_path":"Alpaca-CoT\/belle_cn\/belle_instructions\/belle_data1M_cn.jsonl"},"text":"为一组电影列表添加与电影类型相关的标签。 某个电影推荐网站的前100名电影列表,包括电影名称和类型。 添加标签:\n\n1. 动作片:XXX、XXX、XXX...\n2. 喜剧片:XXX、XXX、XXX...\n3. 科幻片:XXX、XXX、XXX...\n4. 恐怖片:XXX、XXX、XXX...\n5. 爱情片:XXX、XXX、XXX...\n6. 剧情片:XXX、XXX、XXX...\n7. 纪录片:XXX、XXX、XXX...\n8. 动画片:XXX、XXX、XXX...\n9. 战争片:XXX、XXX、XXX...\n10. 悬疑片:XXX、XXX、XXX..."} {"instruction":"磷","input":"","output":"磷","meta":{"Dataset":"ShareGPT","Gen":"MIX","IFT":false,"Lang":"EN", "CFT-MR":true, "CFT-P":false, "CFT-SR":true,"Task":"MT","origin_path":"Alpaca-CoT\/ShareGPT\/sharegpt_zh.jsonl"},"text":"磷 磷"} {"instruction":"酶","input":"","output":"酶","meta":{"Dataset":"ShareGPT","Gen":"MIX","IFT":false,"Lang":"EN", "CFT-MR":true, "CFT-P":false, "CFT-SR":true,"Task":"MT","origin_path":"Alpaca-CoT\/ShareGPT\/sharegpt_zh.jsonl"},"text":"酶 酶"} -{"instruction":"输出一个正方形","input":"","output":"","meta":{"Dataset":"Guanaco","Gen":"SI","IFT":false,"Lang":"ML", "CFT-MR":false, "CFT-P":false, "CFT-SR":true,"Task":"MT","origin_path":"Alpaca-CoT\/Guanaco\/GuanacoDataset.jsonl"},"text":"输出一个正方形 "} \ No newline at end of file +{"instruction":"输出一个正方形","input":"","output":"","meta":{"Dataset":"Guanaco","Gen":"SI","IFT":false,"Lang":"ML", "CFT-MR":false, "CFT-P":false, "CFT-SR":true,"Task":"MT","origin_path":"Alpaca-CoT\/Guanaco\/GuanacoDataset.jsonl"},"text":"输出一个正方形 "} diff --git a/demos/process_on_ray/data/demo-dataset.json b/demos/process_on_ray/data/demo-dataset.json index 14aa71f9a..a212d42f4 100644 --- a/demos/process_on_ray/data/demo-dataset.json +++ b/demos/process_on_ray/data/demo-dataset.json @@ -8,4 +8,4 @@ {"text": "Sur la plateforme MT4, plusieurs manières d'accéder à ces fonctionnalités sont conçues simultanément."} {"text": "欢迎来到阿里巴巴!"} {"text": "This paper proposed a novel method on LLM pretraining."} -{"text":"世界十大网投平台_2022年卡塔尔世界杯官网\n177-8228-4819\n网站首页\n关于我们\n产品展示\n广告牌制作 广告灯箱制作 标识牌制作 楼宇亮化工程 门头店招制作 不锈钢金属字制作 LED发光字制作 形象墙Logo墙背景墙制作 LED显示屏制作 装饰装潢工程 铜字铜牌制作 户外广告 亚克力制品 各类广告设计 建筑工地广告制作 楼顶大字制作|楼顶发光字制作 霓虹灯制作 三维扣板|3D扣板|广告扣板 房地产广告制作设计 精神堡垒|立牌|指示牌制作 大型商业喷绘写真 展览展示 印刷服务\n合作伙伴\n新闻资讯\n公司新闻 行业新闻 制作知识 设计知识\n成功案例\n技术园地\n联系方式\n"} \ No newline at end of file +{"text":"世界十大网投平台_2022年卡塔尔世界杯官网\n177-8228-4819\n网站首页\n关于我们\n产品展示\n广告牌制作 广告灯箱制作 标识牌制作 楼宇亮化工程 门头店招制作 不锈钢金属字制作 LED发光字制作 形象墙Logo墙背景墙制作 LED显示屏制作 装饰装潢工程 铜字铜牌制作 户外广告 亚克力制品 各类广告设计 建筑工地广告制作 楼顶大字制作|楼顶发光字制作 霓虹灯制作 三维扣板|3D扣板|广告扣板 房地产广告制作设计 精神堡垒|立牌|指示牌制作 大型商业喷绘写真 展览展示 印刷服务\n合作伙伴\n新闻资讯\n公司新闻 行业新闻 制作知识 设计知识\n成功案例\n技术园地\n联系方式\n"} diff --git a/demos/tool_dataset_splitting_by_language/data/demo-dataset.jsonl b/demos/tool_dataset_splitting_by_language/data/demo-dataset.jsonl index 14aa71f9a..a212d42f4 100644 --- a/demos/tool_dataset_splitting_by_language/data/demo-dataset.jsonl +++ b/demos/tool_dataset_splitting_by_language/data/demo-dataset.jsonl @@ -8,4 +8,4 @@ {"text": "Sur la plateforme MT4, plusieurs manières d'accéder à ces fonctionnalités sont conçues simultanément."} {"text": "欢迎来到阿里巴巴!"} {"text": "This paper proposed a novel method on LLM pretraining."} -{"text":"世界十大网投平台_2022年卡塔尔世界杯官网\n177-8228-4819\n网站首页\n关于我们\n产品展示\n广告牌制作 广告灯箱制作 标识牌制作 楼宇亮化工程 门头店招制作 不锈钢金属字制作 LED发光字制作 形象墙Logo墙背景墙制作 LED显示屏制作 装饰装潢工程 铜字铜牌制作 户外广告 亚克力制品 各类广告设计 建筑工地广告制作 楼顶大字制作|楼顶发光字制作 霓虹灯制作 三维扣板|3D扣板|广告扣板 房地产广告制作设计 精神堡垒|立牌|指示牌制作 大型商业喷绘写真 展览展示 印刷服务\n合作伙伴\n新闻资讯\n公司新闻 行业新闻 制作知识 设计知识\n成功案例\n技术园地\n联系方式\n"} \ No newline at end of file +{"text":"世界十大网投平台_2022年卡塔尔世界杯官网\n177-8228-4819\n网站首页\n关于我们\n产品展示\n广告牌制作 广告灯箱制作 标识牌制作 楼宇亮化工程 门头店招制作 不锈钢金属字制作 LED发光字制作 形象墙Logo墙背景墙制作 LED显示屏制作 装饰装潢工程 铜字铜牌制作 户外广告 亚克力制品 各类广告设计 建筑工地广告制作 楼顶大字制作|楼顶发光字制作 霓虹灯制作 三维扣板|3D扣板|广告扣板 房地产广告制作设计 精神堡垒|立牌|指示牌制作 大型商业喷绘写真 展览展示 印刷服务\n合作伙伴\n新闻资讯\n公司新闻 行业新闻 制作知识 设计知识\n成功案例\n技术园地\n联系方式\n"} diff --git a/demos/tool_quality_classifier/data/demo-dataset.jsonl b/demos/tool_quality_classifier/data/demo-dataset.jsonl index 14aa71f9a..a212d42f4 100644 --- a/demos/tool_quality_classifier/data/demo-dataset.jsonl +++ b/demos/tool_quality_classifier/data/demo-dataset.jsonl @@ -8,4 +8,4 @@ {"text": "Sur la plateforme MT4, plusieurs manières d'accéder à ces fonctionnalités sont conçues simultanément."} {"text": "欢迎来到阿里巴巴!"} {"text": "This paper proposed a novel method on LLM pretraining."} -{"text":"世界十大网投平台_2022年卡塔尔世界杯官网\n177-8228-4819\n网站首页\n关于我们\n产品展示\n广告牌制作 广告灯箱制作 标识牌制作 楼宇亮化工程 门头店招制作 不锈钢金属字制作 LED发光字制作 形象墙Logo墙背景墙制作 LED显示屏制作 装饰装潢工程 铜字铜牌制作 户外广告 亚克力制品 各类广告设计 建筑工地广告制作 楼顶大字制作|楼顶发光字制作 霓虹灯制作 三维扣板|3D扣板|广告扣板 房地产广告制作设计 精神堡垒|立牌|指示牌制作 大型商业喷绘写真 展览展示 印刷服务\n合作伙伴\n新闻资讯\n公司新闻 行业新闻 制作知识 设计知识\n成功案例\n技术园地\n联系方式\n"} \ No newline at end of file +{"text":"世界十大网投平台_2022年卡塔尔世界杯官网\n177-8228-4819\n网站首页\n关于我们\n产品展示\n广告牌制作 广告灯箱制作 标识牌制作 楼宇亮化工程 门头店招制作 不锈钢金属字制作 LED发光字制作 形象墙Logo墙背景墙制作 LED显示屏制作 装饰装潢工程 铜字铜牌制作 户外广告 亚克力制品 各类广告设计 建筑工地广告制作 楼顶大字制作|楼顶发光字制作 霓虹灯制作 三维扣板|3D扣板|广告扣板 房地产广告制作设计 精神堡垒|立牌|指示牌制作 大型商业喷绘写真 展览展示 印刷服务\n合作伙伴\n新闻资讯\n公司新闻 行业新闻 制作知识 设计知识\n成功案例\n技术园地\n联系方式\n"} diff --git a/demos/tool_quality_classifier/quality_classifier/qc_utils.py b/demos/tool_quality_classifier/quality_classifier/qc_utils.py index 862e6f1bd..86187acbd 100644 --- a/demos/tool_quality_classifier/quality_classifier/qc_utils.py +++ b/demos/tool_quality_classifier/quality_classifier/qc_utils.py @@ -186,6 +186,7 @@ def eval(model_path, ds, tokenizer=None): logger.info(f'FP: {FP}, TN: {TN}') logger.info(f'P: {precision}, R: {recall}, F1: {F1}') + def predict(model, ds, tokenizer=None, keep_method='label'): logger.info('Start scoring dataset...') if tokenizer: diff --git a/setup.py b/setup.py index dd237c9ea..741b2e29a 100644 --- a/setup.py +++ b/setup.py @@ -58,8 +58,8 @@ def get_install_requirements(require_f_paths, env_dir='environments'): long_description=readme_md, long_description_content_type='text/markdown', license='Apache License 2.0', - packages=setuptools.find_packages(exclude=['tests*', 'tools*']) - + list(get_package_dir().keys()), + packages=setuptools.find_packages(exclude=['tests*', 'tools*']) + + list(get_package_dir().keys()), package_dir=get_package_dir(), entry_points={ 'console_scripts': [ diff --git a/tests/format/test_unify_format.py b/tests/format/test_unify_format.py index 7b561bee4..2f64d0dcf 100644 --- a/tests/format/test_unify_format.py +++ b/tests/format/test_unify_format.py @@ -53,8 +53,7 @@ def test_text_key(self): ] self.run_test(samples[0]) self.run_test(samples[1], args={'text_keys': ['content']}) - self.run_test(samples[2], - args={'text_keys': ['input', 'instruction']}) + self.run_test(samples[2], args={'text_keys': ['input', 'instruction']}) def test_empty_text(self): # filter out samples containing None field, but '' is OK diff --git a/tests/ops/filter/test_character_repetition_filter.py b/tests/ops/filter/test_character_repetition_filter.py index 1d5683c82..b54d76a71 100644 --- a/tests/ops/filter/test_character_repetition_filter.py +++ b/tests/ops/filter/test_character_repetition_filter.py @@ -41,9 +41,7 @@ def test_case(self): 'text': '中文也是一个字算一个长度' }] dataset = Dataset.from_list(ds_list) - op = CharacterRepetitionFilter(rep_len=5, - min_ratio=0.0, - max_ratio=0.4) + op = CharacterRepetitionFilter(rep_len=5, min_ratio=0.0, max_ratio=0.4) self._run_character_repetition_filter(dataset, tgt_list, op) diff --git a/tests/ops/filter/test_token_num_filter.py b/tests/ops/filter/test_token_num_filter.py index ab1efaeb6..a830e91fe 100644 --- a/tests/ops/filter/test_token_num_filter.py +++ b/tests/ops/filter/test_token_num_filter.py @@ -10,16 +10,35 @@ class WordNumFilterTest(unittest.TestCase): def test_token_num(self): src = [ - {"text": "Today is Sunday and it's a happy day!"}, - {"text": "Do you need a cup of coffee?"}, - {"text": "你好,请问你是谁"}, - {"text": "Sur la plateforme MT4, plusieurs manières d'accéder à " - "ces fonctionnalités sont conçues simultanément."}, - {"text": "欢迎来到阿里巴巴!"}, - {"text": "This paper proposed a novel method on LLM pretraining."}, + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': 'Do you need a cup of coffee?' + }, + { + 'text': '你好,请问你是谁' + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à " + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + { + 'text': + 'This paper proposed a novel method on LLM pretraining.' + }, ] tgt = [ - 10, 8, 9, 31, 14, 12, + 10, + 8, + 9, + 31, + 14, + 12, ] ds = Dataset.from_list(src) op = TokenNumFilter() @@ -30,18 +49,39 @@ def test_token_num(self): def test_token_num_filter(self): src = [ - {"text": "Today is Sunday and it's a happy day!"}, - {"text": "Do you need a cup of coffee?"}, - {"text": "你好,请问你是谁"}, - {"text": "Sur la plateforme MT4, plusieurs manières d'accéder à " - "ces fonctionnalités sont conçues simultanément."}, - {"text": "欢迎来到阿里巴巴!"}, - {"text": "This paper proposed a novel method on LLM pretraining."}, + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': 'Do you need a cup of coffee?' + }, + { + 'text': '你好,请问你是谁' + }, + { + 'text': + "Sur la plateforme MT4, plusieurs manières d'accéder à " + 'ces fonctionnalités sont conçues simultanément.' + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + { + 'text': + 'This paper proposed a novel method on LLM pretraining.' + }, ] tgt = [ - {"text": "Today is Sunday and it's a happy day!"}, - {"text": "欢迎来到阿里巴巴!"}, - {"text": "This paper proposed a novel method on LLM pretraining."}, + { + 'text': "Today is Sunday and it's a happy day!" + }, + { + 'text': '欢迎来到阿里巴巴!' + }, + { + 'text': + 'This paper proposed a novel method on LLM pretraining.' + }, ] ds = Dataset.from_list(src) tgt = Dataset.from_list(tgt) diff --git a/tests/ops/mapper/test_nlpaug_en_mapper.py b/tests/ops/mapper/test_nlpaug_en_mapper.py index 34e8e0c79..628a8f57c 100644 --- a/tests/ops/mapper/test_nlpaug_en_mapper.py +++ b/tests/ops/mapper/test_nlpaug_en_mapper.py @@ -12,10 +12,7 @@ def setUp(self): 'I am a deep learning engineer. I love LLM.', 'A short test with numbers 2023' ], - 'meta': [ - 'meta information', - 'meta information with numbers' - ], + 'meta': ['meta information', 'meta information with numbers'], }) def test_number_of_generated_samples_with_sequential_on(self): diff --git a/tests/ops/mapper/test_nlpcda_zh_mapper.py b/tests/ops/mapper/test_nlpcda_zh_mapper.py index 3d9a8b3dd..0f4437794 100644 --- a/tests/ops/mapper/test_nlpcda_zh_mapper.py +++ b/tests/ops/mapper/test_nlpcda_zh_mapper.py @@ -3,6 +3,7 @@ from data_juicer.core import NestedDataset from data_juicer.ops.mapper.nlpcda_zh_mapper import NlpcdaZhMapper + class NlpaugEnMapperTest(unittest.TestCase): def setUp(self): @@ -68,9 +69,9 @@ def test_number_of_generated_samples_with_sequential_off(self): ) self.assertEqual(len(op.aug_pipeline), aug_method_num) result = self.samples.map(op.process) - self.assertLessEqual( - len(result['text']), - (aug_num * aug_method_num + 1) * len(self.samples['text'])) + self.assertLessEqual(len(result['text']), + (aug_num * aug_method_num + 1) * + len(self.samples['text'])) self.assertGreaterEqual(len(result['text']), len(self.samples['text'])) self.assertEqual(len(result['meta']), len(result['text'])) @@ -135,9 +136,9 @@ def test_all_aug_methods_with_sequential_off(self): ) self.assertEqual(len(op.aug_pipeline), aug_method_num) result = self.samples.map(op.process) - self.assertLessEqual( - len(result['text']), - (aug_num * aug_method_num + 1) * len(self.samples['text'])) + self.assertLessEqual(len(result['text']), + (aug_num * aug_method_num + 1) * + len(self.samples['text'])) self.assertGreaterEqual(len(result['text']), len(self.samples['text'])) self.assertEqual(len(result['meta']), len(result['text'])) diff --git a/tests/ops/test_op_fusion.py b/tests/ops/test_op_fusion.py index 18c36d14c..5fc46d9bb 100644 --- a/tests/ops/test_op_fusion.py +++ b/tests/ops/test_op_fusion.py @@ -2,6 +2,7 @@ from data_juicer.ops.load import load_ops + class OpFusionTest(unittest.TestCase): def _run_op_fusion(self, original_process_list, target_process_list): @@ -108,106 +109,124 @@ def test_regular_config(self): 'window_size': 6 } }] - target_process = [{ - 'language_id_score_filter': { - 'lang': 'en', - 'min_score': 0.8, - 'text_key': 'text' - } - }, { - 'whitespace_normalization_mapper': { - 'text_key': 'text' - } - }, { - 'punctuation_normalization_mapper': { - 'text_key': 'text' - } - }, { - 'fix_unicode_mapper': { - 'text_key': 'text' - } - }, { - 'remove_words_with_incorrect_substrings_mapper': { - 'lang': 'en', - 'substrings': None, - 'text_key': 'text', - 'tokenization': False - } - }, { - 'remove_long_words_mapper': { - 'max_len': 25, - 'min_len': 1, - 'text_key': 'text' - } - }, { - 'character_repetition_filter': { - 'max_ratio': 0.106, - 'min_ratio': 0.0, - 'rep_len': 10, - 'text_key': 'text' - } - }, { - 'special_characters_filter': { - 'max_ratio': 0.4, - 'min_ratio': 0.0, - 'text_key': 'text' - } - }, { - 'OpFusion:(words_num_filter,word_repetition_filter,stopwords_filter,flagged_words_filter,perplexity_filter)': [{ # noqa: E501 - 'words_num_filter': { + target_process = [ + { + 'language_id_score_filter': { 'lang': 'en', - 'max_num': 100000, - 'min_num': 20, - 'text_key': 'text', - 'tokenization': False + 'min_score': 0.8, + 'text_key': 'text' + } + }, + { + 'whitespace_normalization_mapper': { + 'text_key': 'text' + } + }, + { + 'punctuation_normalization_mapper': { + 'text_key': 'text' + } + }, + { + 'fix_unicode_mapper': { + 'text_key': 'text' } - }, { - 'word_repetition_filter': { + }, + { + 'remove_words_with_incorrect_substrings_mapper': { 'lang': 'en', - 'max_ratio': 0.19, - 'min_ratio': 0.0, - 'rep_len': 5, + 'substrings': None, 'text_key': 'text', 'tokenization': False } - }, { - 'stopwords_filter': { - 'lang': 'en', - 'min_ratio': 0.3, - 'text_key': 'text', - 'tokenization': False, - 'use_words_aug': False, - 'words_aug_group_sizes': [2], - 'words_aug_join_char': '' + }, + { + 'remove_long_words_mapper': { + 'max_len': 25, + 'min_len': 1, + 'text_key': 'text' } - }, { - 'flagged_words_filter': { - 'lang': 'en', - 'max_ratio': 0.01, - 'text_key': 'text', - 'tokenization': False, - 'use_words_aug': False, - 'words_aug_group_sizes': [2], - 'words_aug_join_char': '' + }, + { + 'character_repetition_filter': { + 'max_ratio': 0.106, + 'min_ratio': 0.0, + 'rep_len': 10, + 'text_key': 'text' } - }, { - 'perplexity_filter': { - 'lang': 'en', - 'max_ppl': 1500, + }, + { + 'special_characters_filter': { + 'max_ratio': 0.4, + 'min_ratio': 0.0, 'text_key': 'text' } - }] - }, { - 'document_simhash_deduplicator': { - 'hamming_distance': 4, - 'ignore_pattern': '\\p{P}', - 'lowercase': True, - 'num_blocks': 6, - 'text_key': 'text', - 'tokenization': 'space', - 'window_size': 6 + }, + { + 'OpFusion:(words_num_filter,word_repetition_filter,stopwords_filter,flagged_words_filter,perplexity_filter)': + [ + { # noqa: E501 + 'words_num_filter': { + 'lang': 'en', + 'max_num': 100000, + 'min_num': 20, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'word_repetition_filter': { + 'lang': 'en', + 'max_ratio': 0.19, + 'min_ratio': 0.0, + 'rep_len': 5, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'stopwords_filter': { + 'lang': 'en', + 'min_ratio': 0.3, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, + { + 'flagged_words_filter': { + 'lang': 'en', + 'max_ratio': 0.01, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, + { + 'perplexity_filter': { + 'lang': 'en', + 'max_ppl': 1500, + 'text_key': 'text' + } + } + ] + }, + { + 'document_simhash_deduplicator': { + 'hamming_distance': 4, + 'ignore_pattern': '\\p{P}', + 'lowercase': True, + 'num_blocks': 6, + 'text_key': 'text', + 'tokenization': 'space', + 'window_size': 6 + } } - }] + ] self._run_op_fusion(original_process, target_process) def test_only_mapper(self): @@ -524,108 +543,125 @@ def test_multiple_groups(self): 'window_size': 6 } }] - target_process = [{ - 'language_id_score_filter': { - 'lang': 'en', - 'min_score': 0.8, - 'text_key': 'text' - } - }, { - 'OpFusion:(stopwords_filter,flagged_words_filter)': [{ - 'stopwords_filter': { + target_process = [ + { + 'language_id_score_filter': { 'lang': 'en', - 'min_ratio': 0.3, - 'text_key': 'text', - 'tokenization': False, - 'use_words_aug': False, - 'words_aug_group_sizes': [2], - 'words_aug_join_char': '' + 'min_score': 0.8, + 'text_key': 'text' } - }, { - 'flagged_words_filter': { - 'lang': 'en', - 'max_ratio': 0.01, - 'text_key': 'text', - 'tokenization': False, - 'use_words_aug': False, - 'words_aug_group_sizes': [2], - 'words_aug_join_char': '' + }, + { + 'OpFusion:(stopwords_filter,flagged_words_filter)': [{ + 'stopwords_filter': { + 'lang': 'en', + 'min_ratio': 0.3, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, { + 'flagged_words_filter': { + 'lang': 'en', + 'max_ratio': 0.01, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }] + }, + { + 'whitespace_normalization_mapper': { + 'text_key': 'text' } - }] - }, { - 'whitespace_normalization_mapper': { - 'text_key': 'text' - } - }, { - 'punctuation_normalization_mapper': { - 'text_key': 'text' - } - }, { - 'fix_unicode_mapper': { - 'text_key': 'text' - } - }, { - 'remove_words_with_incorrect_substrings_mapper': { - 'lang': 'en', - 'substrings': None, - 'text_key': 'text', - 'tokenization': False - } - }, { - 'remove_long_words_mapper': { - 'max_len': 25, - 'min_len': 1, - 'text_key': 'text' - } - }, { - 'character_repetition_filter': { - 'max_ratio': 0.106, - 'min_ratio': 0.0, - 'rep_len': 10, - 'text_key': 'text' - } - }, { - 'special_characters_filter': { - 'max_ratio': 0.4, - 'min_ratio': 0.0, - 'text_key': 'text' - } - }, { - 'OpFusion:(words_num_filter,word_repetition_filter,perplexity_filter)': [{ # noqa: E501 - 'words_num_filter': { + }, + { + 'punctuation_normalization_mapper': { + 'text_key': 'text' + } + }, + { + 'fix_unicode_mapper': { + 'text_key': 'text' + } + }, + { + 'remove_words_with_incorrect_substrings_mapper': { 'lang': 'en', - 'max_num': 100000, - 'min_num': 20, + 'substrings': None, 'text_key': 'text', 'tokenization': False } - }, { - 'word_repetition_filter': { - 'lang': 'en', - 'max_ratio': 0.19, + }, + { + 'remove_long_words_mapper': { + 'max_len': 25, + 'min_len': 1, + 'text_key': 'text' + } + }, + { + 'character_repetition_filter': { + 'max_ratio': 0.106, 'min_ratio': 0.0, - 'rep_len': 5, - 'text_key': 'text', - 'tokenization': False + 'rep_len': 10, + 'text_key': 'text' } - }, { - 'perplexity_filter': { - 'lang': 'en', - 'max_ppl': 1500, + }, + { + 'special_characters_filter': { + 'max_ratio': 0.4, + 'min_ratio': 0.0, 'text_key': 'text' } - }] - }, { - 'document_simhash_deduplicator': { - 'hamming_distance': 4, - 'ignore_pattern': '\\p{P}', - 'lowercase': True, - 'num_blocks': 6, - 'text_key': 'text', - 'tokenization': 'space', - 'window_size': 6 + }, + { + 'OpFusion:(words_num_filter,word_repetition_filter,perplexity_filter)': + [ + { # noqa: E501 + 'words_num_filter': { + 'lang': 'en', + 'max_num': 100000, + 'min_num': 20, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'word_repetition_filter': { + 'lang': 'en', + 'max_ratio': 0.19, + 'min_ratio': 0.0, + 'rep_len': 5, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'perplexity_filter': { + 'lang': 'en', + 'max_ppl': 1500, + 'text_key': 'text' + } + } + ] + }, + { + 'document_simhash_deduplicator': { + 'hamming_distance': 4, + 'ignore_pattern': '\\p{P}', + 'lowercase': True, + 'num_blocks': 6, + 'text_key': 'text', + 'tokenization': 'space', + 'window_size': 6 + } } - }] + ] self._run_op_fusion(original_process, target_process) def test_only_fusible_ops(self): @@ -675,50 +711,57 @@ def test_only_fusible_ops(self): } }] target_process = [{ - 'OpFusion:(words_num_filter,word_repetition_filter,stopwords_filter,flagged_words_filter,perplexity_filter)': [{ # noqa: E501 - 'words_num_filter': { - 'lang': 'en', - 'max_num': 100000, - 'min_num': 20, - 'text_key': 'text', - 'tokenization': False + 'OpFusion:(words_num_filter,word_repetition_filter,stopwords_filter,flagged_words_filter,perplexity_filter)': + [ + { # noqa: E501 + 'words_num_filter': { + 'lang': 'en', + 'max_num': 100000, + 'min_num': 20, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'word_repetition_filter': { + 'lang': 'en', + 'max_ratio': 0.19, + 'min_ratio': 0.0, + 'rep_len': 5, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'stopwords_filter': { + 'lang': 'en', + 'min_ratio': 0.3, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, + { + 'flagged_words_filter': { + 'lang': 'en', + 'max_ratio': 0.01, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, + { + 'perplexity_filter': { + 'lang': 'en', + 'max_ppl': 1500, + 'text_key': 'text' + } } - }, { - 'word_repetition_filter': { - 'lang': 'en', - 'max_ratio': 0.19, - 'min_ratio': 0.0, - 'rep_len': 5, - 'text_key': 'text', - 'tokenization': False - } - }, { - 'stopwords_filter': { - 'lang': 'en', - 'min_ratio': 0.3, - 'text_key': 'text', - 'tokenization': False, - 'use_words_aug': False, - 'words_aug_group_sizes': [2], - 'words_aug_join_char': '' - } - }, { - 'flagged_words_filter': { - 'lang': 'en', - 'max_ratio': 0.01, - 'text_key': 'text', - 'tokenization': False, - 'use_words_aug': False, - 'words_aug_group_sizes': [2], - 'words_aug_join_char': '' - } - }, { - 'perplexity_filter': { - 'lang': 'en', - 'max_ppl': 1500, - 'text_key': 'text' - } - }] + ] }] self._run_op_fusion(original_process, target_process) @@ -832,118 +875,141 @@ def test_different_intermediate_vars(self): 'window_size': 6 } }] - target_process = [{ - 'language_id_score_filter': { - 'lang': 'en', - 'min_score': 0.8, - 'text_key': 'text' - } - }, { - 'whitespace_normalization_mapper': { - 'text_key': 'text' - } - }, { - 'punctuation_normalization_mapper': { - 'text_key': 'text' - } - }, { - 'fix_unicode_mapper': { - 'text_key': 'text' - } - }, { - 'remove_words_with_incorrect_substrings_mapper': { - 'lang': 'en', - 'substrings': None, - 'text_key': 'text', - 'tokenization': False - } - }, { - 'remove_long_words_mapper': { - 'max_len': 25, - 'min_len': 1, - 'text_key': 'text' - } - }, { - 'character_repetition_filter': { - 'max_ratio': 0.106, - 'min_ratio': 0.0, - 'rep_len': 10, - 'text_key': 'text' - } - }, { - 'special_characters_filter': { - 'max_ratio': 0.4, - 'min_ratio': 0.0, - 'text_key': 'text' - } - }, { - 'OpFusion:(average_line_length_filter,maximum_line_length_filter)': [{ # noqa: E501 - 'average_line_length_filter': { - 'min_len': 10, - 'text_key': 'text', + target_process = [ + { + 'language_id_score_filter': { + 'lang': 'en', + 'min_score': 0.8, + 'text_key': 'text' } - }, { - 'maximum_line_length_filter': { - 'min_len': 20, - 'text_key': 'text', + }, + { + 'whitespace_normalization_mapper': { + 'text_key': 'text' } - }] - }, { - 'OpFusion:(words_num_filter,word_repetition_filter,stopwords_filter,flagged_words_filter,perplexity_filter)': [{ # noqa: E501 - 'words_num_filter': { - 'lang': 'en', - 'max_num': 100000, - 'min_num': 20, - 'text_key': 'text', - 'tokenization': False + }, + { + 'punctuation_normalization_mapper': { + 'text_key': 'text' + } + }, + { + 'fix_unicode_mapper': { + 'text_key': 'text' } - }, { - 'word_repetition_filter': { + }, + { + 'remove_words_with_incorrect_substrings_mapper': { 'lang': 'en', - 'max_ratio': 0.19, - 'min_ratio': 0.0, - 'rep_len': 5, + 'substrings': None, 'text_key': 'text', 'tokenization': False } - }, { - 'stopwords_filter': { - 'lang': 'en', - 'min_ratio': 0.3, - 'text_key': 'text', - 'tokenization': False, - 'use_words_aug': False, - 'words_aug_group_sizes': [2], - 'words_aug_join_char': '' + }, + { + 'remove_long_words_mapper': { + 'max_len': 25, + 'min_len': 1, + 'text_key': 'text' } - }, { - 'flagged_words_filter': { - 'lang': 'en', - 'max_ratio': 0.01, - 'text_key': 'text', - 'tokenization': False, - 'use_words_aug': False, - 'words_aug_group_sizes': [2], - 'words_aug_join_char': '' + }, + { + 'character_repetition_filter': { + 'max_ratio': 0.106, + 'min_ratio': 0.0, + 'rep_len': 10, + 'text_key': 'text' } - }, { - 'perplexity_filter': { - 'lang': 'en', - 'max_ppl': 1500, + }, + { + 'special_characters_filter': { + 'max_ratio': 0.4, + 'min_ratio': 0.0, 'text_key': 'text' } - }] - }, { - 'document_simhash_deduplicator': { - 'hamming_distance': 4, - 'ignore_pattern': '\\p{P}', - 'lowercase': True, - 'num_blocks': 6, - 'text_key': 'text', - 'tokenization': 'space', - 'window_size': 6 + }, + { + 'OpFusion:(average_line_length_filter,maximum_line_length_filter)': + [ + { # noqa: E501 + 'average_line_length_filter': { + 'min_len': 10, + 'text_key': 'text', + } + }, + { + 'maximum_line_length_filter': { + 'min_len': 20, + 'text_key': 'text', + } + } + ] + }, + { + 'OpFusion:(words_num_filter,word_repetition_filter,stopwords_filter,flagged_words_filter,perplexity_filter)': + [ + { # noqa: E501 + 'words_num_filter': { + 'lang': 'en', + 'max_num': 100000, + 'min_num': 20, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'word_repetition_filter': { + 'lang': 'en', + 'max_ratio': 0.19, + 'min_ratio': 0.0, + 'rep_len': 5, + 'text_key': 'text', + 'tokenization': False + } + }, + { + 'stopwords_filter': { + 'lang': 'en', + 'min_ratio': 0.3, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, + { + 'flagged_words_filter': { + 'lang': 'en', + 'max_ratio': 0.01, + 'text_key': 'text', + 'tokenization': False, + 'use_words_aug': False, + 'words_aug_group_sizes': [2], + 'words_aug_join_char': '' + } + }, + { + 'perplexity_filter': { + 'lang': 'en', + 'max_ppl': 1500, + 'text_key': 'text' + } + } + ] + }, + { + 'document_simhash_deduplicator': { + 'hamming_distance': 4, + 'ignore_pattern': '\\p{P}', + 'lowercase': True, + 'num_blocks': 6, + 'text_key': 'text', + 'tokenization': 'space', + 'window_size': 6 + } } - }] + ] self._run_op_fusion(original_process, target_process) diff --git a/tools/converter/convert_gpt_to_transformers.py b/tools/converter/convert_gpt_to_transformers.py index 07308aeed..751866181 100644 --- a/tools/converter/convert_gpt_to_transformers.py +++ b/tools/converter/convert_gpt_to_transformers.py @@ -26,57 +26,60 @@ import argparse import json -import sys import os import re - +import sys import types import torch - from transformers import AutoTokenizer, LlamaConfig +from transformers.modeling_utils import (WEIGHTS_INDEX_NAME, WEIGHTS_NAME, + shard_checkpoint) + from modeling_megatron_llama import MegatronLlamaConfig -from transformers.modeling_utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, shard_checkpoint def add_checkpointing_args(parser): - parser.add_argument("--megatron-path", type=str, default=None, help="Base directory of Megatron repository") + parser.add_argument('--megatron-path', + type=str, + default=None, + help='Base directory of Megatron repository') parser.add_argument( - "--load_path", + '--load_path', type=str, required=True, - help="Path to the checkpoint to convert.", + help='Path to the checkpoint to convert.', ) parser.add_argument( - "--save_path", + '--save_path', type=str, required=True, - help="Path to the converted checkpoint.", + help='Path to the converted checkpoint.', ) - parser.add_argument("--print-checkpoint-structure", action="store_true") + parser.add_argument('--print-checkpoint-structure', action='store_true') return parser def add_transformers_checkpoint_args(parser): parser.add_argument( - "--tokenizer_name", + '--tokenizer_name', type=str, default=None, - help=( - "The name of the pre-trained tokenizer to save. " - "If not None, the tokenizer will be saved. " - "Only used when converting a Megatron checkpoint to a Transformers checkpoint." - ), + help= + ('The name of the pre-trained tokenizer to save. ' + 'If not None, the tokenizer will be saved. ' + 'Only used when converting a Megatron checkpoint to a Transformers checkpoint.' + ), ) parser.add_argument( - "--max_shard_size", + '--max_shard_size', type=str, - default="10GB", - help=( - "The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size " - "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`). " - "Only used when converting a Megatron checkpoint to a Transformers checkpoint." - ), + default='10GB', + help= + ('The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size ' + 'lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`). ' + 'Only used when converting a Megatron checkpoint to a Transformers checkpoint.' + ), ) return parser @@ -84,33 +87,36 @@ def add_transformers_checkpoint_args(parser): # The simple map of names for "automated" rules. megatron_to_transformers = { - "attention.dense": ".self_attn.o_proj.", - "self_attention.dense": ".self_attn.o_proj.", + 'attention.dense': '.self_attn.o_proj.', + 'self_attention.dense': '.self_attn.o_proj.', # TODO: one to two vectors - "mlp.dense_h_to_4h": ".mlp.{}_proj.", - "mlp.dense_4h_to_h": ".mlp.down_proj.", + 'mlp.dense_h_to_4h': '.mlp.{}_proj.', + 'mlp.dense_4h_to_h': '.mlp.down_proj.', +} +transformers_to_megatron = { + v[1:-1]: k + for k, v in megatron_to_transformers.items() } -transformers_to_megatron = {v[1:-1]: k for k, v in megatron_to_transformers.items()} tensor_parallel_params = [ # megatron-lm layers to merge across tp ranks - "self_attention.query_key_value.weight", - "self_attention.query_key_value.bias", - "self_attention.dense.weight", - "mlp.dense_h_to_4h.weight", - "mlp.dense_h_to_4h.bias", - "mlp.dense_4h_to_h.weight", + 'self_attention.query_key_value.weight', + 'self_attention.query_key_value.bias', + 'self_attention.dense.weight', + 'mlp.dense_h_to_4h.weight', + 'mlp.dense_h_to_4h.bias', + 'mlp.dense_4h_to_h.weight', # deprecated - "attention.query_key_value.weight", - "attention.query_key_value.bias", - "attention.dense.weight", + 'attention.query_key_value.weight', + 'attention.query_key_value.bias', + 'attention.dense.weight', # transformers layers to split across tp ranks - "attn.c_attn.weight", - "attn.c_attn.bias", - "attn.c_proj.weight", - "mlp.c_fc.weight", - "mlp.c_fc.bias", - "mlp.c_proj.weight", + 'attn.c_attn.weight', + 'attn.c_attn.bias', + 'attn.c_proj.weight', + 'mlp.c_fc.weight', + 'mlp.c_fc.bias', + 'mlp.c_proj.weight', ] @@ -127,7 +133,7 @@ def recursive_print(name, val, spaces=0): if name is None: msg = None else: - fmt = "." * max(0, spaces - 2) + "# {:" + str(50 - spaces) + "s}" + fmt = '.' * max(0, spaces - 2) + '# {:' + str(50 - spaces) + 's}' msg = fmt.format(name) # Print and recurse (if needed). @@ -137,14 +143,13 @@ def recursive_print(name, val, spaces=0): for k in val.keys(): recursive_print(k, val[k], spaces + 2) elif isinstance(val, torch.Tensor): - print(msg, ":", val.size()) + print(msg, ':', val.size()) else: - print(msg, ":", val) + print(msg, ':', val) def megatron_to_transformers_fix_query_key_value_ordering( - param, checkpoint_version, num_splits, num_heads, hidden_size -): + param, checkpoint_version, num_splits, num_heads, hidden_size): """ Permutes layout of param tensor to [num_splits * num_heads * hidden_size, :] for compatibility with later versions of NVIDIA Megatron-LM. The inverse operation is performed inside Megatron-LM to read checkpoints: @@ -177,8 +182,7 @@ def megatron_to_transformers_fix_query_key_value_ordering( def transformers_to_megatron_fix_query_key_value_ordering( - param, checkpoint_version, num_splits, num_heads, hidden_size -): + param, checkpoint_version, num_splits, num_heads, hidden_size): """ Permutes layout of param tensor to the one compatible with respective NVIDIA Megatron-LM chekpoint versions. Input is [num_splits * num_heads * hidden_size, :] and output is [num_heads * hidden_size * num_splits, :] for version @@ -220,8 +224,9 @@ def merge_transformers_sharded_states(path, num_checkpoints): """ state_dict = {} for i in range(1, num_checkpoints + 1): - checkpoint_path = os.path.join(path, f"pytorch_model-{i:05d}-of-{num_checkpoints:05d}.bin") - current_chunk = torch.load(checkpoint_path, map_location="cpu") + checkpoint_path = os.path.join( + path, f'pytorch_model-{i:05d}-of-{num_checkpoints:05d}.bin') + current_chunk = torch.load(checkpoint_path, map_location='cpu') state_dict.update(current_chunk) return state_dict @@ -239,10 +244,12 @@ def get_megatron_sharded_states(args, tp_size, pp_size, pp_rank): """ tp_state_dicts = [] for i in range(tp_size): - sub_dir_name = f"mp_rank_{i:02d}" if pp_size == 1 else f"mp_rank_{i:02d}_{pp_rank:03d}" - checkpoint_name = os.listdir(os.path.join(args.load_path, sub_dir_name))[0] - checkpoint_path = os.path.join(args.load_path, sub_dir_name, checkpoint_name) - state_dict = torch.load(checkpoint_path, map_location="cpu") + sub_dir_name = f'mp_rank_{i:02d}' if pp_size == 1 else f'mp_rank_{i:02d}_{pp_rank:03d}' + checkpoint_name = os.listdir(os.path.join(args.load_path, + sub_dir_name))[0] + checkpoint_path = os.path.join(args.load_path, sub_dir_name, + checkpoint_name) + state_dict = torch.load(checkpoint_path, map_location='cpu') tp_state_dicts.append(state_dict) return tp_state_dicts @@ -255,7 +262,7 @@ def get_element_from_dict_by_path(d, path): d (dict): the dictionary to get the element from path (list): the path to the element which is delimited by "." """ - path = path.split(".") + path = path.split('.') for k in path: if k not in d: d[k] = {} @@ -275,44 +282,45 @@ def convert_checkpoint_from_megatron_to_transformers(args): """ # Load Megatron-LM checkpoint arguments from the state dict sub_dirs = os.listdir(args.load_path) - possible_sub_dirs = ["mp_rank_00", "mp_rank_00_000"] + possible_sub_dirs = ['mp_rank_00', 'mp_rank_00_000'] for sub_dir in possible_sub_dirs: if sub_dir in sub_dirs: - rank0_checkpoint_name = os.listdir(os.path.join(args.load_path, sub_dir))[0] - rank0_checkpoint_path = os.path.join(args.load_path, sub_dir, rank0_checkpoint_name) + rank0_checkpoint_name = os.listdir( + os.path.join(args.load_path, sub_dir))[0] + rank0_checkpoint_path = os.path.join(args.load_path, sub_dir, + rank0_checkpoint_name) break - print(f"Loading Megatron-LM checkpoint arguments from: {rank0_checkpoint_path}") - state_dict = torch.load(rank0_checkpoint_path, map_location="cpu") - megatron_args = state_dict.get("args", None) + print( + f'Loading Megatron-LM checkpoint arguments from: {rank0_checkpoint_path}' + ) + state_dict = torch.load(rank0_checkpoint_path, map_location='cpu') + megatron_args = state_dict.get('args', None) if megatron_args is None: raise ValueError( - "Megatron-LM checkpoint does not contain arguments. This utility only supports Megatron-LM checkpoints" - " containing all the megatron arguments. This is because it loads all config related to model" - " architecture, the tensor and pipeline model parallel size from the checkpoint insead of user having to" - " manually specify all the details. Please save Megatron-LM checkpoint along with all the megatron" - " arguments to use this utility." - ) + 'Megatron-LM checkpoint does not contain arguments. This utility only supports Megatron-LM checkpoints' + ' containing all the megatron arguments. This is because it loads all config related to model' + ' architecture, the tensor and pipeline model parallel size from the checkpoint insead of user having to' + ' manually specify all the details. Please save Megatron-LM checkpoint along with all the megatron' + ' arguments to use this utility.') # Create Transformers GPT2 config from Megatron-LM arguments if megatron_args is not None: # dawei: use swish as activation function if megatron_args.swiglu: - activation_function = "silu" + activation_function = 'silu' elif megatron_args.bias_gelu_fusion: - activation_function = "gelu_fast" + activation_function = 'gelu_fast' elif megatron_args.openai_gelu: - activation_function = "gelu_new" + activation_function = 'gelu_new' else: - activation_function = "gelu" + activation_function = 'gelu' else: # in the very early days this used to be "gelu_new" - activation_function = "gelu_new" - vocab_size = ( - megatron_args.padded_vocab_size - if getattr(megatron_args, "orig_vocab_size", None) is None - else megatron_args.orig_vocab_size - ) - print("vocab size:", vocab_size) + activation_function = 'gelu_new' + vocab_size = (megatron_args.padded_vocab_size + if getattr(megatron_args, 'orig_vocab_size', None) is None + else megatron_args.orig_vocab_size) + print('vocab size:', vocab_size) config = MegatronLlamaConfig( # dawei: from megatron-lm @@ -333,48 +341,48 @@ def convert_checkpoint_from_megatron_to_transformers(args): bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, - architectures=["MegatronLlamaForCausalLM"], - + architectures=['MegatronLlamaForCausalLM'], use_bias=True, ) output_state_dict = {} - checkpoint_version = state_dict.get("checkpoint_version", 0.0) + checkpoint_version = state_dict.get('checkpoint_version', 0.0) tp_size = megatron_args.tensor_model_parallel_size pp_size = megatron_args.pipeline_model_parallel_size dtype = torch.float32 # The regex to extract layer names. - layer_re = re.compile(r"layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") + layer_re = re.compile(r'layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)') # Convert. - print("Converting") + print('Converting') # Embeddings - print("Converting embeddings") + print('Converting embeddings') tp_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, 0) # Convert and store the position embeddings. position_embeddings = get_element_from_dict_by_path( - tp_state_dicts[0], "model.language_model.embedding.position_embeddings.weight" - ) - output_state_dict["model.embed_position.weight"] = position_embeddings.to(dtype) + tp_state_dicts[0], + 'model.language_model.embedding.position_embeddings.weight') + output_state_dict['model.embed_position.weight'] = position_embeddings.to( + dtype) # Convert and store the word embeddings. word_embeddings = torch.cat( [ get_element_from_dict_by_path( - tp_state_dicts[tp_rank], "model.language_model.embedding.word_embeddings.weight" - ) + tp_state_dicts[tp_rank], + 'model.language_model.embedding.word_embeddings.weight') for tp_rank in range(tp_size) ], dim=0, ) word_embeddings = word_embeddings[:vocab_size].to(dtype) - output_state_dict["model.embed_tokens.weight"] = word_embeddings + output_state_dict['model.embed_tokens.weight'] = word_embeddings # Transformer Layers - print("Converting transformer layers") + print('Converting transformer layers') # The number of heads. heads = config.num_attention_heads # The hidden_size per head. @@ -384,17 +392,18 @@ def convert_checkpoint_from_megatron_to_transformers(args): for pp_rank in range(pp_size): if pp_size > 0: - print(f"Converting pipeline parallel rank {pp_rank}") - tp_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, pp_rank) + print(f'Converting pipeline parallel rank {pp_rank}') + tp_state_dicts = get_megatron_sharded_states( + args, tp_size, pp_size, pp_rank) # The transformer. - path = ( - "model.language_model.transformer" - if "transformer" in get_element_from_dict_by_path(tp_state_dicts[0], "model.language_model").keys() - else "model.language_model.encoder" - ) + path = ('model.language_model.transformer' + if 'transformer' in get_element_from_dict_by_path( + tp_state_dicts[0], 'model.language_model').keys() else + 'model.language_model.encoder') # Extract the layers. - for key, val in get_element_from_dict_by_path(tp_state_dicts[0], path).items(): + for key, val in get_element_from_dict_by_path(tp_state_dicts[0], + path).items(): # Match the name. m = layer_re.match(key) # Stop if that's not a layer @@ -410,9 +419,9 @@ def convert_checkpoint_from_megatron_to_transformers(args): weight_or_bias = m.group(3) # The name of the layer. - layer_name = f"model.layers.{layer_idx}" + layer_name = f'model.layers.{layer_idx}' - if op_name + "." + weight_or_bias not in tensor_parallel_params: + if op_name + '.' + weight_or_bias not in tensor_parallel_params: # dawei: input_layernorm.weight, input_layernorm.bias, self_attention.dense.bias, # dawei: self_attention_layernorm.weight, self_attention_layernorm.bias, mlp.dense_4h_to_h.bias # dawei: post_attention_layernorm.weight, post_attention_layernorm.bias @@ -421,16 +430,20 @@ def convert_checkpoint_from_megatron_to_transformers(args): # dawei: self_attention.query_key_value.weight, self_attention_query_value.bias, self_attention.dense.weight, # mlp.dense_h_to_4h.weight, mlp.dense_h_to_4h.bias, # mlp.dense_4h_to_h.weight - dim = 1 if op_name in ["self_attention.dense", "mlp.dense_4h_to_h", "attention.dense"] else 0 + dim = 1 if op_name in [ + 'self_attention.dense', 'mlp.dense_4h_to_h', + 'attention.dense' + ] else 0 # dawei: maybe only stored in the first chunk # dawei: fix bug in swiglu and dense_h_to_4h.weight - if op_name == "mlp.dense_h_to_4h" and weight_or_bias == "weight": + if op_name == 'mlp.dense_h_to_4h' and weight_or_bias == 'weight': params_list = [val] + [ - get_element_from_dict_by_path(tp_state_dicts[tp_rank], f"{path}")[key] + get_element_from_dict_by_path(tp_state_dicts[tp_rank], + f'{path}')[key] for tp_rank in range(1, tp_size) - ] + ] ws, vs = list(), list() for p in params_list: w, v = torch.chunk(p, 2, dim=0) @@ -440,9 +453,9 @@ def convert_checkpoint_from_megatron_to_transformers(args): else: params = torch.cat( - [val] - + [ - get_element_from_dict_by_path(tp_state_dicts[tp_rank], f"{path}")[key] + [val] + [ + get_element_from_dict_by_path( + tp_state_dicts[tp_rank], f'{path}')[key] for tp_rank in range(1, tp_size) ], dim=dim, @@ -450,17 +463,19 @@ def convert_checkpoint_from_megatron_to_transformers(args): # For layernorm(s), simply store the layer norm. # dawei: ignore the bias for layernorm - if op_name.endswith("layernorm"): + if op_name.endswith('layernorm'): # dawei: input_layernorm & post_attention_layernorm - if weight_or_bias == "weight": + if weight_or_bias == 'weight': # dawei: skip bias - ln_name = "input_layernorm" if op_name.startswith("input") else "post_attention_layernorm" - output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = params + ln_name = 'input_layernorm' if op_name.startswith( + 'input') else 'post_attention_layernorm' + output_state_dict[layer_name + '.' + ln_name + '.' + + weight_or_bias] = params # Transpose the QKV matrix. - elif ( - op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value" - ) and weight_or_bias == "weight": + elif (op_name == 'attention.query_key_value' + or op_name == 'self_attention.query_key_value' + ) and weight_or_bias == 'weight': # dawei: (gpt2) self_attention.query_key_value.weight out_val = megatron_to_transformers_fix_query_key_value_ordering( params, @@ -478,70 +493,73 @@ def convert_checkpoint_from_megatron_to_transformers(args): # (3*D) x D ==> D x D, still [out_dim, in_dim] q, k, v = torch.chunk(out_val, 3, dim=0) # Store. - output_state_dict[layer_name + ".self_attn.q_proj.weight"] = q - output_state_dict[layer_name + ".self_attn.k_proj.weight"] = k - output_state_dict[layer_name + ".self_attn.v_proj.weight"] = v + output_state_dict[layer_name + '.self_attn.q_proj.weight'] = q + output_state_dict[layer_name + '.self_attn.k_proj.weight'] = k + output_state_dict[layer_name + '.self_attn.v_proj.weight'] = v # Transpose the bias. - elif ( - op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value" - ) and weight_or_bias == "bias": + elif (op_name == 'attention.query_key_value' + or op_name == 'self_attention.query_key_value' + ) and weight_or_bias == 'bias': # dawei: (gpt2) self_attention.query_key_value.bias out_val = megatron_to_transformers_fix_query_key_value_ordering( - params, checkpoint_version, 3, heads, hidden_size_per_head - ) + params, checkpoint_version, 3, heads, hidden_size_per_head) # dawei: split in to 3 bias q_b, k_b, v_b = torch.chunk(out_val, 3, dim=0) # Store. No change of shape. - output_state_dict[layer_name + ".self_attn.q_proj.bias"] = q_b - output_state_dict[layer_name + ".self_attn.k_proj.bias"] = k_b - output_state_dict[layer_name + ".self_attn.v_proj.bias"] = v_b + output_state_dict[layer_name + '.self_attn.q_proj.bias'] = q_b + output_state_dict[layer_name + '.self_attn.k_proj.bias'] = k_b + output_state_dict[layer_name + '.self_attn.v_proj.bias'] = v_b - elif ( - op_name == "mlp.dense_h_to_4h" and weight_or_bias == "weight" - ): + elif (op_name == 'mlp.dense_h_to_4h' + and weight_or_bias == 'weight'): # dawei: mlp.dense_h_to_4h.weight out_name = megatron_to_transformers[op_name] gate, up = torch.chunk(params, 2, dim=0) - output_state_dict[layer_name + out_name.format("gate") + "weight"] = gate - output_state_dict[layer_name + out_name.format("up") + "weight"] = up + output_state_dict[layer_name + out_name.format('gate') + + 'weight'] = gate + output_state_dict[layer_name + out_name.format('up') + + 'weight'] = up # Transpose the weights. - elif weight_or_bias == "weight": + elif weight_or_bias == 'weight': # dawei: self_attention.dense.weight, mlp.dense_4h_to_h.weight out_name = megatron_to_transformers[op_name] - output_state_dict[layer_name + out_name + "weight"] = params + output_state_dict[layer_name + out_name + 'weight'] = params - elif ( - op_name == "mlp.dense_h_to_4h" and weight_or_bias == "bias" - ): + elif (op_name == 'mlp.dense_h_to_4h' and weight_or_bias == 'bias'): # dawei: mlp.dense_h_to_4h.bias out_name = megatron_to_transformers[op_name] gate_b, up_b = torch.chunk(params, 2, dim=0) - output_state_dict[layer_name + out_name.format("gate") + "bias"] = gate_b - output_state_dict[layer_name + out_name.format("up") + "bias"] = up_b + output_state_dict[layer_name + out_name.format('gate') + + 'bias'] = gate_b + output_state_dict[layer_name + out_name.format('up') + + 'bias'] = up_b # Copy the bias. - elif weight_or_bias == "bias": + elif weight_or_bias == 'bias': # dawei: (gpt2) self_attention.query_key_value.bias out_name = megatron_to_transformers[op_name] - output_state_dict[layer_name + out_name + "bias"] = params + output_state_dict[layer_name + out_name + 'bias'] = params if config.num_hidden_layers != (layer_idx + 1): - raise ValueError(f"Expected {config.num_hidden_layers} layers but found {layer_idx + 1}") + raise ValueError( + f'Expected {config.num_hidden_layers} layers but found {layer_idx + 1}' + ) # The final layernorm. - print("Converting final layernorm") + print('Converting final layernorm') params = get_element_from_dict_by_path(tp_state_dicts[0], str(path)) - output_state_dict["model.norm.weight"] = params["final_layernorm.weight"].to(dtype) + output_state_dict['model.norm.weight'] = params[ + 'final_layernorm.weight'].to(dtype) # For LM head, transformers' wants the matrix to weight embeddings. - print("Converting LM head") - output_state_dict["lm_head.weight"] = word_embeddings.to(dtype) + print('Converting LM head') + output_state_dict['lm_head.weight'] = word_embeddings.to(dtype) # It should be done! - print("Conversion from Megatron-LM to Transformers is done!") + print('Conversion from Megatron-LM to Transformers is done!') # Print the structure of converted state dict. if args.print_checkpoint_structure: @@ -550,9 +568,9 @@ def convert_checkpoint_from_megatron_to_transformers(args): # Add tokenizer class info to config # see https://github.com/huggingface/transformers/issues/13906) - print("Tokenizer_name: ", args.tokenizer_name) + print('Tokenizer_name: ', args.tokenizer_name) if args.tokenizer_name is None: - tokenizer_name = "gpt2" + tokenizer_name = 'gpt2' else: tokenizer_name = args.tokenizer_name @@ -561,35 +579,38 @@ def convert_checkpoint_from_megatron_to_transformers(args): config.tokenizer_class = tokenizer_class # Store the config to file. - print("Saving config") + print('Saving config') config.save_pretrained(args.save_path) # Save tokenizer based on args if args.tokenizer_name is not None: - print(f"Adding {tokenizer_class} tokenizer files") + print(f'Adding {tokenizer_class} tokenizer files') tokenizer.save_pretrained(args.save_path) # Store the state_dict to file. - max_shard_size = int(args.max_shard_size) if args.max_shard_size.isdigit() else args.max_shard_size - shards, index = shard_checkpoint(output_state_dict, max_shard_size=max_shard_size) + max_shard_size = int(args.max_shard_size) if args.max_shard_size.isdigit( + ) else args.max_shard_size + shards, index = shard_checkpoint(output_state_dict, + max_shard_size=max_shard_size) # Save the model for shard_file, shard in shards.items(): torch.save(shard, os.path.join(args.save_path, shard_file)) if index is None: - print(f"Model weights saved in {os.path.join(args.save_path, WEIGHTS_NAME)}") + print( + f'Model weights saved in {os.path.join(args.save_path, WEIGHTS_NAME)}' + ) else: save_index_file = os.path.join(args.save_path, WEIGHTS_INDEX_NAME) # Save the index as well - with open(save_index_file, "w", encoding="utf-8") as f: - content = json.dumps(index, indent=2, sort_keys=True) + "\n" + with open(save_index_file, 'w', encoding='utf-8') as f: + content = json.dumps(index, indent=2, sort_keys=True) + '\n' f.write(content) print( - f"The model is bigger than the maximum size per checkpoint ({args.max_shard_size}) and is going to be " - f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) + f'The model is bigger than the maximum size per checkpoint ({args.max_shard_size}) and is going to be ' + f'split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the ' + f'index located at {save_index_file}.') def main(): @@ -601,5 +622,5 @@ def main(): convert_checkpoint_from_megatron_to_transformers(args) -if __name__ == "__main__": +if __name__ == '__main__': main() diff --git a/tools/converter/modeling_megatron_llama.py b/tools/converter/modeling_megatron_llama.py index f35e6665c..268856001 100644 --- a/tools/converter/modeling_megatron_llama.py +++ b/tools/converter/modeling_megatron_llama.py @@ -30,23 +30,25 @@ import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - from transformers.activations import ACT2FN -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import (add_start_docstrings, + add_start_docstrings_to_model_forward, logging, + replace_return_docstrings) logger = logging.get_logger(__name__) -_CONFIG_FOR_DOC = "MegatronLlamaConfig" - +_CONFIG_FOR_DOC = 'MegatronLlamaConfig' LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} class MegatronLlamaConfig(PretrainedConfig): - model_type = "megatron-llama" + model_type = 'megatron-llama' def __init__( self, @@ -55,7 +57,7 @@ def __init__( intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, - hidden_act="silu", + hidden_act='silu', max_position_embeddings=2048, initializer_range=0.02, rms_norm_eps=1e-6, @@ -91,39 +93,53 @@ def __init__( # Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): +def _make_causal_mask(input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask = torch.full((tgt_len, tgt_len), + torch.tensor(torch.finfo(dtype).min, device=device), + device=device) mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + mask = torch.cat([ + torch.zeros( + tgt_len, past_key_values_length, dtype=dtype, device=device), + mask + ], + dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, + tgt_len + past_key_values_length) # Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): +def _expand_mask(mask: torch.Tensor, + dtype: torch.dtype, + tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, + src_len).to(dtype) inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), + torch.finfo(dtype).min) class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): """ LlamaRMSNorm is equivalent to T5LayerNorm @@ -133,8 +149,10 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + variance = hidden_states.to(torch.float32).pow(2).mean(-1, + keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: @@ -144,31 +162,49 @@ def forward(self, hidden_states): class LlamaRotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + + def __init__(self, + dim, + max_position_embeddings=2048, + base=10000, + device=None): super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - self.register_buffer("inv_freq", inv_freq) + inv_freq = 1.0 / (base + **(torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer('inv_freq', inv_freq) # Build here to make `torch.jit.trace` work. self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) + t = torch.arange(self.max_seq_len_cached, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) + self.register_buffer('cos_cached', + emb.cos()[None, None, :, :], + persistent=False) + self.register_buffer('sin_cached', + emb.sin()[None, None, :, :], + persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. if seq_len > self.max_seq_len_cached: self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) + t = torch.arange(self.max_seq_len_cached, + device=x.device, + dtype=self.inv_freq.dtype) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) + self.register_buffer('cos_cached', + emb.cos()[None, None, :, :], + persistent=False) + self.register_buffer('sin_cached', + emb.sin()[None, None, :, :], + persistent=False) return ( self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), @@ -177,7 +213,7 @@ def forward(self, x, seq_len=None): def rotate_half(x): """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] + x1 = x[..., :x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) @@ -194,16 +230,16 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): class LlamaMLP(nn.Module): - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - use_bias: bool - ): + + def __init__(self, hidden_size: int, intermediate_size: int, + hidden_act: str, use_bias: bool): super().__init__() - self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=use_bias) - self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=use_bias) + self.gate_proj = nn.Linear(hidden_size, + intermediate_size, + bias=use_bias) + self.down_proj = nn.Linear(intermediate_size, + hidden_size, + bias=use_bias) self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=use_bias) self.act_fn = ACT2FN[hidden_act] @@ -213,7 +249,7 @@ def forward(self, x): class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - + def __init__(self, config: MegatronLlamaConfig): super().__init__() self.config = config @@ -224,38 +260,53 @@ def __init__(self, config: MegatronLlamaConfig): if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.use_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.use_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.use_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.use_bias) - self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}' + f' and `num_heads`: {self.num_heads}).') + self.q_proj = nn.Linear(self.hidden_size, + self.num_heads * self.head_dim, + bias=config.use_bias) + self.k_proj = nn.Linear(self.hidden_size, + self.num_heads * self.head_dim, + bias=config.use_bias) + self.v_proj = nn.Linear(self.hidden_size, + self.num_heads * self.head_dim, + bias=config.use_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, + self.hidden_size, + bias=config.use_bias) + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view( + bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view( + bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view( + bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids) # [bsz, nh, t, hd] if past_key_value is not None: @@ -265,31 +316,35 @@ def forward( past_key_value = (key_states, value_states) if use_cache else None - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose( + 2, 3)) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) + f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is' + f' {attn_weights.size()}') if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}' ) attn_weights = attn_weights + attention_mask - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + attn_weights = torch.max( + attn_weights, + torch.tensor(torch.finfo(attn_weights.dtype).min)) # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.softmax(attn_weights, + dim=-1, + dtype=torch.float32).to( + query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is' + f' {attn_output.size()}') attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -303,6 +358,7 @@ def forward( class LlamaDecoderLayer(nn.Module): + def __init__(self, config: MegatronLlamaConfig): super().__init__() self.hidden_size = config.hidden_size @@ -313,18 +369,21 @@ def __init__(self, config: MegatronLlamaConfig): hidden_act=config.hidden_act, use_bias=config.use_bias, ) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, + torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -360,13 +419,13 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states,) + outputs = (hidden_states, ) if output_attentions: - outputs += (self_attn_weights,) + outputs += (self_attn_weights, ) if use_cache: - outputs += (present_key_value,) + outputs += (present_key_value, ) return outputs @@ -389,15 +448,15 @@ def forward( @add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + 'The bare LLaMA Model outputting raw hidden-states without any specific head on top.', LLAMA_START_DOCSTRING, ) class LlamaPreTrainedModel(PreTrainedModel): config_class = MegatronLlamaConfig - base_model_prefix = "model" + base_model_prefix = 'model' supports_gradient_checkpointing = True - _no_split_modules = ["LlamaDecoderLayer"] - _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] + _no_split_modules = ['LlamaDecoderLayer'] + _keys_to_ignore_on_load_unexpected = [r'decoder\.version'] def _init_weights(self, module): std = self.config.initializer_range @@ -480,7 +539,7 @@ def _set_gradient_checkpointing(self, module, value=False): @add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + 'The bare LLaMA Model outputting raw hidden-states without any specific head on top.', LLAMA_START_DOCSTRING, ) class LlamaModel(LlamaPreTrainedModel): @@ -497,11 +556,15 @@ def __init__(self, config: MegatronLlamaConfig): self.vocab_size = config.vocab_size # TODO: position embeddings, should be removed if rotary position embedding - self.embed_position = nn.Embedding(config.max_sequence_length, config.hidden_size) + self.embed_position = nn.Embedding(config.max_sequence_length, + config.hidden_size) # word embeddings - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, + self.padding_idx) + self.layers = nn.ModuleList([ + LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers) + ]) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -515,7 +578,8 @@ def set_input_embeddings(self, value): self.embed_tokens = value # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, + inputs_embeds, past_key_values_length): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None @@ -529,45 +593,51 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) + expanded_attn_mask = _expand_mask(attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1]).to( + inputs_embeds.device) + combined_attention_mask = (expanded_attn_mask + if combined_attention_mask is None else + expanded_attn_mask + + combined_attention_mask) return combined_attention_mask @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + raise ValueError( + 'You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time' + ) elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + raise ValueError( + 'You have to specify either decoder_input_ids or decoder_inputs_embeds' + ) seq_length_with_past = seq_length past_key_values_length = 0 @@ -578,9 +648,10 @@ def forward( if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) + position_ids = torch.arange(past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() @@ -594,19 +665,19 @@ def forward( # embed positions if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) + attention_mask = torch.ones((batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device) attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) + attention_mask, (batch_size, seq_length), inputs_embeds, + past_key_values_length) hidden_states = inputs_embeds if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' ) use_cache = False @@ -617,13 +688,15 @@ def forward( for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: - all_hidden_states += (hidden_states,) + all_hidden_states += (hidden_states, ) - past_key_value = past_key_values[idx] if past_key_values is not None else None + past_key_value = past_key_values[ + idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: def create_custom_forward(module): + def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, None) @@ -650,20 +723,24 @@ def custom_forward(*inputs): hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache += ( + layer_outputs[2 if output_attentions else 1], ) if output_attentions: - all_self_attns += (layer_outputs[1],) + all_self_attns += (layer_outputs[1], ) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: - all_hidden_states += (hidden_states,) + all_hidden_states += (hidden_states, ) next_cache = next_decoder_cache if use_cache else None if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v for v in + [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -673,6 +750,7 @@ def custom_forward(*inputs): class MegatronLlamaForCausalLM(LlamaPreTrainedModel): + def __init__(self, config): super().__init__(config) # used for model @@ -680,7 +758,9 @@ def __init__(self, config): self.model = LlamaModel(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.lm_head = nn.Linear(config.hidden_size, + config.vocab_size, + bias=False) # Initialize weights and apply final processing self.post_init() @@ -704,19 +784,20 @@ def get_decoder(self): return self.model @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, + config_class=_CONFIG_FOR_DOC) def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -745,9 +826,9 @@ def forward( ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) @@ -780,8 +861,8 @@ def forward( loss = loss_fct(shift_logits, shift_labels) if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + output = (logits, ) + outputs[1:] + return (loss, ) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, @@ -791,13 +872,16 @@ def forward( attentions=outputs.attentions, ) - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): + def prepare_inputs_for_generation(self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs): if past_key_values: input_ids = input_ids[:, -1:] - position_ids = kwargs.get("position_ids", None) + position_ids = kwargs.get('position_ids', None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -807,25 +891,25 @@ def prepare_inputs_for_generation( # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {'inputs_embeds': inputs_embeds} else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) + model_inputs = {'input_ids': input_ids} + + model_inputs.update({ + 'position_ids': position_ids, + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + }) return model_inputs @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += (tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past), ) return reordered_past @@ -845,7 +929,7 @@ def _reorder_cache(past_key_values, beam_idx): LLAMA_START_DOCSTRING, ) class LlamaForSequenceClassification(LlamaPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + _keys_to_ignore_on_load_missing = [r'lm_head.weight'] def __init__(self, config): super().__init__(config) @@ -864,17 +948,17 @@ def set_input_embeddings(self, value): @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -904,43 +988,50 @@ def forward( batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + raise ValueError( + 'Cannot handle batch sizes > 1 if no padding token is defined.' + ) if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + sequence_lengths = ( + torch.ne(input_ids, self.config.pad_token_id).sum(-1) - + 1).to(logits.device) else: sequence_lengths = -1 - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + pooled_logits = logits[torch.arange(batch_size, device=logits.device), + sequence_lengths] loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" + self.config.problem_type = 'regression' + elif self.num_labels > 1 and (labels.dtype == torch.long + or labels.dtype == torch.int): + self.config.problem_type = 'single_label_classification' else: - self.config.problem_type = "multi_label_classification" + self.config.problem_type = 'multi_label_classification' - if self.config.problem_type == "regression": + if self.config.problem_type == 'regression': loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) else: loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": + elif self.config.problem_type == 'single_label_classification': loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": + loss = loss_fct(pooled_logits.view(-1, self.num_labels), + labels.view(-1)) + elif self.config.problem_type == 'multi_label_classification': loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output + output = (pooled_logits, ) + transformer_outputs[1:] + return ((loss, ) + output) if loss is not None else output return SequenceClassifierOutputWithPast( loss=loss, diff --git a/tools/evaluator/config/evaluator_example.yaml b/tools/evaluator/config/evaluator_example.yaml index ee69c0f3f..3c5be1907 100644 --- a/tools/evaluator/config/evaluator_example.yaml +++ b/tools/evaluator/config/evaluator_example.yaml @@ -11,7 +11,7 @@ auto_eval: merge_path: max_tokens: token_per_iteration: - # tokenizer_path: + # tokenizer_path: # log_path: helm: helm_spec_template_path: @@ -30,4 +30,4 @@ auto_eval: result_file: wandb: project: - base_url: \ No newline at end of file + base_url: diff --git a/tools/evaluator/config/helm_spec_template.conf b/tools/evaluator/config/helm_spec_template.conf index d21560a75..e9109cdee 100644 --- a/tools/evaluator/config/helm_spec_template.conf +++ b/tools/evaluator/config/helm_spec_template.conf @@ -104,4 +104,4 @@ entries: [ {description: "civil_comments:model=,demographic=other_religions,data_augmentation=canonical", priority: 2} {description: "civil_comments:model=,demographic=black,data_augmentation=canonical", priority: 2} {description: "civil_comments:model=,demographic=white,data_augmentation=canonical", priority: 2} -] \ No newline at end of file +] diff --git a/tools/evaluator/evaluator.py b/tools/evaluator/evaluator.py index 5efc443be..c783fb4e1 100644 --- a/tools/evaluator/evaluator.py +++ b/tools/evaluator/evaluator.py @@ -1,9 +1,10 @@ import argparse -import yaml import os +import shutil import subprocess import time -import shutil + +import yaml from gpt_eval.gpt_evaluator import GPTEvaluator from recorder.wandb_writer import HelmWriter @@ -12,8 +13,9 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, required=True) - parser.add_argument( - '--model-type', choices=['megatron', 'huggingface'], default='megatron') + parser.add_argument('--model-type', + choices=['megatron', 'huggingface'], + default='megatron') parser.add_argument('--eval-type', choices=['helm', 'gpt'], default='helm') parser.add_argument('--iteration-interval', type=int, default=1000) parser.add_argument('--begin-iteration', type=int, default=None) @@ -25,10 +27,13 @@ def parse_args(): def check_args(args): if args.begin_iteration == None: print( - f"--begin-iteration is not provided, use the value of --iteration-interval ({args.iteration_interval}).") + f'--begin-iteration is not provided, use the value of --iteration-interval ({args.iteration_interval}).' + ) args.begin_iteration = args.iteration_interval if args.end_iteration == None: - print(f"--end-iteration is not provided, evaluator will monitor the traning process continuously.") + print( + f'--end-iteration is not provided, evaluator will monitor the traning process continuously.' + ) args.end_iteration = float('inf') @@ -50,8 +55,9 @@ def load_config(self): self.full_name = f'{self.project_name}-{self.model_name}' # load cache dir self.cur_dir = os.path.abspath(os.getcwd()) - self.cache_dir = self.config['cache_dir'] if 'cache_dir' in self.config else os.path.join( - self.cur_dir, 'cache') + self.cache_dir = self.config[ + 'cache_dir'] if 'cache_dir' in self.config else os.path.join( + self.cur_dir, 'cache') if not os.path.exists(self.cache_dir): os.makedirs(self.cache_dir) # load megatron config @@ -59,7 +65,8 @@ def load_config(self): os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' os.environ['OMP_NUM_THREADS'] = '4' self.megatron_process_num = self.config['megatron']['process_num'] - self.megatron_checkpoint_path = self.config['megatron']['checkpoint_path'] + self.megatron_checkpoint_path = self.config['megatron'][ + 'checkpoint_path'] # for different tokenizer if self.config['megatron']['tokenizer_type'] == 'sentencepiece': self.tokenizer_type = 'sentencepiece' @@ -73,9 +80,10 @@ def load_config(self): self.tokenizer_path = None else: raise NotImplementedError( - f"tokenizer type: {self.config['megatron']['tokenizer_type']} is not supported") - self.megatron_log_path = os.path.join( - self.cache_dir, 'megatron.log') + f"tokenizer type: {self.config['megatron']['tokenizer_type']} is not supported" + ) + self.megatron_log_path = os.path.join(self.cache_dir, + 'megatron.log') if 'log_path' in self.config['megatron']: self.megatron_log_path = self.config['megatron']['log_path'] self.megatron_server_port = 5000 @@ -89,21 +97,27 @@ def load_config(self): self.max_tokens = self.config['megatron']['max_tokens'] self.megatron_token_per_iteration = 0 if 'token_per_iteration' in self.config['megatron']: - self.megatron_token_per_iteration = self.config['megatron']['token_per_iteration'] + self.megatron_token_per_iteration = self.config['megatron'][ + 'token_per_iteration'] # load helm config if 'helm' in self.config: - self.helm_spec_template_path = self.config['helm']['helm_spec_template_path'] + self.helm_spec_template_path = self.config['helm'][ + 'helm_spec_template_path'] self.helm_output_path = self.config['helm']['helm_output_path'] - self.helm_spec_path = os.path.join( - self.cache_dir, 'helm_spec.conf') + self.helm_spec_path = os.path.join(self.cache_dir, + 'helm_spec.conf') self.helm_cache_path = os.path.join(self.cache_dir, 'helm_cache') self.helm_suite_name = self.full_name - self.helm_conda_env = self.config['helm']['helm_env_name'] if 'helm_env_name' in self.config['helm'] else 'crfm-helm' + self.helm_conda_env = self.config['helm'][ + 'helm_env_name'] if 'helm_env_name' in self.config[ + 'helm'] else 'crfm-helm' self.helm_eval_instances = self.config['helm'][ - 'eval_instances'] if 'eval_instances' in self.config['helm'] else 100 - self.helm_benchmarks = self.config['helm']['benchmarks'] if 'benchmarks' in self.config['helm'] else None - self.helm_mymodel_config = os.path.join( - self.cache_dir, 'helm_config.yaml') + 'eval_instances'] if 'eval_instances' in self.config[ + 'helm'] else 100 + self.helm_benchmarks = self.config['helm'][ + 'benchmarks'] if 'benchmarks' in self.config['helm'] else None + self.helm_mymodel_config = os.path.join(self.cache_dir, + 'helm_config.yaml') with open(self.helm_mymodel_config, 'w', encoding='utf-8') as f: mymodel_config = { 'port': self.megatron_server_port, @@ -116,11 +130,15 @@ def load_config(self): } yaml.dump(mymodel_config, f) if self.eval_type == 'gpt': - self.gpt_question_file = self.config['gpt_evaluation']['question_file'] + self.gpt_question_file = self.config['gpt_evaluation'][ + 'question_file'] self.gpt_answer_file = self.config['gpt_evaluation']['answer_file'] if 'wandb' in self.config: - self.wandb_base_url = self.config['wandb']['base_url'] if 'base_url' in self.config['wandb'] else None - self.wandb_project = self.config['wandb']['project'] if 'project' in self.config['wandb'] else self.project_name + self.wandb_base_url = self.config['wandb'][ + 'base_url'] if 'base_url' in self.config['wandb'] else None + self.wandb_project = self.config['wandb'][ + 'project'] if 'project' in self.config[ + 'wandb'] else self.project_name def _set_megatron_tokenizer(self, args): if self.tokenizer_type == 'gpt2': @@ -140,19 +158,25 @@ def run_megatron_server(self, iteration): time.sleep(self.check_iterval * 60) # setup megatron server print( - f'Start megatron text generation server for checkpoint iter_{iteration}') - args = ['torchrun', '--master_addr', '127.0.0.1', '--master_port', '5950', '--nproc_per_node', str(self.megatron_process_num), '--nnodes', '1', '--node_rank', '0', os.path.join(self.megatron_home, 'tools/run_text_generation_server.py'), '--port', - str(self.megatron_server_port), '--use-checkpoint-args', '--load', self.megatron_checkpoint_path, - '--load-iteration', str(iteration), '--tokenizer-type'] + f'Start megatron text generation server for checkpoint iter_{iteration}' + ) + args = [ + 'torchrun', '--master_addr', '127.0.0.1', '--master_port', '5950', + '--nproc_per_node', + str(self.megatron_process_num), '--nnodes', '1', '--node_rank', + '0', + os.path.join(self.megatron_home, + 'tools/run_text_generation_server.py'), '--port', + str(self.megatron_server_port), '--use-checkpoint-args', '--load', + self.megatron_checkpoint_path, '--load-iteration', + str(iteration), '--tokenizer-type' + ] self._set_megatron_tokenizer(args) logfile = open(self.megatron_log_path, 'w') os.chdir(self.megatron_home) process = subprocess.Popen(args, stdout=logfile, stderr=logfile) os.chdir(self.cur_dir) - return { - 'process': process, - 'logfile': logfile - } + return {'process': process, 'logfile': logfile} def stop_megatron_server(self, process, logfile): process.terminate() @@ -164,12 +188,18 @@ def run_megatron_inference(self, iteration): time.sleep(self.check_iterval * 60) print(f'Wait for megatron checkpoint {iteration}') print(f'Start megatron inference for checkpoint iter_{iteration}') - args = ['torchrun', '--master_addr', '127.0.0.1', '--master_port', '5950', '--nproc_per_node', '1', '--nnodes', - str(self.megatron_process_num), '--node_rank', '0', 'tools/inference.py', '--use-checkpoint-args', - '--formatter', 'gpt_eval', '--tokens-to-generate', str( - self.max_tokens), '--input', self.gpt_question_file, - '--output', self.gpt_answer_file, '--load', self.megatron_checkpoint_path, '--load-iteration', - str(iteration), '--model-name', f'{self.full_name}/{iteration}', '--tokenizer-type'] + args = [ + 'torchrun', '--master_addr', '127.0.0.1', '--master_port', '5950', + '--nproc_per_node', '1', '--nnodes', + str(self.megatron_process_num), '--node_rank', '0', + 'tools/inference.py', '--use-checkpoint-args', '--formatter', + 'gpt_eval', '--tokens-to-generate', + str(self.max_tokens), '--input', self.gpt_question_file, + '--output', self.gpt_answer_file, '--load', + self.megatron_checkpoint_path, '--load-iteration', + str(iteration), '--model-name', f'{self.full_name}/{iteration}', + '--tokenizer-type' + ] self._set_megatron_tokenizer(args) logfile = open(self.megatron_log_path, 'w') os.chdir(self.megatron_home) @@ -179,16 +209,21 @@ def run_megatron_inference(self, iteration): return {} def megatron_checkpoint_exists(self, iteration): - with open(os.path.join(self.megatron_checkpoint_path, 'latest_checkpointed_iteration.txt'), 'r') as f: + with open( + os.path.join(self.megatron_checkpoint_path, + 'latest_checkpointed_iteration.txt'), 'r') as f: latest_checkpoint_iter = int(f.readline()) if iteration > latest_checkpoint_iter: return False - checkpoint_path = os.path.join( - self.megatron_checkpoint_path, 'iter_{:07d}'.format(iteration)) + checkpoint_path = os.path.join(self.megatron_checkpoint_path, + 'iter_{:07d}'.format(iteration)) return os.path.exists(checkpoint_path) def replace_pattern(self, input_file, output_file, pattern, s): - with open(input_file, 'r', encoding='utf-8') as input, open(output_file, 'w', encoding='utf-8') as output: + with open(input_file, 'r', + encoding='utf-8') as input, open(output_file, + 'w', + encoding='utf-8') as output: lines = input.readlines() for i in range(len(lines)): lines[i] = lines[i].replace(pattern, s) @@ -199,15 +234,23 @@ def run_helm_eval(self, iteration): if os.path.exists(self.helm_cache_path): shutil.rmtree(self.helm_cache_path) self.replace_pattern(self.helm_spec_template_path, self.helm_spec_path, - '', f'mymodel/{self.full_name}/{iteration}') - helm_run_args = ['conda', 'run', '-n', self.helm_conda_env, '--no-capture-output', 'helm-run', '-n', '4', '-m', str(self.helm_eval_instances), - '--conf-paths', self.helm_spec_path, '--my-config-path', self.helm_mymodel_config, - '--local-path', self.helm_cache_path, - '--suite', self.helm_suite_name, '-o', self.helm_output_path] + '', + f'mymodel/{self.full_name}/{iteration}') + helm_run_args = [ + 'conda', 'run', '-n', self.helm_conda_env, '--no-capture-output', + 'helm-run', '-n', '4', '-m', + str(self.helm_eval_instances), '--conf-paths', self.helm_spec_path, + '--my-config-path', self.helm_mymodel_config, '--local-path', + self.helm_cache_path, '--suite', self.helm_suite_name, '-o', + self.helm_output_path + ] subprocess.check_call(helm_run_args) print(f'run helm summarize for checkpoint iter_{iteration}') - helm_summarize_args = ['conda', 'run', '-n', self.helm_conda_env, '--no-capture-output', - 'helm-summarize', '--suite', self.helm_suite_name, '-o', self.helm_output_path] + helm_summarize_args = [ + 'conda', 'run', '-n', self.helm_conda_env, '--no-capture-output', + 'helm-summarize', '--suite', self.helm_suite_name, '-o', + self.helm_output_path + ] subprocess.check_call(helm_summarize_args) print(f'Finish helm evaluation for checkpoint iter_{iteration}') @@ -226,9 +269,11 @@ def write_wandb(self): if self.helm_benchmarks is not None: helm_config['benchmarks'] = self.helm_benchmarks HelmWriter(project_name=self.wandb_project, - base_url=self.wandb_base_url, helm_config=helm_config) + base_url=self.wandb_base_url, + helm_config=helm_config) - def evaluate(self, start_gen_func, start_eval_func, stop_gen_func, stop_eval_func): + def evaluate(self, start_gen_func, start_eval_func, stop_gen_func, + stop_eval_func): cur_iter = self.begin_iteration while cur_iter <= self.end_iteration: states = start_gen_func(cur_iter) @@ -251,8 +296,8 @@ def run(self): start_eval_func = self.run_gpt_eval stop_gen_func = self.dummy_stop stop_eval_func = self.dummy_stop - self.evaluate(start_gen_func, start_eval_func, - stop_gen_func, stop_eval_func) + self.evaluate(start_gen_func, start_eval_func, stop_gen_func, + stop_eval_func) if __name__ == '__main__': diff --git a/tools/evaluator/gpt_eval/README.md b/tools/evaluator/gpt_eval/README.md index 32408b24c..5cd216b8b 100644 --- a/tools/evaluator/gpt_eval/README.md +++ b/tools/evaluator/gpt_eval/README.md @@ -71,4 +71,4 @@ 3. Run the script. ```shell python gpt_evaluator.py --config - ``` \ No newline at end of file + ``` diff --git a/tools/evaluator/gpt_eval/answer_generator.py b/tools/evaluator/gpt_eval/answer_generator.py index a8e8e6f09..c9b5cc233 100644 --- a/tools/evaluator/gpt_eval/answer_generator.py +++ b/tools/evaluator/gpt_eval/answer_generator.py @@ -1,16 +1,17 @@ -from transformers import AutoModelForCausalLM, AutoTokenizer -import subprocess -import yaml -import jsonlines import argparse -import openai -import time import json import os -import requests - +import subprocess +import time from abc import ABC, abstractmethod + +import jsonlines +import requests +import yaml from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +import openai def parse_args(): @@ -45,12 +46,17 @@ def __init__(self, config): def generate(self, texts, max_tokens, temperature): texts = [format_question(text) for text in texts] - inputs = self.tokenizer( - texts, return_tensors='pt', padding=True).to(self.model.device) - outputs = self.model.generate( - **inputs, max_new_tokens=max_tokens, do_sample=True, temperature=temperature) - return [self.tokenizer.decode( - output[inputs.input_ids.shape[1]:], skip_special_tokens=True) for output in outputs] + inputs = self.tokenizer(texts, return_tensors='pt', + padding=True).to(self.model.device) + outputs = self.model.generate(**inputs, + max_new_tokens=max_tokens, + do_sample=True, + temperature=temperature) + return [ + self.tokenizer.decode(output[inputs.input_ids.shape[1]:], + skip_special_tokens=True) + for output in outputs + ] class OpenAIGenerator(AbstractGenerator): @@ -71,31 +77,31 @@ def __init__(self, config): def generate(self, texts, max_tokens, temperature): outputs = [] for text in texts: - output = "" + output = '' for _ in range(self.max_retry): try: response = openai.ChatCompletion.create( model=self.model, messages=[ { - "role": "system", - "content": "You are a helpful assistant." + 'role': 'system', + 'content': 'You are a helpful assistant.' }, { - "role": "user", - "content": text, + 'role': 'user', + 'content': text, }, ], temperature=temperature, max_tokens=max_tokens, ) - output = response["choices"][0]["message"]["content"] + output = response['choices'][0]['message']['content'] break except Exception as e: print(e) time.sleep(self.retry_wait) if len(output) == 0: - print(f"Failed to answer [{text}]") + print(f'Failed to answer [{text}]') outputs.append(output) return outputs @@ -121,11 +127,11 @@ def __init__(self, config): self.merge_path = config['merge_path'] self.tokenizer_path = None else: - raise NotImplementedError("Unsupported tokenizer type") + raise NotImplementedError('Unsupported tokenizer type') self.megatron_home = self.cur_dir if 'megatron_home' in config: self.megatron_home = config['megatron_home'] - print(f"Megatron-LM home: {self.megatron_home}") + print(f'Megatron-LM home: {self.megatron_home}') self.server_port = config['port'] if 'port' in config else 5000 self.handle = self._run_megatron_server() self.url = f'http://localhost:{self.server_port}/api' @@ -148,28 +154,36 @@ def _set_megatron_tokenizer(self, args): args.append(self.tokenizer_path) def _run_megatron_server(self): - args = ['torchrun', '--master_addr', '127.0.0.1', '--master_port', '5950', '--nproc_per_node', '1', '--nnodes', str(self.process_num), '--node_rank', '0', 'tools/run_text_generation_server.py', '--port', str( - self.server_port), '--use-checkpoint-args', '--load', self.checkpoint_path, '--load-iteration', str(self.load_iteration), '--tokenizer-type'] + args = [ + 'torchrun', '--master_addr', '127.0.0.1', '--master_port', '5950', + '--nproc_per_node', '1', '--nnodes', + str(self.process_num), '--node_rank', '0', + 'tools/run_text_generation_server.py', '--port', + str(self.server_port), '--use-checkpoint-args', '--load', + self.checkpoint_path, '--load-iteration', + str(self.load_iteration), '--tokenizer-type' + ] self._set_megatron_tokenizer(args) os.chdir(self.megatron_home) - process = subprocess.Popen( - args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + process = subprocess.Popen(args, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL) os.chdir(self.cur_dir) return process def _request(self, prompts, max_tokens, temperature): for _ in range(5): try: - response = requests.put(self.url, headers=self.header, data=json.dumps({ - 'prompts': prompts, - 'tokens_to_generate': max_tokens, - 'temperature': temperature, - 'echo_prompts': False - })).json() + response = requests.put(self.url, + headers=self.header, + data=json.dumps({ + 'prompts': prompts, + 'tokens_to_generate': max_tokens, + 'temperature': temperature, + 'echo_prompts': False + })).json() except Exception as e: - response = { - 'message': e - } + response = {'message': e} if 'text' not in response: print(f'Error in megatron response: {response}, retry in 10s') time.sleep(10) @@ -186,15 +200,18 @@ def close(self): class TextGenerator(): + def __init__(self, args): with open(args.config, 'r') as f: config = yaml.safe_load(f)['answer_generation'] - self.questions = [q for q in jsonlines.open( - config['question_file'], 'r')] + self.questions = [ + q for q in jsonlines.open(config['question_file'], 'r') + ] if not os.path.exists(os.path.dirname(config['answer_file'])): os.makedirs(os.path.dirname(config['answer_file'])) - self.answer_writer = jsonlines.open( - config['answer_file'], 'w', flush=True) + self.answer_writer = jsonlines.open(config['answer_file'], + 'w', + flush=True) self.batch_size = config['batch_size'] self.max_tokens = config['max_tokens'] self.temperature = config['temperature'] @@ -206,12 +223,12 @@ def __init__(self, args): elif 'megatron' in config: self.generator = MegatronGenerator(config['megatron']) else: - raise NotImplementedError("Generator not found") + raise NotImplementedError('Generator not found') def generate(self, questions): texts = [question['text'] for question in questions] - answer_texts = self.generator.generate( - texts, self.max_tokens, self.temperature) + answer_texts = self.generator.generate(texts, self.max_tokens, + self.temperature) for (question, answer_text) in zip(questions, answer_texts): self.answer_writer.write({ 'question_id': question['question_id'], diff --git a/tools/evaluator/gpt_eval/config/config.yaml b/tools/evaluator/gpt_eval/config/config.yaml index 08e06d264..7133a99f9 100644 --- a/tools/evaluator/gpt_eval/config/config.yaml +++ b/tools/evaluator/gpt_eval/config/config.yaml @@ -33,4 +33,4 @@ gpt_evaluation: baseline_file: ./answer/openai/gpt-3.5-turbo.jsonl prompt_file: ./config/prompt.jsonl reviewer_file: ./config/reviewer.jsonl - result_file: ./review/myorg/mymodel-gpt3.5-turbo.jsonl \ No newline at end of file + result_file: ./review/myorg/mymodel-gpt3.5-turbo.jsonl diff --git a/tools/evaluator/gpt_eval/config/reviewer.jsonl b/tools/evaluator/gpt_eval/config/reviewer.jsonl index bc9b6fefb..8d49b2f93 100644 --- a/tools/evaluator/gpt_eval/config/reviewer.jsonl +++ b/tools/evaluator/gpt_eval/config/reviewer.jsonl @@ -1,3 +1,3 @@ {"category": "general", "metadata": {"temperature": 0.2, "max_tokens": 1024, "model": "gpt-3.5-turbo"}} {"category": "coding", "metadata": {"temperature": 0.2, "max_tokens": 1024, "model": "gpt-3.5-turbo"}} -{"category": "math", "metadata": {"temperature": 0.2, "max_tokens": 1024, "model": "gpt-3.5-turbo"}} \ No newline at end of file +{"category": "math", "metadata": {"temperature": 0.2, "max_tokens": 1024, "model": "gpt-3.5-turbo"}} diff --git a/tools/evaluator/gpt_eval/gpt_evaluator.py b/tools/evaluator/gpt_eval/gpt_evaluator.py index 1561829e9..e3d8c8adb 100644 --- a/tools/evaluator/gpt_eval/gpt_evaluator.py +++ b/tools/evaluator/gpt_eval/gpt_evaluator.py @@ -2,29 +2,38 @@ # https://github.com/lm-sys/FastChat # -------------------------------------------------------- -import jsonlines -import openai -import logging -import time import argparse -import yaml +import logging import os +import time from multiprocessing import Pool + +import jsonlines +import yaml from tqdm import tqdm +import openai + logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--config', type=str, required=True, - help="Config file path") - parser.add_argument('--worker-num', type=int, default=4, - help="Number of workers for OpenAI API") - parser.add_argument("--max-retry", type=int, default=5, + parser.add_argument('--config', + type=str, + required=True, + help='Config file path') + parser.add_argument('--worker-num', + type=int, + default=4, + help='Number of workers for OpenAI API') + parser.add_argument('--max-retry', + type=int, + default=5, help='Retry times for OpenAI API') - parser.add_argument("--debug", action='store_true', + parser.add_argument('--debug', + action='store_true', help='Run without calling OpenAI API') return parser.parse_args() @@ -34,20 +43,26 @@ class GPTEvaluator(): def __init__(self, config): openai.organization = config['openai_organization'] openai.api_key = config['openai_api_key'] - self.questions = [q for q in jsonlines.open( - config['question_file'], 'r')] - self.answers = [a for a in jsonlines.open( - config['answer_file'], 'r')] - self.baseline = [b for b in jsonlines.open( - config['baseline_file'], 'r')] + self.questions = [ + q for q in jsonlines.open(config['question_file'], 'r') + ] + self.answers = [a for a in jsonlines.open(config['answer_file'], 'r')] + self.baseline = [ + b for b in jsonlines.open(config['baseline_file'], 'r') + ] self.prompt_templates = { - p['category']: p for p in jsonlines.open(config['prompt_file'], 'r')} + p['category']: p + for p in jsonlines.open(config['prompt_file'], 'r') + } self.reviewers = { - z['category']: z for z in jsonlines.open(config['reviewer_file'], 'r')} + z['category']: z + for z in jsonlines.open(config['reviewer_file'], 'r') + } if not os.path.exists(os.path.dirname(config['result_file'])): os.makedirs(os.path.dirname(config['result_file'])) - self.result_writer = jsonlines.open( - config['result_file'], 'w', flush=True) + self.result_writer = jsonlines.open(config['result_file'], + 'w', + flush=True) self.worker_num = config['worker_num'] if 'worker_num' in config else 4 self.max_retry = config['max_retry'] if 'max_retry' in config else 5 self.debug = config['debug'] if 'debug' in config else False @@ -59,32 +74,32 @@ def generate_prompt(self, question, answer, baseline, prompts): else: reviewer = self.reviewers['general'] prompt_json = prompts['general'] - sys_prompt = prompt_json["system_prompt"] - prompt_template = prompt_json["prompt_template"] - defaults = prompt_json["defaults"] - prompt1 = prompt_template.format( - question=question['text'], answer_1=answer['text'], answer_2=baseline['text'], **defaults - ) - prompt2 = prompt_template.format( - question=question['text'], answer_1=baseline['text'], answer_2=answer['text'], **defaults - ) + sys_prompt = prompt_json['system_prompt'] + prompt_template = prompt_json['prompt_template'] + defaults = prompt_json['defaults'] + prompt1 = prompt_template.format(question=question['text'], + answer_1=answer['text'], + answer_2=baseline['text'], + **defaults) + prompt2 = prompt_template.format(question=question['text'], + answer_1=baseline['text'], + answer_2=answer['text'], + **defaults) return sys_prompt, prompt1, prompt2, reviewer def parse_score(self, review): review = review.strip('\n') - score_pair = review.split("\n")[-1] + score_pair = review.split('\n')[-1] score_pair.strip() - sp = score_pair.split(",") + sp = score_pair.split(',') try: if len(sp) == 2: return [float(sp[0]), float(sp[1])] else: - logger.error( - f"Invalid score pair." - ) + logger.error(f'Invalid score pair.') return [0, 0] except Exception as e: - logger.error("Invalid answer") + logger.error('Invalid answer') return [0, 0] def run(self): @@ -93,7 +108,8 @@ def run(self): question_num = len(self.questions) for i in range(question_num): sys_prompt, prompt1, prompt2, reviewer = self.generate_prompt( - self.questions[i], self.answers[i], self.baseline[i], self.prompt_templates) + self.questions[i], self.answers[i], self.baseline[i], + self.prompt_templates) results.append({ 'question_id': self.questions[i]['question_id'], 'metadata': reviewer['metadata'], @@ -102,9 +118,23 @@ def run(self): }) pool = Pool(processes=self.worker_num) requests.append({ - 'sys_prompt': sys_prompt, 'user_prompt': prompt1, 'temperature': reviewer['metadata']['temperature'], 'max_tokens': reviewer['metadata']['max_tokens'], 'model': reviewer['metadata']['model'], 'debug': self.debug, 'retry': self.max_retry}) + 'sys_prompt': sys_prompt, + 'user_prompt': prompt1, + 'temperature': reviewer['metadata']['temperature'], + 'max_tokens': reviewer['metadata']['max_tokens'], + 'model': reviewer['metadata']['model'], + 'debug': self.debug, + 'retry': self.max_retry + }) requests.append({ - 'sys_prompt': sys_prompt, 'user_prompt': prompt2, 'temperature': reviewer['metadata']['temperature'], 'max_tokens': reviewer['metadata']['max_tokens'], 'model': reviewer['metadata']['model'], 'debug': self.debug, 'retry': self.max_retry}) + 'sys_prompt': sys_prompt, + 'user_prompt': prompt2, + 'temperature': reviewer['metadata']['temperature'], + 'max_tokens': reviewer['metadata']['max_tokens'], + 'model': reviewer['metadata']['model'], + 'debug': self.debug, + 'retry': self.max_retry + }) reviews = pool.map(eval, requests) target_score = 0.0 baseline_score = 0.0 @@ -126,10 +156,10 @@ def run(self): cnt += 1 target_avg_score = target_score / cnt / 2 baseline_avg_score = baseline_score / cnt / 2 - print("-------------------------") + print('-------------------------') print(f"> {results[0]['model1']}: {target_avg_score}") print(f"> {results[0]['model2']}: {baseline_avg_score}") - print("-------------------------") + print('-------------------------') self.result_writer.write({ f"{results[0]['model1']}": target_avg_score, f"{results[0]['model2']}": baseline_avg_score @@ -140,34 +170,38 @@ def run(self): def eval(request): if request['debug']: logger.info(f"Fake response {request['user_prompt']}") - return "Fake response\n10,9\n" + return 'Fake response\n10,9\n' for _ in range(request['retry']): try: response = openai.ChatCompletion.create( model=request['model'], messages=[ - {"role": "system", "content": request['sys_prompt']}, { - "role": "user", - "content": request['user_prompt'], + 'role': 'system', + 'content': request['sys_prompt'] + }, + { + 'role': 'user', + 'content': request['user_prompt'], }, ], temperature=request['temperature'], max_tokens=request['max_tokens'], ) - content = response["choices"][0]["message"]["content"] + content = response['choices'][0]['message']['content'] logger.info(content) return content except Exception as e: logger.error(e) time.sleep(5) logger.error(f"Failed after {request['retry']} retries.") - return "error" + return 'error' -if __name__ == "__main__": +if __name__ == '__main__': args = parse_args() - config = yaml.safe_load(open(args.config, 'r', encoding='utf-8'))['gpt_evaluation'] + config = yaml.safe_load(open(args.config, 'r', + encoding='utf-8'))['gpt_evaluation'] config['worker_num'] = args.worker_num config['max_retry'] = args.max_retry config['debug'] = args.debug diff --git a/tools/evaluator/recorder/README.md b/tools/evaluator/recorder/README.md index b31e55d98..739da3434 100644 --- a/tools/evaluator/recorder/README.md +++ b/tools/evaluator/recorder/README.md @@ -110,4 +110,4 @@ excluded_models: # models that do not participate in the leaderboard - ... ``` -> We use 16 core metrics of HELM as the default leaderboard metrics if the `leaderboard_metrics` field is not provided, the 16 metrics are as same as the default benchmark metrics. \ No newline at end of file +> We use 16 core metrics of HELM as the default leaderboard metrics if the `leaderboard_metrics` field is not provided, the 16 metrics are as same as the default benchmark metrics. diff --git a/tools/evaluator/recorder/config/leaderboard_example.yaml b/tools/evaluator/recorder/config/leaderboard_example.yaml index 1a954facb..d1444f39b 100644 --- a/tools/evaluator/recorder/config/leaderboard_example.yaml +++ b/tools/evaluator/recorder/config/leaderboard_example.yaml @@ -10,4 +10,4 @@ leaderboard_metrics: - ... excluded_runs: - - - ... \ No newline at end of file + - ... diff --git a/tools/evaluator/recorder/config/llama_example.yaml b/tools/evaluator/recorder/config/llama_example.yaml index 646a2001c..faba17bf3 100644 --- a/tools/evaluator/recorder/config/llama_example.yaml +++ b/tools/evaluator/recorder/config/llama_example.yaml @@ -37,4 +37,4 @@ evals: hellaswag: EM: 0.747 openbookqa: - EM: 0.574 \ No newline at end of file + EM: 0.574 diff --git a/tools/evaluator/recorder/config/mymodel_example.yaml b/tools/evaluator/recorder/config/mymodel_example.yaml index f0231294a..fa80ae877 100644 --- a/tools/evaluator/recorder/config/mymodel_example.yaml +++ b/tools/evaluator/recorder/config/mymodel_example.yaml @@ -23,4 +23,4 @@ evals: - name: hellaswag metrics: - EM - - ... \ No newline at end of file + - ... diff --git a/tools/evaluator/recorder/wandb_writer.py b/tools/evaluator/recorder/wandb_writer.py index 188a428b9..95fa2744f 100644 --- a/tools/evaluator/recorder/wandb_writer.py +++ b/tools/evaluator/recorder/wandb_writer.py @@ -1,15 +1,14 @@ -import wandb import argparse import json -import yaml import os +import wandb +import yaml + def get_args(): parser = argparse.ArgumentParser( - description="write evaluation result into wandb", - allow_abbrev=False - ) + description='write evaluation result into wandb', allow_abbrev=False) parser.add_argument('--config', type=str, required=True) parser.add_argument('--summary-only', action='store_true') parser.add_argument('--print-only', action='store_true') @@ -22,36 +21,69 @@ def __init__(self, project_name, base_url=None, print_only=False, - summary_only=False - ) -> None: + summary_only=False) -> None: self.project = project_name self.base_url = base_url self.print_only = print_only self.summary_only = summary_only -DEFAULT_HELM_BENCHMARKS = [ - {"name": "mmlu", "metrics": ["EM"]}, - {"name": "raft", "metrics": ["EM"]}, - {"name": "imdb", "metrics": ["EM"]}, - {"name": "truthful_qa", "metrics": ["EM"]}, - {"name": "summarization_cnndm", "metrics": ["ROUGE-2"]}, - {"name": "summarization_xsum", "metrics": ["ROUGE-2"]}, - {"name": "boolq", "metrics": ["EM"]}, - {"name": "msmarco_trec", "metrics": ["NDCG@10"]}, - {"name": "msmarco_regular", "metrics": ["RR@10"]}, - {"name": "narrative_qa", "metrics": ["F1"]}, - {"name": "natural_qa_closedbook", "metrics": ["F1"]}, - {"name": "natural_qa_openbook_longans", "metrics": ["F1"]}, - {"name": "quac", "metrics": ["F1"]}, - {"name": "civil_comments", "metrics": ["EM"]}, - {"name": "hellaswag", "metrics": ["EM"]}, - {"name": "openbookqa", "metrics": ["EM"]} -] +DEFAULT_HELM_BENCHMARKS = [{ + 'name': 'mmlu', + 'metrics': ['EM'] +}, { + 'name': 'raft', + 'metrics': ['EM'] +}, { + 'name': 'imdb', + 'metrics': ['EM'] +}, { + 'name': 'truthful_qa', + 'metrics': ['EM'] +}, { + 'name': 'summarization_cnndm', + 'metrics': ['ROUGE-2'] +}, { + 'name': 'summarization_xsum', + 'metrics': ['ROUGE-2'] +}, { + 'name': 'boolq', + 'metrics': ['EM'] +}, { + 'name': 'msmarco_trec', + 'metrics': ['NDCG@10'] +}, { + 'name': 'msmarco_regular', + 'metrics': ['RR@10'] +}, { + 'name': 'narrative_qa', + 'metrics': ['F1'] +}, { + 'name': 'natural_qa_closedbook', + 'metrics': ['F1'] +}, { + 'name': 'natural_qa_openbook_longans', + 'metrics': ['F1'] +}, { + 'name': 'quac', + 'metrics': ['F1'] +}, { + 'name': 'civil_comments', + 'metrics': ['EM'] +}, { + 'name': 'hellaswag', + 'metrics': ['EM'] +}, { + 'name': 'openbookqa', + 'metrics': ['EM'] +}] DEFAULT_HELM_METRICS = [ - "mmlu.EM", "raft.EM", "imdb.EM", "truthful_qa.EM", "summarization_cnndm.ROUGE-2", "summarization_xsum.ROUGE-2", "boolq.EM", "msmarco_trec.NDCG@10", "msmarco_regular.RR@10", - "narrative_qa.F1", "natural_qa_closedbook.F1", "natural_qa_openbook_longans.F1", "civil_comments.EM", "hellaswag.EM", "openbookqa.EM" + 'mmlu.EM', 'raft.EM', 'imdb.EM', 'truthful_qa.EM', + 'summarization_cnndm.ROUGE-2', 'summarization_xsum.ROUGE-2', 'boolq.EM', + 'msmarco_trec.NDCG@10', 'msmarco_regular.RR@10', 'narrative_qa.F1', + 'natural_qa_closedbook.F1', 'natural_qa_openbook_longans.F1', + 'civil_comments.EM', 'hellaswag.EM', 'openbookqa.EM' ] @@ -68,9 +100,10 @@ def __init__(self, self.conf = helm_config self.leaderboard = leaderboard if self.leaderboard: - self.leaderboard_metrics = self.conf['leaderboard_metrics'] if 'leaderboard_metrics' in self.conf else DEFAULT_HELM_METRICS - self.excluded_models = self.conf['excluded_models'] if 'excluded_models' in self.conf else [ - ] + self.leaderboard_metrics = self.conf[ + 'leaderboard_metrics'] if 'leaderboard_metrics' in self.conf else DEFAULT_HELM_METRICS + self.excluded_models = self.conf[ + 'excluded_models'] if 'excluded_models' in self.conf else [] return self.parse_from_helm = False self.parse_from_file = False @@ -96,9 +129,7 @@ def __init__(self, self.default_iteration = self.conf['default_iteration'] def make_leaderboard(self): - api = wandb.Api(overrides={ - 'base_url': self.base_url - }) + api = wandb.Api(overrides={'base_url': self.base_url}) runs = api.runs(path=f'{self.project}', filters={'tags': 'summary'}) result = {} token_num = {} @@ -108,12 +139,15 @@ def make_leaderboard(self): continue print(run.id) run_name = run.group - history = run.scan_history( - keys=['_step'] + self.leaderboard_metrics, page_size=2000, min_step=0) + history = run.scan_history(keys=['_step'] + + self.leaderboard_metrics, + page_size=2000, + min_step=0) if 'token_num' in run.config: token_num[run_name] = run.config['token_num'] if 'token_per_iteration' in run.config: - token_per_iteration[run_name] = run.config['token_per_iteration'] + token_per_iteration[run_name] = run.config[ + 'token_per_iteration'] for step in history: for metric_name, score in step.items(): if metric_name in ['_step', 'average']: @@ -137,14 +171,13 @@ def make_leaderboard(self): if self.print_only: print(sum_scores) else: - run = wandb.init( - project=self.project, - group='leaderboard', - name='leaderboard', - save_code=False, - id=f'{self.project}-leaderboard', - tags=['leaderboard'], - reinit=True) + run = wandb.init(project=self.project, + group='leaderboard', + name='leaderboard', + save_code=False, + id=f'{self.project}-leaderboard', + tags=['leaderboard'], + reinit=True) data = [] for name, iters in sum_scores.items(): for iter, score in iters.items(): @@ -155,8 +188,8 @@ def make_leaderboard(self): [name, iter * token_per_iteration[name], score]) else: data.append([name, None, score]) - table = wandb.Table(data=data, columns=[ - 'model', 'token_num', 'score']) + table = wandb.Table(data=data, + columns=['model', 'token_num', 'score']) wandb.log( {'benchmark_score': wandb.plot.bar(table, 'model', 'score')}) run.finish() @@ -169,8 +202,8 @@ def cal_score(self, scores): min_score = min(min(iters.values()), min_score) for subject, iters in scores.items(): for iter, score in iters.items(): - scores[subject][iter] = ( - score - min_score) / (max_score - min_score) + scores[subject][iter] = (score - min_score) / (max_score - + min_score) def write(self): if self.leaderboard: @@ -178,9 +211,9 @@ def write(self): elif self.parse_from_helm: self.parse_scenarios() elif self.parse_from_file: - self.write_wandb('summary', { - self.default_iteration: self.eval_result - }, 'summary') + self.write_wandb('summary', + {self.default_iteration: self.eval_result}, + 'summary') else: print('do nothing, please check your config file') @@ -188,8 +221,8 @@ def parse_scenarios(self): summary = {} for scenario in self.scenarios: try: - result = self.parse_scenario( - scenario['name'], scenario['metrics'], self.model) + result = self.parse_scenario(scenario['name'], + scenario['metrics'], self.model) if not self.summary_only: self.write_wandb(scenario['name'], result, 'detail') self.make_summary(scenario['name'], result, summary) @@ -198,7 +231,7 @@ def parse_scenarios(self): self.write_wandb('summary', summary, 'summary') def make_summary(self, scenario_name, eval_result, summary): - print(f"summarize for {scenario_name}") + print(f'summarize for {scenario_name}') for iteration, scenarios in eval_result.items(): if iteration not in summary: summary[iteration] = dict() @@ -221,8 +254,10 @@ def make_average(self, summary): def parse_scenario(self, scenario_name, scenario_metrics, model=None): evaluate_result = {} - with open(os.path.join(self.helm_root, 'runs', self.suite_name, 'groups', f'{scenario_name}.json')) as f: - print(f"parsing {scenario_name}.json") + with open( + os.path.join(self.helm_root, 'runs', self.suite_name, 'groups', + f'{scenario_name}.json')) as f: + print(f'parsing {scenario_name}.json') subjects = json.load(f) for subject in subjects: print(f" parsing {subject['title']}") @@ -244,11 +279,11 @@ def parse_scenario(self, scenario_name, scenario_metrics, model=None): evaluate_result[iteration] = dict() if scenario_name not in evaluate_result[iteration]: evaluate_result[iteration][scenario_name] = dict() - evaluate_result[iteration][scenario_name][subject['title'].split(',')[ - 0]] = dict() + evaluate_result[iteration][scenario_name][ + subject['title'].split(',')[0]] = dict() for metric, i in record_column_idx.items(): - evaluate_result[iteration][scenario_name][subject['title'].split(',')[ - 0]][metric] = row[i]['value'] + evaluate_result[iteration][scenario_name][subject[ + 'title'].split(',')[0]][metric] = row[i]['value'] return evaluate_result def write_wandb(self, name, result, tag): @@ -260,18 +295,17 @@ def write_wandb(self, name, result, tag): config['token_num'] = self.conf['token_num'] if 'token_per_iteration' in self.conf: config['token_per_iteration'] = self.conf['token_per_iteration'] - run = wandb.init( - project=self.project, - group=self.model, - name=name, - save_code=False, - id=f'{self.project}-{self.model}-{name}', - tags=['evalate', tag], - config=config, - reinit=True) - print(f"write {name} to wandb") + run = wandb.init(project=self.project, + group=self.model, + name=name, + save_code=False, + id=f'{self.project}-{self.model}-{name}', + tags=['evalate', tag], + config=config, + reinit=True) + print(f'write {name} to wandb') for iteration in sorted(result.keys()): - print(f" write iteration {iteration} to wandb") + print(f' write iteration {iteration} to wandb') wandb.log(result[iteration], int(iteration)) run.finish() @@ -282,25 +316,21 @@ def main(): eval_configs = config['evals'] if 'evals' in config else [] for eval in eval_configs: if eval['eval_type'] == 'helm': - HelmWriter( - project_name=config['project'], - base_url=config['base_url'], - print_only=args.print_only, - summary_only=args.summary_only, - helm_config=eval - ).write() + HelmWriter(project_name=config['project'], + base_url=config['base_url'], + print_only=args.print_only, + summary_only=args.summary_only, + helm_config=eval).write() else: raise NotImplementedError( f"Unsupported type for eval type {eval['eval_type']}") if 'leaderboard' in config and config['leaderboard'] == True: - HelmWriter( - project_name=config['project'], - base_url=config['base_url'], - leaderboard=True, - helm_config=config, - print_only=args.print_only - ).write() + HelmWriter(project_name=config['project'], + base_url=config['base_url'], + leaderboard=True, + helm_config=config, + print_only=args.print_only).write() -if __name__ == "__main__": +if __name__ == '__main__': main() diff --git a/tools/hpo/README.md b/tools/hpo/README.md index 8728c043e..3b3e8f126 100644 --- a/tools/hpo/README.md +++ b/tools/hpo/README.md @@ -41,11 +41,60 @@ python execute_hpo.py --config configs/process.yaml --hpo_config configs/quality We provide an illustrative objective "quality_score" in `hpo/objects.py`, which uses quality scorer to measure the processed data, and links the average scores to hyper-parameters of data recipes. +After running it, you will get the result similar to: ![img](https://img.alicdn.com/imgextra/i2/O1CN017fT4Al1bVldeuCmiI_!!6000000003471-2-tps-2506-1710.png) -You can implement your own HPO objective in `get_hpo_objective` function, e.g., that links the data + +You can implement your own HPO objective in `get_hpo_objective` function, e.g., linking the data recipes to - model_loss (by replacing the quality scorer into a training procedure), -- downstream_task (by eplacing the quality scorer into a training and an -evaluation procedure), or -- some synergy measures that combines metrics you are interested, such that - the trade-offs from different views can be explored. +- downstream_task (by eplacing the quality scorer into a training and an# Hyper-parameter Optimization for Data Recipe + +## Auto-HPO + +We incorporate an automated HPO tool, WandB [Sweep](https://docs.wandb.ai/guides/sweeps), into Data-Juicer to streamline the finding of good data processing hyper-parameters. +With this tool, users can investigate correlations and importance scores of +specific hyper-parameters of data recipes from the HPO view. + +*Note*: this is an experimental feature. Auto-HPO for data recipes still has +a large room to explore. Feel free to provide more suggestions, discussions, and contributions via new PRs! + + +## Prerequisite +You need to install data-juicer first. +Besides, the tool leverages WandB, install it via `pip install wandb`. +Before using this tool, you need to run ` +```wandb login``` and enter your WandB +API key. +If you have your own instance of WandB (e.g., [locally-hosted machine](https://docs.wandb.ai/guides/hosting/)), run the following script: + +```shell +wandb login --host +# enter your api key +``` + + + +## Usage and Customization + +Given a data recipe, characterized by a specified configuration file +``, you can use `execute_hpo.py` to search the +hyper-parameter space defined by ``. +```shell +# cd tools/hpo +python execute_hpo.py --config --hpo_config + +# e.g., +python execute_hpo.py --config configs/process.yaml --hpo_config configs/quality_score_hpo.yaml +``` + +We provide an illustrative objective "quality_score" in `hpo/objects.py`, +which uses quality scorer to measure the processed data, and links the average scores to hyper-parameters of data recipes. +After running it, you will get the result similar to: ![img](https://img.alicdn.com/imgextra/i2/O1CN017fT4Al1bVldeuCmiI_!!6000000003471-2-tps-2506-1710.png) + + +You can implement your own HPO objective in `get_hpo_objective` function, e.g., linking the data +recipes to +- model_loss (by replacing the quality scorer with a training procedure), +- downstream_task (by replacing the quality scorer with training and + evaluation procedures), or +- some synergy measures that combine metrics you are interested in, such that the trade-offs from different views can be explored. diff --git a/tools/hpo/README_ZH.md b/tools/hpo/README_ZH.md new file mode 100644 index 000000000..ab84e7341 --- /dev/null +++ b/tools/hpo/README_ZH.md @@ -0,0 +1,49 @@ +# 数据菜谱的自动化超参优化 + +## Auto-HPO + +我们将自动化 HPO (hyper-parameters optimization) 工具 WandB [Sweep](https://docs.wandb.ai/guides/sweeps) 结合到 +Data-Juicer 中,以简化改良数据处理超参数的过程。 +使用此工具,用户可以研究探索 *数据配方的特定超参数* 和 *指定目标度量(如数据质量分、模型loss等)* 之间的 相关性和重要性得分 + +*注意*:这是一个实验性功能。 用于数据配方的 Auto-HPO 仍然有 +一个极大的探索空间,暂无标准做法。 欢迎大家提出更多的建议、讨论、 +并通过新的 PR 做出贡献! + + +## 前置条件 +您需要先安装 data-juicer。 +此外,该工具利用了 WandB,通过`pip install wandb`安装它。 +在使用此工具之前,您需要运行`wandb login`并输入您的 WandB +API 密钥。 +如果您有自己的 WandB 实例(例如 [本地托管模式](https://docs.wandb.ai/guides/hosting/) ),请运行以下脚本: + +```shell +wandb login --host +# enter your api key +``` + + + +## 使用和定制化 + +给定一个数据配方,以指定的配置文件所定义``,您可以使用 `execute_hpo.py` 来搜索 +由``定义的超参数空间。 + +```shell +# cd tools/hpo +python execute_hpo.py --config --hpo_config + +# e.g., +python execute_hpo.py --config configs/process.yaml --hpo_config configs/quality_score_hpo.yaml +``` + +我们在`hpo/objects.py`中提供了一个示意性的搜索目标 `quality_score`, +它使用质量评分器来度量处理后的数据,并将平均质分数链接到数据配方的超参数。 +运行后,你会得到类似如下的结果:![img](https://img.alicdn.com/imgextra/i2/O1CN017fT4Al1bVldeuCmiI_!!6000000003471-2-tps-2506-1710.png) + + +您可以在 `get_hpo_objective` 函数中实现您自己的 HPO 目标,例如,将数据配方链接到 +- model_loss(通过用训练程序 替换质量评分器), +- 下游任务(通过用训练和评测程序 替换质量评分器),或 +- 一些您感兴趣的指标的综合考量,以便可以探索不同角度的权衡(如size-quality-diversity)。 diff --git a/tools/hpo/configs/quality_score_hpo.yaml b/tools/hpo/configs/quality_score_hpo.yaml index 0f5ef41a8..543e7b64b 100644 --- a/tools/hpo/configs/quality_score_hpo.yaml +++ b/tools/hpo/configs/quality_score_hpo.yaml @@ -19,8 +19,8 @@ parameters: values: [0.3, 0.5, 0.7] text_length_filter.min_len: distribution: q_log_uniform_values - min: 8 - max: 512 + min: 256 + max: 8192 #early_terminate: diff --git a/tools/hpo/execute_hpo.py b/tools/hpo/execute_hpo.py index 66658564e..1293bfa66 100644 --- a/tools/hpo/execute_hpo.py +++ b/tools/hpo/execute_hpo.py @@ -1,9 +1,9 @@ import sys +import wandb import yaml from jsonargparse import namespace_to_dict -import wandb from data_juicer.config import init_configs, merge_config from objects import get_hpo_objective @@ -32,7 +32,7 @@ def search(): wandb.config = namespace_to_dict(dj_cfg) # for configuration track # 2.2: calculate objective using new hyper-parameters, track the results - score = object_func(dj_cfg) + score = float(object_func(dj_cfg)) wandb.log({sweep_configuration['metric']['name']: score}) diff --git a/tools/hpo/objects.py b/tools/hpo/objects.py index f6749bc58..eff8eff59 100644 --- a/tools/hpo/objects.py +++ b/tools/hpo/objects.py @@ -1,3 +1,6 @@ +import os +import shutil + from data_juicer.core import Executor from tools.quality_classifier.predict import predict_score @@ -45,8 +48,16 @@ def obj_quality_score(dj_cfg): # [--tokenizer ] \ # [--keep_method ] \ # [--text_key ] \ + + tmp_res_export_path = dj_cfg.export_path + '.tmp_hpo.jsonl' + if os.path.exists(tmp_res_export_path): + if os.path.isfile(tmp_res_export_path): + os.remove(tmp_res_export_path) + if os.path.isdir(tmp_res_export_path): + shutil.rmtree(tmp_res_export_path) + overall_quality_stats = predict_score(dj_cfg.export_path, - dj_cfg.export_path, + tmp_res_export_path, overall_stats=True) # by default, using the mean quality score of processed data as final score diff --git a/tools/postprocess/README.md b/tools/postprocess/README.md index de2dfe4bb..5aa437eeb 100644 --- a/tools/postprocess/README.md +++ b/tools/postprocess/README.md @@ -14,7 +14,7 @@ python tools/postprocess/count_token.py \ --text_keys \ --tokenizer_method \ --num_proc - + # get help python tools/postprocess/count_token.py --help ``` @@ -71,4 +71,4 @@ python tools/postprocess/deserialize_meta.py --help - `serialized_key`: the key corresponding to the field that will be deserialized. Default it's 'source_info'. - `num_proc` (optional): number of process workers. Default it's 1. -**Note:** After deserialization, all serialized fields in the original file will be placed in `'serialized_key'`, this is to ensure that the fields generated after data-juicer processing will not conflict with the original meta fields. \ No newline at end of file +**Note:** After deserialization, all serialized fields in the original file will be placed in `'serialized_key'`, this is to ensure that the fields generated after data-juicer processing will not conflict with the original meta fields. diff --git a/tools/postprocess/README_ZH.md b/tools/postprocess/README_ZH.md index 99b5579c9..b59bebe62 100644 --- a/tools/postprocess/README_ZH.md +++ b/tools/postprocess/README_ZH.md @@ -14,7 +14,7 @@ python tools/postprocess/count_token.py \ --text_keys \ --tokenizer_method \ --num_proc - + # get help python tools/postprocess/count_token.py --help ``` @@ -69,4 +69,4 @@ python tools/postprocess/deserialize_meta.py --help - `serialized_key`: 将被反序列化的字段对应的 key, 默认为“source_info”。. - `num_proc` (optional): worker 进程数量,默认为 1 -**注意事项:** 经过反序列化后原始文件中所有被序列化的字段都会放在`‘serialized_key’`中,这样做是为了保证 data-juicer 处理后生成的字段不会和原有的元字段冲突。 \ No newline at end of file +**注意事项:** 经过反序列化后原始文件中所有被序列化的字段都会放在`‘serialized_key’`中,这样做是为了保证 data-juicer 处理后生成的字段不会和原有的元字段冲突。 diff --git a/tools/postprocess/count_token.py b/tools/postprocess/count_token.py index 8edfee2dd..ccb6864aa 100644 --- a/tools/postprocess/count_token.py +++ b/tools/postprocess/count_token.py @@ -1,14 +1,14 @@ +from multiprocessing import Pool import fire import jsonlines as jl - -from tqdm import tqdm -from multiprocessing import Pool from loguru import logger +from tqdm import tqdm from transformers import AutoTokenizer TOKENIZER = None + def count_token_single(sample, text_keys): global TOKENIZER num = 0 @@ -23,6 +23,7 @@ def prepare_tokenizer(tokenizer_method): TOKENIZER = AutoTokenizer.from_pretrained(tokenizer_method, trust_remote_code=True) + def main(data_path, text_keys='text', tokenizer_method='EleutherAI/pythia-6.9b-deduped', @@ -45,8 +46,12 @@ def main(data_path, result_list = [] with Pool(num_proc) as p: for sample in tqdm(reader): - result_list.append(p.apply_async(count_token_single, - args=(sample, text_keys,))) + result_list.append( + p.apply_async(count_token_single, + args=( + sample, + text_keys, + ))) for res in tqdm(result_list): token_count += res.get() diff --git a/tools/postprocess/deserialize_meta.py b/tools/postprocess/deserialize_meta.py index f7f67bedd..c9659f78c 100644 --- a/tools/postprocess/deserialize_meta.py +++ b/tools/postprocess/deserialize_meta.py @@ -1,9 +1,10 @@ -import os import json +import os import pathlib -import jsonlines from multiprocessing import Pool + import fire +import jsonlines def fp_iter(src_dir): @@ -48,16 +49,16 @@ def main(src_dir, target_dir, serialized_key='source_info', num_proc=1): if not os.path.exists(target_dir): os.makedirs(target_dir, exist_ok=True) - pool = Pool(num_proc) for fp in fp_iter(src_dir): print(fp) jsonl_fp = os.path.join(target_dir, fp.name) - pool.apply_async(meta_deserialize, args=(str(fp), jsonl_fp, serialized_key)) + pool.apply_async(meta_deserialize, + args=(str(fp), jsonl_fp, serialized_key)) pool.close() pool.join() if __name__ == '__main__': - fire.Fire(main) \ No newline at end of file + fire.Fire(main) diff --git a/tools/preprocess/README.md b/tools/preprocess/README.md index c8ca7a6b7..6a33910ed 100644 --- a/tools/preprocess/README.md +++ b/tools/preprocess/README.md @@ -165,4 +165,4 @@ python tools/preprocess/serialize_meta.py --help - `target_dir`: path to save the converted jsonl files. - `text_key`: the key corresponding to the field that will not be serialized. Defaul it's 'text'. - `serialized_key`: the key corresponding to the field that the serialized info saved. Default it's 'source_info'. -- `num_proc` (optional): number of process workers. Default it's 1. \ No newline at end of file +- `num_proc` (optional): number of process workers. Default it's 1. diff --git a/tools/preprocess/README_ZH.md b/tools/preprocess/README_ZH.md index cdd0d7c2b..f715a50df 100644 --- a/tools/preprocess/README_ZH.md +++ b/tools/preprocess/README_ZH.md @@ -158,4 +158,4 @@ python tools/preprocess/serialize_meta.py --help - `target_dir`: 保存转换后的 jsonl 文件的路径。 - `text_key`: 不会被序列化的字段对应的 key, 默认为 “text”。 - `serialized_key`: 序列化后的信息保存的字段对应的 key, 默认为 “source_info”。 -- `num_proc` (可选): worker 进程数量,默认为 1 \ No newline at end of file +- `num_proc` (可选): worker 进程数量,默认为 1 diff --git a/tools/preprocess/raw_alpaca_cot_merge_add_meta.py b/tools/preprocess/raw_alpaca_cot_merge_add_meta.py index 52b3fb5f4..fdef96795 100644 --- a/tools/preprocess/raw_alpaca_cot_merge_add_meta.py +++ b/tools/preprocess/raw_alpaca_cot_merge_add_meta.py @@ -21,8 +21,7 @@ 'IFT': True, # whether is IFT data, added by Data-Juicer 'CFT-SR': False, # whether is CFT single-round data, added by # Data-Juicer - 'CFT-P': - False, # whether is Preference data, added by Data-Juicer + 'CFT-P': False, # whether is Preference data, added by Data-Juicer }, 'GPT4all': { 'Task': 'MT', @@ -364,45 +363,45 @@ 'CFT-SR': True, 'CFT-P': True, }, - "ConvAI2": { - "Task": "TS", - "Gen": "HG", - "Lang": "EN", - "Dataset": "ConvAI2", - "CFT-MR": False, - "IFT": False, - "CFT-SR": True, - "CFT-P": False, - }, - "FastChat": { - "Task": "MT", - "Gen": "SI", - "Lang": "EN", - "Dataset": "FastChat", - "CFT-MR": False, - "IFT": False, - "CFT-SR": True, - "CFT-P": False, + 'ConvAI2': { + 'Task': 'TS', + 'Gen': 'HG', + 'Lang': 'EN', + 'Dataset': 'ConvAI2', + 'CFT-MR': False, + 'IFT': False, + 'CFT-SR': True, + 'CFT-P': False, + }, + 'FastChat': { + 'Task': 'MT', + 'Gen': 'SI', + 'Lang': 'EN', + 'Dataset': 'FastChat', + 'CFT-MR': False, + 'IFT': False, + 'CFT-SR': True, + 'CFT-P': False, }, 'Tabular-LLM-Data': { 'Task': 'MT', 'Gen': 'COL', 'Lang': 'EN/CN', - "Dataset": "Tabular-LLM-Data", - "CFT-MR": False, - "IFT": True, - "CFT-SR": False, - "CFT-P": False, + 'Dataset': 'Tabular-LLM-Data', + 'CFT-MR': False, + 'IFT': True, + 'CFT-SR': False, + 'CFT-P': False, }, 'ThoughtSource': { 'Task': 'MT', 'Gen': 'COL', 'Lang': 'EN', - "Dataset": "ThoughtSource", - "CFT-MR": False, - "IFT": True, - "CFT-SR": False, - "CFT-P": False, + 'Dataset': 'ThoughtSource', + 'CFT-MR': False, + 'IFT': True, + 'CFT-SR': False, + 'CFT-P': False, } } diff --git a/tools/quality_classifier/predict.py b/tools/quality_classifier/predict.py index 71a258ce7..f3a88ef6b 100644 --- a/tools/quality_classifier/predict.py +++ b/tools/quality_classifier/predict.py @@ -60,8 +60,9 @@ import fire from loguru import logger -from .qc_utils import (export_result, init_spark, load_dataset, predict, - prepare_model) +from tools.quality_classifier.qc_utils import (export_result, init_spark, + load_dataset, predict, + prepare_model) @logger.catch @@ -104,6 +105,12 @@ def predict_score(dataset_path, keep_method = 'gpt3' # initialize a spark session + if '_JAVA_OPTIONS' in os.environ and \ + '-Djava.net.preferIPv6Addresses=true' \ + in os.environ['_JAVA_OPTIONS']: + os.environ['_JAVA_OPTIONS'] = os.environ['_JAVA_OPTIONS'].replace( + '-Djava.net.preferIPv6Addresses=true', + '-Djava.net.preferIPv6Addresses=false') spark = init_spark() # load the quality classifier model model = prepare_model(model_name=model) diff --git a/tools/quality_classifier/qc_utils.py b/tools/quality_classifier/qc_utils.py index 0ab2c42ed..448ea2368 100644 --- a/tools/quality_classifier/qc_utils.py +++ b/tools/quality_classifier/qc_utils.py @@ -2,12 +2,8 @@ import zipfile import numpy as np - import sentencepiece as spm import wget -from data_juicer.utils.cache_utils import DATA_JUICER_MODELS_CACHE -from data_juicer.utils.model_utils import (MODEL_LINKS, - prepare_sentencepiece_model) from loguru import logger from pyspark.ml import Pipeline, PipelineModel from pyspark.ml.classification import LogisticRegression @@ -16,6 +12,10 @@ from pyspark.sql.functions import col, rand, udf from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StringType +from data_juicer.utils.cache_utils import DATA_JUICER_MODELS_CACHE +from data_juicer.utils.model_utils import (MODEL_LINKS, + prepare_sentencepiece_model) + def init_spark(spark_executor_memory=None, spark_driver_memory=None, @@ -31,7 +31,7 @@ def init_spark(spark_executor_memory=None, spark_driver_memory = '64g' if not spark_executor_memoryOverhead: spark_executor_memoryOverhead = '20000' - spark = (SparkSession.builder.config( + spark = (SparkSession.builder.master('local[*]').config( 'spark.driver.memory', spark_driver_memory).config( 'spark.executor.memory', spark_executor_memory).config( 'spark.sql.shuffle.partitions', '300').config( From c96b7cfaf63b28b3b5154f7d956b1cc68968545c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=81=93=E8=BE=95?= Date: Wed, 8 Nov 2023 12:15:53 +0800 Subject: [PATCH 3/6] minor fix for relative import --- data_juicer/config/config.py | 3 +-- demos/data_process_hpo/app.py | 1 - tools/quality_classifier/eval.py | 2 +- tools/quality_classifier/train.py | 3 ++- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index 0fbbb80f2..cda9e5c36 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -468,8 +468,7 @@ def merge_config(ori_cfg, new_cfg: Dict): ori_cfg_val = ori_cfg.process[op_order][op_name][para_name] print( '=' * 15, f'\nBefore merging, the cfg item is: ' - f'{new_k}: {ori_cfg_val}' - ) + f'{new_k}: {ori_cfg_val}') ori_cfg.process[op_order][op_name][para_name] = new_v print( f'After merging, the cfg item is: ' diff --git a/demos/data_process_hpo/app.py b/demos/data_process_hpo/app.py index a5f62552c..aecf23ba2 100644 --- a/demos/data_process_hpo/app.py +++ b/demos/data_process_hpo/app.py @@ -1,4 +1,3 @@ - import streamlit as st diff --git a/tools/quality_classifier/eval.py b/tools/quality_classifier/eval.py index 47da8ef03..be1bfa622 100644 --- a/tools/quality_classifier/eval.py +++ b/tools/quality_classifier/eval.py @@ -24,7 +24,7 @@ import fire from loguru import logger -from .qc_utils import eval, init_spark, load_datasets +from tools.quality_classifier.qc_utils import eval, init_spark, load_datasets @logger.catch diff --git a/tools/quality_classifier/train.py b/tools/quality_classifier/train.py index dab2b2759..e0f7fd5aa 100644 --- a/tools/quality_classifier/train.py +++ b/tools/quality_classifier/train.py @@ -30,7 +30,8 @@ import fire from loguru import logger -from .qc_utils import eval, init_spark, load_datasets, shuffle, train +from tools.quality_classifier.qc_utils import (eval, init_spark, load_datasets, + shuffle, train) @logger.catch From 6293d8d85af9e5dc90dce891d450b1e997a09213 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=81=93=E8=BE=95?= Date: Wed, 8 Nov 2023 16:19:20 +0800 Subject: [PATCH 4/6] fix according to yilun's comments --- data_juicer/config/config.py | 31 ++++++------- tools/hpo/README.md | 55 ++---------------------- tools/hpo/README_ZH.md | 4 +- tools/hpo/configs/quality_score_hpo.yaml | 9 ++-- 4 files changed, 22 insertions(+), 77 deletions(-) diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index cda9e5c36..3a5ff0ed0 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -436,12 +436,7 @@ def merge_config(ori_cfg, new_cfg: Dict): try: ori_specified_op_names = set() ori_specified_op_idx = {} # {op_name: op_order} - # format of ori_cfg.process - # ori_cfg.process[i] = { - # op_in_process_name: - # None if internal_op_para is None else - # namespace_to_dict(internal_op_para) - # } + for op_order, op_in_process in enumerate(ori_cfg.process): op_name = list(op_in_process.keys())[0] ori_specified_op_names.add(op_name) @@ -450,13 +445,13 @@ def merge_config(ori_cfg, new_cfg: Dict): for new_k, new_v in new_cfg.items(): # merge parameters other than `cfg.process` and DJ-OPs if new_k in ori_cfg and new_k != 'process' and '.' not in new_k: - print( - '=' * 15, f'\nBefore merging, the cfg item is: ' - f'{new_k}: {ori_cfg[new_k]}') + logger.info('=' * 15) + logger.info(f'Before merging, the cfg item is: ' + f'{new_k}: {ori_cfg[new_k]}') ori_cfg[new_k] = new_v - print( - f'After merging, the cfg item is: ' - f'{new_k}: {new_v}\n', '=' * 15, '\n') + logger.info(f'After merging, the cfg item is: ' + f'{new_k}: {new_v}') + logger.info('=' * 15) else: # merge parameters of DJ-OPs into cfg.process # for nested style, e.g., `remove_table_text_mapper.min_col: 2` @@ -466,13 +461,13 @@ def merge_config(ori_cfg, new_cfg: Dict): op_name, para_name = key_as_groups[0], key_as_groups[1] op_order = ori_specified_op_idx[op_name] ori_cfg_val = ori_cfg.process[op_order][op_name][para_name] - print( - '=' * 15, f'\nBefore merging, the cfg item is: ' - f'{new_k}: {ori_cfg_val}') + logger.info('=' * 15) + logger.info(f'Before merging, the cfg item is: ' + f'{new_k}: {ori_cfg_val}') ori_cfg.process[op_order][op_name][para_name] = new_v - print( - f'After merging, the cfg item is: ' - f'{new_k}: {new_v}\n', '=' * 15, '\n') + logger.info(f'After merging, the cfg item is: ' + f'{new_k}: {new_v}') + logger.info('=' * 15) ori_cfg = init_setup_from_cfg(ori_cfg) diff --git a/tools/hpo/README.md b/tools/hpo/README.md index 3b3e8f126..f7b7108b3 100644 --- a/tools/hpo/README.md +++ b/tools/hpo/README.md @@ -6,7 +6,7 @@ We incorporate an automated HPO tool, WandB [Sweep](https://docs.wandb.ai/guides With this tool, users can investigate correlations and importance scores of specific hyper-parameters of data recipes from the HPO view. -*Note*: this is an experimental feature. Auto-HPO for data recipes still has +**Note**: this is an experimental feature. Auto-HPO for data recipes still has a large room to explore. Feel free to provide more suggestions, discussion, and contribution via new PRs! @@ -14,7 +14,7 @@ and contribution via new PRs! ## Prerequisite You need to install data-juicer first. Besides, the tool leverages WandB, install it via `pip install wandb`. -Before using this tool, you need to run ` +Before using this tool, you need to run ```wandb login``` and enter your WandB API key. If you have your own instance of WandB (e.g., [locally-hosted machine](https://docs.wandb.ai/guides/hosting/)), run the following script: @@ -47,54 +47,5 @@ After running it, you will get the result similar to: ![img](https://img.alicdn. You can implement your own HPO objective in `get_hpo_objective` function, e.g., linking the data recipes to - model_loss (by replacing the quality scorer into a training procedure), -- downstream_task (by eplacing the quality scorer into a training and an# Hyper-parameter Optimization for Data Recipe - -## Auto-HPO - -We incorporate an automated HPO tool, WandB [Sweep](https://docs.wandb.ai/guides/sweeps), into Data-Juicer to streamline the finding of good data processing hyper-parameters. -With this tool, users can investigate correlations and importance scores of -specific hyper-parameters of data recipes from the HPO view. - -*Note*: this is an experimental feature. Auto-HPO for data recipes still has -a large room to explore. Feel free to provide more suggestions, discussions, and contributions via new PRs! - - -## Prerequisite -You need to install data-juicer first. -Besides, the tool leverages WandB, install it via `pip install wandb`. -Before using this tool, you need to run ` -```wandb login``` and enter your WandB -API key. -If you have your own instance of WandB (e.g., [locally-hosted machine](https://docs.wandb.ai/guides/hosting/)), run the following script: - -```shell -wandb login --host -# enter your api key -``` - - - -## Usage and Customization - -Given a data recipe, characterized by a specified configuration file -``, you can use `execute_hpo.py` to search the -hyper-parameter space defined by ``. -```shell -# cd tools/hpo -python execute_hpo.py --config --hpo_config - -# e.g., -python execute_hpo.py --config configs/process.yaml --hpo_config configs/quality_score_hpo.yaml -``` - -We provide an illustrative objective "quality_score" in `hpo/objects.py`, -which uses quality scorer to measure the processed data, and links the average scores to hyper-parameters of data recipes. -After running it, you will get the result similar to: ![img](https://img.alicdn.com/imgextra/i2/O1CN017fT4Al1bVldeuCmiI_!!6000000003471-2-tps-2506-1710.png) - - -You can implement your own HPO objective in `get_hpo_objective` function, e.g., linking the data -recipes to -- model_loss (by replacing the quality scorer with a training procedure), -- downstream_task (by replacing the quality scorer with training and - evaluation procedures), or +- downstream_task (by replacing the quality scorer with training and evaluation procedures), or - some synergy measures that combine metrics you are interested in, such that the trade-offs from different views can be explored. diff --git a/tools/hpo/README_ZH.md b/tools/hpo/README_ZH.md index ab84e7341..7cdb1a407 100644 --- a/tools/hpo/README_ZH.md +++ b/tools/hpo/README_ZH.md @@ -6,7 +6,7 @@ Data-Juicer 中,以简化改良数据处理超参数的过程。 使用此工具,用户可以研究探索 *数据配方的特定超参数* 和 *指定目标度量(如数据质量分、模型loss等)* 之间的 相关性和重要性得分 -*注意*:这是一个实验性功能。 用于数据配方的 Auto-HPO 仍然有 +**注意**:这是一个实验性功能。 用于数据配方的 Auto-HPO 仍然有 一个极大的探索空间,暂无标准做法。 欢迎大家提出更多的建议、讨论、 并通过新的 PR 做出贡献! @@ -39,7 +39,7 @@ python execute_hpo.py --config configs/process.yaml --hpo_config configs/quality ``` 我们在`hpo/objects.py`中提供了一个示意性的搜索目标 `quality_score`, -它使用质量评分器来度量处理后的数据,并将平均质分数链接到数据配方的超参数。 +它使用质量评分器来度量处理后的数据,并将平均质量分数链接到数据配方的超参数。 运行后,你会得到类似如下的结果:![img](https://img.alicdn.com/imgextra/i2/O1CN017fT4Al1bVldeuCmiI_!!6000000003471-2-tps-2506-1710.png) diff --git a/tools/hpo/configs/quality_score_hpo.yaml b/tools/hpo/configs/quality_score_hpo.yaml index 543e7b64b..cedae00ed 100644 --- a/tools/hpo/configs/quality_score_hpo.yaml +++ b/tools/hpo/configs/quality_score_hpo.yaml @@ -22,8 +22,7 @@ parameters: min: 256 max: 8192 - -#early_terminate: -# type: hyperband -# max_iter: 27 -# s: 2 +early_terminate: + type: hyperband + max_iter: 27 + s: 2 From 2b9763c63c8c4b36346734a1074aaa5c6f1b8c61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=81=93=E8=BE=95?= Date: Wed, 8 Nov 2023 17:46:07 +0800 Subject: [PATCH 5/6] fix according to yilun's comments --- .gitignore | 1 + tools/hpo/README.md | 5 +++++ tools/hpo/README_ZH.md | 8 +++++++- tools/hpo/configs/quality_score_hpo.yaml | 2 +- tools/hpo/execute_hpo.py | 4 ++-- 5 files changed, 16 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 7972867ed..d5d1d0782 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,5 @@ dist # others .DS_Store .idea/ +wandb/ __pycache__ diff --git a/tools/hpo/README.md b/tools/hpo/README.md index f7b7108b3..0323e1855 100644 --- a/tools/hpo/README.md +++ b/tools/hpo/README.md @@ -39,6 +39,11 @@ python execute_hpo.py --config --hpo_config `, +please see more details in our [guidance](https://github. +com/alibaba/data-juicer#build-up-config-files). As for the configuration +for HPO, i.e., ``, please refer to sweep [guidance](https://docs.wandb.ai/guides/sweeps/define-sweep-configuration). + We provide an illustrative objective "quality_score" in `hpo/objects.py`, which uses quality scorer to measure the processed data, and links the average scores to hyper-parameters of data recipes. After running it, you will get the result similar to: ![img](https://img.alicdn.com/imgextra/i2/O1CN017fT4Al1bVldeuCmiI_!!6000000003471-2-tps-2506-1710.png) diff --git a/tools/hpo/README_ZH.md b/tools/hpo/README_ZH.md index 7cdb1a407..b2437302b 100644 --- a/tools/hpo/README_ZH.md +++ b/tools/hpo/README_ZH.md @@ -4,7 +4,7 @@ 我们将自动化 HPO (hyper-parameters optimization) 工具 WandB [Sweep](https://docs.wandb.ai/guides/sweeps) 结合到 Data-Juicer 中,以简化改良数据处理超参数的过程。 -使用此工具,用户可以研究探索 *数据配方的特定超参数* 和 *指定目标度量(如数据质量分、模型loss等)* 之间的 相关性和重要性得分 +使用此工具,用户可以研究探索 *数据配方的特定超参数* 和 *指定目标度量(如数据质量分、模型loss等)* 之间的 相关性和重要性得分。 **注意**:这是一个实验性功能。 用于数据配方的 Auto-HPO 仍然有 一个极大的探索空间,暂无标准做法。 欢迎大家提出更多的建议、讨论、 @@ -38,6 +38,12 @@ python execute_hpo.py --config --hpo_config `, +请参阅我们的 [指南](https://github.com/alibaba/data-juicer/blob/main/README_ZH.md#%E6%9E%84%E5%BB%BA%E9%85%8D%E7%BD%AE%E6%96%87%E4%BB%B6) +获取更多详细信息。 +对于HPO的配置,即``,请参阅Sweep提供的 [指南](https://docs.wandb.ai/guides/sweeps/define-sweep-configuration) 。 + + 我们在`hpo/objects.py`中提供了一个示意性的搜索目标 `quality_score`, 它使用质量评分器来度量处理后的数据,并将平均质量分数链接到数据配方的超参数。 运行后,你会得到类似如下的结果:![img](https://img.alicdn.com/imgextra/i2/O1CN017fT4Al1bVldeuCmiI_!!6000000003471-2-tps-2506-1710.png) diff --git a/tools/hpo/configs/quality_score_hpo.yaml b/tools/hpo/configs/quality_score_hpo.yaml index cedae00ed..1f7766f98 100644 --- a/tools/hpo/configs/quality_score_hpo.yaml +++ b/tools/hpo/configs/quality_score_hpo.yaml @@ -1,6 +1,6 @@ sweep_name: hpo_for_data-juicer -# sweep_count: 10 +sweep_max_count: 1000 # the maximal number of trials; `None` for unlimited # hpo configuration from original sweep, see more options and details in # https://docs.wandb.ai/guides/sweeps/define-sweep-configuration diff --git a/tools/hpo/execute_hpo.py b/tools/hpo/execute_hpo.py index 1293bfa66..7f7bb2d09 100644 --- a/tools/hpo/execute_hpo.py +++ b/tools/hpo/execute_hpo.py @@ -42,5 +42,5 @@ def search(): wandb.agent(sweep_id, function=search, - count=sweep_configuration['sweep_count'] - if 'sweep_count' in sweep_configuration else None) + count=sweep_configuration['sweep_max_count'] + if 'sweep_max_count' in sweep_configuration else None) From ca2f564c9a0094e0da81bbbc0bcd218384c1fea3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=81=93=E8=BE=95?= Date: Wed, 8 Nov 2023 17:52:36 +0800 Subject: [PATCH 6/6] fix according to yilun's comments --- tools/hpo/README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tools/hpo/README.md b/tools/hpo/README.md index 0323e1855..01a3735f5 100644 --- a/tools/hpo/README.md +++ b/tools/hpo/README.md @@ -40,8 +40,7 @@ python execute_hpo.py --config configs/process.yaml --hpo_config configs/quality ``` For the configuration for data recipe, i.e., ``, -please see more details in our [guidance](https://github. -com/alibaba/data-juicer#build-up-config-files). As for the configuration +please see more details in our [guidance](https://github.com/alibaba/data-juicer#build-up-config-files). As for the configuration for HPO, i.e., ``, please refer to sweep [guidance](https://docs.wandb.ai/guides/sweeps/define-sweep-configuration). We provide an illustrative objective "quality_score" in `hpo/objects.py`,