Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: allow for padding free + pretraining #426

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,7 @@ class AttentionAndDistributedPackingConfig:
def __post_init__(self):
# ensure nested dataclasses initialized
ensure_nested_dataclasses_initialized(self)

@property
def is_padding_free(self):
HarikrishnanBalagopal marked this conversation as resolved.
Show resolved Hide resolved
return self.padding_free is not None
23 changes: 16 additions & 7 deletions tuning/data/setup_dataprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,21 @@ def _get_pretokenized_dataset_handlers(data_args, packing, is_eval_tokenized):


### Data format 2
def _get_dataset_formatting_handlers(data_args, packing):
def _get_dataset_formatting_handlers(data_args, packing, padding_free=None):

if data_args.response_template is None:
if packing is False:
raise ValueError(
"Since dataset_text_field or data_formatter_template \
is provided and packing is disabled, \
needs a corresponding response template for masking"
)
if padding_free:
logger.debug(
"Assuming extended pretraining scenario because, packing is false"
+ ", padding_free is used and no response template was provided."
)
else:
raise ValueError(
"Since dataset_text_field or data_formatter_template \
is provided and packing is disabled, \
needs a corresponding response template for masking"
)

if data_args.response_template:
# To use Response template, pass datasets with single sequence instances \
Expand Down Expand Up @@ -209,6 +215,7 @@ def _process_raw_data_args(
packing: bool,
max_seq_length: int,
additional_data_handlers: Dict[str, Callable] = None,
**kwargs,
):

# Create a data processor with default processor config
Expand Down Expand Up @@ -266,7 +273,7 @@ def _process_raw_data_args(
elif data_args.data_formatter_template or data_args.dataset_text_field:
# Data Format 3: Single Sequence Dataset
handlers, dataset_text_field = _get_dataset_formatting_handlers(
data_args, packing
data_args, packing, **kwargs
)
else:
# Default Data Format: Dataset with Input/Output Fields
Expand Down Expand Up @@ -300,6 +307,7 @@ def process_dataargs(
tokenizer: AutoTokenizer,
train_args: TrainingArguments,
additional_data_handlers: Dict[str, Callable] = None,
**kwargs,
):
"""
Args:
Expand Down Expand Up @@ -345,6 +353,7 @@ def process_dataargs(
train_args.packing,
max_seq_length,
additional_data_handlers,
**kwargs,
)

# Note: This check should not be removed.
Expand Down
11 changes: 10 additions & 1 deletion tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,9 @@ def train(
data_collator = None
logger.info("Packing is set to %s ", train_args.packing)

padding_free = None
if attention_and_distributed_packing_config is not None:
padding_free = attention_and_distributed_packing_config.padding_free
data_preprocessing_time = time.time()
(
formatted_train_dataset,
Expand All @@ -310,7 +313,13 @@ def train(
data_collator,
train_args.max_seq_length,
dataset_kwargs,
) = process_dataargs(data_args, tokenizer, train_args, additional_data_handlers)
) = process_dataargs(
data_args,
tokenizer,
train_args,
additional_data_handlers,
padding_free=padding_free,
HarikrishnanBalagopal marked this conversation as resolved.
Show resolved Hide resolved
)
additional_metrics["data_preprocessing_time"] = (
time.time() - data_preprocessing_time
)
Expand Down
Loading