Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataset iterators - adds batching, collating for iterators #462

Merged
merged 3 commits into from
Feb 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 27 additions & 36 deletions examples/06-mnist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,50 +23,36 @@ use indicatif::ProgressBar;
use mnist::*;
use rand::prelude::{SeedableRng, StdRng};

use dfdx::{data::SubsetIterator, losses::cross_entropy_with_logits_loss, optim::Adam, prelude::*};
use dfdx::{data::*, optim::Adam, prelude::*};

#[cfg(not(feature = "cuda"))]
type Dev = Cpu;

#[cfg(feature = "cuda")]
type Dev = Cuda;

struct MnistDataset {
img: Vec<f32>,
lbl: Vec<usize>,
}
struct MnistTrainSet(Mnist);

impl MnistDataset {
fn train(path: &str) -> Self {
let mnist: Mnist = MnistBuilder::new().base_path(path).finalize();
Self {
img: mnist.trn_img.iter().map(|&v| v as f32 / 255.0).collect(),
lbl: mnist.trn_lbl.iter().map(|&v| v as usize).collect(),
}
impl MnistTrainSet {
fn new(path: &str) -> Self {
Self(MnistBuilder::new().base_path(path).finalize())
}
}

fn len(&self) -> usize {
self.lbl.len()
impl ExactSizeDataset for MnistTrainSet {
type Item = (Vec<f32>, usize);
fn get(&self, index: usize) -> Self::Item {
let mut img_data: Vec<f32> = Vec::with_capacity(784);
let start = 784 * index;
img_data.extend(
self.0.trn_img[start..start + 784]
.iter()
.map(|x| *x as f32 / 255.0),
);
(img_data, self.0.trn_lbl[index] as usize)
}

pub fn get_batch<const B: usize>(
&self,
dev: &Dev,
idxs: [usize; B],
) -> (
Tensor<Rank2<B, 784>, f32, Dev>,
Tensor<Rank2<B, 10>, f32, Dev>,
) {
let mut img_data: Vec<f32> = Vec::with_capacity(B * 784);
let mut lbl_data: Vec<f32> = Vec::with_capacity(B * 10);
for (_batch_i, &img_idx) in idxs.iter().enumerate() {
let start = 784 * img_idx;
img_data.extend(&self.img[start..start + 784]);
let mut choices = [0.0; 10];
choices[self.lbl[img_idx]] = 1.0;
lbl_data.extend(choices);
}
(dev.tensor(img_data), dev.tensor(lbl_data))
fn len(&self) -> usize {
self.0.trn_lbl.len()
}
}

Expand Down Expand Up @@ -100,17 +86,22 @@ fn main() {
let mut opt = Adam::new(&model, Default::default());

// initialize dataset
let dataset = MnistDataset::train(&mnist_path);
let dataset = MnistTrainSet::new(&mnist_path);
println!("Found {:?} training images", dataset.len());

for i_epoch in 0..10 {
let mut total_epoch_loss = 0.0;
let mut num_batches = 0;
let start = Instant::now();
let bar = ProgressBar::new(dataset.len() as u64);
for (img, lbl) in SubsetIterator::<BATCH_SIZE>::shuffled(dataset.len(), &mut rng)
.map(|i| dataset.get_batch(&dev, i))
for (img, lbl) in dataset
.shuffled(&mut rng)
.batch(Const::<BATCH_SIZE>)
.collate()
{
let img = dev.stack(img.map(|x| dev.tensor((x, (Const::<784>,)))));
let lbl = dev.one_hot_encode(Const::<10>, lbl);

let logits = model.forward_mut(img.traced());
let loss = cross_entropy_with_logits_loss(logits, lbl);

Expand Down
158 changes: 0 additions & 158 deletions src/data.rs

This file was deleted.

37 changes: 37 additions & 0 deletions src/data/arange.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use crate::{
shapes::*,
tensor::{CopySlice, DeviceStorage, Tensor, ZerosTensor},
};

use std::vec::Vec;

/// Generates a tensor with ordered data from 0 to `N`.
pub trait Arange<E: Dtype>: DeviceStorage + ZerosTensor<E> + CopySlice<E> {
/// Generates a tensor with ordered data from 0 to `N`.
///
/// Const sized tensor:
/// ```rust
/// # use dfdx::{prelude::*, data::Arange};
/// # let dev: Cpu = Default::default();
/// let t: Tensor<Rank1<5>, f32, _> = dev.arange(Const::<5>);
/// assert_eq!(t.array(), [0.0, 1.0, 2.0, 3.0, 4.0]);
/// ```
///
/// Runtime sized tensor:
/// ```rust
/// # use dfdx::{prelude::*, data::Arange};
/// # let dev: Cpu = Default::default();
/// let t: Tensor<(usize, ), f32, _> = dev.arange(5);
/// assert_eq!(t.as_vec(), [0.0, 1.0, 2.0, 3.0, 4.0]);
/// ```
fn arange<Size: Dim>(&self, n: Size) -> Tensor<(Size,), E, Self> {
let mut data = Vec::with_capacity(n.size());
for i in 0..n.size() {
data.push(E::from_usize(i).unwrap());
}
let mut t = self.zeros_like(&(n,));
t.copy_from(&data);
t
}
}
impl<E: Dtype, D: ZerosTensor<E> + CopySlice<E>> Arange<E> for D {}
61 changes: 61 additions & 0 deletions src/data/batch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use crate::shapes::{Const, Dim};

use std::vec::Vec;

pub struct Batcher<Size, I> {
size: Size,
iter: I,
}

impl<const N: usize, I: Iterator> Iterator for Batcher<Const<N>, I> {
type Item = [I::Item; N];
fn next(&mut self) -> Option<Self::Item> {
let items = [(); N].map(|_| self.iter.next());
if items.iter().any(Option::is_none) {
None
} else {
Some(items.map(Option::unwrap))
}
}
}

impl<I: Iterator> Iterator for Batcher<usize, I> {
type Item = Vec<I::Item>;
fn next(&mut self) -> Option<Self::Item> {
let mut batch = Vec::with_capacity(self.size);
for _ in 0..self.size {
batch.push(self.iter.next()?);
}
Some(batch)
}
}

/// Create batches of items from an [Iterator]
pub trait IteratorBatchExt: Iterator {
/// Return an [Iterator] where the items are either:
/// - `[Self::Item; N]`, if `Size` is [Const<N>]
/// - `Vec<Self::Item>`, if `Size` is [usize].
///
/// **Drop last is not supported - always returns exact batches**
///
/// Const batches:
/// ```rust
/// # use dfdx::{prelude::*, data::IteratorBatchExt};
/// let items: Vec<[usize; 5]> = (0..12).batch(Const::<5>).collect();
/// assert_eq!(&items, &[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]);
/// ```
///
/// Runtime batches:
/// ```rust
/// # use dfdx::{prelude::*, data::IteratorBatchExt};
/// let items: Vec<Vec<usize>> = (0..12).batch(5).collect();
/// assert_eq!(&items, &[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]);
/// ```
fn batch<Size: Dim>(self, size: Size) -> Batcher<Size, Self>
where
Self: Sized,
{
Batcher { size, iter: self }
}
}
impl<I: Iterator> IteratorBatchExt for I {}
Loading