Skip to content

Commit

Permalink
Added example in the tutorial page for batched datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
ancestor-mithril committed Dec 6, 2024
1 parent 8256b45 commit ff09b00
Showing 1 changed file with 63 additions and 0 deletions.
63 changes: 63 additions & 0 deletions docs/tutorials.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,66 @@ def main():
if __name__ == "__main__":
main()
```

## Integrating bs-scheduler with already batched data

Using already batched data disables the automated batching features of the PyTorch DataLoaders.
In this case, a `BatchSizeManager` must be implemented by the end user to enable Batch Size Schedulers to change the batch size.
The `CustomBatchSizeManager` can be used when the dataset class implements the `get_batch_size` and `change_batch_size`
methods, similar to the example below:

```python
from bs_scheduler import StepBS
from bs_scheduler.batch_size_schedulers import CustomBatchSizeManager
from torch import Tensor
import torch
from torch.utils.data import Dataset, DataLoader



class BatchedDataset(Dataset):
def __init__(self, data: Tensor, batch_size: int):
self.data = data
self.batch_size = batch_size

def __len__(self) -> int:
return len(self.data) // self.batch_size

def __getitem__(self, i: int) -> Tensor:
return self.data[i * self.batch_size: (i + 1) * self.batch_size]

def get_batch_size(self) -> int:
return self.batch_size

def change_batch_size(self, batch_size: int):
self.batch_size = batch_size


dataset = BatchedDataset(torch.rand(10000, 128), batch_size=100)
dataloader = DataLoader(dataset, batch_size=None)
scheduler = StepBS(dataloader,
step_size=2,
gamma=2.0,
batch_size_manager=CustomBatchSizeManager(dataset),
max_batch_size=10000)


for epoch in range(10):
for batched_data in dataloader:
pass
print(f"There are {len(dataloader)} batches in epoch {epoch}.")
scheduler.step()
```
Output:
```
There are 100 batches in epoch 0.
There are 100 batches in epoch 1.
There are 50 batches in epoch 2.
There are 50 batches in epoch 3.
There are 25 batches in epoch 4.
There are 25 batches in epoch 5.
There are 12 batches in epoch 6.
There are 12 batches in epoch 7.
There are 6 batches in epoch 8.
There are 6 batches in epoch 9.
```

0 comments on commit ff09b00

Please sign in to comment.