-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dead features metric #92
Add dead features metric #92
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this!
I've added some comments to help frame it - I think we need to think a bit about where to log & how much of the processing wandb can do. Happy to chat on a call if easier.
class HistogramMetric(AbstractTrainMetric): | ||
"""Histogram metric — log the histogram, as well as the number of dead features.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs a descriptive name - e.g. BatchDeadNeuronsMetric
@final | ||
def calculate(self, data: TrainMetricData) -> OrderedDict[str, Any]: | ||
"""Create a log item for Weights and Biases.""" | ||
feature_mean_acts = data.learned_activations.mean(dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be sum? What I think you're getting at is "how many neurons didn't fire at all in this batch.
def calculate(self, data: TrainMetricData) -> OrderedDict[str, Any]: | ||
"""Create a log item for Weights and Biases.""" | ||
feature_mean_acts = data.learned_activations.mean(dim=0) | ||
num_dead_features = (feature_mean_acts <= self.dead_threshold).sum().item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is meaningless with such a small sample size (one batch). However if you just log a table of all neurons (and the indices of these) it may be possible to then do the analysis on wandb? Might be too much data to send.
Alternatively we may need to make this a "resample neurons" metric and log it there rather than for every train batch. Have you experimented with logging this on wandb to see what it looks like?
"""L0 (sparsity) norm metric.""" | ||
from collections import OrderedDict | ||
from typing import final | ||
|
||
import torch | ||
|
||
from sparse_autoencoder.metrics.train.abstract_train_metric import ( | ||
AbstractTrainMetric, | ||
TrainMetricData, | ||
) | ||
|
||
|
||
@final | ||
class L0NormMetric(AbstractTrainMetric): | ||
"""L0 (sparsity) norm metric.""" | ||
|
||
@final | ||
def calculate(self, data: TrainMetricData) -> OrderedDict[str, float]: | ||
"""Create a log item for Weights and Biases.""" | ||
# The L0 norm is the number of non-zero elements | ||
# (We're averaging over the batch) | ||
acts = data.learned_activations | ||
value = (torch.sum(acts != 0) / acts.size(0)).item() | ||
return OrderedDict(l0_norm=value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Has been added in from the other PR - should be deleted here
"""Metric reducer.""" | ||
from collections import OrderedDict | ||
from collections.abc import Iterator | ||
from typing import final | ||
|
||
from sparse_autoencoder.metrics.train.abstract_train_metric import ( | ||
AbstractTrainMetric, | ||
TrainMetricData, | ||
) | ||
|
||
|
||
@final | ||
class TrainingMetricReducer(AbstractTrainMetric): | ||
"""Training metric reducer. | ||
|
||
Reduces multiple training metrics into a single training metric (by merging | ||
their OrderedDicts). | ||
""" | ||
|
||
_modules: list["AbstractTrainMetric"] | ||
"""Children training metric modules.""" | ||
|
||
@final | ||
def __init__( | ||
self, | ||
*metric_modules: AbstractTrainMetric, | ||
): | ||
"""Initialize the training metric reducer. | ||
|
||
Args: | ||
metric_modules: Training metric modules to reduce. | ||
|
||
Raises: | ||
ValueError: If the training metric reducer has no training metric modules. | ||
""" | ||
super().__init__() | ||
|
||
self._modules = list(metric_modules) | ||
|
||
if len(self) == 0: | ||
error_message = "Training metric reducer must have at least one training metric module." | ||
raise ValueError(error_message) | ||
|
||
@final | ||
def calculate(self, data: TrainMetricData) -> OrderedDict[str, float]: | ||
"""Create a log item for Weights and Biases.""" | ||
result = OrderedDict() | ||
for module in self._modules: | ||
result.update(module.calculate(data)) | ||
return result | ||
|
||
def __dir__(self) -> list[str]: | ||
"""Dir dunder method.""" | ||
return list(self._modules.__dir__()) | ||
|
||
def __getitem__(self, idx: int) -> AbstractTrainMetric: | ||
"""Get item dunder method.""" | ||
return self._modules[idx] | ||
|
||
def __iter__(self) -> Iterator[AbstractTrainMetric]: | ||
"""Iterator dunder method.""" | ||
return iter(self._modules) | ||
|
||
def __len__(self) -> int: | ||
"""Length dunder method.""" | ||
return len(self._modules) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if we need this. Currently the pipeline just goes through and logs any metrics it's given.
We could abstract it but then maybe don't need all these dunder methods (and as a small point it's a mapping function then, not a reducer).
"""Tests for the L0NormMetric class.""" | ||
import pytest | ||
import torch | ||
|
||
from sparse_autoencoder.metrics.l0_norm_metric import L0NormMetric | ||
from sparse_autoencoder.metrics.train.abstract_train_metric import TrainMetricData | ||
|
||
|
||
@pytest.fixture() | ||
def l0_norm_metric() -> L0NormMetric: | ||
"""Fixture for L0NormMetric.""" | ||
return L0NormMetric() | ||
|
||
|
||
def test_l0_norm_metric(l0_norm_metric: L0NormMetric) -> None: | ||
"""Test the L0NormMetric.""" | ||
learned_activations = torch.tensor([[1.0, 0.0, 0.0], [0.0, 0.01, 2.0]]) | ||
data = TrainMetricData( | ||
input_activations=torch.zeros_like(learned_activations), | ||
learned_activations=learned_activations, | ||
decoded_activations=torch.zeros_like(learned_activations), | ||
) | ||
log = l0_norm_metric.calculate(data) | ||
expected = 2 / 3 | ||
assert log["l0_norm"] == expected |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't be here
FYI I'm adding resample metrics here - #98 - so will pull that part out. I think a metric per batch is still useful but it's different (e.g. maybe we just call it TrainingBatchNeuronActivity ?) |
Adds histogram and num_dead_features logging to wandb, closes #88