diff --git a/sparse_autoencoder/activation_store/base_store.py b/sparse_autoencoder/activation_store/base_store.py index 0ea2a26a..f92e0c34 100644 --- a/sparse_autoencoder/activation_store/base_store.py +++ b/sparse_autoencoder/activation_store/base_store.py @@ -4,6 +4,7 @@ from typing import final from jaxtyping import Float +from pydantic import PositiveInt, validate_call import torch from torch import Tensor from torch.utils.data import Dataset @@ -105,12 +106,13 @@ def shuffle(self) -> None: """Optional shuffle method.""" @final + @validate_call def fill_with_test_data( self, - n_batches: int = 1, - batch_size: int = 16, - n_components: int = 1, - input_features: int = 256, + n_batches: PositiveInt = 1, + batch_size: PositiveInt = 16, + n_components: PositiveInt = 1, + input_features: PositiveInt = 256, ) -> None: """Fill the store with test data. diff --git a/sparse_autoencoder/activation_store/disk_store.py b/sparse_autoencoder/activation_store/disk_store.py index 1552531f..f6ed5137 100644 --- a/sparse_autoencoder/activation_store/disk_store.py +++ b/sparse_autoencoder/activation_store/disk_store.py @@ -4,6 +4,7 @@ import tempfile from jaxtyping import Float +from pydantic import PositiveInt, validate_call import torch from torch import Tensor @@ -70,12 +71,13 @@ def current_activations_stored_per_component(self) -> list[int]: disk_items_stored = len(self) return [cache_items + disk_items_stored for cache_items in self._items_stored] + @validate_call def __init__( self, - n_neurons: int, + n_neurons: PositiveInt, storage_path: Path = DEFAULT_DISK_ACTIVATION_STORE_PATH, - max_cache_size: int = 10_000, - n_components: int = 1, + max_cache_size: PositiveInt = 10_000, + n_components: PositiveInt = 1, *, empty_dir: bool = False, ): diff --git a/sparse_autoencoder/activation_store/tensor_store.py b/sparse_autoencoder/activation_store/tensor_store.py index 3cb45422..cfd85ca7 100644 --- a/sparse_autoencoder/activation_store/tensor_store.py +++ b/sparse_autoencoder/activation_store/tensor_store.py @@ -1,5 +1,6 @@ """Tensor Activation Store.""" from jaxtyping import Float +from pydantic import PositiveInt, validate_call import torch from torch import Tensor @@ -75,11 +76,12 @@ def current_activations_stored_per_component(self) -> list[int]: """Number of activations stored per component.""" return self._items_stored + @validate_call(config={"arbitrary_types_allowed": True}) def __init__( self, - max_items: int, - n_neurons: int, - n_components: int = 1, + max_items: PositiveInt, + n_neurons: PositiveInt, + n_components: PositiveInt = 1, device: torch.device | None = None, ) -> None: """Initialise the Tensor Activation Store. diff --git a/sparse_autoencoder/autoencoder/components/abstract_decoder.py b/sparse_autoencoder/autoencoder/components/abstract_decoder.py index 5b282e74..69e163cc 100644 --- a/sparse_autoencoder/autoencoder/components/abstract_decoder.py +++ b/sparse_autoencoder/autoencoder/components/abstract_decoder.py @@ -3,6 +3,7 @@ from typing import final from jaxtyping import Float, Int64 +from pydantic import PositiveInt, validate_call import torch from torch import Tensor from torch.nn import Module, Parameter @@ -24,11 +25,12 @@ class AbstractDecoder(Module, ABC): _n_components: int | None + @validate_call def __init__( self, - learnt_features: int, - decoded_features: int, - n_components: int | None, + learnt_features: PositiveInt, + decoded_features: PositiveInt, + n_components: PositiveInt | None, ) -> None: """Initialise the decoder. diff --git a/sparse_autoencoder/autoencoder/components/abstract_encoder.py b/sparse_autoencoder/autoencoder/components/abstract_encoder.py index 480f538d..080dfc20 100644 --- a/sparse_autoencoder/autoencoder/components/abstract_encoder.py +++ b/sparse_autoencoder/autoencoder/components/abstract_encoder.py @@ -3,6 +3,7 @@ from typing import final from jaxtyping import Float, Int64 +from pydantic import PositiveInt, validate_call import torch from torch import Tensor from torch.nn import Module, Parameter @@ -25,11 +26,12 @@ class AbstractEncoder(Module, ABC): _n_components: int | None + @validate_call def __init__( self, - input_features: int, - learnt_features: int, - n_components: int | None, + input_features: PositiveInt, + learnt_features: PositiveInt, + n_components: PositiveInt | None, ) -> None: """Initialise the encoder. diff --git a/sparse_autoencoder/autoencoder/components/linear_encoder.py b/sparse_autoencoder/autoencoder/components/linear_encoder.py index 3ce1d46e..50b27a39 100644 --- a/sparse_autoencoder/autoencoder/components/linear_encoder.py +++ b/sparse_autoencoder/autoencoder/components/linear_encoder.py @@ -4,6 +4,7 @@ import einops from jaxtyping import Float +from pydantic import PositiveInt, validate_call import torch from torch import Tensor from torch.nn import Parameter, ReLU, init @@ -77,11 +78,12 @@ def reset_optimizer_parameter_details(self) -> list[tuple[Parameter, int]]: activation_function: ReLU """Activation function.""" + @validate_call def __init__( self, - input_features: int, - learnt_features: int, - n_components: int | None, + input_features: PositiveInt, + learnt_features: PositiveInt, + n_components: PositiveInt | None, ): """Initialize the linear encoder layer. diff --git a/sparse_autoencoder/autoencoder/components/unit_norm_decoder.py b/sparse_autoencoder/autoencoder/components/unit_norm_decoder.py index e1dba03d..ff0e5701 100644 --- a/sparse_autoencoder/autoencoder/components/unit_norm_decoder.py +++ b/sparse_autoencoder/autoencoder/components/unit_norm_decoder.py @@ -3,6 +3,7 @@ import einops from jaxtyping import Float +from pydantic import PositiveInt, validate_call import torch from torch import Tensor from torch.nn import Parameter, init @@ -76,11 +77,12 @@ def reset_optimizer_parameter_details(self) -> list[tuple[Parameter, int]]: """ return [(self.weight, -1)] + @validate_call def __init__( self, - learnt_features: int, - decoded_features: int, - n_components: int | None, + learnt_features: PositiveInt, + decoded_features: PositiveInt, + n_components: PositiveInt | None, *, enable_gradient_hook: bool = True, ) -> None: diff --git a/sparse_autoencoder/autoencoder/model.py b/sparse_autoencoder/autoencoder/model.py index 09753c42..0c3280f5 100644 --- a/sparse_autoencoder/autoencoder/model.py +++ b/sparse_autoencoder/autoencoder/model.py @@ -3,6 +3,7 @@ from typing import final from jaxtyping import Float +from pydantic import PositiveInt, validate_call import torch from torch import Tensor from torch.nn.parameter import Parameter @@ -72,15 +73,16 @@ def post_decoder_bias(self) -> TiedBias: """Post-decoder bias.""" return self._post_decoder_bias + @validate_call(config={"arbitrary_types_allowed": True}) def __init__( self, - n_input_features: int, - n_learned_features: int, + n_input_features: PositiveInt, + n_learned_features: PositiveInt, geometric_median_dataset: Float[ Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE) ] | None = None, - n_components: int | None = None, + n_components: PositiveInt | None = None, ) -> None: """Initialize the Sparse Autoencoder Model. diff --git a/sparse_autoencoder/loss/learned_activations_l1.py b/sparse_autoencoder/loss/learned_activations_l1.py index e917aea2..e302efea 100644 --- a/sparse_autoencoder/loss/learned_activations_l1.py +++ b/sparse_autoencoder/loss/learned_activations_l1.py @@ -2,6 +2,7 @@ from typing import final from jaxtyping import Float +from pydantic import PositiveFloat, validate_call import torch from torch import Tensor @@ -37,8 +38,9 @@ def log_name(self) -> str: """ return "learned_activations_l1_loss_penalty" + @validate_call(config={"arbitrary_types_allowed": True}) def __init__( - self, l1_coefficient: float | Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL)] + self, l1_coefficient: PositiveFloat | Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL)] ) -> None: """Initialize the absolute error loss. diff --git a/sparse_autoencoder/metrics/train/feature_density.py b/sparse_autoencoder/metrics/train/feature_density.py index 9922d94d..8c37c4c6 100644 --- a/sparse_autoencoder/metrics/train/feature_density.py +++ b/sparse_autoencoder/metrics/train/feature_density.py @@ -3,6 +3,7 @@ from jaxtyping import Float import numpy as np from numpy import histogram +from pydantic import NonNegativeFloat, validate_call import torch from torch import Tensor import wandb @@ -32,9 +33,10 @@ class TrainBatchFeatureDensityMetric(AbstractTrainMetric): threshold: float + @validate_call def __init__( self, - threshold: float = 0.0, + threshold: NonNegativeFloat = 0.0, ) -> None: """Initialise the train batch feature density metric. diff --git a/sparse_autoencoder/source_data/abstract_dataset.py b/sparse_autoencoder/source_data/abstract_dataset.py index 3029a59f..5cdd1d3c 100644 --- a/sparse_autoencoder/source_data/abstract_dataset.py +++ b/sparse_autoencoder/source_data/abstract_dataset.py @@ -5,6 +5,7 @@ from datasets import Dataset, IterableDataset, load_dataset from jaxtyping import Int +from pydantic import PositiveInt, validate_call from torch import Tensor from torch.utils.data import DataLoader from torch.utils.data import Dataset as TorchDataset @@ -106,17 +107,18 @@ def preprocess( """ @abstractmethod + @validate_call def __init__( self, dataset_path: str, dataset_split: str, - context_size: int, - buffer_size: int = 1000, + context_size: PositiveInt, + buffer_size: PositiveInt = 1000, dataset_dir: str | None = None, dataset_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None, dataset_column_name: str = "input_ids", - n_processes_preprocessing: int | None = None, - preprocess_batch_size: int = 1000, + n_processes_preprocessing: PositiveInt | None = None, + preprocess_batch_size: PositiveInt = 1000, *, pre_download: bool = False, ): diff --git a/sparse_autoencoder/source_data/mock_dataset.py b/sparse_autoencoder/source_data/mock_dataset.py index bd6f9d53..35668604 100644 --- a/sparse_autoencoder/source_data/mock_dataset.py +++ b/sparse_autoencoder/source_data/mock_dataset.py @@ -7,6 +7,7 @@ from datasets import IterableDataset from jaxtyping import Int +from pydantic import PositiveInt, validate_call import torch from torch import Tensor from transformers import PreTrainedTokenizerFast @@ -139,11 +140,12 @@ def preprocess( # Nothing to do here return source_batch + @validate_call def __init__( self, - context_size: int = 250, - buffer_size: int = 1000, # noqa: ARG002 - preprocess_batch_size: int = 1000, # noqa: ARG002 + context_size: PositiveInt = 250, + buffer_size: PositiveInt = 1000, # noqa: ARG002 + preprocess_batch_size: PositiveInt = 1000, # noqa: ARG002 dataset_path: str = "dummy", # noqa: ARG002 dataset_split: str = "train", # noqa: ARG002 ): diff --git a/sparse_autoencoder/source_data/pretokenized_dataset.py b/sparse_autoencoder/source_data/pretokenized_dataset.py index a80826ae..f84761fd 100644 --- a/sparse_autoencoder/source_data/pretokenized_dataset.py +++ b/sparse_autoencoder/source_data/pretokenized_dataset.py @@ -14,6 +14,8 @@ from collections.abc import Mapping, Sequence from typing import final +from pydantic import PositiveInt, validate_call + from sparse_autoencoder.source_data.abstract_dataset import SourceDataset, TokenizedPrompts @@ -67,16 +69,17 @@ def preprocess( return {"input_ids": context_size_prompts} + @validate_call def __init__( self, dataset_path: str, - context_size: int = 256, - buffer_size: int = 1000, + context_size: PositiveInt = 256, + buffer_size: PositiveInt = 1000, dataset_dir: str | None = None, dataset_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None, dataset_split: str = "train", dataset_column_name: str = "input_ids", - preprocess_batch_size: int = 1000, + preprocess_batch_size: PositiveInt = 1000, *, pre_download: bool = False, ): diff --git a/sparse_autoencoder/source_data/text_dataset.py b/sparse_autoencoder/source_data/text_dataset.py index a2afcc9b..2d12f043 100644 --- a/sparse_autoencoder/source_data/text_dataset.py +++ b/sparse_autoencoder/source_data/text_dataset.py @@ -9,6 +9,7 @@ from typing import TypedDict, final from datasets import IterableDataset +from pydantic import PositiveInt, validate_call from transformers import PreTrainedTokenizerBase from sparse_autoencoder.source_data.abstract_dataset import SourceDataset, TokenizedPrompts @@ -63,18 +64,19 @@ def preprocess( return {"input_ids": context_size_prompts} + @validate_call(config={"arbitrary_types_allowed": True}) def __init__( self, dataset_path: str, tokenizer: PreTrainedTokenizerBase, - buffer_size: int = 1000, - context_size: int = 256, + buffer_size: PositiveInt = 1000, + context_size: PositiveInt = 256, dataset_dir: str | None = None, dataset_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None, dataset_split: str = "train", dataset_column_name: str = "input_ids", - n_processes_preprocessing: int | None = None, - preprocess_batch_size: int = 1000, + n_processes_preprocessing: PositiveInt | None = None, + preprocess_batch_size: PositiveInt = 1000, *, pre_download: bool = False, ): @@ -116,12 +118,13 @@ def __init__( preprocess_batch_size=preprocess_batch_size, ) + @validate_call def push_to_hugging_face_hub( self, repo_id: str, commit_message: str = "Upload preprocessed dataset using sparse_autoencoder.", max_shard_size: str | None = None, - n_shards: int = 64, + n_shards: PositiveInt = 64, revision: str = "main", *, private: bool = False, diff --git a/sparse_autoencoder/train/pipeline.py b/sparse_autoencoder/train/pipeline.py index b1603d1d..933a529c 100644 --- a/sparse_autoencoder/train/pipeline.py +++ b/sparse_autoencoder/train/pipeline.py @@ -7,6 +7,7 @@ from urllib.parse import quote_plus from jaxtyping import Float, Int, Int64 +from pydantic import NonNegativeInt, PositiveInt, validate_call import torch from torch import Tensor from torch.utils.data import DataLoader @@ -91,21 +92,22 @@ def n_components(self) -> int: return len(self.cache_names) @final + @validate_call(config={"arbitrary_types_allowed": True}) def __init__( self, activation_resampler: AbstractActivationResampler | None, autoencoder: SparseAutoencoder, cache_names: list[str], - layer: int, + layer: NonNegativeInt, loss: AbstractLoss, optimizer: AbstractOptimizerWithReset, source_dataset: SourceDataset, source_model: HookedTransformer, run_name: str = "sparse_autoencoder", checkpoint_directory: Path = DEFAULT_CHECKPOINT_DIRECTORY, - log_frequency: int = 100, + log_frequency: PositiveInt = 100, metrics: MetricsContainer = default_metrics, - source_data_batch_size: int = 12, + source_data_batch_size: PositiveInt = 12, ) -> None: """Initialize the pipeline. @@ -142,7 +144,8 @@ def __init__( source_dataloader = source_dataset.get_dataloader(source_data_batch_size) self.source_data = self.stateful_dataloader_iterable(source_dataloader) - def generate_activations(self, store_size: int) -> TensorActivationStore: + @validate_call + def generate_activations(self, store_size: PositiveInt) -> TensorActivationStore: """Generate activations. Args: @@ -152,12 +155,9 @@ def generate_activations(self, store_size: int) -> TensorActivationStore: Activation store for the train section. Raises: - ValueError: If the store size is not positive or is not divisible by the batch size. + ValueError: If the store size is not divisible by the batch size. """ - # Check the store size is positive and divisible by the batch size - if store_size <= 0: - error_message = f"Store size must be positive, got {store_size}" - raise ValueError(error_message) + # Check the store size is divisible by the batch size if store_size % (self.source_data_batch_size * self.source_dataset.context_size) != 0: error_message = ( f"Store size must be divisible by the batch size ({self.source_data_batch_size}), " @@ -195,8 +195,9 @@ def generate_activations(self, store_size: int) -> TensorActivationStore: return store + @validate_call(config={"arbitrary_types_allowed": True}) def train_autoencoder( - self, activation_store: TensorActivationStore, train_batch_size: int + self, activation_store: TensorActivationStore, train_batch_size: PositiveInt ) -> Int64[Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE)]: """Train the sparse autoencoder. @@ -310,7 +311,8 @@ def update_parameters(self, parameter_updates: list[ParameterUpdateResults]) -> component_idx=component_idx, ) - def validate_sae(self, validation_n_activations: int) -> None: + @validate_call + def validate_sae(self, validation_n_activations: PositiveInt) -> None: """Get validation metrics. Args: @@ -416,14 +418,15 @@ def save_checkpoint(self, *, is_final: bool = False) -> Path: return file_path + @validate_call def run_pipeline( self, - train_batch_size: int, - max_store_size: int, - max_activations: int, - validation_n_activations: int = 1024, - validate_frequency: int | None = None, - checkpoint_frequency: int | None = None, + train_batch_size: PositiveInt, + max_store_size: PositiveInt, + max_activations: PositiveInt, + validation_n_activations: PositiveInt = 1024, + validate_frequency: PositiveInt | None = None, + checkpoint_frequency: PositiveInt | None = None, ) -> None: """Run the full training pipeline. diff --git a/sparse_autoencoder/train/sweep.py b/sparse_autoencoder/train/sweep.py index 7d62c914..44673c6f 100644 --- a/sparse_autoencoder/train/sweep.py +++ b/sparse_autoencoder/train/sweep.py @@ -84,6 +84,7 @@ def setup_autoencoder( return SparseAutoencoder( n_input_features=autoencoder_input_dim, n_learned_features=autoencoder_input_dim * expansion_factor, + n_components=len(hyperparameters["source_model"]["cache_names"]), ).to(device) diff --git a/sparse_autoencoder/train/tests/__snapshots__/test_sweep.ambr b/sparse_autoencoder/train/tests/__snapshots__/test_sweep.ambr index 092a9007..7cb942aa 100644 --- a/sparse_autoencoder/train/tests/__snapshots__/test_sweep.ambr +++ b/sparse_autoencoder/train/tests/__snapshots__/test_sweep.ambr @@ -7,10 +7,10 @@ SparseAutoencoder( (_pre_encoder_bias): TiedBias(position=pre_encoder) (_encoder): LinearEncoder( - input_features=512, learnt_features=2048, n_components=None + input_features=512, learnt_features=2048, n_components=1 (activation_function): ReLU() ) - (_decoder): UnitNormDecoder(learnt_features=2048, decoded_features=512, n_components=None) + (_decoder): UnitNormDecoder(learnt_features=2048, decoded_features=512, n_components=1) (_post_decoder_bias): TiedBias(position=post_decoder) ) ''' diff --git a/sparse_autoencoder/train/tests/test_pipeline.py b/sparse_autoencoder/train/tests/test_pipeline.py index a7b8cdb0..eee37e09 100644 --- a/sparse_autoencoder/train/tests/test_pipeline.py +++ b/sparse_autoencoder/train/tests/test_pipeline.py @@ -37,7 +37,7 @@ def pipeline_fixture() -> Pipeline: device = torch.device("cpu") src_model = HookedTransformer.from_pretrained("tiny-stories-1M", device=device) autoencoder = SparseAutoencoder( - src_model.cfg.d_model, src_model.cfg.d_model * 2, n_components=2 + src_model.cfg.d_model, int(src_model.cfg.d_model * 2), n_components=2 ) loss = LossReducer( LearnedActivationsL1Loss( @@ -83,10 +83,11 @@ def test_generates_store(self, pipeline_fixture: Pipeline) -> None: def test_store_has_unique_items(self, pipeline_fixture: Pipeline) -> None: """Test that each item from the store iterable is unique.""" store_size: int = 1000 - store = pipeline_fixture.generate_activations(store_size) + store = pipeline_fixture.generate_activations(store_size // 2) + store2 = pipeline_fixture.generate_activations(store_size // 2) # Get the number of unique activations generated - activations = list(iter(store)) + activations = list(iter(store)) + list(iter(store2)) activations_tensor = torch.stack(activations) unique_activations = activations_tensor.unique(dim=0) @@ -357,7 +358,7 @@ def test_run_pipeline_calls_all_methods(self, pipeline_fixture: Pipeline) -> Non max_activations=store_size * 5, validation_n_activations=store_size, validate_frequency=store_size * (total_loops // validate_expected_calls), - checkpoint_frequency=store_size * (total_loops // checkpoint_expected_calls - 1), + checkpoint_frequency=store_size, ) # Check the number of calls