diff --git a/sparse_autoencoder/train/pipeline.py b/sparse_autoencoder/train/pipeline.py index f5664084..dc8ff77a 100644 --- a/sparse_autoencoder/train/pipeline.py +++ b/sparse_autoencoder/train/pipeline.py @@ -84,7 +84,7 @@ def pipeline( # Create a dataloader from the store dataloader = DataLoader( activation_store, - batch_size=8192, + batch_size=sweep_parameters.batch_size, ) # Train the autoencoder diff --git a/sparse_autoencoder/train/sweep_config.py b/sparse_autoencoder/train/sweep_config.py index 495bb46c..7b099dd8 100644 --- a/sparse_autoencoder/train/sweep_config.py +++ b/sparse_autoencoder/train/sweep_config.py @@ -54,6 +54,11 @@ class SweepParameterConfig(Parameters): paper](https://transformer-circuits.pub/2023/monosemantic-features/index.html). """ + batch_size: Parameter[int] | None + """Batch size. + + Used in SAE Forward Pass.""" + # NOTE: This must be kept in sync with SweepParameterConfig @dataclass(frozen=True) @@ -72,6 +77,8 @@ class SweepParametersRuntime(dict[str, Any]): l1_coefficient: float = 0.01 + batch_size: int = 8192 + def to_dict(self) -> dict[str, Any]: """Return dict representation of this object.""" return asdict(self)