Skip to content

Commit

Permalink
feature: clip grad (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
cnstark authored Jul 13, 2022
1 parent 676aa1e commit 580c7ff
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 1 deletion.
5 changes: 5 additions & 0 deletions easytorch/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,9 @@ def init_training(self, cfg: Dict):
self.start_epoch = 0
self.ckpt_save_strategy = cfg['TRAIN'].get('CKPT_SAVE_STRATEGY')
self.best_metrics = {}
self.clip_grad_param = cfg['TRAIN'].get('CLIP_GRAD_PARAM')
if self.clip_grad_param is not None:
self.logger.info('Set clip grad, param: {}'.format(self.clip_grad_param))

# train data loader
self.train_data_loader = self.build_train_data_loader(cfg)
Expand Down Expand Up @@ -491,6 +494,8 @@ def backward(self, loss: torch.Tensor):

self.optim.zero_grad()
loss.backward()
if self.clip_grad_param is not None:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), **self.clip_grad_param)
self.optim.step()

@torch.no_grad()
Expand Down
2 changes: 1 addition & 1 deletion easytorch/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = '1.2.8'
__version__ = '1.2.9'
__all__ = ['__version__']
69 changes: 69 additions & 0 deletions examples/imagenet/configs/resnet50_clip_grad_8x_cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
from easydict import EasyDict

from imagenet_runner import ImagenetRunner

CFG = EasyDict()

CFG.DESC = 'imagenet resnet50'
CFG.RUNNER = ImagenetRunner
CFG.GPU_NUM = 8

CFG.MODEL = EasyDict()
CFG.MODEL.NAME = 'resnet50'

CFG.TRAIN = EasyDict()

CFG.TRAIN.NUM_EPOCHS = 90
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join(
'checkpoints',
'_'.join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)])
)
CFG.TRAIN.CKPT_SAVE_STRATEGY = None

CFG.TRAIN.OPTIM = EasyDict()
CFG.TRAIN.OPTIM.TYPE = 'SGD'
CFG.TRAIN.OPTIM.PARAM = {
'lr': 0.1,
'momentum': 0.9,
'weight_decay': 1e-4
}

CFG.TRAIN.LR_SCHEDULER = EasyDict()
CFG.TRAIN.LR_SCHEDULER.TYPE = 'StepLR'
CFG.TRAIN.LR_SCHEDULER.PARAM = {
'step_size': 30,
'gamma': 0.1
}

CFG.TRAIN.CLIP_GRAD_PARAM = {
'max_norm': 1.0
}

IMAGENET_PATH = 'datasets/imagenet/jpegs'

CFG.TRAIN.DATA = EasyDict()
CFG.TRAIN.DATA.BATCH_SIZE = 32
CFG.TRAIN.DATA.NUM_WORKERS = 4
CFG.TRAIN.DATA.SHUFFLE = True

CFG.TRAIN.DATA.DIR = os.path.join(IMAGENET_PATH, 'train')
CFG.TRAIN.DATA.CROP_SIZE = 224
CFG.TRAIN.DATA.NORMALIZE = {
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225]
}

CFG.VAL = EasyDict()

CFG.VAL.INTERVAL = 1

CFG.VAL.DATA = EasyDict()
CFG.VAL.DATA.BATCH_SIZE = 32
CFG.VAL.DATA.DIR = os.path.join(IMAGENET_PATH, 'val')
CFG.VAL.DATA.CROP_SIZE = 224
CFG.VAL.DATA.RESIZE = 256
CFG.VAL.DATA.NORMALIZE = {
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225]
}

0 comments on commit 580c7ff

Please sign in to comment.