Skip to content

Commit

Permalink
Add context size checks to pre-tokenized datasets (#166)
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney authored Jan 2, 2024
1 parent 065f101 commit 5132eb6
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 42 deletions.
10 changes: 6 additions & 4 deletions sparse_autoencoder/source_data/abstract_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,10 @@ def __init__(
data_files=dataset_files,
)

# Setup preprocessing
existing_columns: list[str] = list(next(iter(dataset)).keys())
# Setup preprocessing (we remove all columns except for input ids)
remove_columns: list[str] = list(next(iter(dataset)).keys())
if "input_ids" in remove_columns:
remove_columns.remove("input_ids")

if pre_download:
if not isinstance(dataset, Dataset):
Expand All @@ -179,7 +181,7 @@ def __init__(
batched=True,
batch_size=preprocess_batch_size,
fn_kwargs={"context_size": context_size},
remove_columns=existing_columns,
remove_columns=remove_columns,
num_proc=n_processes_preprocessing,
)
self.dataset = mapped_dataset.shuffle()
Expand All @@ -199,7 +201,7 @@ def __init__(
batched=True,
batch_size=preprocess_batch_size,
fn_kwargs={"context_size": context_size},
remove_columns=existing_columns,
remove_columns=remove_columns,
)
self.dataset = mapped_dataset.shuffle(buffer_size=buffer_size) # type: ignore

Expand Down
11 changes: 11 additions & 0 deletions sparse_autoencoder/source_data/pretokenized_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,20 @@ def preprocess(
Returns:
Tokenized prompts.
Raises:
ValueError: If the context size is larger than the tokenized prompt size.
"""
tokenized_prompts: list[list[int]] = source_batch[self._dataset_column_name]

# Check the context size is not too large
if context_size > len(tokenized_prompts[0]):
error_message = (
f"The context size ({context_size}) is larger than the "
f"tokenized prompt size ({len(tokenized_prompts[0])})."
)
raise ValueError(error_message)

# Chunk each tokenized prompt into blocks of context_size,
# discarding the last block if too small.
context_size_prompts = []
Expand Down
50 changes: 12 additions & 38 deletions sparse_autoencoder/source_data/tests/test_pretokenized_dataset.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,17 @@
"""Tests for General Pre-Tokenized Dataset."""
import pytest

from sparse_autoencoder.source_data.pretokenized_dataset import PreTokenizedDataset

TEST_DATASET = "alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2"


# Mock class for PreTokenizedDataset
class MockPreTokenizedDataset:
"""Mock class for PreTokenizedDataset used in testing.
Attributes:
dataset_path: Path to the dataset.
context_size: The context size of the tokenized prompts.
dataset: The mock dataset.
"""

def __init__(self, dataset_path: str, context_size: int) -> None:
"""Initializes the mock PreTokenizedDataset with a dataset path and context size.
Args:
dataset_path: Path to the dataset.
context_size: The context size of the tokenized prompts.
"""
self.dataset_path = dataset_path
self.context_size = context_size
self.dataset = self._generate_mock_data()

def _generate_mock_data(self) -> list[dict]:
"""Generates mock data for testing.
Returns:
list[dict]: A list of dictionaries representing mock data items.
"""
mock_data = []
for _ in range(10):
item = {"input_ids": list(range(self.context_size))}
mock_data.append(item)
return mock_data
TEST_DATASET = "alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2"


@pytest.mark.integration_test()
@pytest.mark.parametrize("context_size", [50, 250])
@pytest.mark.parametrize("context_size", [128, 256])
def test_tokenized_prompts_correct_size(context_size: int) -> None:
"""Test that the tokenized prompts have the correct context size."""
# Use an appropriate tokenizer and dataset path

data = MockPreTokenizedDataset(dataset_path=TEST_DATASET, context_size=context_size)
data = PreTokenizedDataset(dataset_path=TEST_DATASET, context_size=context_size)

# Check the first k items
iterable = iter(data.dataset)
Expand All @@ -56,3 +22,11 @@ def test_tokenized_prompts_correct_size(context_size: int) -> None:
# Check the tokens are integers
for token in item["input_ids"]:
assert isinstance(token, int)


@pytest.mark.integration_test()
def test_fails_context_size_too_large() -> None:
"""Test that it fails if the context size is set as larger than the source dataset on HF."""
data = PreTokenizedDataset(dataset_path=TEST_DATASET, context_size=512)
with pytest.raises(ValueError, match=r"larger than the tokenized prompt size"):
next(iter(data))

0 comments on commit 5132eb6

Please sign in to comment.