Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context size checks to pre-tokenized datasets #166

Merged
merged 1 commit into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions sparse_autoencoder/source_data/abstract_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,10 @@ def __init__(
data_files=dataset_files,
)

# Setup preprocessing
existing_columns: list[str] = list(next(iter(dataset)).keys())
# Setup preprocessing (we remove all columns except for input ids)
remove_columns: list[str] = list(next(iter(dataset)).keys())
if "input_ids" in remove_columns:
remove_columns.remove("input_ids")

if pre_download:
if not isinstance(dataset, Dataset):
Expand All @@ -179,7 +181,7 @@ def __init__(
batched=True,
batch_size=preprocess_batch_size,
fn_kwargs={"context_size": context_size},
remove_columns=existing_columns,
remove_columns=remove_columns,
num_proc=n_processes_preprocessing,
)
self.dataset = mapped_dataset.shuffle()
Expand All @@ -199,7 +201,7 @@ def __init__(
batched=True,
batch_size=preprocess_batch_size,
fn_kwargs={"context_size": context_size},
remove_columns=existing_columns,
remove_columns=remove_columns,
)
self.dataset = mapped_dataset.shuffle(buffer_size=buffer_size) # type: ignore

Expand Down
11 changes: 11 additions & 0 deletions sparse_autoencoder/source_data/pretokenized_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,20 @@ def preprocess(

Returns:
Tokenized prompts.

Raises:
ValueError: If the context size is larger than the tokenized prompt size.
"""
tokenized_prompts: list[list[int]] = source_batch[self._dataset_column_name]

# Check the context size is not too large
if context_size > len(tokenized_prompts[0]):
error_message = (
f"The context size ({context_size}) is larger than the "
f"tokenized prompt size ({len(tokenized_prompts[0])})."
)
raise ValueError(error_message)

# Chunk each tokenized prompt into blocks of context_size,
# discarding the last block if too small.
context_size_prompts = []
Expand Down
50 changes: 12 additions & 38 deletions sparse_autoencoder/source_data/tests/test_pretokenized_dataset.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,17 @@
"""Tests for General Pre-Tokenized Dataset."""
import pytest

from sparse_autoencoder.source_data.pretokenized_dataset import PreTokenizedDataset

TEST_DATASET = "alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2"


# Mock class for PreTokenizedDataset
class MockPreTokenizedDataset:
"""Mock class for PreTokenizedDataset used in testing.

Attributes:
dataset_path: Path to the dataset.
context_size: The context size of the tokenized prompts.
dataset: The mock dataset.
"""

def __init__(self, dataset_path: str, context_size: int) -> None:
"""Initializes the mock PreTokenizedDataset with a dataset path and context size.

Args:
dataset_path: Path to the dataset.
context_size: The context size of the tokenized prompts.
"""
self.dataset_path = dataset_path
self.context_size = context_size
self.dataset = self._generate_mock_data()

def _generate_mock_data(self) -> list[dict]:
"""Generates mock data for testing.

Returns:
list[dict]: A list of dictionaries representing mock data items.
"""
mock_data = []
for _ in range(10):
item = {"input_ids": list(range(self.context_size))}
mock_data.append(item)
return mock_data
TEST_DATASET = "alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2"


@pytest.mark.integration_test()
@pytest.mark.parametrize("context_size", [50, 250])
@pytest.mark.parametrize("context_size", [128, 256])
def test_tokenized_prompts_correct_size(context_size: int) -> None:
"""Test that the tokenized prompts have the correct context size."""
# Use an appropriate tokenizer and dataset path

data = MockPreTokenizedDataset(dataset_path=TEST_DATASET, context_size=context_size)
data = PreTokenizedDataset(dataset_path=TEST_DATASET, context_size=context_size)

# Check the first k items
iterable = iter(data.dataset)
Expand All @@ -56,3 +22,11 @@ def test_tokenized_prompts_correct_size(context_size: int) -> None:
# Check the tokens are integers
for token in item["input_ids"]:
assert isinstance(token, int)


@pytest.mark.integration_test()
def test_fails_context_size_too_large() -> None:
"""Test that it fails if the context size is set as larger than the source dataset on HF."""
data = PreTokenizedDataset(dataset_path=TEST_DATASET, context_size=512)
with pytest.raises(ValueError, match=r"larger than the tokenized prompt size"):
next(iter(data))