From 58552fe389d4f965da64d062f95be0cbeb627946 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 6 Aug 2023 17:33:40 -0400 Subject: [PATCH] add unit tests for cum seq lens, add ability to build cu_seq_lens from positional ids, fix prompt test --- .../monkeypatch/llama_attn_hijack_flash.py | 42 +------ src/axolotl/monkeypatch/utils.py | 103 ++++++++++++++++++ src/axolotl/prompters.py | 2 +- .../test_llama_attn_hijack_flash.py | 30 +++++ tests/test_prompt_tokenizers.py | 12 +- 5 files changed, 144 insertions(+), 45 deletions(-) create mode 100644 src/axolotl/monkeypatch/utils.py create mode 100644 tests/monkeypatch/test_llama_attn_hijack_flash.py diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index ada4ce73e..179cfc5fa 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -17,47 +17,7 @@ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb - -def get_cu_seqlens(attn_mask): - device = attn_mask.device - # Exclude zeros to avoid adding their positions to the mask - t_non_zeros = attn_mask[attn_mask != 0] - # Find where the sequence number changes (including the first position) - seq_change = torch.cat( - [ - torch.tensor([1], dtype=torch.int32, device=device), - t_non_zeros[1:] != t_non_zeros[:-1], - ] - ) - # Get the indices where the sequence changes - change_indices = torch.cat( - [ - (seq_change == 1).nonzero(as_tuple=True)[0], - torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device), - ] - ) - # Calculate the sequence lengths - seq_lengths = change_indices[1:] - change_indices[:-1] - # Calculate the length of the final sequence or padding - final_seq_length = attn_mask.shape[1] - change_indices[-1] - # Append the length of the final sequence or padding to seq_lengths - if final_seq_length.item(): - seq_lengths = torch.cat( - [ - seq_lengths, - torch.tensor( - [final_seq_length.item()], dtype=torch.int32, device=device - ), - ] - ) - # Calculate the cumulative sequence lengths - cu_seqlens = torch.cat( - [torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)] - ) - - max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - - return cu_seqlens.to(dtype=torch.int32), max_seq_len +from axolotl.monkeypatch.utils import get_cu_seqlens def forward( diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py new file mode 100644 index 000000000..3b007e05d --- /dev/null +++ b/src/axolotl/monkeypatch/utils.py @@ -0,0 +1,103 @@ +""" +Shared utils for the monkeypatches +""" +import torch + + +def get_cu_seqlens(attn_mask): + """generate a cumulative sequence length mask for flash attention using attn mask""" + if len(attn_mask.shape) == 1: + attn_mask = attn_mask.unsqueeze(0) + + device = attn_mask.device + results = [] + max_seq_lens = [] + + for row in attn_mask: + # Exclude zeros to avoid adding their positions to the mask + t_non_zeros = row[row != 0] + # Find where the sequence number changes (including the first position) + seq_change = torch.cat( + [ + torch.tensor([1], dtype=torch.int32, device=device), + t_non_zeros[1:] != t_non_zeros[:-1], + ] + ) + # Get the indices where the sequence changes + change_indices = torch.cat( + [ + (seq_change == 1).nonzero(as_tuple=True)[0], + torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device), + ] + ) + # Calculate the sequence lengths + seq_lengths = change_indices[1:] - change_indices[:-1] + # Calculate the length of the final sequence or padding + final_seq_length = len(row) - change_indices[-1] + # Append the length of the final sequence or padding to seq_lengths + if final_seq_length.item(): + seq_lengths = torch.cat( + [ + seq_lengths, + torch.tensor( + [final_seq_length.item()], dtype=torch.int32, device=device + ), + ] + ) + # Calculate the cumulative sequence lengths + cu_seqlens = torch.cat( + [torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)] + ) + max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + results.append(cu_seqlens) + max_seq_lens.append(max_seq_len) + + return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) + + +def get_cu_seqlens_from_pos_ids(position_ids): + """generate a cumulative sequence length mask for flash attention using pos ids""" + if len(position_ids.shape) == 1: + position_ids = position_ids.unsqueeze(0) + + device = position_ids.device + results = [] + max_seq_lens = [] + + for row in position_ids: + # Count the number of consecutive zeros from the right side + padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item() + + # Adjust the row to exclude padding + adjusted_row = row[:-padding_length] if padding_length else row.clone() + + # Find where the position resets to 0 (indicating a new sequence) + seq_starts = torch.cat( + [ + torch.tensor([True], dtype=torch.bool, device=device), + adjusted_row[1:] == 0, + ] + ) + # Get the indices where the sequence starts + start_indices = torch.cat( + [ + (seq_starts).nonzero(as_tuple=True)[0], + torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device), + ] + ) + # Calculate the sequence lengths + seq_lengths = start_indices[1:] - start_indices[:-1] + # Calculate the cumulative sequence lengths + cu_seqlens = torch.cat( + [torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)] + ) + # Append the padding length to the cumulative sequence lengths + if padding_length: + cu_seqlens = torch.cat( + [cu_seqlens, torch.tensor([len(row)], dtype=torch.int32, device=device)] + ) + max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + results.append(cu_seqlens) + max_seq_lens.append(max_seq_len) + + return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 226cb86af..facac17b6 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -45,7 +45,7 @@ def match_prompt_style(self): if self.prompt_style == PromptStyle.CHAT.value: self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:" self.turn_no_input_format = "USER: {instruction}\nASSISTANT:" - self.system_format = "SYSTEM:{system}\n" + self.system_format = "SYSTEM: {system}\n" if self.prompt_style == PromptStyle.CHATML.value: self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n" self.turn_no_input_format = ( diff --git a/tests/monkeypatch/test_llama_attn_hijack_flash.py b/tests/monkeypatch/test_llama_attn_hijack_flash.py new file mode 100644 index 000000000..289c01a86 --- /dev/null +++ b/tests/monkeypatch/test_llama_attn_hijack_flash.py @@ -0,0 +1,30 @@ +""" +Unit tests for the monkeypatch utils +""" +import unittest + +import torch + +from axolotl.monkeypatch.utils import get_cu_seqlens, get_cu_seqlens_from_pos_ids + + +class TestMonkeyPatchUtils(unittest.TestCase): + """ + Unit test class for monkeypatch utils + """ + + def test_get_cu_seqlens_1d(self): + attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]]) + target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32) + self.assertTrue(torch.allclose(get_cu_seqlens(attn_mask)[0], target_res)) + + def test_get_cu_seqlens_from_pos_ids_1d(self): + position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0]]) + target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32) + self.assertTrue( + torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index 0b9545f43..d6496187d 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -130,9 +130,15 @@ def test_system_alpaca(self): "output": "Hi! How can I help?", } example = strat.tokenize_prompt(sample) - assert example["input_ids"][0:4] == [1, 835, 2184, 29901] # "### System:" - assert example["input_ids"][5:7] == [1509, 20118] # "use cot" - assert example["input_ids"][9] == 11889 # USER + assert example["input_ids"][0:5] == [ + 1, + 28962, + 1254, + 12665, + 29901, + ] # "SYSTEM:" + assert example["input_ids"][5:7] == [671, 20118] # " use cot" + assert example["input_ids"][8] == 11889 # USER if __name__ == "__main__":