Skip to content

Commit

Permalink
add unit tests for cum seq lens, add ability to build cu_seq_lens fro…
Browse files Browse the repository at this point in the history
…m positional ids, fix prompt test
  • Loading branch information
winglian committed Aug 6, 2023
1 parent b6fd675 commit 58552fe
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 45 deletions.
42 changes: 1 addition & 41 deletions src/axolotl/monkeypatch/llama_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,47 +17,7 @@

from transformers.models.llama.modeling_llama import apply_rotary_pos_emb


def get_cu_seqlens(attn_mask):
device = attn_mask.device
# Exclude zeros to avoid adding their positions to the mask
t_non_zeros = attn_mask[attn_mask != 0]
# Find where the sequence number changes (including the first position)
seq_change = torch.cat(
[
torch.tensor([1], dtype=torch.int32, device=device),
t_non_zeros[1:] != t_non_zeros[:-1],
]
)
# Get the indices where the sequence changes
change_indices = torch.cat(
[
(seq_change == 1).nonzero(as_tuple=True)[0],
torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = change_indices[1:] - change_indices[:-1]
# Calculate the length of the final sequence or padding
final_seq_length = attn_mask.shape[1] - change_indices[-1]
# Append the length of the final sequence or padding to seq_lengths
if final_seq_length.item():
seq_lengths = torch.cat(
[
seq_lengths,
torch.tensor(
[final_seq_length.item()], dtype=torch.int32, device=device
),
]
)
# Calculate the cumulative sequence lengths
cu_seqlens = torch.cat(
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
)

max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()

return cu_seqlens.to(dtype=torch.int32), max_seq_len
from axolotl.monkeypatch.utils import get_cu_seqlens


def forward(
Expand Down
103 changes: 103 additions & 0 deletions src/axolotl/monkeypatch/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""
Shared utils for the monkeypatches
"""
import torch


def get_cu_seqlens(attn_mask):
"""generate a cumulative sequence length mask for flash attention using attn mask"""
if len(attn_mask.shape) == 1:
attn_mask = attn_mask.unsqueeze(0)

device = attn_mask.device
results = []
max_seq_lens = []

for row in attn_mask:
# Exclude zeros to avoid adding their positions to the mask
t_non_zeros = row[row != 0]
# Find where the sequence number changes (including the first position)
seq_change = torch.cat(
[
torch.tensor([1], dtype=torch.int32, device=device),
t_non_zeros[1:] != t_non_zeros[:-1],
]
)
# Get the indices where the sequence changes
change_indices = torch.cat(
[
(seq_change == 1).nonzero(as_tuple=True)[0],
torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = change_indices[1:] - change_indices[:-1]
# Calculate the length of the final sequence or padding
final_seq_length = len(row) - change_indices[-1]
# Append the length of the final sequence or padding to seq_lengths
if final_seq_length.item():
seq_lengths = torch.cat(
[
seq_lengths,
torch.tensor(
[final_seq_length.item()], dtype=torch.int32, device=device
),
]
)
# Calculate the cumulative sequence lengths
cu_seqlens = torch.cat(
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
)
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
results.append(cu_seqlens)
max_seq_lens.append(max_seq_len)

return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)


def get_cu_seqlens_from_pos_ids(position_ids):
"""generate a cumulative sequence length mask for flash attention using pos ids"""
if len(position_ids.shape) == 1:
position_ids = position_ids.unsqueeze(0)

device = position_ids.device
results = []
max_seq_lens = []

for row in position_ids:
# Count the number of consecutive zeros from the right side
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()

# Adjust the row to exclude padding
adjusted_row = row[:-padding_length] if padding_length else row.clone()

# Find where the position resets to 0 (indicating a new sequence)
seq_starts = torch.cat(
[
torch.tensor([True], dtype=torch.bool, device=device),
adjusted_row[1:] == 0,
]
)
# Get the indices where the sequence starts
start_indices = torch.cat(
[
(seq_starts).nonzero(as_tuple=True)[0],
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = start_indices[1:] - start_indices[:-1]
# Calculate the cumulative sequence lengths
cu_seqlens = torch.cat(
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
)
# Append the padding length to the cumulative sequence lengths
if padding_length:
cu_seqlens = torch.cat(
[cu_seqlens, torch.tensor([len(row)], dtype=torch.int32, device=device)]
)
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
results.append(cu_seqlens)
max_seq_lens.append(max_seq_len)

return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
2 changes: 1 addition & 1 deletion src/axolotl/prompters.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def match_prompt_style(self):
if self.prompt_style == PromptStyle.CHAT.value:
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
self.system_format = "SYSTEM:{system}\n"
self.system_format = "SYSTEM: {system}\n"
if self.prompt_style == PromptStyle.CHATML.value:
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
self.turn_no_input_format = (
Expand Down
30 changes: 30 additions & 0 deletions tests/monkeypatch/test_llama_attn_hijack_flash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
Unit tests for the monkeypatch utils
"""
import unittest

import torch

from axolotl.monkeypatch.utils import get_cu_seqlens, get_cu_seqlens_from_pos_ids


class TestMonkeyPatchUtils(unittest.TestCase):
"""
Unit test class for monkeypatch utils
"""

def test_get_cu_seqlens_1d(self):
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)
self.assertTrue(torch.allclose(get_cu_seqlens(attn_mask)[0], target_res))

def test_get_cu_seqlens_from_pos_ids_1d(self):
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0]])
target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)
self.assertTrue(
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
)


if __name__ == "__main__":
unittest.main()
12 changes: 9 additions & 3 deletions tests/test_prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,15 @@ def test_system_alpaca(self):
"output": "Hi! How can I help?",
}
example = strat.tokenize_prompt(sample)
assert example["input_ids"][0:4] == [1, 835, 2184, 29901] # "<s>### System:"
assert example["input_ids"][5:7] == [1509, 20118] # "use cot"
assert example["input_ids"][9] == 11889 # USER
assert example["input_ids"][0:5] == [
1,
28962,
1254,
12665,
29901,
] # "<s>SYSTEM:"
assert example["input_ids"][5:7] == [671, 20118] # " use cot"
assert example["input_ids"][8] == 11889 # USER


if __name__ == "__main__":
Expand Down

0 comments on commit 58552fe

Please sign in to comment.