Skip to content

Commit

Permalink
add greedy_zero_padding
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Aug 15, 2024
1 parent 75c7636 commit 274b881
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 63 deletions.
11 changes: 4 additions & 7 deletions llm/alignment/dpo/dpo_argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,10 @@ class DPODataArgument:
default=False,
metadata={"help": "Whether to run benchmark by autotuner. True for from_scratch."},
)
greedy_intokens: bool = field(
default=True,
metadata={"help": "Whether apply greedy intokens."},
greedy_zero_padding: bool = field(
default=False,
metadata={"help": "Whether to use Greedy Zero Padding data stream."},
)
buffer_size: int = field(default=500, metadata={"help": "Buffer size for greedy_intokens strategy."})


@dataclass
Expand All @@ -87,9 +86,7 @@ class DPOModelArgument:
"help": "The granularity of recompute training can be selected as `full` or `full_attn` or `core_attn`."
},
)
flash_mask: bool = field(
default=False, metadata={"help": "Whether to use flash mask in flash attention."}
)
flash_mask: bool = field(default=False, metadata={"help": "Whether to use flash mask in flash attention."})
virtual_pp_degree: int = field(
default=1,
metadata={"help": "virtual_pp_degree"},
Expand Down
18 changes: 9 additions & 9 deletions llm/alignment/dpo/run_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import os
import sys
import time
import inspect
from functools import partial

import paddle
Expand All @@ -30,17 +29,19 @@
get_last_checkpoint,
set_seed,
)
from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from paddlenlp.transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
LlamaForCausalLM,
LlamaForCausalLMPipe,
)
from paddlenlp.trl import (
DPOTrainer,
calculate_effective_tokens,
preference_collate_fn,
preprocess_preference_data,
)
from paddlenlp.transformers import (
LlamaForCausalLM,
LlamaForCausalLMPipe,
)
from paddlenlp.utils.log import logger

flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe]
Expand Down Expand Up @@ -132,9 +133,7 @@ def main():
model.set_state_dict(ref_model.state_dict())

if model_args.flash_mask and not model.config.use_flash_attention:
logger.warning(
"`flash_mask` must use with zero padding and flash attention."
)
logger.warning("`flash_mask` must use with zero padding and flash attention.")
model.config.use_flash_attention = True

if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list):
Expand All @@ -161,6 +160,7 @@ def main():
train_ds.map(trans_func),
tokenizer=tokenizer,
max_length=data_args.max_seq_len,
greedy_zero_padding=data_args.greedy_zero_padding,
)
if train_ds is not None
else None
Expand Down
2 changes: 2 additions & 0 deletions llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def neft_post_hook(module, input, output):
train_ds,
tokenizer=tokenizer,
max_length=data_args.max_length,
greedy_zero_padding=data_args.greedy_zero_padding,
)
if train_ds is not None
else None
Expand All @@ -400,6 +401,7 @@ def neft_post_hook(module, input, output):
ptq_ds,
tokenizer=tokenizer,
max_length=data_args.max_length,
greedy_zero_padding=data_args.greedy_zero_padding,
)
if ptq_ds is not None
else None
Expand Down
6 changes: 6 additions & 0 deletions llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ class DataArgument:
dataset_name_or_path: str = field(default=None, metadata={"help": "Name or path for dataset"})
task_name: str = field(default=None, metadata={"help": "Additional name to select a more specific task."})
zero_padding: bool = field(default=False, metadata={"help": "Whether to use Zero Padding data stream"})
greedy_zero_padding: bool = field(
default=False,
metadata={
"help": "Whether to use Greedy Zero Padding data stream, should be used together with `zero_padding=True`."
},
)
pad_to_multiple_of: int = field(
default=None, metadata={"help": "If set will pad the sequence to a multiple of the provided value."}
)
Expand Down
170 changes: 123 additions & 47 deletions paddlenlp/datasets/zero_padding_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,27 @@
from scipy.linalg import block_diag


def generate_greedy_packs(examples, max_length):
left_len = np.zeros([len(examples)]) - 1
left_len[0] = max_length # At the beginning, only the first pack is valid.
generate_packs = [[] for i in range(len(examples))]
index, left_index = 0, 0

Check warning on line 24 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L21-L24

Added lines #L21 - L24 were not covered by tests

while index < len(examples):
record = examples[index]
max_left_index = left_len.argmax()

Check warning on line 28 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L26-L28

Added lines #L26 - L28 were not covered by tests
# Put the current sequence into the largest left space valid pack.
if len(record["input_ids"]) <= left_len[max_left_index]:
generate_packs[max_left_index].append(record)
left_len[max_left_index] -= len(record["input_ids"])
index += 1

Check warning on line 33 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L30-L33

Added lines #L30 - L33 were not covered by tests
else:
left_index += 1
left_len[left_index] = max_length

Check warning on line 36 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L35-L36

Added lines #L35 - L36 were not covered by tests

return generate_packs

Check warning on line 38 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L38

Added line #L38 was not covered by tests


class ZeroPadding:
required_output_keys = ["input_ids", "labels", "attention_mask"]
# Only supported the following keys for ZeroPadding. Keys outside of the set will be ignored.
Expand Down Expand Up @@ -80,38 +101,66 @@ def _pad_batch_records(cls, batch_records):


class ZeroPaddingMapDataset(ZeroPadding, Dataset):
def __init__(self, data, tokenizer, max_length):
def __init__(self, data, tokenizer, max_length, greedy_zero_padding=False):
self.tokenizer = tokenizer
self.max_length = max_length
self.greedy_zero_padding = greedy_zero_padding

Check warning on line 107 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L107

Added line #L107 was not covered by tests
self.new_data = self._create_zero_padding_data(data)

def _create_zero_padding_data(self, data):
batch_records, max_len = [], 0
cur_len_so_far = 0

total_data = []
for i in range(len(data)):
record = data[i]
max_len = max(max_len, len(record["input_ids"]))
to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length
if to_append:
batch_records.append(record)
cur_len_so_far += len(record["input_ids"])
else:
# exceed max length
if not self.greedy_zero_padding:
batch_records = []
cur_len_so_far = 0
for i in range(len(data)):
record = data[i]
if len(record["input_ids"]) > self.max_length:
continue
to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length
if to_append:
batch_records.append(record)
cur_len_so_far += len(record["input_ids"])

Check warning on line 122 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L112-L122

Added lines #L112 - L122 were not covered by tests
else:
# exceed max length
padded_list = self._pad_batch_records(batch_records)
total_data.append(padded_list)

Check warning on line 126 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L125-L126

Added lines #L125 - L126 were not covered by tests
# reset
batch_records = []
cur_len_so_far = 0

Check warning on line 129 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L128-L129

Added lines #L128 - L129 were not covered by tests
# append current data
batch_records.append(record)
cur_len_so_far += len(record["input_ids"])

Check warning on line 132 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L131-L132

Added lines #L131 - L132 were not covered by tests

# remaining data
if batch_records:

Check warning on line 135 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L135

Added line #L135 was not covered by tests
padded_list = self._pad_batch_records(batch_records)
total_data.append(padded_list)
# reset
batch_records, max_len = [], 0
cur_len_so_far = 0
# append current data
batch_records.append(record)
cur_len_so_far += len(record["input_ids"])

# remaining data
if batch_records:
padded_list = self._pad_batch_records(batch_records)
total_data.append(padded_list)
else:
examples = []
buffer_size = 500
i = 0
for record in data:
if len(record["input_ids"]) > self.max_length:
continue
if i < buffer_size:
examples.append(record)
i += 1

Check warning on line 147 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L139-L147

Added lines #L139 - L147 were not covered by tests
else:
# Running greedy strategy in examples.
generate_packs = generate_greedy_packs(examples, self.max_length)
for batch_records in generate_packs:
if len(batch_records) > 0:
padded_list = self._pad_batch_records(batch_records)
total_data.append(padded_list)
examples = [record]
i = 1
if len(examples) > 0:
generate_packs = generate_greedy_packs(examples, self.max_length)
for batch_records in generate_packs:
if len(batch_records) > 0:
padded_list = self._pad_batch_records(batch_records)
total_data.append(padded_list)

Check warning on line 162 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L150-L162

Added lines #L150 - L162 were not covered by tests

return total_data

def __getitem__(self, idx):
Expand All @@ -122,34 +171,61 @@ def __len__(self):


class ZeroPaddingIterableDataset(ZeroPadding, IterableDataset):
def __init__(self, data, tokenizer, max_length):

def __init__(self, data, tokenizer, max_length, greedy_zero_padding=False):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
self.zero_padding_global_step = 0
self.greedy_zero_padding = greedy_zero_padding

Check warning on line 179 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L179

Added line #L179 was not covered by tests

def __iter__(self):
batch_records, max_len = [], 0
cur_len_so_far = 0
for record in self.data:
max_len = max(max_len, len(record["input_ids"]))
to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length
if to_append:
batch_records.append(record)
self.zero_padding_global_step += 1
cur_len_so_far += len(record["input_ids"])
else:
# exceed max length
if not self.greedy_zero_padding:
batch_records = []
cur_len_so_far = 0
for record in self.data:
to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length
if to_append:
batch_records.append(record)
self.zero_padding_global_step += 1
cur_len_so_far += len(record["input_ids"])

Check warning on line 190 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L182-L190

Added lines #L182 - L190 were not covered by tests
else:
# exceed max length
padded_list = self._pad_batch_records(batch_records)
yield padded_list

Check warning on line 194 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L193-L194

Added lines #L193 - L194 were not covered by tests
# reset
batch_records = []
cur_len_so_far = 0

Check warning on line 197 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L196-L197

Added lines #L196 - L197 were not covered by tests
# append current data
batch_records.append(record)
self.zero_padding_global_step += 1
cur_len_so_far += len(record["input_ids"])
if batch_records:

Check warning on line 202 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L199-L202

Added lines #L199 - L202 were not covered by tests
padded_list = self._pad_batch_records(batch_records)
yield padded_list
# reset
batch_records, max_len = [], 0
cur_len_so_far = 0
# append current data
batch_records.append(record)
self.zero_padding_global_step += 1
cur_len_so_far += len(record["input_ids"])
if batch_records:
padded_list = self._pad_batch_records(batch_records)
yield padded_list
else:
examples = []
buffer_size = 500
i = 0
for record in self.data:
if len(record["input_ids"]) > self.max_length:
continue
if i < buffer_size:
examples.append(record)
self.zero_padding_global_step += 1
i += 1

Check warning on line 215 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L206-L215

Added lines #L206 - L215 were not covered by tests
else:
# Running greedy strategy in examples.
generate_packs = generate_greedy_packs(examples, self.max_length)
for batch_records in generate_packs:
if len(batch_records) > 0:
padded_list = self._pad_batch_records(batch_records)
yield padded_list
examples = [record]
self.zero_padding_global_step += 1
i = 1
if len(examples) > 0:
generate_packs = generate_greedy_packs(examples, self.max_length)
for batch_records in generate_packs:
if len(batch_records) > 0:
padded_list = self._pad_batch_records(batch_records)
yield padded_list

Check warning on line 231 in paddlenlp/datasets/zero_padding_dataset.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/datasets/zero_padding_dataset.py#L218-L231

Added lines #L218 - L231 were not covered by tests

0 comments on commit 274b881

Please sign in to comment.