Skip to content

Commit

Permalink
Default to average metric aggregation across components (#180)
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney authored Jan 9, 2024
1 parent 6b9c2a5 commit e62c470
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 13 deletions.
7 changes: 4 additions & 3 deletions sparse_autoencoder/metrics/abstract_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from jaxtyping import Float, Int
import numpy as np
from strenum import LowercaseStrEnum, SnakeCaseStrEnum
import torch
from torch import Tensor
from wandb import data_types

Expand Down Expand Up @@ -99,7 +100,7 @@ def __init__(
| Int[Tensor, Axis.names(Axis.COMPONENT)],
name: str,
location: MetricLocation,
aggregate_approach: ComponentAggregationApproach | None = ComponentAggregationApproach.ALL,
aggregate_approach: ComponentAggregationApproach | None = ComponentAggregationApproach.MEAN,
aggregate_value: Any | None = None, # noqa: ANN401
postfix: str | None = None,
) -> None:
Expand Down Expand Up @@ -195,9 +196,9 @@ def aggregate_value( # noqa: PLR0911
):
match self.aggregate_approach:
case ComponentAggregationApproach.MEAN:
return self.component_wise_values.mean(dim=0)
return self.component_wise_values.mean(dim=0, dtype=torch.float32)
case ComponentAggregationApproach.SUM:
return self.component_wise_values.sum(dim=0)
return self.component_wise_values.sum(dim=0, dtype=torch.float32)
case ComponentAggregationApproach.ALL:
return self.component_wise_values
case _:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
'component_3/train/learned_activations_l0_norm': tensor(8.),
'component_4/train/learned_activations_l0_norm': tensor(8.),
'component_5/train/learned_activations_l0_norm': tensor(8.),
'train/learned_activations_l0_norm': tensor([8., 8., 8., 8., 8., 8.]),
'train/learned_activations_l0_norm/component_mean': tensor(8.),
}),
])
# ---
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
'component_3/train/learned_neuron_activity/dead_over_10_activations': tensor(0),
'component_4/train/learned_neuron_activity/dead_over_10_activations': tensor(0),
'component_5/train/learned_neuron_activity/dead_over_10_activations': tensor(0),
'train/learned_neuron_activity/dead_over_10_activations': tensor([0, 0, 0, 0, 0, 0]),
'train/learned_neuron_activity/dead_over_10_activations/component_mean': tensor(0.),
}),
dict({
'component_0/train/learned_neuron_activity/alive_over_10_activations': tensor(8),
Expand All @@ -17,7 +17,7 @@
'component_3/train/learned_neuron_activity/alive_over_10_activations': tensor(8),
'component_4/train/learned_neuron_activity/alive_over_10_activations': tensor(8),
'component_5/train/learned_neuron_activity/alive_over_10_activations': tensor(8),
'train/learned_neuron_activity/alive_over_10_activations': tensor([8, 8, 8, 8, 8, 8]),
'train/learned_neuron_activity/alive_over_10_activations/component_mean': tensor(8.),
}),
dict({
'component_0/train/learned_neuron_activity/activity_histogram_over_10_activations': Histogram(
Expand Down Expand Up @@ -366,7 +366,7 @@
'component_3/train/learned_neuron_activity/almost_dead_1.0e-05_over_10_activations': tensor(0),
'component_4/train/learned_neuron_activity/almost_dead_1.0e-05_over_10_activations': tensor(0),
'component_5/train/learned_neuron_activity/almost_dead_1.0e-05_over_10_activations': tensor(0),
'train/learned_neuron_activity/almost_dead_1.0e-05_over_10_activations': tensor([0, 0, 0, 0, 0, 0]),
'train/learned_neuron_activity/almost_dead_1.0e-05_over_10_activations/component_mean': tensor(0.),
}),
dict({
'component_0/train/learned_neuron_activity/almost_dead_1.0e-06_over_10_activations': tensor(0),
Expand All @@ -375,7 +375,7 @@
'component_3/train/learned_neuron_activity/almost_dead_1.0e-06_over_10_activations': tensor(0),
'component_4/train/learned_neuron_activity/almost_dead_1.0e-06_over_10_activations': tensor(0),
'component_5/train/learned_neuron_activity/almost_dead_1.0e-06_over_10_activations': tensor(0),
'train/learned_neuron_activity/almost_dead_1.0e-06_over_10_activations': tensor([0, 0, 0, 0, 0, 0]),
'train/learned_neuron_activity/almost_dead_1.0e-06_over_10_activations/component_mean': tensor(0.),
}),
])
# ---
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
'component_3/validate/reconstruction_score/baseline_loss': tensor(0.4598),
'component_4/validate/reconstruction_score/baseline_loss': tensor(0.4281),
'component_5/validate/reconstruction_score/baseline_loss': tensor(0.4961),
'validate/reconstruction_score/baseline_loss': tensor([0.3800, 0.5251, 0.4923, 0.4598, 0.4281, 0.4961]),
'validate/reconstruction_score/baseline_loss/component_mean': tensor(0.4636),
}),
dict({
'component_0/validate/reconstruction_score/loss_with_reconstruction': tensor(0.6111),
Expand All @@ -17,7 +17,7 @@
'component_3/validate/reconstruction_score/loss_with_reconstruction': tensor(0.6497),
'component_4/validate/reconstruction_score/loss_with_reconstruction': tensor(0.4929),
'component_5/validate/reconstruction_score/loss_with_reconstruction': tensor(0.3723),
'validate/reconstruction_score/loss_with_reconstruction': tensor([0.6111, 0.5219, 0.4063, 0.6497, 0.4929, 0.3723]),
'validate/reconstruction_score/loss_with_reconstruction/component_mean': tensor(0.5090),
}),
dict({
'component_0/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.2891),
Expand All @@ -26,7 +26,7 @@
'component_3/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.4740),
'component_4/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.5452),
'component_5/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.3733),
'validate/reconstruction_score/loss_with_zero_ablation': tensor([0.2891, 0.3879, 0.5850, 0.4740, 0.5452, 0.3733]),
'validate/reconstruction_score/loss_with_zero_ablation/component_mean': tensor(0.4424),
}),
dict({
'component_0/validate/reconstruction_score': tensor(3.5422),
Expand All @@ -35,8 +35,7 @@
'component_3/validate/reconstruction_score': tensor(-12.3338),
'component_4/validate/reconstruction_score': tensor(0.4468),
'component_5/validate/reconstruction_score': tensor(-0.0081),
'validate/reconstruction_score': tensor([ 3.5422e+00, 9.7672e-01, 1.9278e+00, -1.2334e+01, 4.4681e-01,
-8.1113e-03]),
'validate/reconstruction_score/component_mean': tensor(-0.9081),
}),
])
# ---
6 changes: 6 additions & 0 deletions sparse_autoencoder/train/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,12 @@ def calculate(self, data: ValidationMetricData) -> list[MetricResult]:
dummy_metric.data.source_model_loss_with_zero_ablation is not None
), "Source model loss with zero ablation should be calculated."

# Check the dimensions are correct
ndim_with_component = 2
assert dummy_metric.data.source_model_loss.ndim == ndim_with_component
assert dummy_metric.data.source_model_loss_with_reconstruction.ndim == ndim_with_component
assert dummy_metric.data.source_model_loss_with_zero_ablation.ndim == ndim_with_component


class TestSaveCheckpoint:
"""Test the save_checkpoint method."""
Expand Down

0 comments on commit e62c470

Please sign in to comment.