Skip to content

Commit

Permalink
Remove abstract resampler class (#183)
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney authored Jan 9, 2024
1 parent fb53d5d commit 830adc8
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 92 deletions.

This file was deleted.

32 changes: 24 additions & 8 deletions sparse_autoencoder/activation_resampler/activation_resampler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Activation resampler."""
from dataclasses import dataclass
from typing import Annotated, NamedTuple

from einops import rearrange
Expand All @@ -8,10 +9,6 @@
from torch import Tensor
from torch.utils.data import DataLoader

from sparse_autoencoder.activation_resampler.abstract_activation_resampler import (
AbstractActivationResampler,
ParameterUpdateResults,
)
from sparse_autoencoder.activation_resampler.utils.component_slice_tensor import (
get_component_slice_tensor,
)
Expand All @@ -23,6 +20,27 @@
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes


@dataclass
class ParameterUpdateResults:
"""Parameter update results from resampling dead neurons."""

dead_neuron_indices: Int64[Tensor, Axis.LEARNT_FEATURE_IDX]
"""Dead neuron indices."""

dead_encoder_weight_updates: Float[
Tensor, Axis.names(Axis.DEAD_FEATURE, Axis.INPUT_OUTPUT_FEATURE)
]
"""Dead encoder weight updates."""

dead_encoder_bias_updates: Float[Tensor, Axis.DEAD_FEATURE]
"""Dead encoder bias updates."""

dead_decoder_weight_updates: Float[
Tensor, Axis.names(Axis.INPUT_OUTPUT_FEATURE, Axis.DEAD_FEATURE)
]
"""Dead decoder weight updates."""


class LossInputActivationsTuple(NamedTuple):
"""Loss and corresponding input activations tuple."""

Expand All @@ -32,7 +50,7 @@ class LossInputActivationsTuple(NamedTuple):
]


class ActivationResampler(AbstractActivationResampler):
class ActivationResampler:
"""Activation resampler.
Collates the number of times each neuron fires over a set number of learned activation vectors,
Expand Down Expand Up @@ -510,9 +528,7 @@ def resample_dead_neurons(

def step_resampler(
self,
batch_neuron_activity: Int64[
Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
],
batch_neuron_activity: Int64[Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE)],
activation_store: ActivationStore,
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
loss_fn: AbstractLoss,
Expand Down
8 changes: 4 additions & 4 deletions sparse_autoencoder/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from transformer_lens import HookedTransformer
import wandb

from sparse_autoencoder.activation_resampler.abstract_activation_resampler import (
AbstractActivationResampler,
from sparse_autoencoder.activation_resampler.activation_resampler import (
ActivationResampler,
ParameterUpdateResults,
)
from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore
Expand Down Expand Up @@ -48,7 +48,7 @@ class Pipeline:
hyperparameters.
"""

activation_resampler: AbstractActivationResampler | None
activation_resampler: ActivationResampler | None
"""Activation resampler to use."""

autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder]
Expand Down Expand Up @@ -96,7 +96,7 @@ def n_components(self) -> int:
@validate_call(config={"arbitrary_types_allowed": True})
def __init__(
self,
activation_resampler: AbstractActivationResampler | None,
activation_resampler: ActivationResampler | None,
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
cache_names: list[str],
layer: NonNegativeInt,
Expand Down
4 changes: 2 additions & 2 deletions sparse_autoencoder/train/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
Pipeline,
SparseAutoencoder,
)
from sparse_autoencoder.activation_resampler.abstract_activation_resampler import (
from sparse_autoencoder.activation_resampler.activation_resampler import (
ActivationResampler,
ParameterUpdateResults,
)
from sparse_autoencoder.activation_resampler.activation_resampler import ActivationResampler
from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore
from sparse_autoencoder.autoencoder.model import SparseAutoencoderConfig
from sparse_autoencoder.metrics.abstract_metric import MetricResult
Expand Down

0 comments on commit 830adc8

Please sign in to comment.