diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cccc63f..592d965 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ repos: - id: double-quote-string-fixer - repo: https://github.com/psf/black - rev: 21.12b0 + rev: 22.1.0 hooks: - id: black args: ["--line-length", "80", "--skip-string-normalization"] @@ -37,3 +37,16 @@ repos: hooks: - id: prettier language_version: system + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.931 + hooks: + - id: mypy + additional_dependencies: [ + # Type stubs + types-setuptools, + types-pkg_resources, + # Dependencies that are typed + numpy, + xarray, + ] diff --git a/conftest.py b/conftest.py index 44ad179..acc08d3 100644 --- a/conftest.py +++ b/conftest.py @@ -1,3 +1,4 @@ +# type: ignore import pytest diff --git a/dev-requirements.txt b/dev-requirements.txt index 0aeec8b..34f20d8 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,4 +1,6 @@ pytest +torch +coverage pytest-cov adlfs -r requirements.txt diff --git a/doc/api.rst b/doc/api.rst index f400b2c..f9f424c 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -5,12 +5,6 @@ API reference This page provides an auto-generated summary of Xbatcher's API. -Core -==== - -.. autoclass:: xbatcher.BatchGenerator - :members: - Dataset.batch and DataArray.batch ================================= @@ -22,3 +16,17 @@ Dataset.batch and DataArray.batch Dataset.batch.generator DataArray.batch.generator + +Core +==== + +.. autoclass:: xbatcher.BatchGenerator + :members: + +Dataloaders +=========== +.. autoclass:: xbatcher.loaders.torch.MapDataset + :members: + +.. autoclass:: xbatcher.loaders.torch.IterableDataset + :members: diff --git a/doc/conf.py b/doc/conf.py index 5d7674c..ef5de36 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -12,6 +12,8 @@ # All configuration values have a default; values that are commented out # serve to show the default. +# type: ignore + import os import sys diff --git a/setup.cfg b/setup.cfg index 959a2fb..84899d3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,7 +7,7 @@ select = B,C,E,F,W,T4,B9 [isort] known_first_party=xbatcher -known_third_party=numpy,pkg_resources,pytest,setuptools,sphinx_autosummary_accessors,xarray +known_third_party=numpy,pkg_resources,pytest,setuptools,sphinx_autosummary_accessors,torch,xarray multi_line_output=3 include_trailing_comma=True force_grid_wrap=0 diff --git a/setup.py b/setup.py index 685f7b9..50c7aa7 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +# type: ignore import os from setuptools import find_packages, setup diff --git a/xbatcher/accessors.py b/xbatcher/accessors.py index 4a92bf8..44c2e84 100644 --- a/xbatcher/accessors.py +++ b/xbatcher/accessors.py @@ -24,3 +24,21 @@ def generator(self, *args, **kwargs): Keyword arguments to pass to the `BatchGenerator` constructor. ''' return BatchGenerator(self._obj, *args, **kwargs) + + +@xr.register_dataarray_accessor('torch') +class TorchAccessor: + def __init__(self, xarray_obj): + self._obj = xarray_obj + + def to_tensor(self): + """Convert this DataArray to a torch.Tensor""" + import torch + + return torch.tensor(self._obj.data) + + def to_named_tensor(self): + """Convert this DataArray to a torch.Tensor with named dimensions""" + import torch + + return torch.tensor(self._obj.data, names=self._obj.dims) diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 612be61..da80995 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -2,6 +2,7 @@ import itertools from collections import OrderedDict +from typing import Any, Dict, Hashable, Iterator import xarray as xr @@ -99,12 +100,12 @@ class BatchGenerator: def __init__( self, - ds, - input_dims, - input_overlap={}, - batch_dims={}, - concat_input_dims=False, - preload_batch=True, + ds: xr.Dataset, + input_dims: Dict[Hashable, int], + input_overlap: Dict[Hashable, int] = {}, + batch_dims: Dict[Hashable, int] = {}, + concat_input_dims: bool = False, + preload_batch: bool = True, ): self.ds = _as_xarray_dataset(ds) @@ -115,7 +116,38 @@ def __init__( self.concat_input_dims = concat_input_dims self.preload_batch = preload_batch - def __iter__(self): + self._batches: Dict[ + int, Any + ] = self._gen_batches() # dict cache for batches + # in the future, we can make this a lru cache or similar thing (cachey?) + + def __iter__(self) -> Iterator[xr.Dataset]: + for batch in self._batches.values(): + yield batch + + def __len__(self) -> int: + return len(self._batches) + + def __getitem__(self, idx: int) -> xr.Dataset: + + if not isinstance(idx, int): + raise NotImplementedError( + f'{type(self).__name__}.__getitem__ currently requires a single integer key' + ) + + if idx < 0: + idx = list(self._batches)[idx] + + if idx in self._batches: + return self._batches[idx] + else: + raise IndexError('list index out of range') + + def _gen_batches(self) -> dict: + # in the future, we will want to do the batch generation lazily + # going the eager route for now is allowing me to fill out the loader api + # but it is likely to perform poorly. + batches = [] for ds_batch in self._iterate_batch_dims(self.ds): if self.preload_batch: ds_batch.load() @@ -130,15 +162,17 @@ def __iter__(self): ] dsc = xr.concat(all_dsets, dim='input_batch') new_input_dims = [ - dim + new_dim_suffix for dim in self.input_dims + str(dim) + new_dim_suffix for dim in self.input_dims ] - yield _maybe_stack_batch_dims(dsc, new_input_dims) + batches.append(_maybe_stack_batch_dims(dsc, new_input_dims)) else: for ds_input in input_generator: - yield _maybe_stack_batch_dims( - ds_input, list(self.input_dims) + batches.append( + _maybe_stack_batch_dims(ds_input, list(self.input_dims)) ) + return dict(zip(range(len(batches)), batches)) + def _iterate_batch_dims(self, ds): return _iterate_through_dataset(ds, self.batch_dims) diff --git a/xbatcher/loaders/__init__.py b/xbatcher/loaders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/xbatcher/loaders/torch.py b/xbatcher/loaders/torch.py new file mode 100644 index 0000000..f68f63a --- /dev/null +++ b/xbatcher/loaders/torch.py @@ -0,0 +1,88 @@ +from typing import Any, Callable, Optional, Tuple + +import torch + +# Notes: +# This module includes two PyTorch datasets. +# - The MapDataset provides an indexable interface +# - The IterableDataset provides a simple iterable interface +# Both can be provided as arguments to the the Torch DataLoader +# Assumptions made: +# - Each dataset takes pre-configured X/y xbatcher generators (may not always want two generators ina dataset) +# TODOs: +# - sort out xarray -> numpy pattern. Currently there is a hardcoded variable name for x/y +# - need to test with additional dataset parameters (e.g. transforms) + + +class MapDataset(torch.utils.data.Dataset): + def __init__( + self, + X_generator, + y_generator, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + ''' + PyTorch Dataset adapter for Xbatcher + + Parameters + ---------- + X_generator : xbatcher.BatchGenerator + y_generator : xbatcher.BatchGenerator + transform : callable, optional + A function/transform that takes in an array and returns a transformed version. + target_transform : callable, optional + A function/transform that takes in the target and transforms it. + ''' + self.X_generator = X_generator + self.y_generator = y_generator + self.transform = transform + self.target_transform = target_transform + + def __len__(self) -> int: + return len(self.X_generator) + + def __getitem__(self, idx) -> Tuple[Any, Any]: + if torch.is_tensor(idx): + idx = idx.tolist() + if len(idx) == 1: + idx = idx[0] + else: + raise NotImplementedError( + f'{type(self).__name__}.__getitem__ currently requires a single integer key' + ) + + # TODO: figure out the dataset -> array workflow + # currently hardcoding a variable name + X_batch = self.X_generator[idx]['x'].torch.to_tensor() + y_batch = self.y_generator[idx]['y'].torch.to_tensor() + + if self.transform: + X_batch = self.transform(X_batch) + + if self.target_transform: + y_batch = self.target_transform(y_batch) + return X_batch, y_batch + + +class IterableDataset(torch.utils.data.IterableDataset): + def __init__( + self, + X_generator, + y_generator, + ) -> None: + ''' + PyTorch Dataset adapter for Xbatcher + + Parameters + ---------- + X_generator : xbatcher.BatchGenerator + y_generator : xbatcher.BatchGenerator + ''' + + self.X_generator = X_generator + self.y_generator = y_generator + + def __iter__(self): + for xb, yb in zip(self.X_generator, self.y_generator): + yield (xb['x'].torch.to_tensor(), yb['y'].torch.to_tensor()) diff --git a/xbatcher/tests/test_accessors.py b/xbatcher/tests/test_accessors.py index d9be321..4860803 100644 --- a/xbatcher/tests/test_accessors.py +++ b/xbatcher/tests/test_accessors.py @@ -38,3 +38,25 @@ def test_batch_accessor_da(sample_ds_3d): assert isinstance(bg_acc, BatchGenerator) for batch_class, batch_acc in zip(bg_class, bg_acc): assert batch_class.equals(batch_acc) + + +def test_torch_to_tensor(sample_ds_3d): + torch = pytest.importorskip('torch') + + da = sample_ds_3d['foo'] + t = da.torch.to_tensor() + assert isinstance(t, torch.Tensor) + assert t.names == (None, None, None) + assert t.shape == da.shape + np.testing.assert_array_equal(t, da.values) + + +def test_torch_to_named_tensor(sample_ds_3d): + torch = pytest.importorskip('torch') + + da = sample_ds_3d['foo'] + t = da.torch.to_named_tensor() + assert isinstance(t, torch.Tensor) + assert t.names == da.dims + assert t.shape == da.shape + np.testing.assert_array_equal(t, da.values) diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 38acae9..23f9448 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -41,6 +41,28 @@ def test_constructor_coerces_to_dataset(): assert bg.ds.equals(da.to_dataset()) +@pytest.mark.parametrize('bsize', [5, 6]) +def test_batcher_lenth(sample_ds_1d, bsize): + bg = BatchGenerator(sample_ds_1d, input_dims={'x': bsize}) + assert len(bg) == sample_ds_1d.dims['x'] // bsize + + +def test_batcher_getitem(sample_ds_1d): + bg = BatchGenerator(sample_ds_1d, input_dims={'x': 10}) + + # first batch + assert bg[0].dims['x'] == 10 + # last batch + assert bg[-1].dims['x'] == 10 + # raises IndexError for out of range index + with pytest.raises(IndexError, match=r'list index out of range'): + bg[9999999] + + # raises NotImplementedError for iterable index + with pytest.raises(NotImplementedError): + bg[[1, 2, 3]] + + # TODO: decide how to handle bsizes like 15 that don't evenly divide the dimension # Should we enforce that each batch size always has to be the same @pytest.mark.parametrize('bsize', [5, 10]) diff --git a/xbatcher/tests/test_torch_loaders.py b/xbatcher/tests/test_torch_loaders.py new file mode 100644 index 0000000..4e4a412 --- /dev/null +++ b/xbatcher/tests/test_torch_loaders.py @@ -0,0 +1,111 @@ +import numpy as np +import pytest +import xarray as xr + +torch = pytest.importorskip('torch') + +from xbatcher import BatchGenerator +from xbatcher.loaders.torch import IterableDataset, MapDataset + + +@pytest.fixture(scope='module') +def ds_xy(): + n_samples = 100 + n_features = 5 + ds = xr.Dataset( + { + 'x': ( + ['sample', 'feature'], + np.random.random((n_samples, n_features)), + ), + 'y': (['sample'], np.random.random(n_samples)), + }, + ) + return ds + + +def test_map_dataset(ds_xy): + + x = ds_xy['x'] + y = ds_xy['y'] + + x_gen = BatchGenerator(x, {'sample': 10}) + y_gen = BatchGenerator(y, {'sample': 10}) + + dataset = MapDataset(x_gen, y_gen) + + # test __getitem__ + x_batch, y_batch = dataset[0] + assert len(x_batch) == len(y_batch) + assert isinstance(x_batch, torch.Tensor) + + idx = torch.tensor([0]) + x_batch, y_batch = dataset[idx] + assert len(x_batch) == len(y_batch) + assert isinstance(x_batch, torch.Tensor) + + with pytest.raises(NotImplementedError): + idx = torch.tensor([0, 1]) + x_batch, y_batch = dataset[idx] + + # test __len__ + assert len(dataset) == len(x_gen) + + # test integration with torch DataLoader + loader = torch.utils.data.DataLoader(dataset) + + for x_batch, y_batch in loader: + assert len(x_batch) == len(y_batch) + assert isinstance(x_batch, torch.Tensor) + + # TODO: why does pytorch add an extra dimension (length 1) to x_batch + assert x_gen[-1]['x'].shape == x_batch.shape[1:] + # TODO: also need to revisit the variable extraction bits here + assert np.array_equal(x_gen[-1]['x'], x_batch[0, :, :]) + + +def test_map_dataset_with_transform(ds_xy): + + x = ds_xy['x'] + y = ds_xy['y'] + + x_gen = BatchGenerator(x, {'sample': 10}) + y_gen = BatchGenerator(y, {'sample': 10}) + + def x_transform(batch): + return batch * 0 + 1 + + def y_transform(batch): + return batch * 0 - 1 + + dataset = MapDataset( + x_gen, y_gen, transform=x_transform, target_transform=y_transform + ) + x_batch, y_batch = dataset[0] + assert len(x_batch) == len(y_batch) + assert isinstance(x_batch, torch.Tensor) + assert (x_batch == 1).all() + assert (y_batch == -1).all() + + +def test_iterable_dataset(ds_xy): + + x = ds_xy['x'] + y = ds_xy['y'] + + x_gen = BatchGenerator(x, {'sample': 10}) + y_gen = BatchGenerator(y, {'sample': 10}) + + dataset = IterableDataset(x_gen, y_gen) + + # test integration with torch DataLoader + loader = torch.utils.data.DataLoader(dataset) + + for x_batch, y_batch in loader: + assert len(x_batch) == len(y_batch) + assert isinstance(x_batch, torch.Tensor) + + # TODO: why does pytorch add an extra dimension (length 1) to x_batch + assert x_gen[-1]['x'].shape == x_batch.shape[1:] + # TODO: also need to revisit the variable extraction bits here + assert np.array_equal(x_gen[-1]['x'], x_batch[0, :, :])