Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests #27

Merged
merged 5 commits into from
Apr 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ matplotlib
numpy
torch>=0.4.1
tqdm
pytest
66 changes: 66 additions & 0 deletions tests/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
## Requirements
- pytest

## Run tests
- normal (use GPU if it's available)
```bash
# in root directory of this package
$ python -mpytest ./tests
```

- forcibly run all tests on CPU
```bash
# in root directory of this package
$ python -mpytest --cpu_only ./tests
```

## How to add new test cases
To make it able to create test cases and re-use settings conveniently, here we package those basic elements for running a training task into objects inheriting `BaseTask` in `task.py`.

A `BaseTask` is formed of these members:
- `batch_size`
- `model`
- `optimizer`
- `criterion` (loss function)
- `device` (`cpu`, `cuda`, etc.)
- `train_loader` (`torch.utils.data.DataLoader` for training set)
- `val_loader` (`torch.utils.data.DataLoader` for validation set)

If you want to create a new task, just write a new class inheriting `BaseTask` and add your configuration in `__init__`.

Note-1: Any task inheriting `BaseTask` in `task.py` will be collected by the function `test_lr_finder.py::collect_task_classes()`.

Note-2: Model and dataset will be instantiated when a task class is **initialized**, so that it is not recommended to collect a lot of task **objects** at once.


### Directly use specific task in a test case
```python
from . import task as mod_task
def test_run():
task = mod_task.FooTask()
...
```

### Use `pytest.mark.parametrize`
- Use specified task in a test case
```python
@pytest.mark.parametrize(
'cls_task, arg', # names of parameters (see also the signature of the following function)
[
(task.FooTask, 'foo'),
(task.BarTask, 'bar'),
], # list of parameters
)
def test_run(cls_task, arg):
...
```

- Use all existing tasks in a test case
```python
@pytest.mark.parametrize(
'cls_task',
collect_task_classes(),
)
def test_run(cls_task):
...
```
Empty file added tests/__init__.py
Empty file.
41 changes: 41 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest


class CustomCommandLineOption(object):
"""An object for storing command line options parsed by pytest.

Since `pytest.config` global object is deprecated and removed in version
5.0, this class is made to work as a store of command line options for
those components which are not able to access them via `request.config`.
"""

def __init__(self):
self._content = {}

def __str__(self):
return str(self._content)

def add(self, key, value):
self._content.update({key: value})

def delete(self, key):
del self._content[key]

def __getattr__(self, key):
if key in self._content:
return self._content[key]
else:
return super(CustomCommandLineOption, self).__getattr__(key)


def pytest_addoption(parser):
parser.addoption(
"--cpu_only", action="store_true", help="Forcibly run all tests on CPU."
)


def pytest_configure(config):
# Bind a config object to `pytest` module instance
pytest.custom_cmdopt = CustomCommandLineOption()

pytest.custom_cmdopt.add("cpu_only", config.getoption("--cpu_only"))
58 changes: 58 additions & 0 deletions tests/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numpy as np
import torch
from torch.utils.data import Dataset


class XORDataset(Dataset):
def __init__(self, length, shape=None):
"""
Arguments:
length (int): length of dataset, which equals `len(self)`.
shape (list, tuple, optional): shape of dataset. If it isn't
specified, it will be initialized to `(length, 8)`.
Default: None.
"""
_shape = (length,) + tuple(shape) if shape else (length, 8)
raw = np.random.randint(0, 2, _shape)
self.data = torch.FloatTensor(raw)
self.label = (
torch.tensor(np.bitwise_xor.reduce(raw, axis=1)).unsqueeze(dim=1).float()
)

def __getitem__(self, index):
return self.data[index], self.label[index]

def __len__(self):
return len(self.data)


class ExtraXORDataset(XORDataset):
""" A XOR dataset which is able to return extra values. """

def __init__(self, length, shape=None, extra_dims=1):
"""
Arguments:
length (int): length of dataset, which equals `len(self)`.
shape (list, tuple, optional): shape of dataset. If it isn't
specified, it will be initialized to `(length, 8)`.
Default: None.
extra_dims (int, optional): dimension of extra values.
Default: 1.
"""
super(ExtraXORDataset, self).__init__(length, shape=shape)
if extra_dims:
_extra_shape = (length, extra_dims)
self.extras = torch.randint(0, 2, _extra_shape)
else:
self.extras = None

def __getitem__(self, index):
if self.extras is not None:
retval = [self.data[index], self.label[index]]
retval.extend([v for v in self.extras[index]])
return retval
else:
return self.data[index], self.label[index]

def __len__(self):
return len(self.data)
15 changes: 15 additions & 0 deletions tests/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class LinearMLP(nn.Module):
def __init__(self, layer_dim):
super(LinearMLP, self).__init__()
io_pairs = zip(layer_dim[:-1], layer_dim[1:])
layers = [nn.Linear(idim, odim) for idim, odim in io_pairs]
self.net = nn.Sequential(*layers)

def forward(self, x):
return self.net(x)
112 changes: 112 additions & 0 deletions tests/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import pytest

from .model import LinearMLP
from .dataset import XORDataset, ExtraXORDataset


def use_cuda():
if pytest.custom_cmdopt.cpu_only:
return False
else:
return torch.cuda.is_available()


class TaskTemplate(type):
def __call__(cls, *args, **kwargs):
obj = type.__call__(cls, *args, **kwargs)
if hasattr(obj, "__post_init__"):
obj.__post_init__()
return obj


class BaseTask(metaclass=TaskTemplate):
def __init__(self):
self.batch_size = -1
self.model = None
self.optimizer = None
self.criterion = None
self.device = None
self.train_loader = None
self.val_loader = None

def __post_init__(self):
# Check whether cuda is available or not, and we will cast `self.device`
# to `torch.device` here to make sure operations related to moving tensor
# would work fine later.
if not use_cuda():
self.device = None
if self.device is None:
return

if isinstance(self.device, str):
self.device = torch.device(self.device)
elif not isinstance(self.device, torch.device):
raise TypeError("Invalid type of device.")


class XORTask(BaseTask):
def __init__(self, validate=False):
super(XORTask, self).__init__()
bs, steps = 8, 64
dataset = XORDataset(bs * steps)
if validate:
self.train_loader = DataLoader(Subset(dataset, range(steps - bs)))
self.val_loader = DataLoader(Subset(dataset, range(steps - bs, steps)))
else:
self.train_loader = DataLoader(dataset)
self.val_loader = None

self.batch_size = bs
self.model = LinearMLP([8, 4, 1])
self.optimizer = optim.SGD(self.model.parameters(), lr=1e-3)
self.criterion = nn.MSELoss()
self.device = torch.device("cuda")


class ExtraXORTask(BaseTask):
def __init__(self, validate=False):
super(ExtraXORTask, self).__init__()
bs, steps = 8, 64
dataset = ExtraXORDataset(bs * steps, extra_dims=2)
if validate:
self.train_loader = DataLoader(Subset(dataset, range(steps - bs)))
self.val_loader = DataLoader(Subset(dataset, range(steps - bs, steps)))
else:
self.train_loader = DataLoader(dataset)
self.val_loader = None

self.model = LinearMLP([8, 4, 1])
self.optimizer = optim.SGD(self.model.parameters(), lr=1e-3)
self.criterion = nn.MSELoss()
self.device = torch.device("cuda")


class DiscriminativeLearningRateTask(BaseTask):
def __init__(self, validate=False):
super(DiscriminativeLearningRateTask, self).__init__()
bs, steps = 8, 64
dataset = XORDataset(bs * steps)
if validate:
self.train_loader = DataLoader(Subset(dataset, range(steps - bs)))
self.val_loader = DataLoader(Subset(dataset, range(steps - bs, steps)))
else:
self.train_loader = DataLoader(dataset)
self.val_loader = None

dataset = XORDataset(128)
self.model = LinearMLP([8, 4, 1])
self.optimizer = optim.SGD(
[
{"params": self.model.net[0].parameters(), "lr": 0.01},
{"params": self.model.net[1].parameters(), "lr": 0.001},
],
lr=1e-3,
momentum=0.5,
)
self.criterion = nn.MSELoss()
self.device = torch.device("cuda")
72 changes: 72 additions & 0 deletions tests/test_lr_finder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pytest
from torch_lr_finder import LRFinder

from . import task as mod_task


def collect_task_classes():
names = [v for v in dir(mod_task) if v.endswith("Task") and v != "BaseTask"]
attrs = [getattr(mod_task, v) for v in names]
classes = [v for v in attrs if issubclass(v, mod_task.BaseTask)]
return classes


def prepare_lr_finder(task, **kwargs):
model = task.model
optimizer = task.optimizer
criterion = task.criterion
config = {
"device": kwargs.get("device", None),
"memory_cache": kwargs.get("memory_cache", True),
"cache_dir": kwargs.get("cache_dir", None),
}
lr_finder = LRFinder(model, optimizer, criterion, **config)
return lr_finder


def get_optim_lr(optimizer):
return [grp["lr"] for grp in optimizer.param_groups]


class TestRangeTest:
@pytest.mark.parametrize("cls_task", collect_task_classes())
def test_run(self, cls_task):
task = cls_task()
init_lrs = get_optim_lr(task.optimizer)

lr_finder = prepare_lr_finder(task)
lr_finder.range_test(task.train_loader)

# check whether lr is actually changed
assert max(lr_finder.history["lr"]) >= init_lrs[0]

@pytest.mark.parametrize("cls_task", collect_task_classes())
def test_run_with_val_loader(self, cls_task):
task = cls_task(validate=True)
init_lrs = get_optim_lr(task.optimizer)

lr_finder = prepare_lr_finder(task)
lr_finder.range_test(task.train_loader, val_loader=task.val_loader)

# check whether lr is actually changed
assert max(lr_finder.history["lr"]) >= init_lrs[0]


class TestReset:
@pytest.mark.parametrize(
"cls_task",
[
mod_task.XORTask,
mod_task.DiscriminativeLearningRateTask,
],
)
def test_reset(self, cls_task):
task = cls_task()
init_lrs = get_optim_lr(task.optimizer)

lr_finder = prepare_lr_finder(task)
lr_finder.range_test(task.train_loader, val_loader=task.val_loader)
lr_finder.reset()

restored_lrs = get_optim_lr(task.optimizer)
assert init_lrs == restored_lrs