Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the activation store support multiple component dimensions #160

Merged
merged 11 commits into from
Dec 17, 2023
2 changes: 0 additions & 2 deletions sparse_autoencoder/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Sparse Autoencoder Library."""
from sparse_autoencoder.activation_resampler.activation_resampler import ActivationResampler
from sparse_autoencoder.activation_store.disk_store import DiskActivationStore
from sparse_autoencoder.activation_store.list_store import ListActivationStore
from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore
from sparse_autoencoder.autoencoder.model import SparseAutoencoder
from sparse_autoencoder.loss.abstract_loss import LossLogType, LossReductionType
Expand Down Expand Up @@ -65,7 +64,6 @@
"Kind",
"L2ReconstructionLoss",
"LearnedActivationsL1Loss",
"ListActivationStore",
"LossHyperparameters",
"LossLogType",
"LossReducer",
Expand Down
49 changes: 42 additions & 7 deletions sparse_autoencoder/activation_store/base_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from sparse_autoencoder.tensor_types import Axis


class ActivationStore(Dataset[Float[Tensor, Axis.INPUT_OUTPUT_FEATURE]], ABC):
class ActivationStore(
Dataset[Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)]], ABC
):
"""Activation Store Abstract Class.

Extends the `torch.utils.data.Dataset` class to provide an activation store, with additional
Expand All @@ -25,6 +27,15 @@ class ActivationStore(Dataset[Float[Tensor, Axis.INPUT_OUTPUT_FEATURE]], ABC):
Example:
>>> import torch
>>> class MyActivationStore(ActivationStore):
...
... @property
... def current_activations_stored_per_component(self):
... raise NotImplementedError
...
... @property
... def num_components(self):
... raise NotImplementedError
...
... def __init__(self):
... super().__init__()
... self._data = [] # In this example, we just store in a list
Expand All @@ -51,33 +62,55 @@ class ActivationStore(Dataset[Float[Tensor, Axis.INPUT_OUTPUT_FEATURE]], ABC):
"""

@abstractmethod
def append(self, item: Float[Tensor, Axis.INPUT_OUTPUT_FEATURE]) -> Future | None:
def append(
self,
item: Float[Tensor, Axis.names(Axis.INPUT_OUTPUT_FEATURE)],
component_idx: int = 0,
) -> Future | None:
"""Add a Single Item to the Store."""

@abstractmethod
def extend(
self, batch: Float[Tensor, Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)]
self,
batch: Float[Tensor, Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)],
component_idx: int = 0,
) -> Future | None:
"""Add a Batch to the Store."""

@abstractmethod
def empty(self) -> None:
"""Empty the Store."""

@property
@abstractmethod
def num_components(self) -> int:
"""Number of components."""

@property
@abstractmethod
def current_activations_stored_per_component(self) -> list[int]:
"""Current activations stored per component."""

@abstractmethod
def __len__(self) -> int:
"""Get the Length of the Store."""

@abstractmethod
def __getitem__(self, index: int) -> Float[Tensor, Axis.INPUT_OUTPUT_FEATURE]:
def __getitem__(
self, index: int
) -> Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)]:
"""Get an Item from the Store."""

def shuffle(self) -> None:
"""Optional shuffle method."""

@final
def fill_with_test_data(
self, num_batches: int = 16, batch_size: int = 16, input_features: int = 256
self,
num_batches: int = 16,
batch_size: int = 16,
num_components: int = 1,
input_features: int = 256,
) -> None:
"""Fill the store with test data.

Expand All @@ -99,11 +132,13 @@ def fill_with_test_data(
Args:
num_batches: Number of batches to fill the store with.
batch_size: Number of items per batch.
num_components: Number of source model components the SAE is trained on.
input_features: Number of input features per item.
"""
for _ in range(num_batches):
sample = torch.rand((batch_size, input_features))
self.extend(sample)
for component_idx in range(num_components):
sample = torch.rand(batch_size, input_features)
self.extend(sample, component_idx)


class StoreFullError(IndexError):
Expand Down
Loading