Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deepspeed support #186

Merged
merged 5 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .vscode/cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"jaxtyping",
"kaiming",
"keepdim",
"logit",
"lognormal",
"loguniform",
"loguniformvalues",
Expand All @@ -73,6 +74,7 @@
"neox",
"nonlinerity",
"numel",
"onebit",
"openwebtext",
"optim",
"penality",
Expand Down Expand Up @@ -116,6 +118,7 @@
"venv",
"virtualenv",
"virtualenvs",
"wandb"
"wandb",
"zoadam"
]
}
515 changes: 320 additions & 195 deletions poetry.lock

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,20 @@
readme="README.md"
version="0.0.0"

# Note: Zstandard is required for downloading datasets such as The Pile
[tool.poetry.dependencies]
datasets=">=2.15.0"
deepspeed={version=">=0.12.6", extras=["deepspeed"], optional=false}
einops=">=0.6"
mpi4py={version=">=3.1.5", extras=["deepspeed"], optional=true}
pydantic=">=2.5.2"
python=">=3.10, <3.12"
strenum=">=0.4.15"
tokenizers=">=0.15.0"
torch=">=2.1.1"
transformers=">=4.35.2"
wandb=">=0.16.1"
zstandard=">=0.22.0" # Required for downloading datasets such as The Pile
zstandard=">=0.22.0"

[tool.poetry.group]
[tool.poetry.group.dev.dependencies]
Expand Down Expand Up @@ -54,6 +57,9 @@
pymdown-extensions=">=10.5"
pytkdocs-tweaks=">=0.0.7"

[tool.poetry.extras]
deepspeed=["deepspeed", "mpi4py"]

[tool.poetry.scripts]
join-sae-sweep='sparse_autoencoder.train.join_sweep:run'

Expand Down
2 changes: 2 additions & 0 deletions sparse_autoencoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sparse_autoencoder.metrics.train.capacity import CapacityMetric
from sparse_autoencoder.metrics.train.feature_density import TrainBatchFeatureDensityMetric
from sparse_autoencoder.optimizer.adam_with_reset import AdamWithReset
from sparse_autoencoder.optimizer.deepspeed_adam_with_reset import ZeroOneAdamWithReset
from sparse_autoencoder.source_data.pretokenized_dataset import PreTokenizedDataset
from sparse_autoencoder.source_data.text_dataset import TextDataset
from sparse_autoencoder.train.pipeline import Pipeline
Expand Down Expand Up @@ -83,4 +84,5 @@
"TensorActivationStore",
"TextDataset",
"TrainBatchFeatureDensityMetric",
"ZeroOneAdamWithReset",
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from dataclasses import dataclass
from typing import Annotated, NamedTuple

from deepspeed import DeepSpeedEngine
from einops import rearrange
from jaxtyping import Bool, Float, Int64
from pydantic import Field, NonNegativeInt, PositiveInt, validate_call
import torch
from torch import Tensor
from torch.nn.parallel import DataParallel
from torch.utils.data import DataLoader

from sparse_autoencoder.activation_resampler.utils.component_slice_tensor import (
Expand All @@ -17,7 +19,6 @@
from sparse_autoencoder.loss.abstract_loss import AbstractLoss
from sparse_autoencoder.tensor_types import Axis
from sparse_autoencoder.train.utils.get_model_device import get_model_device
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes


@dataclass
Expand Down Expand Up @@ -207,7 +208,7 @@ def _get_dead_neuron_indices(
def compute_loss_and_get_activations(
self,
store: ActivationStore,
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder] | DeepSpeedEngine,
loss_fn: AbstractLoss,
train_batch_size: int,
) -> LossInputActivationsTuple:
Expand Down Expand Up @@ -440,7 +441,7 @@ def renormalize_and_scale(
def resample_dead_neurons(
self,
activation_store: ActivationStore,
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder] | DeepSpeedEngine,
loss_fn: AbstractLoss,
train_batch_size: int,
) -> list[ParameterUpdateResults]:
Expand Down Expand Up @@ -530,7 +531,7 @@ def step_resampler(
self,
batch_neuron_activity: Int64[Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE)],
activation_store: ActivationStore,
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder] | DeepSpeedEngine,
loss_fn: AbstractLoss,
train_batch_size: int,
) -> list[ParameterUpdateResults] | None:
Expand Down
2 changes: 1 addition & 1 deletion sparse_autoencoder/autoencoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def load(
The loaded model.
"""
# Load the file
serialized_state = torch.load(file_path)
serialized_state = torch.load(file_path, map_location=torch.device("cpu"))
state = SparseAutoencoderState.model_validate(serialized_state)

# Initialise the model
Expand Down
2 changes: 1 addition & 1 deletion sparse_autoencoder/optimizer/adam_with_reset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__( # (extending existing implementation)
lr: float | Float[Tensor, Axis.names(Axis.SINGLE_ITEM)] = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
weight_decay: float = 0.0,
*,
amsgrad: bool = False,
foreach: bool | None = None,
Expand Down
194 changes: 194 additions & 0 deletions sparse_autoencoder/optimizer/deepspeed_adam_with_reset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""Deepspeed Zero One Adam Optimizer with a reset method.

This reset method is useful when resampling dead neurons during training.
"""
from collections.abc import Iterator
from typing import final

from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
from jaxtyping import Int
from torch import Tensor
from torch.nn.parameter import Parameter
from torch.optim.optimizer import params_t

from sparse_autoencoder.optimizer.abstract_optimizer import AbstractOptimizerWithReset
from sparse_autoencoder.tensor_types import Axis


@final
class ZeroOneAdamWithReset(ZeroOneAdam, AbstractOptimizerWithReset):
"""Deepspeed Zero One Adam Optimizer with a reset method.

https://deepspeed.readthedocs.io/en/latest/optimizers.html#zerooneadam-gpu

The :meth:`reset_state_all_parameters` and :meth:`reset_neurons_state` methods are useful when
manually editing the model parameters during training (e.g. when resampling dead neurons). This
is because Adam maintains running averages of the gradients and the squares of gradients, which
will be incorrect if the parameters are changed.

Otherwise this is the same as the standard ZeroOneAdam optimizer.

Warning:
Requires a distributed torch backend.
"""

parameter_names: list[str]
"""Parameter Names.

The names of the parameters, so that we can find them later when resetting the state.
"""

_has_components_dim: bool
"""Whether the parameters have a components dimension."""

def __init__(
self,
params: params_t,
lr: float = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
*,
named_parameters: Iterator[tuple[str, Parameter]],
has_components_dim: bool,
) -> None:
"""Initialize the optimizer.

Warning:
Named parameters must be with default settings (remove duplicates and not recursive).

Args:
params: Iterable of parameters to optimize or dicts defining parameter groups.
lr: Learning rate. A Tensor LR is not yet fully supported for all implementations. Use a
float LR unless specifying fused=True or capturable=True.
betas: Coefficients used for computing running averages of gradient and its square.
eps: Term added to the denominator to improve numerical stability.
weight_decay: Weight decay (L2 penalty).
named_parameters: An iterator over the named parameters of the model. This is used to
find the parameters when resetting their state. You should set this as
`model.named_parameters()`.
has_components_dim: If the parameters have a components dimension (i.e. if you are
training an SAE on more than one component).


Raises:
ValueError: If the number of parameter names does not match the number of parameters.
"""
# Initialise the parent class (note we repeat the parameter names so that type hints work).
super().__init__(
params=params,
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
)

self._has_components_dim = has_components_dim

# Store the names of the parameters, so that we can find them later when resetting the
# state.
self.parameter_names = [name for name, _value in named_parameters]

if len(self.parameter_names) != len(self.param_groups[0]["params"]):
error_message = (
"The number of parameter names does not match the number of parameters. "
"If using model.named_parameters() make sure remove_duplicates is True "
"and recursive is False (the default settings)."
)
raise ValueError(error_message)

def reset_state_all_parameters(self) -> None:
"""Reset the state for all parameters.

Iterates over all parameters and resets both the running averages of the gradients and the
squares of gradients.
"""
# Iterate over every parameter
for group in self.param_groups:
for parameter in group["params"]:
# Get the state
state = self.state[parameter]

# Check if state is initialized
if len(state) == 0:
continue

# Reset running averages
exp_avg: Tensor = state["exp_avg"]
exp_avg.zero_()
exp_avg_sq: Tensor = state["exp_avg_sq"]
exp_avg_sq.zero_()

# If AdamW is used (weight decay fix), also reset the max exp_avg_sq
if "max_exp_avg_sq" in state:
max_exp_avg_sq: Tensor = state["max_exp_avg_sq"]
max_exp_avg_sq.zero_()

def reset_neurons_state(
self,
parameter: Parameter,
neuron_indices: Int[Tensor, Axis.names(Axis.LEARNT_FEATURE_IDX)],
axis: int,
component_idx: int = 0,
) -> None:
"""Reset the state for specific neurons, on a specific parameter.

Args:
parameter: The parameter to be reset. Examples from the standard sparse autoencoder
implementation include `tied_bias`, `_encoder._weight`, `_encoder._bias`,
neuron_indices: The indices of the neurons to reset.
axis: The axis of the state values to reset (i.e. the input/output features axis, as
we're resetting all input/output features for a specific dead neuron).
component_idx: The component index of the state values to reset.

Raises:
ValueError: If the parameter has a components dimension, but has_components_dim is
False.
"""
# Get the state of the parameter
state = self.state[parameter]

# If the number of dimensions is 3, we definitely have a components dimension. If 2, we may
# do (as the bias has 2 dimensions with components, but the weight has 2 dimensions without
# components).
definitely_has_components_dimension = 3
if (
not self._has_components_dim
and state["exp_avg"].ndim == definitely_has_components_dimension
):
error_message = (
"The parameter has a components dimension, but has_components_dim is False. "
"This should not happen."
)
raise ValueError(error_message)

# Check if state is initialized
if len(state) == 0:
return

# Check there are any neurons to reset
if neuron_indices.numel() == 0:
return

# Move the neuron indices to the correct device
neuron_indices = neuron_indices.to(device=state["exp_avg"].device)

# Reset running averages for the specified neurons
if "exp_avg" in state:
if self._has_components_dim:
state["exp_avg"][component_idx].index_fill_(axis, neuron_indices, 0)
else:
state["exp_avg"].index_fill_(axis, neuron_indices, 0)

if "exp_avg_sq" in state:
if self._has_components_dim:
state["exp_avg_sq"][component_idx].index_fill_(axis, neuron_indices, 0)
else:
state["exp_avg_sq"].index_fill_(axis, neuron_indices, 0)

# If AdamW is used (weight decay fix), also reset the max exp_avg_sq
if "max_exp_avg_sq" in state:
if self._has_components_dim:
state["max_exp_avg_sq"][component_idx].index_fill_(axis, neuron_indices, 0)
else:
state["max_exp_avg_sq"].index_fill_(axis, neuron_indices, 0)
23 changes: 12 additions & 11 deletions sparse_autoencoder/source_model/replace_activations_hook.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,37 @@
"""Replace activations hook."""
from typing import TYPE_CHECKING

from deepspeed import DeepSpeedEngine
from jaxtyping import Float
from torch import Tensor
from torch.nn.parallel import DataParallel
from transformer_lens.hook_points import HookPoint

from sparse_autoencoder.autoencoder.model import SparseAutoencoder
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes


if TYPE_CHECKING:
from sparse_autoencoder.tensor_types import Axis
from jaxtyping import Float


def replace_activations_hook(
value: Tensor,
hook: HookPoint, # noqa: ARG001
sparse_autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
sparse_autoencoder: SparseAutoencoder | DataParallel[SparseAutoencoder] | DeepSpeedEngine,
component_idx: int | None = None,
n_components: int | None = None,
) -> Tensor:
"""Replace activations hook.

This should be pre-initialised with `functools.partial`.

Args:
value: The activations to replace.
hook: The hook point.
sparse_autoencoder: The sparse autoencoder. This should be pre-initialised with
`functools.partial`.
sparse_autoencoder: The sparse autoencoder.
component_idx: The component index to replace the activations with, if just replacing
activations for a single component. Requires the model to have a component axis.
n_components: The number of components that the SAE is trained on.

Returns:
Replaced activations.
Expand All @@ -43,11 +47,8 @@ def replace_activations_hook(
)

if component_idx is not None:
if sparse_autoencoder.config.n_components is None:
error_message = (
"Cannot replace for a specific component, if the model does not have a "
"component axis."
)
if n_components is None:
error_message = "The number of model components must be set if component_idx is set."
raise RuntimeError(error_message)

# The approach here is to run a forward pass with dummy values for all components other than
Expand All @@ -56,7 +57,7 @@ def replace_activations_hook(
# components.
expanded_shape = [
squashed_value.shape[0],
sparse_autoencoder.config.n_components,
n_components,
squashed_value.shape[-1],
]
expanded = squashed_value.unsqueeze(1).expand(*expanded_shape)
Expand Down
Loading