diff --git a/ding/torch_utils/backend_helper.py b/ding/torch_utils/backend_helper.py new file mode 100644 index 0000000000..5962c6906a --- /dev/null +++ b/ding/torch_utils/backend_helper.py @@ -0,0 +1,6 @@ +import torch + + +def enable_tf32() -> None: + torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul + torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn diff --git a/ding/torch_utils/lr_scheduler.py b/ding/torch_utils/lr_scheduler.py new file mode 100644 index 0000000000..7c296ea180 --- /dev/null +++ b/ding/torch_utils/lr_scheduler.py @@ -0,0 +1,38 @@ +from functools import partial +import math + +import torch.optim +from torch.optim.lr_scheduler import LambdaLR + + +def get_lr_ratio(epoch: int, warmup_epochs: int, learning_rate: float, lr_decay_epochs: int, min_lr: float) -> float: + # 1) linear warmup for warmup_epochs. + if epoch < warmup_epochs: + return epoch / warmup_epochs + # 2) if epoch> lr_decay_epochs, return min learning rate + if epoch > lr_decay_epochs: + return min_lr / learning_rate + # 3) in between, use cosine decay down to min learning rate + decay_ratio = (epoch - warmup_epochs) / (lr_decay_epochs - warmup_epochs) + assert 0 <= decay_ratio <= 1 + coefficient = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) + return (min_lr + coefficient * (learning_rate - min_lr)) / learning_rate + + +def cos_lr_scheduler( + optimizer: torch.optim.Optimizer, + learning_rate: float, + warmup_epochs: float = 5, + lr_decay_epochs: float = 100, + min_lr: float = 6e-5 +) -> torch.optim.lr_scheduler.LambdaLR: + return LambdaLR( + optimizer, + partial( + get_lr_ratio, + warmup_epochs=warmup_epochs, + lr_decay_epochs=lr_decay_epochs, + min_lr=min_lr, + learning_rate=learning_rate + ) + ) diff --git a/ding/torch_utils/model_helper.py b/ding/torch_utils/model_helper.py new file mode 100644 index 0000000000..a0234612b0 --- /dev/null +++ b/ding/torch_utils/model_helper.py @@ -0,0 +1,9 @@ +import torch + + +def get_num_params(model: torch.nn.Module) -> int: + """ + Return the number of parameters in the model. + """ + n_params = sum(p.numel() for p in model.parameters()) + return n_params diff --git a/ding/torch_utils/network/activation.py b/ding/torch_utils/network/activation.py index 56afa03ce3..989703d8a1 100644 --- a/ding/torch_utils/network/activation.py +++ b/ding/torch_utils/network/activation.py @@ -1,3 +1,5 @@ +import math + import torch import torch.nn as nn import torch.nn.functional as F @@ -67,17 +69,33 @@ class Swish(nn.Module): def __init__(self): super(Swish, self).__init__() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = x * torch.sigmoid(x) return x +class GELU(nn.Module): + r""" + Overview: + Gaussian Error Linear Units (GELU) activation function, which is widely used in NLP models like GPT, BERT. + The original paper can be viewed in: + Interfaces: + forward + """ + + def __init__(self): + super(GELU, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + def build_activation(activation: str, inplace: bool = None) -> nn.Module: r""" Overview: Return the activation module according to the given type. Arguments: - - actvation (:obj:`str`): the type of activation module, now supports ['relu', 'glu', 'prelu'] + - activation (:obj:`str`): the type of activation module, now supports ['relu', 'glu', 'prelu'] - inplace (:obj:`bool`): can optionally do the operation in-place in relu. Default ``None`` Returns: - act_func (:obj:`nn.module`): the corresponding activation module @@ -86,7 +104,7 @@ def build_activation(activation: str, inplace: bool = None) -> nn.Module: assert activation == 'relu', 'inplace argument is not compatible with {}'.format(activation) else: inplace = False - act_func = {'relu': nn.ReLU(inplace=inplace), 'glu': GLU, 'prelu': nn.PReLU(), 'swish': Swish()} + act_func = {'relu': nn.ReLU(inplace=inplace), 'glu': GLU, 'prelu': nn.PReLU(), 'swish': Swish(), 'gelu': GELU()} if activation in act_func.keys(): return act_func[activation] else: diff --git a/ding/torch_utils/optimizer_helper.py b/ding/torch_utils/optimizer_helper.py index cbd5b3bedd..7103903442 100644 --- a/ding/torch_utils/optimizer_helper.py +++ b/ding/torch_utils/optimizer_helper.py @@ -1,8 +1,7 @@ import torch import math from torch.nn.utils import clip_grad_norm_, clip_grad_value_ -from torch._six import inf -from typing import Union, Iterable, Tuple, Callable +from typing import Union, Iterable, Tuple, Callable, List import torch.nn as nn import torch.nn.functional as F import torch.optim as optim @@ -11,6 +10,8 @@ import copy import random +inf = math.inf + def calculate_grad_norm(model: torch.nn.Module, norm_type=2) -> float: r""" @@ -193,10 +194,10 @@ def _state_init(self, p, amsgrad): # others if torch.__version__ < "1.12.0": state['step'] = 0 - #TODO - #wait torch upgrad to 1.4, 1.3.1 didn't support memory format state['step'] = 0 + # TODO + # wait torch upgrad to 1.4, 1.3.1 didn't support memory format state['step'] = 0 else: - state['step'] = torch.zeros((1, ), dtype=torch.float, device=p.device) \ + state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ if self.defaults['capturable'] else torch.tensor(0.) state['exp_avg'] = torch.zeros_like(p.data) @@ -235,7 +236,7 @@ def step(self, closure: Union[Callable, None] = None): if len(state) == 0: self._state_init(p, group['amsgrad']) grad = p.grad.data - #should we use same beta group? + # should we use same beta group? beta1, beta2 = group['betas'] bias_correction2 = 1 - beta2 ** state['step'] state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad) @@ -259,7 +260,7 @@ def step(self, closure: Union[Callable, None] = None): if len(state) == 0: self._state_init(p, group['amsgrad']) grad = p.grad.data - #should we use same beta group? + # should we use same beta group? beta1, beta2 = group['betas'] bias_correction2 = 1 - beta2 ** state['step'] state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad) @@ -267,7 +268,7 @@ def step(self, closure: Union[Callable, None] = None): param_norm = grad.norm(self._clip_norm_type) total_norm += param_norm.item() ** self._clip_norm_type - #sum momentum_norm + # sum momentum_norm momentum = ((state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) * self._clip_coef).norm(self._clip_norm_type) total_momentum_norm += momentum.item() ** self._clip_norm_type @@ -294,7 +295,7 @@ def step(self, closure: Union[Callable, None] = None): if len(state) == 0: self._state_init(p, group['amsgrad']) grad = p.grad.data - #should we use same beta group? + # should we use same beta group? beta1, beta2 = group['betas'] bias_correction2 = 1 - beta2 ** state['step'] state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad) @@ -326,7 +327,7 @@ def step(self, closure: Union[Callable, None] = None): if len(state) == 0: self._state_init(p, group['amsgrad']) grad = p.grad.data - #should we use same beta group? + # should we use same beta group? beta1, beta2 = group['betas'] bias_correction2 = 1 - beta2 ** state['step'] state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad) @@ -334,7 +335,7 @@ def step(self, closure: Union[Callable, None] = None): param_norm = grad.norm(self._ignore_norm_type) total_norm += param_norm.item() ** self._ignore_norm_type - #sum momentum_norm + # sum momentum_norm momentum = ((state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) * self._ignore_coef).norm(self._ignore_norm_type) total_momentum_norm += momentum.item() ** self._ignore_norm_type @@ -348,7 +349,7 @@ def step(self, closure: Union[Callable, None] = None): for p in group['params']: p.grad.zero_() - #Adam optim type + # Adam optim type if self._optim_type == 'adamw': for group in self.param_groups: for p in group['params']: @@ -517,7 +518,7 @@ def step(self, closure: Union[Callable, None] = None): param_norm = grad.norm(self._clip_norm_type) total_norm += param_norm.item() ** self._clip_norm_type - #sum momentum_norm + # sum momentum_norm momentum = (state['thre_square_avg'].sqrt() * self._clip_coef).norm(self._clip_norm_type) total_momentum_norm += momentum.item() ** self._clip_norm_type step = min(step, state['step']) @@ -578,7 +579,7 @@ def step(self, closure: Union[Callable, None] = None): param_norm = grad.norm(self._ignore_norm_type) total_norm += param_norm.item() ** self._ignore_norm_type - #sum momentum_norm + # sum momentum_norm momentum = (state['thre_square_avg'].sqrt() * self._ignore_coef).norm(self._ignore_norm_type) total_momentum_norm += momentum.item() ** self._ignore_norm_type step = min(step, state['step']) @@ -730,3 +731,58 @@ def _retrieve_grad(self): grad.append(p.grad.clone()) has_grad.append(torch.ones_like(p).to(p.device)) return grad, shape, has_grad + + +def configure_weight_decay(model: nn.Module, weight_decay: float) -> List: + r""" + Overview: + Separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layer-norm or embedding weights). + Arguments: + - model (:obj:`nn.Module`): the given PyTorch model. + - weight_decay (:obj:`float`): weight decay value for optimizer. + Returns: + - optim groups (:obj:`List`): the parameter groups to be set in the latter optimizer. + """ + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear, ) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + for mn, m in model.named_modules(): + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + # Because named_modules and named_parameters are recursive + # we will see the same tensors p many times. But doing it this way + # allows us to know which parent module any tensor p belongs to. + if pn.endswith('bias'): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + else: + decay.add(fpn) + + decay = decay - no_decay + # validate that we considered every parameter + param_dict = {pn: p for pn, p in model.named_parameters()} + union_params = decay | no_decay + assert len( + param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ + % (str(param_dict.keys() - union_params),) + + optim_groups = [ + { + "params": [param_dict[pn] for pn in sorted(list(decay))], + "weight_decay": weight_decay + }, + { + "params": [param_dict[pn] for pn in sorted(list(no_decay))], + "weight_decay": 0.0 + }, + ] + + return optim_groups diff --git a/ding/torch_utils/tests/test_backend_helper.py b/ding/torch_utils/tests/test_backend_helper.py new file mode 100644 index 0000000000..f598884693 --- /dev/null +++ b/ding/torch_utils/tests/test_backend_helper.py @@ -0,0 +1,21 @@ +import pytest +import torch + +from ding.torch_utils.backend_helper import enable_tf32 + + +@pytest.mark.cudatest +class TestBackendHelper: + + def test_tf32(self): + r""" + Overview: + Test the tf32. + """ + enable_tf32() + net = torch.nn.Linear(3, 4) + x = torch.randn(1, 3) + y = torch.sum(net(x)) + net.zero_grad() + y.backward() + assert net.weight.grad is not None diff --git a/ding/torch_utils/tests/test_lr_scheduler.py b/ding/torch_utils/tests/test_lr_scheduler.py new file mode 100644 index 0000000000..4ba52d9e1f --- /dev/null +++ b/ding/torch_utils/tests/test_lr_scheduler.py @@ -0,0 +1,20 @@ +import pytest +import torch +from torch.optim import Adam + +from ding.torch_utils.lr_scheduler import cos_lr_scheduler + + +@pytest.mark.unittest +class TestLRSchedulerHelper: + + def test_cos_lr_scheduler(self): + r""" + Overview: + Test the cos lr scheduler. + """ + net = torch.nn.Linear(3, 4) + opt = Adam(net.parameters(), lr=1e-2) + scheduler = cos_lr_scheduler(opt, learning_rate=1e-2, min_lr=6e-5) + scheduler.step(101) + assert opt.param_groups[0]['lr'] == 6e-5 diff --git a/ding/torch_utils/tests/test_model_helper.py b/ding/torch_utils/tests/test_model_helper.py new file mode 100644 index 0000000000..e2dd72e54e --- /dev/null +++ b/ding/torch_utils/tests/test_model_helper.py @@ -0,0 +1,19 @@ +import pytest +import torch + +from ding.torch_utils.model_helper import get_num_params + + +@pytest.mark.unittest +class TestModelHelper: + + def test_model_helper(self): + r""" + Overview: + Test the model helper. + """ + net = torch.nn.Linear(3, 4, bias=False) + assert get_num_params(net) == 12 + + net = torch.nn.Conv2d(3, 3, kernel_size=3, bias=False) + assert get_num_params(net) == 81 diff --git a/ding/torch_utils/tests/test_optimizer.py b/ding/torch_utils/tests/test_optimizer.py index 6f985d506f..389346fe54 100644 --- a/ding/torch_utils/tests/test_optimizer.py +++ b/ding/torch_utils/tests/test_optimizer.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch.optim as optim from ding.torch_utils.optimizer_helper import Adam, RMSprop, calculate_grad_norm, \ - calculate_grad_norm_without_bias_two_norm, PCGrad + calculate_grad_norm_without_bias_two_norm, PCGrad, configure_weight_decay import pytest import time @@ -177,3 +177,21 @@ def naive_test(self): pc_adam.pc_backward([loss1, loss2]) for p in net.parameters(): assert isinstance(p, torch.Tensor) + + +@pytest.mark.unittest +class TestWeightDecay: + + def test_wd(self): + net = nn.Sequential(nn.Linear(3, 4), nn.LayerNorm(4)) + x = torch.randn(1, 3) + group_params = configure_weight_decay(model=net, weight_decay=1e-4) + assert group_params[0]['weight_decay'] == 1e-4 + assert group_params[1]['weight_decay'] == 0 + assert len(group_params[0]['params']) == 1 + assert len(group_params[1]['params']) == 3 + opt = Adam(group_params, lr=1e-2) + opt.zero_grad() + y = torch.sum(net(x)) + y.backward() + opt.step() diff --git a/ding/utils/data/collate_fn.py b/ding/utils/data/collate_fn.py index 0485938d82..94cfa105ee 100644 --- a/ding/utils/data/collate_fn.py +++ b/ding/utils/data/collate_fn.py @@ -4,11 +4,11 @@ import torch import treetensor.torch as ttorch import re -from torch._six import string_classes import collections.abc as container_abcs from ding.compatibility import torch_ge_131 int_classes = int +string_classes = (str, bytes) np_str_obj_array_pattern = re.compile(r'[SaUO]') default_collate_err_msg_format = (