Skip to content

Commit

Permalink
Add support for multi-worker data loading
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney committed Jan 6, 2024
1 parent ba9d136 commit 4d727d3
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 3 deletions.
8 changes: 6 additions & 2 deletions sparse_autoencoder/source_data/abstract_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from datasets import Dataset, IterableDataset, VerificationMode, load_dataset
from jaxtyping import Int
from pydantic import PositiveInt, validate_call
from pydantic import NonNegativeInt, PositiveInt, validate_call
from torch import Tensor
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset
Expand Down Expand Up @@ -217,11 +217,14 @@ def __iter__(self) -> Any: # noqa: ANN401
return self.dataset.__iter__()

@final
def get_dataloader(self, batch_size: int) -> DataLoader[TorchTokenizedPrompts]:
def get_dataloader(
self, batch_size: int, num_workers: NonNegativeInt = 0
) -> DataLoader[TorchTokenizedPrompts]:
"""Get a PyTorch DataLoader.
Args:
batch_size: The batch size to use.
num_workers: Number of CPU workers.
Returns:
PyTorch DataLoader.
Expand All @@ -234,4 +237,5 @@ def get_dataloader(self, batch_size: int) -> DataLoader[TorchTokenizedPrompts]:
# Shuffle is most efficiently done with the `shuffle` method on the dataset itself, not
# here.
shuffle=False,
num_workers=num_workers,
)
6 changes: 5 additions & 1 deletion sparse_autoencoder/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(
checkpoint_directory: Path = DEFAULT_CHECKPOINT_DIRECTORY,
log_frequency: PositiveInt = 100,
metrics: MetricsContainer = default_metrics,
num_workers_data_loading: NonNegativeInt = 0,
source_data_batch_size: PositiveInt = 12,
) -> None:
"""Initialize the pipeline.
Expand All @@ -125,6 +126,7 @@ def __init__(
checkpoint_directory: Directory to save checkpoints to.
log_frequency: Frequency at which to log metrics (in steps)
metrics: Metrics to use.
num_workers_data_loading: Number of CPU workers for the dataloader.
source_data_batch_size: Batch size for the source data.
"""
self.activation_resampler = activation_resampler
Expand All @@ -142,7 +144,9 @@ def __init__(
self.source_model = source_model

# Create a stateful iterator
source_dataloader = source_dataset.get_dataloader(source_data_batch_size)
source_dataloader = source_dataset.get_dataloader(
source_data_batch_size, num_workers=num_workers_data_loading
)
self.source_data = iter(source_dataloader)

@validate_call
Expand Down
1 change: 1 addition & 0 deletions sparse_autoencoder/train/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def run_training_pipeline(
source_model=source_model,
log_frequency=hyperparameters["pipeline"]["log_frequency"],
run_name=run_name,
num_workers_data_loading=hyperparameters["pipeline"]["num_workers_data_loading"],
)

pipeline.run_pipeline(
Expand Down
4 changes: 4 additions & 0 deletions sparse_autoencoder/train/sweep_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,9 @@ class PipelineHyperparameters(NestedParameter):
)
"""Max activations."""

num_workers_data_loading: Parameter[int] = field(default=Parameter(0))
"""Number of CPU workers for data loading."""

checkpoint_frequency: Parameter[int] = field(
default=Parameter(round_to_multiple(5e7, DEFAULT_STORE_SIZE))
)
Expand All @@ -294,6 +297,7 @@ class PipelineRuntimeHyperparameters(TypedDict):
train_batch_size: int
max_store_size: int
max_activations: int
num_workers_data_loading: int
checkpoint_frequency: int
validation_frequency: int
validation_n_activations: int
Expand Down

0 comments on commit 4d727d3

Please sign in to comment.