Skip to content

Commit

Permalink
Simplify dataparallel approach (#191)
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney authored Jan 26, 2024
1 parent 2ec1af7 commit abc0291
Show file tree
Hide file tree
Showing 16 changed files with 763 additions and 748 deletions.
5 changes: 4 additions & 1 deletion .vscode/cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"jaxtyping",
"kaiming",
"keepdim",
"logit",
"lognormal",
"loguniform",
"loguniformvalues",
Expand All @@ -73,6 +74,7 @@
"neox",
"nonlinerity",
"numel",
"onebit",
"openwebtext",
"optim",
"penality",
Expand Down Expand Up @@ -116,6 +118,7 @@
"venv",
"virtualenv",
"virtualenvs",
"wandb"
"wandb",
"zoadam"
]
}
1,253 changes: 648 additions & 605 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
pytest-integration=">=0.2.3"
pytest-timeout=">=2.2.0"
pytest-xdist="^3.5.0"
ruff=">=0.1.4"
ruff=">=0.1.14"
syrupy=">=4.6.0"

[tool.poetry.group.demos.dependencies]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pydantic import Field, NonNegativeInt, PositiveInt, validate_call
import torch
from torch import Tensor
from torch.nn.parallel import DataParallel
from torch.utils.data import DataLoader

from sparse_autoencoder.activation_resampler.utils.component_slice_tensor import (
Expand All @@ -17,7 +18,6 @@
from sparse_autoencoder.loss.abstract_loss import AbstractLoss
from sparse_autoencoder.tensor_types import Axis
from sparse_autoencoder.train.utils.get_model_device import get_model_device
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes


@dataclass
Expand Down Expand Up @@ -207,7 +207,7 @@ def _get_dead_neuron_indices(
def compute_loss_and_get_activations(
self,
store: ActivationStore,
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder],
loss_fn: AbstractLoss,
train_batch_size: int,
) -> LossInputActivationsTuple:
Expand Down Expand Up @@ -440,7 +440,7 @@ def renormalize_and_scale(
def resample_dead_neurons(
self,
activation_store: ActivationStore,
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder],
loss_fn: AbstractLoss,
train_batch_size: int,
) -> list[ParameterUpdateResults]:
Expand Down Expand Up @@ -530,7 +530,7 @@ def step_resampler(
self,
batch_neuron_activity: Int64[Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE)],
activation_store: ActivationStore,
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder],
loss_fn: AbstractLoss,
train_batch_size: int,
) -> list[ParameterUpdateResults] | None:
Expand Down
2 changes: 1 addition & 1 deletion sparse_autoencoder/autoencoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def load(
The loaded model.
"""
# Load the file
serialized_state = torch.load(file_path)
serialized_state = torch.load(file_path, map_location=torch.device("cpu"))
state = SparseAutoencoderState.model_validate(serialized_state)

# Initialise the model
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
"""Test the model reconstruction score metric."""

from jaxtyping import Float
import pytest
from syrupy.session import SnapshotSession
import torch
from torch import Tensor
from torch import Tensor, tensor

from sparse_autoencoder.metrics.utils.find_metric_result import find_metric_result
from sparse_autoencoder.metrics.validate.abstract_validate_metric import ValidationMetricData
from sparse_autoencoder.metrics.validate.model_reconstruction_score import ModelReconstructionScore
from sparse_autoencoder.tensor_types import Axis


def test_model_reconstruction_score_empty_data() -> None:
Expand All @@ -19,9 +17,9 @@ def test_model_reconstruction_score_empty_data() -> None:
is provided (i.e., at the end of training or in similar scenarios).
"""
data = ValidationMetricData(
source_model_loss=Float[Tensor, Axis.ITEMS]([]),
source_model_loss_with_reconstruction=Float[Tensor, Axis.ITEMS]([]),
source_model_loss_with_zero_ablation=Float[Tensor, Axis.ITEMS]([]),
source_model_loss=tensor([]),
source_model_loss_with_reconstruction=tensor([]),
source_model_loss_with_zero_ablation=tensor([]),
)
metric = ModelReconstructionScore()
result = metric.calculate(data)
Expand All @@ -41,13 +39,9 @@ def test_model_reconstruction_score_empty_data() -> None:
),
(
ValidationMetricData(
source_model_loss=Float[Tensor, Axis.ITEMS]([[0.5], [1.5], [2.5]]),
source_model_loss_with_reconstruction=Float[Tensor, Axis.ITEMS](
[[1.5], [2.5], [3.5]]
),
source_model_loss_with_zero_ablation=Float[Tensor, Axis.ITEMS](
[[8.0], [7.0], [4.0]]
),
source_model_loss=tensor([[0.5], [1.5], [2.5]]),
source_model_loss_with_reconstruction=tensor([[1.5], [2.5], [3.5]]),
source_model_loss_with_zero_ablation=tensor([[8.0], [7.0], [4.0]]),
),
0.79,
),
Expand Down
2 changes: 1 addition & 1 deletion sparse_autoencoder/optimizer/adam_with_reset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__( # (extending existing implementation)
lr: float | Float[Tensor, Axis.names(Axis.SINGLE_ITEM)] = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
weight_decay: float = 0.0,
*,
amsgrad: bool = False,
foreach: bool | None = None,
Expand Down
22 changes: 11 additions & 11 deletions sparse_autoencoder/source_model/replace_activations_hook.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,36 @@
"""Replace activations hook."""
from typing import TYPE_CHECKING

from jaxtyping import Float
from torch import Tensor
from torch.nn.parallel import DataParallel
from transformer_lens.hook_points import HookPoint

from sparse_autoencoder.autoencoder.model import SparseAutoencoder
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes


if TYPE_CHECKING:
from sparse_autoencoder.tensor_types import Axis
from jaxtyping import Float


def replace_activations_hook(
value: Tensor,
hook: HookPoint, # noqa: ARG001
sparse_autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
sparse_autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder],
component_idx: int | None = None,
n_components: int | None = None,
) -> Tensor:
"""Replace activations hook.
This should be pre-initialised with `functools.partial`.
Args:
value: The activations to replace.
hook: The hook point.
sparse_autoencoder: The sparse autoencoder. This should be pre-initialised with
`functools.partial`.
sparse_autoencoder: The sparse autoencoder.
component_idx: The component index to replace the activations with, if just replacing
activations for a single component. Requires the model to have a component axis.
n_components: The number of components that the SAE is trained on.
Returns:
Replaced activations.
Expand All @@ -43,11 +46,8 @@ def replace_activations_hook(
)

if component_idx is not None:
if sparse_autoencoder.config.n_components is None:
error_message = (
"Cannot replace for a specific component, if the model does not have a "
"component axis."
)
if n_components is None:
error_message = "The number of model components must be set if component_idx is set."
raise RuntimeError(error_message)

# The approach here is to run a forward pass with dummy values for all components other than
Expand All @@ -56,7 +56,7 @@ def replace_activations_hook(
# components.
expanded_shape = [
squashed_value.shape[0],
sparse_autoencoder.config.n_components,
n_components,
squashed_value.shape[-1],
]
expanded = squashed_value.unsqueeze(1).expand(*expanded_shape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ def test_hook_replaces_activations_2_components() -> None:
fwd_hooks=[
(
"blocks.0.hook_mlp_out",
partial(replace_activations_hook, sparse_autoencoder=autoencoder, component_idx=1),
partial(
replace_activations_hook,
sparse_autoencoder=autoencoder,
component_idx=1,
n_components=2,
),
)
],
)
Expand Down
34 changes: 24 additions & 10 deletions sparse_autoencoder/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pydantic import NonNegativeInt, PositiveInt, validate_call
import torch
from torch import Tensor
from torch.nn.parallel import DataParallel
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
Expand All @@ -32,7 +33,6 @@
from sparse_autoencoder.source_model.zero_ablate_hook import zero_ablate_hook
from sparse_autoencoder.tensor_types import Axis
from sparse_autoencoder.train.utils.get_model_device import get_model_device
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes


if TYPE_CHECKING:
Expand All @@ -51,9 +51,15 @@ class Pipeline:
activation_resampler: ActivationResampler | None
"""Activation resampler to use."""

autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder]
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder]
"""Sparse autoencoder to train."""

n_input_features: int
"""Number of input features in the sparse autoencoder."""

n_learned_features: int
"""Number of learned features in the sparse autoencoder."""

cache_names: list[str]
"""Names of the cache hook points to use in the source model."""

Expand Down Expand Up @@ -81,7 +87,7 @@ class Pipeline:
source_dataset: SourceDataset
"""Source dataset to generate activation data from (tokenized prompts)."""

source_model: HookedTransformer | DataParallelWithModelAttributes[HookedTransformer]
source_model: HookedTransformer | DataParallel[HookedTransformer]
"""Source model to get activations from."""

total_activations_trained_on: int = 0
Expand All @@ -97,13 +103,15 @@ def n_components(self) -> int:
def __init__(
self,
activation_resampler: ActivationResampler | None,
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder],
cache_names: list[str],
layer: NonNegativeInt,
loss: AbstractLoss,
optimizer: AbstractOptimizerWithReset,
source_dataset: SourceDataset,
source_model: HookedTransformer | DataParallelWithModelAttributes[HookedTransformer],
source_model: HookedTransformer | DataParallel[HookedTransformer],
n_input_features: int,
n_learned_features: int,
run_name: str = "sparse_autoencoder",
checkpoint_directory: Path = DEFAULT_CHECKPOINT_DIRECTORY,
lr_scheduler: LRScheduler | None = None,
Expand All @@ -124,6 +132,8 @@ def __init__(
optimizer: Optimizer to use.
source_dataset: Source dataset to get data from.
source_model: Source model to get activations from.
n_input_features: Number of input features in the sparse autoencoder.
n_learned_features: Number of learned features in the sparse autoencoder.
run_name: Name of the run for saving checkpoints.
checkpoint_directory: Directory to save checkpoints to.
lr_scheduler: Learning rate scheduler to use.
Expand All @@ -146,6 +156,8 @@ def __init__(
self.source_data_batch_size = source_data_batch_size
self.source_dataset = source_dataset
self.source_model = source_model
self.n_input_features = n_input_features
self.n_learned_features = n_learned_features

# Create a stateful iterator
source_dataloader = source_dataset.get_dataloader(
Expand Down Expand Up @@ -175,9 +187,10 @@ def generate_activations(self, store_size: PositiveInt) -> TensorActivationStore
raise ValueError(error_message)

# Setup the store
n_neurons: int = self.autoencoder.config.n_input_features
source_model_device: torch.device = get_model_device(self.source_model)
store = TensorActivationStore(store_size, n_neurons, n_components=self.n_components)
store = TensorActivationStore(
store_size, self.n_input_features, n_components=self.n_components
)

# Add the hook to the model (will automatically store the activations every time the model
# runs)
Expand Down Expand Up @@ -225,9 +238,9 @@ def train_autoencoder(
learned_activations_fired_count: Int64[
Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE)
] = torch.zeros(
(self.n_components, self.autoencoder.config.n_learned_features),
(self.n_components, self.n_learned_features),
dtype=torch.int64,
device=autoencoder_device,
device=torch.device("cpu"),
)

for store_batch in activations_dataloader:
Expand Down Expand Up @@ -260,7 +273,7 @@ def train_autoencoder(
# Store count of how many neurons have fired
with torch.no_grad():
fired = learned_activations > 0
learned_activations_fired_count.add_(fired.sum(dim=0))
learned_activations_fired_count.add_(fired.sum(dim=0).cpu())

# Backwards pass
total_loss.backward()
Expand Down Expand Up @@ -358,6 +371,7 @@ def validate_sae(self, validation_n_activations: PositiveInt) -> None:
replace_activations_hook,
sparse_autoencoder=self.autoencoder,
component_idx=component_idx,
n_components=self.n_components,
)

with torch.no_grad():
Expand Down
Loading

0 comments on commit abc0291

Please sign in to comment.