Skip to content

Commit

Permalink
Add column name support for source datasets (#165)
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney authored Dec 31, 2023
1 parent 75c9b7e commit 065f101
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 17 deletions.
6 changes: 6 additions & 0 deletions sparse_autoencoder/source_data/abstract_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ class SourceDataset(ABC, Generic[HuggingFaceDatasetItem]):
Hugging Face `Dataset` objects are confusingly not the same as PyTorch `Dataset` objects.
"""

_dataset_column_name: str
"""Dataset column name for the prompts."""

@abstractmethod
def preprocess(
self,
Expand Down Expand Up @@ -111,6 +114,7 @@ def __init__(
buffer_size: int = 1000,
dataset_dir: str | None = None,
dataset_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None,
dataset_column_name: str = "input_ids",
n_processes_preprocessing: int | None = None,
preprocess_batch_size: int = 1000,
*,
Expand All @@ -136,6 +140,7 @@ def __init__(
tokenized prompts once the preprocessing function has been applied.
dataset_dir: Defining the `data_dir` of the dataset configuration.
dataset_files: Path(s) to source data file(s).
dataset_column_name: The column name for the prompts.
n_processes_preprocessing: The number of processes to use for preprocessing.
preprocess_batch_size: The batch size to use just for preprocessing the dataset (e.g.
tokenizing prompts).
Expand All @@ -145,6 +150,7 @@ def __init__(
TypeError: If the loaded dataset is not a Hugging Face `Dataset` or `IterableDataset`.
"""
self.context_size = context_size
self._dataset_column_name = dataset_column_name

# Load the dataset
should_stream = not pre_download
Expand Down
22 changes: 7 additions & 15 deletions sparse_autoencoder/source_data/pretokenized_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,21 @@
"""
from collections.abc import Mapping, Sequence
from typing import TypedDict, final
from typing import final

from sparse_autoencoder.source_data.abstract_dataset import SourceDataset, TokenizedPrompts


class PreTokenizedDataBatch(TypedDict):
"""General Pre-Tokenized Dataset Item.
Structure depends on the specific dataset from Hugging Face.
"""

tokens: list[
list[int]
] # This assumes that the dataset structure is similar to the original Neel Nanda dataset.


@final
class PreTokenizedDataset(SourceDataset[PreTokenizedDataBatch]):
class PreTokenizedDataset(SourceDataset[dict]):
"""General Pre-Tokenized Dataset from Hugging Face.
Can be used for various datasets available on Hugging Face.
"""

def preprocess(
self,
source_batch: PreTokenizedDataBatch,
source_batch: dict,
*,
context_size: int,
) -> TokenizedPrompts:
Expand All @@ -52,7 +41,7 @@ def preprocess(
Returns:
Tokenized prompts.
"""
tokenized_prompts: list[list[int]] = source_batch["tokens"]
tokenized_prompts: list[list[int]] = source_batch[self._dataset_column_name]

# Chunk each tokenized prompt into blocks of context_size,
# discarding the last block if too small.
Expand All @@ -75,6 +64,7 @@ def __init__(
dataset_dir: str | None = None,
dataset_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None,
dataset_split: str = "train",
dataset_column_name: str = "input_ids",
preprocess_batch_size: int = 1000,
*,
pre_download: bool = False,
Expand All @@ -95,6 +85,7 @@ def __init__(
dataset_dir: Defining the `data_dir` of the dataset configuration.
dataset_files: Path(s) to source data file(s).
dataset_split: Dataset split (e.g. `train`).
dataset_column_name: The column name for the tokenized prompts.
preprocess_batch_size: The batch size to use just for preprocessing the dataset (e.g.
tokenizing prompts).
pre_download: Whether to pre-download the whole dataset.
Expand All @@ -106,6 +97,7 @@ def __init__(
dataset_files=dataset_files,
dataset_path=dataset_path,
dataset_split=dataset_split,
dataset_column_name=dataset_column_name,
pre_download=pre_download,
preprocess_batch_size=preprocess_batch_size,
)
5 changes: 4 additions & 1 deletion sparse_autoencoder/source_data/text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def preprocess(

# Chunk each tokenized prompt into blocks of context_size, discarding incomplete blocks.
context_size_prompts = []
for encoding in list(tokenized_prompts["input_ids"]): # type: ignore
for encoding in list(tokenized_prompts[self._dataset_column_name]): # type: ignore
chunks = [
encoding[i : i + context_size]
for i in range(0, len(encoding), context_size)
Expand All @@ -72,6 +72,7 @@ def __init__(
dataset_dir: str | None = None,
dataset_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None,
dataset_split: str = "train",
dataset_column_name: str = "input_ids",
n_processes_preprocessing: int | None = None,
preprocess_batch_size: int = 1000,
*,
Expand All @@ -95,6 +96,7 @@ def __init__(
dataset_dir: Defining the `data_dir` of the dataset configuration.
dataset_files: Path(s) to source data file(s).
dataset_split: Dataset split (e.g., 'train').
dataset_column_name: The column name for the prompts.
n_processes_preprocessing: Number of processes to use for preprocessing.
preprocess_batch_size: Batch size for preprocessing (tokenizing prompts).
pre_download: Whether to pre-download the whole dataset.
Expand All @@ -108,6 +110,7 @@ def __init__(
dataset_files=dataset_files,
dataset_path=dataset_path,
dataset_split=dataset_split,
dataset_column_name=dataset_column_name,
n_processes_preprocessing=n_processes_preprocessing,
pre_download=pre_download,
preprocess_batch_size=preprocess_batch_size,
Expand Down
2 changes: 2 additions & 0 deletions sparse_autoencoder/train/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def setup_source_data(hyperparameters: RuntimeHyperparameters) -> SourceDataset:
dataset_dir=dataset_dir,
dataset_files=dataset_files,
dataset_path=hyperparameters["source_data"]["dataset_path"],
dataset_column_name=hyperparameters["source_data"]["dataset_column_name"],
pre_download=hyperparameters["source_data"]["pre_download"],
)

Expand All @@ -175,6 +176,7 @@ def setup_source_data(hyperparameters: RuntimeHyperparameters) -> SourceDataset:

return TextDataset(
context_size=hyperparameters["source_data"]["context_size"],
dataset_column_name=hyperparameters["source_data"]["dataset_column_name"],
dataset_dir=dataset_dir,
dataset_files=dataset_files,
dataset_path=hyperparameters["source_data"]["dataset_path"],
Expand Down
4 changes: 4 additions & 0 deletions sparse_autoencoder/train/sweep_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ class SourceDataHyperparameters(NestedParameter):
context_size: Parameter[int] = field(default=Parameter(DEFAULT_SOURCE_CONTEXT_SIZE))
"""Context size."""

dataset_column_name: Parameter[str] | None = field(default=Parameter(value="input_ids"))
"""Dataset column name."""

dataset_dir: Parameter[str] | None = field(default=None)
"""Dataset directory (within the HF dataset)"""

Expand Down Expand Up @@ -211,6 +214,7 @@ class SourceDataRuntimeHyperparameters(TypedDict):
"""Source data runtime hyperparameters."""

context_size: int
dataset_column_name: str
dataset_dir: str | None
dataset_files: list[str] | None
dataset_path: str
Expand Down
3 changes: 2 additions & 1 deletion sparse_autoencoder/train/tests/test_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@ def dummy_hyperparameters() -> RuntimeHyperparameters:
"random_seed": 49,
"source_data": {
"context_size": 128,
"dataset_column_name": "input_ids",
"dataset_dir": None,
"dataset_files": None,
"dataset_path": "NeelNanda/c4-code-tokenized-2b",
"pre_download": False,
"pre_tokenized": True,
"tokenizer_name": None,
"pre_download": False,
},
"source_model": {
"dtype": "float32",
Expand Down

0 comments on commit 065f101

Please sign in to comment.