Skip to content

Commit

Permalink
Make the activation store support multiple component dimensions (#160)
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney authored Dec 17, 2023
1 parent 73fd7bd commit 8854183
Show file tree
Hide file tree
Showing 21 changed files with 521 additions and 734 deletions.
2 changes: 0 additions & 2 deletions sparse_autoencoder/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Sparse Autoencoder Library."""
from sparse_autoencoder.activation_resampler.activation_resampler import ActivationResampler
from sparse_autoencoder.activation_store.disk_store import DiskActivationStore
from sparse_autoencoder.activation_store.list_store import ListActivationStore
from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore
from sparse_autoencoder.autoencoder.model import SparseAutoencoder
from sparse_autoencoder.loss.abstract_loss import LossLogType, LossReductionType
Expand Down Expand Up @@ -65,7 +64,6 @@
"Kind",
"L2ReconstructionLoss",
"LearnedActivationsL1Loss",
"ListActivationStore",
"LossHyperparameters",
"LossLogType",
"LossReducer",
Expand Down
49 changes: 42 additions & 7 deletions sparse_autoencoder/activation_store/base_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from sparse_autoencoder.tensor_types import Axis


class ActivationStore(Dataset[Float[Tensor, Axis.INPUT_OUTPUT_FEATURE]], ABC):
class ActivationStore(
Dataset[Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)]], ABC
):
"""Activation Store Abstract Class.
Extends the `torch.utils.data.Dataset` class to provide an activation store, with additional
Expand All @@ -25,6 +27,15 @@ class ActivationStore(Dataset[Float[Tensor, Axis.INPUT_OUTPUT_FEATURE]], ABC):
Example:
>>> import torch
>>> class MyActivationStore(ActivationStore):
...
... @property
... def current_activations_stored_per_component(self):
... raise NotImplementedError
...
... @property
... def num_components(self):
... raise NotImplementedError
...
... def __init__(self):
... super().__init__()
... self._data = [] # In this example, we just store in a list
Expand All @@ -51,33 +62,55 @@ class ActivationStore(Dataset[Float[Tensor, Axis.INPUT_OUTPUT_FEATURE]], ABC):
"""

@abstractmethod
def append(self, item: Float[Tensor, Axis.INPUT_OUTPUT_FEATURE]) -> Future | None:
def append(
self,
item: Float[Tensor, Axis.names(Axis.INPUT_OUTPUT_FEATURE)],
component_idx: int = 0,
) -> Future | None:
"""Add a Single Item to the Store."""

@abstractmethod
def extend(
self, batch: Float[Tensor, Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)]
self,
batch: Float[Tensor, Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)],
component_idx: int = 0,
) -> Future | None:
"""Add a Batch to the Store."""

@abstractmethod
def empty(self) -> None:
"""Empty the Store."""

@property
@abstractmethod
def num_components(self) -> int:
"""Number of components."""

@property
@abstractmethod
def current_activations_stored_per_component(self) -> list[int]:
"""Current activations stored per component."""

@abstractmethod
def __len__(self) -> int:
"""Get the Length of the Store."""

@abstractmethod
def __getitem__(self, index: int) -> Float[Tensor, Axis.INPUT_OUTPUT_FEATURE]:
def __getitem__(
self, index: int
) -> Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)]:
"""Get an Item from the Store."""

def shuffle(self) -> None:
"""Optional shuffle method."""

@final
def fill_with_test_data(
self, num_batches: int = 16, batch_size: int = 16, input_features: int = 256
self,
num_batches: int = 16,
batch_size: int = 16,
num_components: int = 1,
input_features: int = 256,
) -> None:
"""Fill the store with test data.
Expand All @@ -99,11 +132,13 @@ def fill_with_test_data(
Args:
num_batches: Number of batches to fill the store with.
batch_size: Number of items per batch.
num_components: Number of source model components the SAE is trained on.
input_features: Number of input features per item.
"""
for _ in range(num_batches):
sample = torch.rand((batch_size, input_features))
self.extend(sample)
for component_idx in range(num_components):
sample = torch.rand(batch_size, input_features)
self.extend(sample, component_idx)


class StoreFullError(IndexError):
Expand Down
Loading

0 comments on commit 8854183

Please sign in to comment.