Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(whl): add tabmwp env and prompt pg policy #667

Merged
merged 61 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
c0f0cac
wrong
May 11, 2023
918171c
update config
May 11, 2023
de33a2c
update config
May 11, 2023
8ef6e1a
update command policy
May 15, 2023
fdb74c0
debug
May 15, 2023
353de6d
debug
May 15, 2023
6223146
debug
May 15, 2023
c2ecc48
debug
May 15, 2023
6773bee
debug
May 15, 2023
25f6b3a
debug
May 19, 2023
9abda20
debug
May 19, 2023
bfdc122
debug
May 19, 2023
0510c83
debug
May 19, 2023
2385df5
debug
May 19, 2023
4cef99b
debug
May 19, 2023
f18fafd
add glm
May 19, 2023
c12b2a2
add glm
May 19, 2023
dd9589e
add glm model
May 20, 2023
a783416
add glm model
May 20, 2023
0bb2df2
add glm model
May 20, 2023
4335a3f
add glm model
May 20, 2023
79b2598
add eval return
May 22, 2023
61e4694
reformat
May 22, 2023
59f4098
modify action space
May 23, 2023
c6afc5d
modify action space
May 23, 2023
9345de6
polish answer process
May 24, 2023
d89e39a
update policy
May 24, 2023
e805a0a
update rwkv
May 24, 2023
1b3f2b4
update policy
May 24, 2023
40b6c46
polish
May 25, 2023
e1f7cac
polish
May 25, 2023
0213f32
Merge branch 'main' of https://github.com/kxzxvbk/DI-engine
May 25, 2023
c1c22fd
resolve conflict
May 25, 2023
39e520d
debug prompt pg
Jun 15, 2023
11bc0ad
add parse
Jun 25, 2023
8c9c40d
update load env
Jun 26, 2023
9cac14e
add merge files
Jul 5, 2023
ff5ad2d
add merge files
Jul 5, 2023
f6d6ac4
feature(whl): add internlm
Jul 10, 2023
c716308
feature(whl): add internlm
Jul 10, 2023
43a8168
update fix parse
Jul 11, 2023
56068b1
add new dataset
Jul 25, 2023
d32b64b
fix datafiles
Jul 26, 2023
1063fbd
polish code
Jul 28, 2023
286d976
polish env
Jul 28, 2023
b229216
polish
Aug 1, 2023
a8fc87b
polish
Aug 1, 2023
c0ad294
add model wrapper
Aug 1, 2023
eff4155
polish wrapper
Aug 1, 2023
00a64e5
polish
Aug 1, 2023
e17f9d5
remove redundant files
Aug 1, 2023
73fb2ee
reformat
Aug 1, 2023
476babf
polish
Aug 11, 2023
b277505
Merge branch 'main' into gpt3_env
Sep 2, 2023
8ad299f
polish
Sep 2, 2023
5e13da9
merge main
Sep 2, 2023
2f15217
debug
Sep 2, 2023
f9c2e73
polish readme
Sep 3, 2023
34044a1
reformat
Sep 3, 2023
16b826b
polish tabmwp
Sep 4, 2023
239ff18
test
Sep 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ding/model/template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .vac import VAC
from .bc import DiscreteBC, ContinuousBC
from .pg import PG
from .language_transformer import LanguageTransformer
# algorithm-specific
from .ppg import PPG
from .qmix import Mixer, QMix
Expand Down
63 changes: 63 additions & 0 deletions ding/model/template/language_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch

from ding.utils import MODEL_REGISTRY
from torch import nn
try:
from transformers import AutoTokenizer, AutoModelForTokenClassification
except ImportError:
import sys
from ditk import logging
logging.warning("not found transformer, please install it using: pip install transformers")
sys.exit(1)


@MODEL_REGISTRY.register('language_transformer')
class LanguageTransformer(nn.Module):

def __init__(
self,
model_name: str = "bert-base-uncased",
add_linear: bool = False,
embedding_size: int = 128,
freeze_encoder: bool = True
) -> None:
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForTokenClassification.from_pretrained(model_name)

# Freeze transformer encoder and only train the linear layer
if freeze_encoder:
for param in self.model.parameters():
param.requires_grad = False

if add_linear:
# Add an additional small, adjustable linear layer on top of BERT tuned through RL
self.embedding_size = embedding_size
self.linear = nn.Linear(
self.model.config.hidden_size, embedding_size
) # 768 for bert-base-uncased, distilbert-base-uncased
else:
self.linear = None

def _calc_embedding(self, x: list) -> torch.Tensor:
# ``truncation=True`` means that if the length of the prompt exceed the ``max_length`` of the tokenizer,
# the exceeded part will be truncated. ``padding=True`` means that if the length of the prompt does not reach
# the ``max_length``, the latter part will be padded. These settings ensure the length of encoded tokens is
# exactly ``max_length``, which can enable batch-wise computing.
input = self.tokenizer(x, truncation=True, padding=True, return_tensors="pt").to(self.model.device)
output = self.model(**input, output_hidden_states=True)
# Get last layer hidden states
last_hidden_states = output.hidden_states[-1]
# Get [CLS] hidden states
sentence_embedding = last_hidden_states[:, 0, :] # len(input_list) x hidden_size

if self.linear:
sentence_embedding = self.linear(sentence_embedding) # len(input_list) x embedding_size

return sentence_embedding

def forward(self, train_samples: list, candidate_samples: list) -> dict:
prompt_embedding = self._calc_embedding(train_samples)
cands_embedding = self._calc_embedding(candidate_samples)
scores = torch.mm(prompt_embedding, cands_embedding.t())
return {'dist': torch.distributions.Categorical(logits=scores), 'logit': scores}
21 changes: 21 additions & 0 deletions ding/model/template/tests/test_language_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest

from ding.model.template.language_transformer import LanguageTransformer


@pytest.mark.unittest
class TestNLPPretrainedModel:

def check_model(self):
test_pids = [1]
cand_pids = [0, 2, 4]
problems = [
"This is problem 0", "This is the first question", "Second problem is here", "Another problem",
"This is the last problem"
]
ctxt_list = [problems[pid] for pid in test_pids]
cands_list = [problems[pid] for pid in cand_pids]

model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256)
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
scores = model(ctxt_list, cands_list)
assert scores.shape == (1, 3)
1 change: 1 addition & 0 deletions ding/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,4 @@

# new-type policy
from .ppof import PPOFPolicy
from .prompt_pg import PromptPGPolicy
6 changes: 6 additions & 0 deletions ding/policy/command_mode_policy_instance.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from .madqn import MADQNPolicy
from .bdq import BDQPolicy
from .edac import EDACPolicy
from .prompt_pg import PromptPGPolicy


class EpsCommandModePolicy(CommandModePolicy):
Expand Down Expand Up @@ -426,3 +427,8 @@ def _get_setting_learn(self, command_info: dict) -> dict:

def _get_setting_eval(self, command_info: dict) -> dict:
return {}


@POLICY_REGISTRY.register('prompt_pg_command')
class PromptPGCommandModePolicy(PromptPGPolicy, DummyCommandModePolicy):
pass
257 changes: 257 additions & 0 deletions ding/policy/prompt_pg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
from typing import List, Dict, Any, Tuple, Union
from collections import namedtuple
import torch

from ding.rl_utils import get_train_sample
from ding.torch_utils import Adam, to_device
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate, default_decollate
from .base_policy import Policy


@POLICY_REGISTRY.register('prompt_pg')
class PromptPGPolicy(Policy):
r"""
Overview:
Policy class of Prompt Policy Gradient (PromptPG) algorithm.
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
"""
config = dict(
# (string) RL policy register name (refer to function "register_policy").
type='prompt_pg',
# (bool) whether to use cuda for network.
cuda=True,
# (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same)
on_policy=True, # for pg strictly on policy algorithm, this line should not be modified by users
# (bool) whether to use deterministic action for evaluation.
deterministic_eval=True,
learn=dict(
# (int) the number of samples for one update.
batch_size=64,
# (float) the step size of one gradient descend.
learning_rate=0.001,
# ==============================================================
# The following configs is algorithm-specific
# ==============================================================
# (float) loss weight of the entropy regularization, the weight of policy network is set to 1
entropy_weight=0.01,
# (float) max grad norm value.
grad_norm=5,
# (bool) whether to ignore done signal for non-termination env.
ignore_done=False,
),
collect=dict(
# (int) collect n_sample data, train model n_iteration times
# n_episode=8,
# (int) trajectory unroll length
unroll_len=1,
# ==============================================================
# The following configs is algorithm-specific
# ==============================================================
# (float) discount factor for future reward, defaults int [0, 1]
discount_factor=0,
collector=dict(get_train_sample=True),
),
eval=dict(),
)

def default_model(self) -> Tuple[str, List[str]]:
return 'language_transformer', ['ding.model.template.language_transformer']

def _init_learn(self) -> None:
r"""
Overview:
Learn mode init method. Called by ``self.__init__``.
Init the optimizer, algorithm config, main and target models.
"""
# Optimizer
self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)

self._entropy_weight = self._cfg.learn.entropy_weight
self._grad_norm = self._cfg.learn.grad_norm
self._learn_model = self._model # for compatibility

def _forward_learn(self, data: dict) -> Dict[str, Any]:
r"""
Overview:
Forward and backward function of learn mode.
Arguments:
- data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward']
Returns:
- info_dict (:obj:`Dict[str, Any]`): Including current lr and loss.
"""
self._model.train()
if self._cuda:
data = to_device(data, self._device)

return_infos = []
for i in range(0, len(data), self._cfg.learn.batch_size):
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
batch = default_collate(data[i:i + self._cfg.learn.batch_size])
# Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected)
train_samples, cand_samples = batch["obs"]["train_sample"], batch["obs"]["candidate_samples"]
for ii in range(len(cand_samples)):
cand_samples[ii] = cand_samples[ii][0]
output = self._learn_model.forward(train_samples, cand_samples)
return_ = batch['return']
if self._cuda:
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
return_ = return_.to(self._device)

# calculate PG loss
real_act = []
for b in range(batch['action'].shape[0]):
tmp_act = []
kxzxvbk marked this conversation as resolved.
Show resolved Hide resolved
act = batch['action'][b].item()
# The action is a combination of indexes of all selected prompts.
# For example, if [3, 6] is selected, action = 2 ** 3 + 2 ** 6 = 8 + 64 = 72.
# In this step, we calculate all the indexes.
idx = 0
while act > 0:
if act % 2 != 0:
tmp_act.append(idx)
act = act // 2
idx += 1
assert len(tmp_act) == self._cfg.shot_number
real_act.append(tmp_act)
real_act = torch.tensor(real_act, device=self._device) # shape: (B, shot_number)
# Calculate loss.
total_loss = 0
total_policy_loss, total_entropy_loss = 0, 0
for ii in range(self._cfg.shot_number):
log_prob = output['dist'].log_prob(real_act[:, ii])
policy_loss = -(log_prob * return_).mean()
total_policy_loss += policy_loss
total_entropy_loss += -self._cfg.learn.entropy_weight * output['dist'].entropy().mean()
total_loss = total_entropy_loss + total_policy_loss

# update
self._optimizer.zero_grad()
total_loss.backward()

grad_norm = torch.nn.utils.clip_grad_norm_(
list(self._learn_model.parameters()),
max_norm=self._grad_norm,
)
self._optimizer.step()

# only record last updates information in logger
return_info = {
'cur_lr': self._optimizer.param_groups[0]['lr'],
'total_loss': total_loss.item(),
'policy_loss': total_policy_loss.item(),
'entropy_loss': total_entropy_loss.item(),
'return_abs_max': return_.abs().max().item(),
'grad_norm': grad_norm,
}
return_infos.append(return_info)
return return_infos

def _init_collect(self) -> None:
self._unroll_len = self._cfg.collect.unroll_len
self._gamma = self._cfg.collect.discount_factor

def _forward_collect(self, data: dict) -> dict:
data_id = list(data.keys())
data = default_collate(list(data.values()))
self._model.eval()
with torch.no_grad():
# Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected)
for ii in range(len(data['candidate_samples'])):
data['candidate_samples'][ii] = data['candidate_samples'][ii][0]
output = self._model.forward(data['train_sample'], data['candidate_samples'])
# Generate actions.
act = []
mask = torch.zeros_like(output['logit'])
for ii in range(self._cfg.shot_number):
dist = torch.distributions.Categorical(logits=output['logit'] + mask)
actions = dist.sample()
act.append(actions)
for jj in range(actions.shape[0]):
mask[jj][actions[jj]] = -1e30
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
# `act` is shaped (shot_num, B)
real_act = []
for b in range(act[0].shape[0]):
tmp_act = torch.zeros_like(act[0])
for shot in act:
tmp_act += 2 ** shot[b].item()
real_act.append(tmp_act)
real_act = torch.tensor(real_act)
# `real_act` is shaped (B)
output['action'] = real_act
if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}

def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
r"""
Overview:
Generate dict type transition data from inputs.
Arguments:
- obs (:obj:`Any`): Env observation
- model_output (:obj:`dict`): Output of collect model, including at least ['action']
- timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \
(here 'obs' indicates obs after env step).
Returns:
- transition (:obj:`dict`): Dict type transition data.
"""
return {
'obs': obs,
'action': model_output['action'],
'reward': timestep.reward,
'done': timestep.done,
}

def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
r"""
Overview:
Get the trajectory and the n step return data, then sample from the n_step return data
Arguments:
- data (:obj:`list`): The trajectory's buffer list
Returns:
- samples (:obj:`dict`): The training samples generated
"""
if self._cfg.learn.ignore_done:
raise NotImplementedError

R = 0.
for i in reversed(range(len(data))):
R = self._gamma * R + data[i]['reward']
data[i]['return'] = R
return get_train_sample(data, self._unroll_len)

def _init_eval(self) -> None:
pass

def _forward_eval(self, data: dict) -> dict:
data_id = list(data.keys())
data = default_collate(list(data.values()))
self._model.eval()
with torch.no_grad():
# Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected)
for ii in range(len(data['candidate_samples'])):
data['candidate_samples'][ii] = data['candidate_samples'][ii][0]
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
output = self._model.forward(data['train_sample'], data['candidate_samples'])
# Generate actions.
act = []
mask = torch.zeros_like(output['logit'])
for ii in range(self._cfg.shot_number):
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
actions = torch.argmax(output['logit'] + mask, dim=-1)
act.append(actions)
for jj in range(actions.shape[0]):
mask[jj][actions[jj]] = -1e30
# `act` is shaped (shot_num, B)
real_act = []
for b in range(act[0].shape[0]):
tmp_act = torch.zeros_like(act[0])
for shot in act:
tmp_act += 2 ** shot[b].item()
real_act.append(tmp_act)
real_act = torch.tensor(real_act)
# `real_act` is shaped (B)
output['action'] = real_act
if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}

def _monitor_vars_learn(self) -> List[str]:
return super()._monitor_vars_learn() + ['policy_loss', 'entropy_loss', 'return_abs_max', 'grad_norm']
Empty file added dizoo/tabmwp/__init__.py
Empty file.
Loading