Skip to content

Commit

Permalink
Add get_dataloader method to source dataset (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney authored Nov 7, 2023
1 parent 13158f7 commit e3d5e86
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 10 deletions.
37 changes: 32 additions & 5 deletions sparse_autoencoder/source_data/abstract_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Any, Generic, TypedDict, TypeVar, final

from datasets import IterableDataset, load_dataset
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset


TokenizedPrompt = list[int]
Expand Down Expand Up @@ -55,9 +57,6 @@ class SourceDataset(ABC, Generic[HuggingFaceDatasetItem]):
Warning:
Hugging Face `Dataset` objects are confusingly not the same as PyTorch `Dataset` objects.
They can still be wrapped with the PyTorch `DataLoader` class, as the Hugging Face class
extends the PyTorch class, but this may lead to performance issues and it's generally best
to use the Hugging Face `IterableDataset` directly.
"""

@abstractmethod
Expand Down Expand Up @@ -125,7 +124,7 @@ def __init__(

# Setup preprocessing
existing_columns: list[str] = list(next(iter(dataset)).keys())
self.dataset = dataset.map(
mapped_dataset = dataset.map(
self.preprocess,
batched=True,
batch_size=preprocess_batch_size,
Expand All @@ -136,7 +135,7 @@ def __init__(
# Setup approximate shuffling. As the dataset is streamed, this just pre-downloads at least
# `buffer_size` items and then shuffles just that buffer.
# https://huggingface.co/docs/datasets/v2.14.5/stream#shuffle
self.dataset.shuffle(buffer_size=buffer_size)
self.dataset = mapped_dataset.shuffle(buffer_size=buffer_size)

@final
def __iter__(self) -> Any: # noqa: ANN401
Expand All @@ -145,3 +144,31 @@ def __iter__(self) -> Any: # noqa: ANN401
Enables direct access to :attr:`dataset` with e.g. `for` loops.
"""
return self.dataset.__iter__()

@final
def __next__(self) -> Any: # noqa: ANN401
"""Next Dunder Method.
Enables direct access to :attr:`dataset` with e.g. `next` calls.
"""
return next(iter(self))

@final
def get_dataloader(self, batch_size: int) -> DataLoader:
"""Get a PyTorch DataLoader.
Args:
batch_size: The batch size to use.
Returns:
PyTorch DataLoader.
"""
torch_dataset: TorchDataset = self.dataset.with_format("torch") # type: ignore

return DataLoader(
torch_dataset,
batch_size=batch_size,
# Shuffle is most efficiently done with the `shuffle` method on the dataset itself, not
# here.
shuffle=False,
)
18 changes: 15 additions & 3 deletions sparse_autoencoder/source_data/random_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import random
from typing import TypedDict, final

from datasets import IterableDataset
from torch.utils.data import DataLoader, Dataset
from transformers import PreTrainedTokenizerFast

from sparse_autoencoder.source_data.abstract_dataset import (
Expand All @@ -20,7 +20,7 @@ class RandomIntSourceData(TypedDict):
input_ids: list[list[int]]


class RandomIntHuggingFaceDataset(IterableDataset):
class RandomIntHuggingFaceDataset(Dataset):
"""Dummy Hugging Face Dataset."""

def __init__(self, vocab_size: int, context_size: int):
Expand All @@ -42,6 +42,14 @@ def __next__(self) -> dict[str, list[int]]:
data = [random.randint(0, self.vocab_size) for _ in range(self.context_size)] # noqa: S311
return {"input_ids": data}

def __len__(self) -> int:
"""Len Dunder Method."""
return 1000

def __getitem__(self, index: int) -> dict[str, list[int]]:
"""Get Item."""
return self.__next__()


@final
class RandomIntDummyDataset(SourceDataset[RandomIntSourceData]):
Expand Down Expand Up @@ -97,4 +105,8 @@ def __init__(
dataset_path: The path to the dataset on Hugging Face.
dataset_split: Dataset split (e.g. `train`).
"""
self.dataset = RandomIntHuggingFaceDataset(50000, context_size=context_size)
self.dataset = RandomIntHuggingFaceDataset(50000, context_size=context_size) # type: ignore

def get_dataloader(self, batch_size: int) -> DataLoader: # type: ignore
"""Get Dataloader."""
return DataLoader(self.dataset, batch_size=batch_size) # type: ignore
17 changes: 15 additions & 2 deletions sparse_autoencoder/source_data/tests/test_abstract_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from datasets import IterableDataset, load_dataset
import pytest
import torch

from sparse_autoencoder.source_data.abstract_dataset import (
PreprocessTokenizedPrompts,
Expand Down Expand Up @@ -31,8 +32,10 @@ def preprocess(
context_size: int, # noqa: ARG002
) -> PreprocessTokenizedPrompts:
"""Preprocess a batch of prompts."""
# Assuming a very simple preprocess that just tokenizes the 'text' field
tokenized_texts = [[1, 5, 23, 2], [1, 2, 61, 12]]
preprocess_batch = 100
tokenized_texts = torch.randint(
low=0, high=50000, size=(preprocess_batch, TEST_CONTEXT_SIZE)
).tolist()
return {"input_ids": tokenized_texts}

def __init__(
Expand Down Expand Up @@ -87,3 +90,13 @@ def test_extended_dataset_iterator(mock_hugging_face_load_dataset: pytest.Functi

first_item = next(iterator)
assert len(first_item["input_ids"]) == TEST_CONTEXT_SIZE


def test_get_dataloader(mock_hugging_face_load_dataset: pytest.Function) -> None:
"""Test the get_dataloader method of the extended dataset."""
data = MockSourceDataset()
batch_size = 3
dataloader = data.get_dataloader(batch_size=batch_size)
first_item = next(iter(dataloader))["input_ids"]
assert first_item.shape[0] == batch_size
assert first_item.shape[-1] == TEST_CONTEXT_SIZE

0 comments on commit e3d5e86

Please sign in to comment.