Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pydantic validation #167

Merged
merged 1 commit into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions sparse_autoencoder/activation_store/base_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import final

from jaxtyping import Float
from pydantic import PositiveInt, validate_call
import torch
from torch import Tensor
from torch.utils.data import Dataset
Expand Down Expand Up @@ -105,12 +106,13 @@ def shuffle(self) -> None:
"""Optional shuffle method."""

@final
@validate_call
def fill_with_test_data(
self,
n_batches: int = 1,
batch_size: int = 16,
n_components: int = 1,
input_features: int = 256,
n_batches: PositiveInt = 1,
batch_size: PositiveInt = 16,
n_components: PositiveInt = 1,
input_features: PositiveInt = 256,
) -> None:
"""Fill the store with test data.

Expand Down
8 changes: 5 additions & 3 deletions sparse_autoencoder/activation_store/disk_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tempfile

from jaxtyping import Float
from pydantic import PositiveInt, validate_call
import torch
from torch import Tensor

Expand Down Expand Up @@ -70,12 +71,13 @@ def current_activations_stored_per_component(self) -> list[int]:
disk_items_stored = len(self)
return [cache_items + disk_items_stored for cache_items in self._items_stored]

@validate_call
def __init__(
self,
n_neurons: int,
n_neurons: PositiveInt,
storage_path: Path = DEFAULT_DISK_ACTIVATION_STORE_PATH,
max_cache_size: int = 10_000,
n_components: int = 1,
max_cache_size: PositiveInt = 10_000,
n_components: PositiveInt = 1,
*,
empty_dir: bool = False,
):
Expand Down
8 changes: 5 additions & 3 deletions sparse_autoencoder/activation_store/tensor_store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tensor Activation Store."""
from jaxtyping import Float
from pydantic import PositiveInt, validate_call
import torch
from torch import Tensor

Expand Down Expand Up @@ -75,11 +76,12 @@ def current_activations_stored_per_component(self) -> list[int]:
"""Number of activations stored per component."""
return self._items_stored

@validate_call(config={"arbitrary_types_allowed": True})
def __init__(
self,
max_items: int,
n_neurons: int,
n_components: int = 1,
max_items: PositiveInt,
n_neurons: PositiveInt,
n_components: PositiveInt = 1,
device: torch.device | None = None,
) -> None:
"""Initialise the Tensor Activation Store.
Expand Down
8 changes: 5 additions & 3 deletions sparse_autoencoder/autoencoder/components/abstract_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import final

from jaxtyping import Float, Int64
from pydantic import PositiveInt, validate_call
import torch
from torch import Tensor
from torch.nn import Module, Parameter
Expand All @@ -24,11 +25,12 @@ class AbstractDecoder(Module, ABC):

_n_components: int | None

@validate_call
def __init__(
self,
learnt_features: int,
decoded_features: int,
n_components: int | None,
learnt_features: PositiveInt,
decoded_features: PositiveInt,
n_components: PositiveInt | None,
) -> None:
"""Initialise the decoder.

Expand Down
8 changes: 5 additions & 3 deletions sparse_autoencoder/autoencoder/components/abstract_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import final

from jaxtyping import Float, Int64
from pydantic import PositiveInt, validate_call
import torch
from torch import Tensor
from torch.nn import Module, Parameter
Expand All @@ -25,11 +26,12 @@ class AbstractEncoder(Module, ABC):

_n_components: int | None

@validate_call
def __init__(
self,
input_features: int,
learnt_features: int,
n_components: int | None,
input_features: PositiveInt,
learnt_features: PositiveInt,
n_components: PositiveInt | None,
) -> None:
"""Initialise the encoder.

Expand Down
8 changes: 5 additions & 3 deletions sparse_autoencoder/autoencoder/components/linear_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import einops
from jaxtyping import Float
from pydantic import PositiveInt, validate_call
import torch
from torch import Tensor
from torch.nn import Parameter, ReLU, init
Expand Down Expand Up @@ -77,11 +78,12 @@ def reset_optimizer_parameter_details(self) -> list[tuple[Parameter, int]]:
activation_function: ReLU
"""Activation function."""

@validate_call
def __init__(
self,
input_features: int,
learnt_features: int,
n_components: int | None,
input_features: PositiveInt,
learnt_features: PositiveInt,
n_components: PositiveInt | None,
):
"""Initialize the linear encoder layer.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import einops
from jaxtyping import Float
from pydantic import PositiveInt, validate_call
import torch
from torch import Tensor
from torch.nn import Parameter, init
Expand Down Expand Up @@ -76,11 +77,12 @@ def reset_optimizer_parameter_details(self) -> list[tuple[Parameter, int]]:
"""
return [(self.weight, -1)]

@validate_call
def __init__(
self,
learnt_features: int,
decoded_features: int,
n_components: int | None,
learnt_features: PositiveInt,
decoded_features: PositiveInt,
n_components: PositiveInt | None,
*,
enable_gradient_hook: bool = True,
) -> None:
Expand Down
8 changes: 5 additions & 3 deletions sparse_autoencoder/autoencoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import final

from jaxtyping import Float
from pydantic import PositiveInt, validate_call
import torch
from torch import Tensor
from torch.nn.parameter import Parameter
Expand Down Expand Up @@ -72,15 +73,16 @@ def post_decoder_bias(self) -> TiedBias:
"""Post-decoder bias."""
return self._post_decoder_bias

@validate_call(config={"arbitrary_types_allowed": True})
def __init__(
self,
n_input_features: int,
n_learned_features: int,
n_input_features: PositiveInt,
n_learned_features: PositiveInt,
geometric_median_dataset: Float[
Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
]
| None = None,
n_components: int | None = None,
n_components: PositiveInt | None = None,
) -> None:
"""Initialize the Sparse Autoencoder Model.

Expand Down
4 changes: 3 additions & 1 deletion sparse_autoencoder/loss/learned_activations_l1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import final

from jaxtyping import Float
from pydantic import PositiveFloat, validate_call
import torch
from torch import Tensor

Expand Down Expand Up @@ -37,8 +38,9 @@ def log_name(self) -> str:
"""
return "learned_activations_l1_loss_penalty"

@validate_call(config={"arbitrary_types_allowed": True})
def __init__(
self, l1_coefficient: float | Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL)]
self, l1_coefficient: PositiveFloat | Float[Tensor, Axis.names(Axis.COMPONENT_OPTIONAL)]
) -> None:
"""Initialize the absolute error loss.

Expand Down
4 changes: 3 additions & 1 deletion sparse_autoencoder/metrics/train/feature_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from jaxtyping import Float
import numpy as np
from numpy import histogram
from pydantic import NonNegativeFloat, validate_call
import torch
from torch import Tensor
import wandb
Expand Down Expand Up @@ -32,9 +33,10 @@ class TrainBatchFeatureDensityMetric(AbstractTrainMetric):

threshold: float

@validate_call
def __init__(
self,
threshold: float = 0.0,
threshold: NonNegativeFloat = 0.0,
) -> None:
"""Initialise the train batch feature density metric.

Expand Down
10 changes: 6 additions & 4 deletions sparse_autoencoder/source_data/abstract_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from datasets import Dataset, IterableDataset, load_dataset
from jaxtyping import Int
from pydantic import PositiveInt, validate_call
from torch import Tensor
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset
Expand Down Expand Up @@ -106,17 +107,18 @@ def preprocess(
"""

@abstractmethod
@validate_call
def __init__(
self,
dataset_path: str,
dataset_split: str,
context_size: int,
buffer_size: int = 1000,
context_size: PositiveInt,
buffer_size: PositiveInt = 1000,
dataset_dir: str | None = None,
dataset_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None,
dataset_column_name: str = "input_ids",
n_processes_preprocessing: int | None = None,
preprocess_batch_size: int = 1000,
n_processes_preprocessing: PositiveInt | None = None,
preprocess_batch_size: PositiveInt = 1000,
*,
pre_download: bool = False,
):
Expand Down
8 changes: 5 additions & 3 deletions sparse_autoencoder/source_data/mock_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from datasets import IterableDataset
from jaxtyping import Int
from pydantic import PositiveInt, validate_call
import torch
from torch import Tensor
from transformers import PreTrainedTokenizerFast
Expand Down Expand Up @@ -139,11 +140,12 @@ def preprocess(
# Nothing to do here
return source_batch

@validate_call
def __init__(
self,
context_size: int = 250,
buffer_size: int = 1000, # noqa: ARG002
preprocess_batch_size: int = 1000, # noqa: ARG002
context_size: PositiveInt = 250,
buffer_size: PositiveInt = 1000, # noqa: ARG002
preprocess_batch_size: PositiveInt = 1000, # noqa: ARG002
dataset_path: str = "dummy", # noqa: ARG002
dataset_split: str = "train", # noqa: ARG002
):
Expand Down
9 changes: 6 additions & 3 deletions sparse_autoencoder/source_data/pretokenized_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from collections.abc import Mapping, Sequence
from typing import final

from pydantic import PositiveInt, validate_call

from sparse_autoencoder.source_data.abstract_dataset import SourceDataset, TokenizedPrompts


Expand Down Expand Up @@ -67,16 +69,17 @@ def preprocess(

return {"input_ids": context_size_prompts}

@validate_call
def __init__(
self,
dataset_path: str,
context_size: int = 256,
buffer_size: int = 1000,
context_size: PositiveInt = 256,
buffer_size: PositiveInt = 1000,
dataset_dir: str | None = None,
dataset_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None,
dataset_split: str = "train",
dataset_column_name: str = "input_ids",
preprocess_batch_size: int = 1000,
preprocess_batch_size: PositiveInt = 1000,
*,
pre_download: bool = False,
):
Expand Down
13 changes: 8 additions & 5 deletions sparse_autoencoder/source_data/text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import TypedDict, final

from datasets import IterableDataset
from pydantic import PositiveInt, validate_call
from transformers import PreTrainedTokenizerBase

from sparse_autoencoder.source_data.abstract_dataset import SourceDataset, TokenizedPrompts
Expand Down Expand Up @@ -63,18 +64,19 @@ def preprocess(

return {"input_ids": context_size_prompts}

@validate_call(config={"arbitrary_types_allowed": True})
def __init__(
self,
dataset_path: str,
tokenizer: PreTrainedTokenizerBase,
buffer_size: int = 1000,
context_size: int = 256,
buffer_size: PositiveInt = 1000,
context_size: PositiveInt = 256,
dataset_dir: str | None = None,
dataset_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None,
dataset_split: str = "train",
dataset_column_name: str = "input_ids",
n_processes_preprocessing: int | None = None,
preprocess_batch_size: int = 1000,
n_processes_preprocessing: PositiveInt | None = None,
preprocess_batch_size: PositiveInt = 1000,
*,
pre_download: bool = False,
):
Expand Down Expand Up @@ -116,12 +118,13 @@ def __init__(
preprocess_batch_size=preprocess_batch_size,
)

@validate_call
def push_to_hugging_face_hub(
self,
repo_id: str,
commit_message: str = "Upload preprocessed dataset using sparse_autoencoder.",
max_shard_size: str | None = None,
n_shards: int = 64,
n_shards: PositiveInt = 64,
revision: str = "main",
*,
private: bool = False,
Expand Down
Loading