Skip to content

Commit

Permalink
Remove message of checking apex.amp module and add tests for featur…
Browse files Browse the repository at this point in the history
…es of gradient accumulation/mixed precision training (#46)

* MAINT: remove message of checking `apex.amp` module

The original propose of that message is to let users know gradient
accumulation and mixed precision training is supported but `apex`
is required.

With an attention brought up by issue #45, the following things are
confirmed:

- Gradient accumulation can still work properly without `apex.amp`.
  And that's why it would fall back on normal `loss.backward()` when
  `apex.amp` is not available or `amp.initialize()` wasn't called.

- When mixed precision training is required, that is to say model
  and optimizer are wrapped by `amp.initialize()`, `amp.scale_loss()`
  will be adopted automatically in current implementation.

Therefore, it seems that message of checking `apex.amp` module is
not necessary anymore.

* TST: parameter `batch_size` was not passed into `DataLoader`

This mistake made batch size of every data loader become the
default value: 1. Though it does not affect the correctness of
all test case, it still needs to be corrected.

However, `batch_size` of a `DataLoader` cannot be modified
after it is initialized. Therefore, we can only determine it
while generating tasks for test, and that's why `batch_size`
and `steps` is moved to the signature of `__init__` of each
`Task`.

* TST: move model to device while post-initializing tasks

This functionality was not added before, and it made all tests run
on CPU even if the pytest argument `--cpu_only` is not specified.

* TST: add tests for gradient accumulation and mixed precision training

* BUILD: add pytest-mock to dependencies for tests

* TST: change the import statement to import `amp` submodule directly
  • Loading branch information
NaleRaphael authored Jun 7, 2020
1 parent 98c4004 commit 23a23cf
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 32 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,15 @@ accumulation_steps = desired_batch_size // real_batch_size
dataset = ...

# Beware of the `batch_size` used by `DataLoader`
trainloader = DataLoader(dataset, batch_size=real_bs, shuffle=True)
trainloader = DataLoader(dataset, batch_size=real_batch_size, shuffle=True)

model = ...
criterion = ...
optimizer = ...

# (Optional) With this setting, `amp.scale_loss()` will be adopted automatically.
# model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

lr_finder = LRFinder(model, optimizer, criterion, device="cuda")
lr_finder.range_test(trainloader, end_lr=10, num_iter=100, step_mode="exp", accumulation_steps=accumulation_steps)
lr_finder.plot()
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,11 @@
python_requires=">=3.5.9",
install_requires=["matplotlib", "numpy", "torch>=0.4.1", "tqdm", "packaging"],
extras_require={
"tests": ["pytest", "pytest-cov"],
"tests": ["pytest", "pytest-cov", "pytest-mock"],
"dev": [
"pytest",
"pytest-cov",
"pytest-mock",
"flake8",
"black",
"pep8-naming",
Expand Down
64 changes: 44 additions & 20 deletions tests/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,57 +48,81 @@ def __post_init__(self):
elif not isinstance(self.device, torch.device):
raise TypeError("Invalid type of device.")

self.model.to(self.device)


class XORTask(BaseTask):
def __init__(self, validate=False):
def __init__(self, batch_size=8, steps=100, validate=False):
super(XORTask, self).__init__()
bs, steps = 8, 64
dataset = XORDataset(bs * steps)
n_total = batch_size * steps
dataset = XORDataset(n_total)
if validate:
self.train_loader = DataLoader(Subset(dataset, range(steps - bs)))
self.val_loader = DataLoader(Subset(dataset, range(steps - bs, steps)))
n_train = int(n_total * 0.9)
self.train_loader = DataLoader(
Subset(dataset, range(n_train)),
batch_size=batch_size
)
self.val_loader = DataLoader(
Subset(dataset, range(n_train, n_total)),
batch_size=batch_size
)
else:
self.train_loader = DataLoader(dataset)
self.train_loader = DataLoader(dataset, batch_size=batch_size)
self.val_loader = None

self.batch_size = bs
self.batch_size = batch_size
self.model = LinearMLP([8, 4, 1])
self.optimizer = optim.SGD(self.model.parameters(), lr=1e-5)
self.criterion = nn.MSELoss()
self.device = torch.device("cuda")


class ExtraXORTask(BaseTask):
def __init__(self, validate=False):
def __init__(self, batch_size=8, steps=100, validate=False):
super(ExtraXORTask, self).__init__()
bs, steps = 8, 64
dataset = ExtraXORDataset(bs * steps, extra_dims=2)
n_total = batch_size * steps
dataset = ExtraXORDataset(n_total, extra_dims=2)
if validate:
self.train_loader = DataLoader(Subset(dataset, range(steps - bs)))
self.val_loader = DataLoader(Subset(dataset, range(steps - bs, steps)))
n_train = int(n_total * 0.9)
self.train_loader = DataLoader(
Subset(dataset, range(n_train)),
batch_size=batch_size
)
self.val_loader = DataLoader(
Subset(dataset, range(n_train, n_total)),
batch_size=batch_size
)
else:
self.train_loader = DataLoader(dataset)
self.train_loader = DataLoader(dataset, batch_size=batch_size)
self.val_loader = None

self.batch_size = batch_size
self.model = LinearMLP([8, 4, 1])
self.optimizer = optim.SGD(self.model.parameters(), lr=1e-5)
self.criterion = nn.MSELoss()
self.device = torch.device("cuda")


class DiscriminativeLearningRateTask(BaseTask):
def __init__(self, validate=False):
def __init__(self, batch_size=8, steps=100, validate=False):
super(DiscriminativeLearningRateTask, self).__init__()
bs, steps = 8, 64
dataset = XORDataset(bs * steps)
n_total = batch_size * steps
dataset = XORDataset(n_total)
if validate:
self.train_loader = DataLoader(Subset(dataset, range(steps - bs)))
self.val_loader = DataLoader(Subset(dataset, range(steps - bs, steps)))
n_train = int(n_total * 0.9)
self.train_loader = DataLoader(
Subset(dataset, range(n_train)),
batch_size=batch_size
)
self.val_loader = DataLoader(
Subset(dataset, range(n_train, n_total)),
batch_size=batch_size
)
else:
self.train_loader = DataLoader(dataset)
self.train_loader = DataLoader(dataset, batch_size=batch_size)
self.val_loader = None

dataset = XORDataset(128)
self.batch_size = batch_size
self.model = LinearMLP([8, 4, 1])
self.optimizer = optim.SGD(
[
Expand Down
77 changes: 77 additions & 0 deletions tests/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
import task as mod_task


try:
from apex import amp

IS_AMP_AVAILABLE = True
except ImportError:
IS_AMP_AVAILABLE = False


def collect_task_classes():
names = [v for v in dir(mod_task) if v.endswith("Task") and v != "BaseTask"]
attrs = [getattr(mod_task, v) for v in names]
Expand Down Expand Up @@ -90,6 +98,75 @@ def test_exponential_lr_history(self):
assert lr_finder.history["lr"] == pytest.approx([1e-5, 1e-4, 1e-3, 1e-2, 0.1])


class TestGradientAccumulation:
def test_gradient_accumulation(self, mocker):
desired_bs, accum_steps = 32, 4
real_bs = desired_bs // accum_steps
num_iter = 10
task = mod_task.XORTask(batch_size=real_bs)

lr_finder = prepare_lr_finder(task)
spy = mocker.spy(lr_finder, "criterion")

lr_finder.range_test(
task.train_loader, num_iter=num_iter, accumulation_steps=accum_steps
)
# NOTE: We are using smaller batch size to simulate a large batch.
# So that the actual times of model/criterion called should be
# `(desired_bs/real_bs) * num_iter` == `accum_steps * num_iter`
assert spy.call_count == accum_steps * num_iter

@pytest.mark.skipif(
not (IS_AMP_AVAILABLE and mod_task.use_cuda()),
reason="`apex` module and gpu is required to run this test."
)
def test_gradient_accumulation_with_apex_amp(self, mocker):
desired_bs, accum_steps = 32, 4
real_bs = desired_bs // accum_steps
num_iter = 10
task = mod_task.XORTask(batch_size=real_bs)

# Wrap model and optimizer by `amp.initialize`. Beside, `amp` requires
# CUDA GPU. So we have to move model to GPU first.
model, optimizer, device = task.model, task.optimizer, task.device
model = model.to(device)
task.model, task.optimizer = amp.initialize(model, optimizer)

lr_finder = prepare_lr_finder(task)
spy = mocker.spy(amp, "scale_loss")

lr_finder.range_test(
task.train_loader, num_iter=num_iter, accumulation_steps=accum_steps
)
assert spy.call_count == accum_steps * num_iter


@pytest.mark.skipif(
not (IS_AMP_AVAILABLE and mod_task.use_cuda()),
reason="`apex` module and gpu is required to run these tests."
)
class TestMixedPrecision:
def test_mixed_precision(self, mocker):
batch_size = 32
num_iter = 10
task = mod_task.XORTask(batch_size=batch_size)

# Wrap model and optimizer by `amp.initialize`. Beside, `amp` requires
# CUDA GPU. So we have to move model to GPU first.
model, optimizer, device = task.model, task.optimizer, task.device
model = model.to(device)
task.model, task.optimizer = amp.initialize(model, optimizer)
assert hasattr(task.optimizer, "_amp_stash")

lr_finder = prepare_lr_finder(task)
spy = mocker.spy(amp, "scale_loss")

lr_finder.range_test(task.train_loader, num_iter=num_iter)
# NOTE: Here we did not perform gradient accumulation, so that call count
# of `amp.scale_loss` should equal to `num_iter`.
assert spy.call_count == num_iter


@pytest.mark.parametrize("num_iter", [0, 1])
@pytest.mark.parametrize("scheduler", ["exp", "linear"])
def test_scheduler_and_num_iter(num_iter, scheduler):
Expand Down
10 changes: 0 additions & 10 deletions torch_lr_finder/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,7 @@

IS_AMP_AVAILABLE = True
except ImportError:
import logging

logging.basicConfig()
logger = logging.getLogger(__name__)
logger.warning(
"To enable mixed precision training, please install `apex`. "
"Or you can re-install this package by the following command:\n"
' pip install torch-lr-finder -v --global-option="amp"'
)
IS_AMP_AVAILABLE = False
del logging


class DataLoaderIter(object):
Expand Down

0 comments on commit 23a23cf

Please sign in to comment.