From f1ee38754b17f2061ccebccb0b377e990335cc38 Mon Sep 17 00:00:00 2001 From: Harikrishnan Balagopal Date: Sat, 21 Dec 2024 00:49:27 +0530 Subject: [PATCH] fix: allow for padding free + pretraining Signed-off-by: Harikrishnan Balagopal --- tuning/data/setup_dataprocessor.py | 25 ++++++++++++++++++------- tuning/sft_trainer.py | 6 +++++- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/tuning/data/setup_dataprocessor.py b/tuning/data/setup_dataprocessor.py index b6f09c323..c45fb5d5a 100644 --- a/tuning/data/setup_dataprocessor.py +++ b/tuning/data/setup_dataprocessor.py @@ -107,15 +107,21 @@ def _get_pretokenized_dataset_handlers(data_args, packing, is_eval_tokenized): ### Data format 2 -def _get_dataset_formatting_handlers(data_args, packing): +def _get_dataset_formatting_handlers(data_args, packing, padding_free: str = ""): if data_args.response_template is None: if packing is False: - raise ValueError( - "Since dataset_text_field or data_formatter_template \ - is provided and packing is disabled, \ - needs a corresponding response template for masking" - ) + if padding_free: + logger.debug( + "when packing is false but padding_free is used and" + + " no response template is used then its a pretrained scenario." + ) + else: + raise ValueError( + "Since dataset_text_field or data_formatter_template \ + is provided and packing is disabled, \ + needs a corresponding response template for masking" + ) if data_args.response_template: # To use Response template, pass datasets with single sequence instances \ @@ -209,6 +215,7 @@ def _process_raw_data_args( packing: bool, max_seq_length: int, additional_data_handlers: Dict[str, Callable] = None, + padding_free: str = "", ): # Create a data processor with default processor config @@ -266,7 +273,7 @@ def _process_raw_data_args( elif data_args.data_formatter_template or data_args.dataset_text_field: # Data Format 3: Single Sequence Dataset handlers, dataset_text_field = _get_dataset_formatting_handlers( - data_args, packing + data_args, packing, padding_free=padding_free, ) else: # Default Data Format: Dataset with Input/Output Fields @@ -300,6 +307,7 @@ def process_dataargs( tokenizer: AutoTokenizer, train_args: TrainingArguments, additional_data_handlers: Dict[str, Callable] = None, + padding_free: str = "", ): """ Args: @@ -310,6 +318,8 @@ def process_dataargs( Used for packing and max_seq_length additional_data_handlers: A Dict of [str, callable] data handlers which need to be registered with the data preprocessor + padding_free: str + padding free method Returns: Tuple(Dataset, Dataset, str, DataCollator, int, Dict) tuple containing @@ -345,6 +355,7 @@ def process_dataargs( train_args.packing, max_seq_length, additional_data_handlers, + padding_free=padding_free, ) # Note: This check should not be removed. diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 2afdd2dac..92cde4f49 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -303,6 +303,10 @@ def train( logger.info("Packing is set to %s ", train_args.packing) data_preprocessing_time = time.time() + padding_free = "" + if attention_and_distributed_packing_config: + if attention_and_distributed_packing_config.padding_free: + padding_free = attention_and_distributed_packing_config.padding_free ( formatted_train_dataset, formatted_validation_dataset, @@ -310,7 +314,7 @@ def train( data_collator, train_args.max_seq_length, dataset_kwargs, - ) = process_dataargs(data_args, tokenizer, train_args, additional_data_handlers) + ) = process_dataargs(data_args, tokenizer, train_args, additional_data_handlers, padding_free=padding_free) additional_metrics["data_preprocessing_time"] = ( time.time() - data_preprocessing_time )