diff --git a/sparse_autoencoder/source_data/abstract_dataset.py b/sparse_autoencoder/source_data/abstract_dataset.py index 5b90583c..12594d9b 100644 --- a/sparse_autoencoder/source_data/abstract_dataset.py +++ b/sparse_autoencoder/source_data/abstract_dataset.py @@ -70,6 +70,9 @@ class SourceDataset(ABC, Generic[HuggingFaceDatasetItem]): Hugging Face `Dataset` objects are confusingly not the same as PyTorch `Dataset` objects. """ + _dataset_column_name: str + """Dataset column name for the prompts.""" + @abstractmethod def preprocess( self, @@ -111,6 +114,7 @@ def __init__( buffer_size: int = 1000, dataset_dir: str | None = None, dataset_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None, + dataset_column_name: str = "input_ids", n_processes_preprocessing: int | None = None, preprocess_batch_size: int = 1000, *, @@ -136,6 +140,7 @@ def __init__( tokenized prompts once the preprocessing function has been applied. dataset_dir: Defining the `data_dir` of the dataset configuration. dataset_files: Path(s) to source data file(s). + dataset_column_name: The column name for the prompts. n_processes_preprocessing: The number of processes to use for preprocessing. preprocess_batch_size: The batch size to use just for preprocessing the dataset (e.g. tokenizing prompts). @@ -145,6 +150,7 @@ def __init__( TypeError: If the loaded dataset is not a Hugging Face `Dataset` or `IterableDataset`. """ self.context_size = context_size + self._dataset_column_name = dataset_column_name # Load the dataset should_stream = not pre_download diff --git a/sparse_autoencoder/source_data/pretokenized_dataset.py b/sparse_autoencoder/source_data/pretokenized_dataset.py index a29fd997..ae1c2892 100644 --- a/sparse_autoencoder/source_data/pretokenized_dataset.py +++ b/sparse_autoencoder/source_data/pretokenized_dataset.py @@ -12,24 +12,13 @@ """ from collections.abc import Mapping, Sequence -from typing import TypedDict, final +from typing import final from sparse_autoencoder.source_data.abstract_dataset import SourceDataset, TokenizedPrompts -class PreTokenizedDataBatch(TypedDict): - """General Pre-Tokenized Dataset Item. - - Structure depends on the specific dataset from Hugging Face. - """ - - tokens: list[ - list[int] - ] # This assumes that the dataset structure is similar to the original Neel Nanda dataset. - - @final -class PreTokenizedDataset(SourceDataset[PreTokenizedDataBatch]): +class PreTokenizedDataset(SourceDataset[dict]): """General Pre-Tokenized Dataset from Hugging Face. Can be used for various datasets available on Hugging Face. @@ -37,7 +26,7 @@ class PreTokenizedDataset(SourceDataset[PreTokenizedDataBatch]): def preprocess( self, - source_batch: PreTokenizedDataBatch, + source_batch: dict, *, context_size: int, ) -> TokenizedPrompts: @@ -52,7 +41,7 @@ def preprocess( Returns: Tokenized prompts. """ - tokenized_prompts: list[list[int]] = source_batch["tokens"] + tokenized_prompts: list[list[int]] = source_batch[self._dataset_column_name] # Chunk each tokenized prompt into blocks of context_size, # discarding the last block if too small. @@ -75,6 +64,7 @@ def __init__( dataset_dir: str | None = None, dataset_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None, dataset_split: str = "train", + dataset_column_name: str = "input_ids", preprocess_batch_size: int = 1000, *, pre_download: bool = False, @@ -95,6 +85,7 @@ def __init__( dataset_dir: Defining the `data_dir` of the dataset configuration. dataset_files: Path(s) to source data file(s). dataset_split: Dataset split (e.g. `train`). + dataset_column_name: The column name for the tokenized prompts. preprocess_batch_size: The batch size to use just for preprocessing the dataset (e.g. tokenizing prompts). pre_download: Whether to pre-download the whole dataset. @@ -106,6 +97,7 @@ def __init__( dataset_files=dataset_files, dataset_path=dataset_path, dataset_split=dataset_split, + dataset_column_name=dataset_column_name, pre_download=pre_download, preprocess_batch_size=preprocess_batch_size, ) diff --git a/sparse_autoencoder/source_data/text_dataset.py b/sparse_autoencoder/source_data/text_dataset.py index becbc3a6..a2afcc9b 100644 --- a/sparse_autoencoder/source_data/text_dataset.py +++ b/sparse_autoencoder/source_data/text_dataset.py @@ -53,7 +53,7 @@ def preprocess( # Chunk each tokenized prompt into blocks of context_size, discarding incomplete blocks. context_size_prompts = [] - for encoding in list(tokenized_prompts["input_ids"]): # type: ignore + for encoding in list(tokenized_prompts[self._dataset_column_name]): # type: ignore chunks = [ encoding[i : i + context_size] for i in range(0, len(encoding), context_size) @@ -72,6 +72,7 @@ def __init__( dataset_dir: str | None = None, dataset_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None, dataset_split: str = "train", + dataset_column_name: str = "input_ids", n_processes_preprocessing: int | None = None, preprocess_batch_size: int = 1000, *, @@ -95,6 +96,7 @@ def __init__( dataset_dir: Defining the `data_dir` of the dataset configuration. dataset_files: Path(s) to source data file(s). dataset_split: Dataset split (e.g., 'train'). + dataset_column_name: The column name for the prompts. n_processes_preprocessing: Number of processes to use for preprocessing. preprocess_batch_size: Batch size for preprocessing (tokenizing prompts). pre_download: Whether to pre-download the whole dataset. @@ -108,6 +110,7 @@ def __init__( dataset_files=dataset_files, dataset_path=dataset_path, dataset_split=dataset_split, + dataset_column_name=dataset_column_name, n_processes_preprocessing=n_processes_preprocessing, pre_download=pre_download, preprocess_batch_size=preprocess_batch_size, diff --git a/sparse_autoencoder/train/sweep.py b/sparse_autoencoder/train/sweep.py index fc83a9c6..7d62c914 100644 --- a/sparse_autoencoder/train/sweep.py +++ b/sparse_autoencoder/train/sweep.py @@ -161,6 +161,7 @@ def setup_source_data(hyperparameters: RuntimeHyperparameters) -> SourceDataset: dataset_dir=dataset_dir, dataset_files=dataset_files, dataset_path=hyperparameters["source_data"]["dataset_path"], + dataset_column_name=hyperparameters["source_data"]["dataset_column_name"], pre_download=hyperparameters["source_data"]["pre_download"], ) @@ -175,6 +176,7 @@ def setup_source_data(hyperparameters: RuntimeHyperparameters) -> SourceDataset: return TextDataset( context_size=hyperparameters["source_data"]["context_size"], + dataset_column_name=hyperparameters["source_data"]["dataset_column_name"], dataset_dir=dataset_dir, dataset_files=dataset_files, dataset_path=hyperparameters["source_data"]["dataset_path"], diff --git a/sparse_autoencoder/train/sweep_config.py b/sparse_autoencoder/train/sweep_config.py index 806ed3bf..4239c06c 100644 --- a/sparse_autoencoder/train/sweep_config.py +++ b/sparse_autoencoder/train/sweep_config.py @@ -174,6 +174,9 @@ class SourceDataHyperparameters(NestedParameter): context_size: Parameter[int] = field(default=Parameter(DEFAULT_SOURCE_CONTEXT_SIZE)) """Context size.""" + dataset_column_name: Parameter[str] | None = field(default=Parameter(value="input_ids")) + """Dataset column name.""" + dataset_dir: Parameter[str] | None = field(default=None) """Dataset directory (within the HF dataset)""" @@ -211,6 +214,7 @@ class SourceDataRuntimeHyperparameters(TypedDict): """Source data runtime hyperparameters.""" context_size: int + dataset_column_name: str dataset_dir: str | None dataset_files: list[str] | None dataset_path: str diff --git a/sparse_autoencoder/train/tests/test_sweep.py b/sparse_autoencoder/train/tests/test_sweep.py index 57eda404..e18862e4 100644 --- a/sparse_autoencoder/train/tests/test_sweep.py +++ b/sparse_autoencoder/train/tests/test_sweep.py @@ -48,12 +48,13 @@ def dummy_hyperparameters() -> RuntimeHyperparameters: "random_seed": 49, "source_data": { "context_size": 128, + "dataset_column_name": "input_ids", "dataset_dir": None, "dataset_files": None, "dataset_path": "NeelNanda/c4-code-tokenized-2b", + "pre_download": False, "pre_tokenized": True, "tokenizer_name": None, - "pre_download": False, }, "source_model": { "dtype": "float32",