Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: optimizer.zero_grad() speed up option #927

Merged
merged 5 commits into from
Sep 4, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions bin/tests/check_dl_core_callbacks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -822,5 +822,31 @@ check_line_counts ${EXP_OUTPUT} "loaded state .*/last_full.pth (global epoch 2,

rm -rf ./tests/logs/_tests_dl_callbacks ${EXP_OUTPUT}

################################ pipeline 23 ################################
# checking optimizer use_fast_zero_grad
LOG_MSG='pipeline 23'
echo ${LOG_MSG}

PYTHONPATH=./examples:./catalyst:${PYTHONPATH} \
python catalyst/dl/scripts/run.py \
--stages/stage1/stage_params/num_epochs='2:int' \
--stages/stage1/callbacks_params/optimizer/use_fast_zero_grad='1:bool' \
--expdir=${EXPDIR} \
--config=${EXPDIR}/config0.yml \
--logdir=${LOGDIR} > ${EXP_OUTPUT}

cat ${EXP_OUTPUT}
check_line_counts ${EXP_OUTPUT} "=> Loading" 0

check_file_existence ${LOGFILE}
cat ${LOGFILE}
echo ${LOG_MSG}

check_checkpoints "${CHECKPOINTS}/best" 1
check_checkpoints "${CHECKPOINTS}/last" 1
check_checkpoints "${CHECKPOINTS}/stage1\.[[:digit:]]" 1
check_num_files ${CHECKPOINTS} 7 # 3x2 checkpoints + metrics.json

rm -rf ${LOGDIR} ${EXP_OUTPUT}

rm -rf {LOGDIR}
17 changes: 15 additions & 2 deletions catalyst/core/callbacks/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
grad_clip_params: Dict = None,
decouple_weight_decay: bool = True,
loss_key: str = None,
use_fast_zero_grad: bool = False,
xla_barrier: bool = True,
):
"""
Expand All @@ -37,8 +38,10 @@ def __init__(
accumulation_steps (int): number of steps before
``model.zero_grad()``
grad_clip_params (dict): params for gradient clipping
decouple_weight_decay (bool): If True - decouple weight decay
decouple_weight_decay (bool): If ``True`` - decouple weight decay
regularization.
use_fast_zero_grad (bool): boost ``optiomizer.zero_grad()``,
default is ``False``.
xla_barrier (bool): barrier option for xla. Here you can find
more about usage of `barrier flag
<https://pytorch.org/xla/release/1.5/index.html?
Expand Down Expand Up @@ -72,6 +75,7 @@ def __init__(
self._optimizer_wd: List[float] = [0.0]
self._optimizer_step_fn: Callable = None
self.is_xla = False
self.use_fast_zero_grad = use_fast_zero_grad
self.use_xla_barrier = xla_barrier

def grad_step(
Expand Down Expand Up @@ -217,8 +221,17 @@ def on_batch_end(self, runner: IRunner) -> None:
optimizer_wds=self._optimizer_wd,
grad_clip_fn=self.grad_clip_fn,
)
if not self.use_fast_zero_grad:
utils.maybe_recursive_call(self._optimizer, "zero_grad")
else:

utils.maybe_recursive_call(self._optimizer, "zero_grad")
def zero_grad(optimizer):
for group in optimizer.param_groups:
for p in group["params"]:
if p.grad is not None:
p.grad = None
ditwoo marked this conversation as resolved.
Show resolved Hide resolved

utils.maybe_recursive_call(self._optimizer, zero_grad)
self._accumulation_counter = 0


Expand Down
80 changes: 80 additions & 0 deletions catalyst/core/tests/test_optimizer_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# flake8: noqa
import random

import torch
import torch.nn as nn

from catalyst.core.callbacks import OptimizerCallback


class DummyRunner:
def __init__(
self, loss_value: torch.tensor, optimizer: torch.optim.Optimizer
):
self.batch_metrics = {"loss": loss_value}
self.is_train_loader = True
self.optimizer = optimizer
self.device = torch.device("cpu")

def get_attr(self, key, *args, **kwargs):
return getattr(self, key)


def test_zero_grad():
model = nn.Linear(10, 2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()

batch_size = 3
inp = torch.randn(batch_size, 10)
target = torch.FloatTensor(batch_size, 2).uniform_()

callback = OptimizerCallback(metric_key="loss", use_fast_zero_grad=False)

loss1 = criterion(model(inp), target)
loss1_value = loss1.detach().item()

runner = DummyRunner(loss1, optimizer)

callback.on_stage_start(runner)
callback.on_epoch_start(runner)
callback.on_batch_end(runner)

loss2 = criterion(model(inp), target)
loss2_value = loss2.detach().item()

runner.batch_metrics = {"loss": loss2}
callback.on_epoch_start(runner)
callback.on_batch_end(runner)

assert loss1_value > loss2_value


def test_fast_zero_grad():
model = nn.Linear(10, 2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()

batch_size = 3
inp = torch.randn(batch_size, 10)
target = torch.FloatTensor(batch_size, 2).uniform_()

callback = OptimizerCallback(metric_key="loss", use_fast_zero_grad=True)

loss1 = criterion(model(inp), target)
loss1_value = loss1.detach().item()

runner = DummyRunner(loss1, optimizer)

callback.on_stage_start(runner)
callback.on_epoch_start(runner)
callback.on_batch_end(runner)

loss2 = criterion(model(inp), target)
loss2_value = loss2.detach().item()

runner.batch_metrics = {"loss": loss2}
callback.on_epoch_start(runner)
callback.on_batch_end(runner)

assert loss1_value > loss2_value
9 changes: 6 additions & 3 deletions catalyst/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# flake8: noqa
# @TODO: code formatting issue for 20.07 release
from typing import Any, Callable, List
from typing import Any, Callable, List, Union
from datetime import datetime
import inspect
from pathlib import Path
Expand All @@ -9,7 +9,7 @@

def maybe_recursive_call(
object_or_dict,
method: str,
method: Union[str, Callable],
recursive_args=None,
recursive_kwargs=None,
**kwargs,
Expand Down Expand Up @@ -46,7 +46,10 @@ def maybe_recursive_call(
if not isinstance(r_args, (list, tuple)):
r_args = [r_args]
r_kwargs = recursive_kwargs or {}
return getattr(object_or_dict, method)(*r_args, **r_kwargs, **kwargs)
if isinstance(method, str):
return getattr(object_or_dict, method)(*r_args, **r_kwargs, **kwargs)
else:
return method(object_or_dict, *r_args, **r_kwargs, **kwargs)


def is_exception(ex: Any) -> bool:
Expand Down
1 change: 1 addition & 0 deletions tests/_tests_dl_callbacks/config0.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ stages:
callback: CriterionCallback
optimizer:
callback: OptimizerCallback
use_fast_zero_grad: false
accuracy:
callback: AccuracyCallback
accuracy_args: [1, 3, 5]
Expand Down