Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(whl): add gpt utils #625

Merged
merged 10 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ding/torch_utils/backend_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import torch


def enable_tf32() -> None:
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
38 changes: 38 additions & 0 deletions ding/torch_utils/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from functools import partial
import math

import torch.optim
from torch.optim.lr_scheduler import LambdaLR


def get_lr_ratio(epoch: int, warmup_epochs: int, learning_rate: float, lr_decay_epochs: int, min_lr: float) -> float:
# 1) linear warmup for warmup_epochs.
if epoch < warmup_epochs:
return epoch / warmup_epochs
# 2) if epoch> lr_decay_epochs, return min learning rate
if epoch > lr_decay_epochs:
return min_lr / learning_rate
# 3) in between, use cosine decay down to min learning rate
decay_ratio = (epoch - warmup_epochs) / (lr_decay_epochs - warmup_epochs)
assert 0 <= decay_ratio <= 1
coefficient = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return (min_lr + coefficient * (learning_rate - min_lr)) / learning_rate


def cos_lr_scheduler(
optimizer: torch.optim.Optimizer,
learning_rate: float,
warmup_epochs: float = 5,
lr_decay_epochs: float = 100,
min_lr: float = 6e-5
) -> torch.optim.lr_scheduler.LambdaLR:
return LambdaLR(
optimizer,
partial(
get_lr_ratio,
warmup_epochs=warmup_epochs,
lr_decay_epochs=lr_decay_epochs,
min_lr=min_lr,
learning_rate=learning_rate
)
)
9 changes: 9 additions & 0 deletions ding/torch_utils/model_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch


def get_num_params(model: torch.nn.Module) -> int:
"""
Return the number of parameters in the model.
"""
n_params = sum(p.numel() for p in model.parameters())
return n_params
24 changes: 21 additions & 3 deletions ding/torch_utils/network/activation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -67,17 +69,33 @@ class Swish(nn.Module):
def __init__(self):
super(Swish, self).__init__()

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x * torch.sigmoid(x)
return x


class GELU(nn.Module):
r"""
Overview:
Gaussian Error Linear Units (GELU) activation function, which is widely used in NLP models like GPT, BERT.
The original paper can be viewed in: <link https://arxiv.org/pdf/1606.08415.pdf link>
Interfaces:
forward
"""

def __init__(self):
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
super(GELU, self).__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))


def build_activation(activation: str, inplace: bool = None) -> nn.Module:
r"""
Overview:
Return the activation module according to the given type.
Arguments:
- actvation (:obj:`str`): the type of activation module, now supports ['relu', 'glu', 'prelu']
- activation (:obj:`str`): the type of activation module, now supports ['relu', 'glu', 'prelu']
- inplace (:obj:`bool`): can optionally do the operation in-place in relu. Default ``None``
Returns:
- act_func (:obj:`nn.module`): the corresponding activation module
Expand All @@ -86,7 +104,7 @@ def build_activation(activation: str, inplace: bool = None) -> nn.Module:
assert activation == 'relu', 'inplace argument is not compatible with {}'.format(activation)
else:
inplace = False
act_func = {'relu': nn.ReLU(inplace=inplace), 'glu': GLU, 'prelu': nn.PReLU(), 'swish': Swish()}
act_func = {'relu': nn.ReLU(inplace=inplace), 'glu': GLU, 'prelu': nn.PReLU(), 'swish': Swish(), 'gelu': GELU()}
if activation in act_func.keys():
return act_func[activation]
else:
Expand Down
84 changes: 70 additions & 14 deletions ding/torch_utils/optimizer_helper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import torch
import math
from torch.nn.utils import clip_grad_norm_, clip_grad_value_
from torch._six import inf
from typing import Union, Iterable, Tuple, Callable
from typing import Union, Iterable, Tuple, Callable, List
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
Expand All @@ -11,6 +10,8 @@
import copy
import random

inf = math.inf


def calculate_grad_norm(model: torch.nn.Module, norm_type=2) -> float:
r"""
Expand Down Expand Up @@ -193,10 +194,10 @@ def _state_init(self, p, amsgrad):
# others
if torch.__version__ < "1.12.0":
state['step'] = 0
#TODO
#wait torch upgrad to 1.4, 1.3.1 didn't support memory format state['step'] = 0
# TODO
# wait torch upgrad to 1.4, 1.3.1 didn't support memory format state['step'] = 0
else:
state['step'] = torch.zeros((1, ), dtype=torch.float, device=p.device) \
state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
if self.defaults['capturable'] else torch.tensor(0.)

state['exp_avg'] = torch.zeros_like(p.data)
Expand Down Expand Up @@ -235,7 +236,7 @@ def step(self, closure: Union[Callable, None] = None):
if len(state) == 0:
self._state_init(p, group['amsgrad'])
grad = p.grad.data
#should we use same beta group?
# should we use same beta group?
beta1, beta2 = group['betas']
bias_correction2 = 1 - beta2 ** state['step']
state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad)
Expand All @@ -259,15 +260,15 @@ def step(self, closure: Union[Callable, None] = None):
if len(state) == 0:
self._state_init(p, group['amsgrad'])
grad = p.grad.data
#should we use same beta group?
# should we use same beta group?
beta1, beta2 = group['betas']
bias_correction2 = 1 - beta2 ** state['step']
state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad)
# sum total_norm
param_norm = grad.norm(self._clip_norm_type)
total_norm += param_norm.item() ** self._clip_norm_type

#sum momentum_norm
# sum momentum_norm
momentum = ((state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) *
self._clip_coef).norm(self._clip_norm_type)
total_momentum_norm += momentum.item() ** self._clip_norm_type
Expand All @@ -294,7 +295,7 @@ def step(self, closure: Union[Callable, None] = None):
if len(state) == 0:
self._state_init(p, group['amsgrad'])
grad = p.grad.data
#should we use same beta group?
# should we use same beta group?
beta1, beta2 = group['betas']
bias_correction2 = 1 - beta2 ** state['step']
state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad)
Expand Down Expand Up @@ -326,15 +327,15 @@ def step(self, closure: Union[Callable, None] = None):
if len(state) == 0:
self._state_init(p, group['amsgrad'])
grad = p.grad.data
#should we use same beta group?
# should we use same beta group?
beta1, beta2 = group['betas']
bias_correction2 = 1 - beta2 ** state['step']
state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad)
# sum total_norm
param_norm = grad.norm(self._ignore_norm_type)
total_norm += param_norm.item() ** self._ignore_norm_type

#sum momentum_norm
# sum momentum_norm
momentum = ((state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) *
self._ignore_coef).norm(self._ignore_norm_type)
total_momentum_norm += momentum.item() ** self._ignore_norm_type
Expand All @@ -348,7 +349,7 @@ def step(self, closure: Union[Callable, None] = None):
for p in group['params']:
p.grad.zero_()

#Adam optim type
# Adam optim type
if self._optim_type == 'adamw':
for group in self.param_groups:
for p in group['params']:
Expand Down Expand Up @@ -517,7 +518,7 @@ def step(self, closure: Union[Callable, None] = None):
param_norm = grad.norm(self._clip_norm_type)
total_norm += param_norm.item() ** self._clip_norm_type

#sum momentum_norm
# sum momentum_norm
momentum = (state['thre_square_avg'].sqrt() * self._clip_coef).norm(self._clip_norm_type)
total_momentum_norm += momentum.item() ** self._clip_norm_type
step = min(step, state['step'])
Expand Down Expand Up @@ -578,7 +579,7 @@ def step(self, closure: Union[Callable, None] = None):
param_norm = grad.norm(self._ignore_norm_type)
total_norm += param_norm.item() ** self._ignore_norm_type

#sum momentum_norm
# sum momentum_norm
momentum = (state['thre_square_avg'].sqrt() * self._ignore_coef).norm(self._ignore_norm_type)
total_momentum_norm += momentum.item() ** self._ignore_norm_type
step = min(step, state['step'])
Expand Down Expand Up @@ -730,3 +731,58 @@ def _retrieve_grad(self):
grad.append(p.grad.clone())
has_grad.append(torch.ones_like(p).to(p.device))
return grad, shape, has_grad


def configure_weight_decay(model: nn.Module, weight_decay: float) -> List:
r"""
Overview:
Separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layer-norm or embedding weights).
Arguments:
- model (:obj:`nn.Module`): the given PyTorch model.
- weight_decay (:obj:`float`): weight decay value for optimizer.
Returns:
- optim groups (:obj:`List`): the parameter groups to be set in the latter optimizer.
"""
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, )
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in model.named_modules():
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
# Because named_modules and named_parameters are recursive
# we will see the same tensors p many times. But doing it this way
# allows us to know which parent module any tensor p belongs to.
if pn.endswith('bias'):
# all biases will not be decayed
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
else:
decay.add(fpn)

decay = decay - no_decay
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
# validate that we considered every parameter
param_dict = {pn: p for pn, p in model.named_parameters()}
union_params = decay | no_decay
assert len(
param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
% (str(param_dict.keys() - union_params),)

optim_groups = [
{
"params": [param_dict[pn] for pn in sorted(list(decay))],
"weight_decay": weight_decay
},
{
"params": [param_dict[pn] for pn in sorted(list(no_decay))],
"weight_decay": 0.0
},
]

return optim_groups
21 changes: 21 additions & 0 deletions ding/torch_utils/tests/test_backend_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest
import torch

from ding.torch_utils.backend_helper import enable_tf32


@pytest.mark.cudatest
class TestBackendHelper:

def test_tf32(self):
r"""
Overview:
Test the tf32.
"""
enable_tf32()
net = torch.nn.Linear(3, 4)
x = torch.randn(1, 3)
y = torch.sum(net(x))
net.zero_grad()
y.backward()
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
assert net.weight.grad is not None
20 changes: 20 additions & 0 deletions ding/torch_utils/tests/test_lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest
import torch
from torch.optim import Adam

from ding.torch_utils.lr_scheduler import cos_lr_scheduler


@pytest.mark.unittest
class TestLRSchedulerHelper:

def test_cos_lr_scheduler(self):
r"""
Overview:
Test the cos lr scheduler.
"""
net = torch.nn.Linear(3, 4)
opt = Adam(net.parameters(), lr=1e-2)
scheduler = cos_lr_scheduler(opt, learning_rate=1e-2, min_lr=6e-5)
scheduler.step(101)
assert opt.param_groups[0]['lr'] == 6e-5
19 changes: 19 additions & 0 deletions ding/torch_utils/tests/test_model_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest
import torch

from ding.torch_utils.model_helper import get_num_params


@pytest.mark.unittest
class TestModelHelper:

def test_model_helper(self):
r"""
Overview:
Test the model helper.
"""
net = torch.nn.Linear(3, 4, bias=False)
assert get_num_params(net) == 12
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved

net = torch.nn.Conv2d(3, 3, kernel_size=3, bias=False)
assert get_num_params(net) == 81
20 changes: 19 additions & 1 deletion ding/torch_utils/tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn
import torch.optim as optim
from ding.torch_utils.optimizer_helper import Adam, RMSprop, calculate_grad_norm, \
calculate_grad_norm_without_bias_two_norm, PCGrad
calculate_grad_norm_without_bias_two_norm, PCGrad, configure_weight_decay
import pytest
import time

Expand Down Expand Up @@ -177,3 +177,21 @@ def naive_test(self):
pc_adam.pc_backward([loss1, loss2])
for p in net.parameters():
assert isinstance(p, torch.Tensor)


@pytest.mark.unittest
class TestWeightDecay:

def test_wd(self):
net = nn.Sequential(nn.Linear(3, 4), nn.LayerNorm(4))
x = torch.randn(1, 3)
group_params = configure_weight_decay(model=net, weight_decay=1e-4)
assert group_params[0]['weight_decay'] == 1e-4
assert group_params[1]['weight_decay'] == 0
assert len(group_params[0]['params']) == 1
assert len(group_params[1]['params']) == 3
opt = Adam(group_params, lr=1e-2)
opt.zero_grad()
y = torch.sum(net(x))
y.backward()
opt.step()
2 changes: 1 addition & 1 deletion ding/utils/data/collate_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import torch
import treetensor.torch as ttorch
import re
from torch._six import string_classes
import collections.abc as container_abcs
from ding.compatibility import torch_ge_131

int_classes = int
string_classes = (str, bytes)
np_str_obj_array_pattern = re.compile(r'[SaUO]')

default_collate_err_msg_format = (
Expand Down