Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Oct 23, 2024
1 parent 99633aa commit 57d9a89
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 14 deletions.
7 changes: 4 additions & 3 deletions tests/data/test_alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ def test_alpaca(mock_tokenizer, alpaca_path):
train_batch = next(iter(train_dataloader))
val_batch = next(iter(val_dataloader))

assert train_batch.keys() == val_batch.keys() == {"input_ids", "labels"}
assert all(seq.shape == (2, 10) for seq in train_batch.values())
assert all(seq.shape == (2, 10) for seq in val_batch.values())
assert train_batch.keys() == val_batch.keys() == {"input_ids", "labels", "token_counts"}
for key in ["input_ids", "labels"]:
assert train_batch[key].shape == (2, 10), f"Unexpected shape for train_batch[{key}]"
assert val_batch[key].shape == (2, 10), f"Unexpected shape for val_batch[{key}]"

assert isinstance(train_dataloader.dataset.prompt_style, AlpacaPromptStyle)
assert isinstance(val_dataloader.dataset.prompt_style, AlpacaPromptStyle)
Expand Down
19 changes: 14 additions & 5 deletions tests/data/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,32 @@ def apply(self, prompt, **kwargs):
def test_sft_collate_fn_padding(pad_id, ignore_index):
collate = get_sft_collate_fn(pad_id=pad_id, ignore_index=ignore_index)
samples = [
{"input_ids": torch.tensor([1, 2, 3]), "labels": torch.tensor([10, 20, 30])},
{"input_ids": torch.tensor([4, 5, 6, 7, 8]), "labels": torch.tensor([40, 50, 60, 70, 80])},
{"input_ids": torch.tensor([1, 2, 3]), "labels": torch.tensor([10, 20, 30]), "token_counts": {"raw": 3, "raw_plus_prompt_template": 25}},
{"input_ids": torch.tensor([4, 5, 6, 7, 8]), "labels": torch.tensor([40, 50, 60, 70, 80]), "token_counts": {"raw": 5, "raw_plus_prompt_template": 27}},
]
expected = {
"input_ids": torch.tensor([[1, 2, 3, pad_id, pad_id], [4, 5, 6, 7, 8]]),
"labels": torch.tensor([[10, 20, 30, ignore_index, ignore_index], [40, 50, 60, 70, 80]]),
"token_counts": {"raw": torch.tensor([[3], [5]]), "raw_plus_prompt_template": torch.tensor([[25], [27]])}
}
batch = collate(samples)
assert all(torch.equal(batch[k], expected[k]) for k in ("input_ids", "labels"))
for key in ("raw", "raw_plus_prompt_template"):
assert torch.equal(batch["token_counts"][key], expected["token_counts"][key]), f"Token count mismatch for {key}"


def test_sft_collate_fn_truncation():
collate = get_sft_collate_fn(max_seq_length=2)
samples = [
{"input_ids": torch.tensor([1, 2, 3]), "labels": torch.tensor([10, 20, 30])},
{"input_ids": torch.tensor([4, 5, 6, 7, 8]), "labels": torch.tensor([40, 50, 60, 70, 80])},
{"input_ids": torch.tensor([1, 2, 3]), "labels": torch.tensor([10, 20, 30]), "token_counts": {"raw": 3, "raw_plus_prompt_template": 25}},
{"input_ids": torch.tensor([4, 5, 6, 7, 8]), "labels": torch.tensor([40, 50, 60, 70, 80]), "token_counts": {"raw": 5, "raw_plus_prompt_template": 27}},
]
expected = {"input_ids": torch.tensor([[1, 2], [4, 5]]), "labels": torch.tensor([[10, 20], [40, 50]])}
expected = {
"input_ids": torch.tensor([[1, 2], [4, 5]]),
"labels": torch.tensor([[10, 20], [40, 50]]),
"token_counts": {"raw": torch.tensor([[3], [5]]), "raw_plus_prompt_template": torch.tensor([[25], [27]])}
}
batch = collate(samples)
assert all(torch.equal(batch[k], expected[k]) for k in ("input_ids", "labels"))
for key in ("raw", "raw_plus_prompt_template"):
assert torch.equal(batch["token_counts"][key], expected["token_counts"][key]), f"Token count mismatch for {key}"
7 changes: 4 additions & 3 deletions tests/data/test_dolly.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ def test_dolly(mock_tokenizer, dolly_path):
train_batch = next(iter(train_dataloader))
val_batch = next(iter(val_dataloader))

assert train_batch.keys() == val_batch.keys() == {"input_ids", "labels"}
assert all(seq.shape == (2, 10) for seq in train_batch.values())
assert all(seq.shape == (2, 10) for seq in val_batch.values())
assert train_batch.keys() == val_batch.keys() == {"input_ids", "labels", "token_counts"}
for key in ["input_ids", "labels"]:
assert train_batch[key].shape == (2, 10), f"Unexpected shape for train_batch[{key}]"
assert val_batch[key].shape == (2, 10), f"Unexpected shape for val_batch[{key}]"

assert isinstance(train_dataloader.dataset.prompt_style, AlpacaPromptStyle)
assert isinstance(val_dataloader.dataset.prompt_style, AlpacaPromptStyle)
Expand Down
7 changes: 4 additions & 3 deletions tests/data/test_longform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ def test_longform(mock_tokenizer, longform_path):
train_batch = next(iter(train_dataloader))
val_batch = next(iter(val_dataloader))

assert train_batch.keys() == val_batch.keys() == {"input_ids", "labels"}
assert all(seq.shape == (2, 10) for seq in train_batch.values())
assert all(seq.shape == (2, 10) for seq in val_batch.values())
assert train_batch.keys() == val_batch.keys() == {"input_ids", "labels", "token_counts"}
for key in ["input_ids", "labels"]:
assert train_batch[key].shape == (2, 10), f"Unexpected shape for train_batch[{key}]"
assert val_batch[key].shape == (2, 10), f"Unexpected shape for val_batch[{key}]"

assert isinstance(train_dataloader.dataset.prompt_style, LongFormPromptStyle)
assert isinstance(val_dataloader.dataset.prompt_style, LongFormPromptStyle)
Expand Down

0 comments on commit 57d9a89

Please sign in to comment.