From ff09b00959e4fe785c45da5b4da32d1b21851388 Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 6 Dec 2024 11:44:39 +0200 Subject: [PATCH] Added example in the tutorial page for batched datasets --- docs/tutorials.md | 63 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/docs/tutorials.md b/docs/tutorials.md index 5033c8a..35537b8 100644 --- a/docs/tutorials.md +++ b/docs/tutorials.md @@ -113,3 +113,66 @@ def main(): if __name__ == "__main__": main() ``` + +## Integrating bs-scheduler with already batched data + +Using already batched data disables the automated batching features of the PyTorch DataLoaders. +In this case, a `BatchSizeManager` must be implemented by the end user to enable Batch Size Schedulers to change the batch size. +The `CustomBatchSizeManager` can be used when the dataset class implements the `get_batch_size` and `change_batch_size` +methods, similar to the example below: + +```python +from bs_scheduler import StepBS +from bs_scheduler.batch_size_schedulers import CustomBatchSizeManager +from torch import Tensor +import torch +from torch.utils.data import Dataset, DataLoader + + + +class BatchedDataset(Dataset): + def __init__(self, data: Tensor, batch_size: int): + self.data = data + self.batch_size = batch_size + + def __len__(self) -> int: + return len(self.data) // self.batch_size + + def __getitem__(self, i: int) -> Tensor: + return self.data[i * self.batch_size: (i + 1) * self.batch_size] + + def get_batch_size(self) -> int: + return self.batch_size + + def change_batch_size(self, batch_size: int): + self.batch_size = batch_size + + +dataset = BatchedDataset(torch.rand(10000, 128), batch_size=100) +dataloader = DataLoader(dataset, batch_size=None) +scheduler = StepBS(dataloader, + step_size=2, + gamma=2.0, + batch_size_manager=CustomBatchSizeManager(dataset), + max_batch_size=10000) + + +for epoch in range(10): + for batched_data in dataloader: + pass + print(f"There are {len(dataloader)} batches in epoch {epoch}.") + scheduler.step() +``` +Output: +``` +There are 100 batches in epoch 0. +There are 100 batches in epoch 1. +There are 50 batches in epoch 2. +There are 50 batches in epoch 3. +There are 25 batches in epoch 4. +There are 25 batches in epoch 5. +There are 12 batches in epoch 6. +There are 12 batches in epoch 7. +There are 6 batches in epoch 8. +There are 6 batches in epoch 9. +```