Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dead features metric #92

Closed

Conversation

lucyfarnik
Copy link
Contributor

Adds histogram and num_dead_features logging to wandb, closes #88

@alan-cooney alan-cooney changed the title 88 Histogram Metric Add dead features metric Nov 22, 2023
Copy link
Collaborator

@alan-cooney alan-cooney left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

I've added some comments to help frame it - I think we need to think a bit about where to log & how much of the processing wandb can do. Happy to chat on a call if easier.

Comment on lines +14 to +15
class HistogramMetric(AbstractTrainMetric):
"""Histogram metric — log the histogram, as well as the number of dead features."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs a descriptive name - e.g. BatchDeadNeuronsMetric

@final
def calculate(self, data: TrainMetricData) -> OrderedDict[str, Any]:
"""Create a log item for Weights and Biases."""
feature_mean_acts = data.learned_activations.mean(dim=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be sum? What I think you're getting at is "how many neurons didn't fire at all in this batch.

def calculate(self, data: TrainMetricData) -> OrderedDict[str, Any]:
"""Create a log item for Weights and Biases."""
feature_mean_acts = data.learned_activations.mean(dim=0)
num_dead_features = (feature_mean_acts <= self.dead_threshold).sum().item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is meaningless with such a small sample size (one batch). However if you just log a table of all neurons (and the indices of these) it may be possible to then do the analysis on wandb? Might be too much data to send.

Alternatively we may need to make this a "resample neurons" metric and log it there rather than for every train batch. Have you experimented with logging this on wandb to see what it looks like?

Comment on lines +1 to +24
"""L0 (sparsity) norm metric."""
from collections import OrderedDict
from typing import final

import torch

from sparse_autoencoder.metrics.train.abstract_train_metric import (
AbstractTrainMetric,
TrainMetricData,
)


@final
class L0NormMetric(AbstractTrainMetric):
"""L0 (sparsity) norm metric."""

@final
def calculate(self, data: TrainMetricData) -> OrderedDict[str, float]:
"""Create a log item for Weights and Biases."""
# The L0 norm is the number of non-zero elements
# (We're averaging over the batch)
acts = data.learned_activations
value = (torch.sum(acts != 0) / acts.size(0)).item()
return OrderedDict(l0_norm=value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has been added in from the other PR - should be deleted here

Comment on lines +1 to +66
"""Metric reducer."""
from collections import OrderedDict
from collections.abc import Iterator
from typing import final

from sparse_autoencoder.metrics.train.abstract_train_metric import (
AbstractTrainMetric,
TrainMetricData,
)


@final
class TrainingMetricReducer(AbstractTrainMetric):
"""Training metric reducer.

Reduces multiple training metrics into a single training metric (by merging
their OrderedDicts).
"""

_modules: list["AbstractTrainMetric"]
"""Children training metric modules."""

@final
def __init__(
self,
*metric_modules: AbstractTrainMetric,
):
"""Initialize the training metric reducer.

Args:
metric_modules: Training metric modules to reduce.

Raises:
ValueError: If the training metric reducer has no training metric modules.
"""
super().__init__()

self._modules = list(metric_modules)

if len(self) == 0:
error_message = "Training metric reducer must have at least one training metric module."
raise ValueError(error_message)

@final
def calculate(self, data: TrainMetricData) -> OrderedDict[str, float]:
"""Create a log item for Weights and Biases."""
result = OrderedDict()
for module in self._modules:
result.update(module.calculate(data))
return result

def __dir__(self) -> list[str]:
"""Dir dunder method."""
return list(self._modules.__dir__())

def __getitem__(self, idx: int) -> AbstractTrainMetric:
"""Get item dunder method."""
return self._modules[idx]

def __iter__(self) -> Iterator[AbstractTrainMetric]:
"""Iterator dunder method."""
return iter(self._modules)

def __len__(self) -> int:
"""Length dunder method."""
return len(self._modules)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if we need this. Currently the pipeline just goes through and logs any metrics it's given.

We could abstract it but then maybe don't need all these dunder methods (and as a small point it's a mapping function then, not a reducer).

Comment on lines +1 to +25
"""Tests for the L0NormMetric class."""
import pytest
import torch

from sparse_autoencoder.metrics.l0_norm_metric import L0NormMetric
from sparse_autoencoder.metrics.train.abstract_train_metric import TrainMetricData


@pytest.fixture()
def l0_norm_metric() -> L0NormMetric:
"""Fixture for L0NormMetric."""
return L0NormMetric()


def test_l0_norm_metric(l0_norm_metric: L0NormMetric) -> None:
"""Test the L0NormMetric."""
learned_activations = torch.tensor([[1.0, 0.0, 0.0], [0.0, 0.01, 2.0]])
data = TrainMetricData(
input_activations=torch.zeros_like(learned_activations),
learned_activations=learned_activations,
decoded_activations=torch.zeros_like(learned_activations),
)
log = l0_norm_metric.calculate(data)
expected = 2 / 3
assert log["l0_norm"] == expected
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't be here

@alan-cooney
Copy link
Collaborator

FYI I'm adding resample metrics here - #98 - so will pull that part out. I think a metric per batch is still useful but it's different (e.g. maybe we just call it TrainingBatchNeuronActivity ?)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Log feature density histograms to weights and biases
2 participants