-
Notifications
You must be signed in to change notification settings - Fork 380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature(whl): add rlhf pipeline. #748
base: main
Are you sure you want to change the base?
Changes from 12 commits
7ba2125
eb5d8c8
f45b742
1ca316f
db72c2c
3f1e47b
dc4cece
30f2994
c1cc454
24de047
c9b71ee
4418c68
d5f931b
11d3c48
f95529f
31a6191
538ddd2
f3a8245
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import copy | ||
from typing import TYPE_CHECKING | ||
from easydict import EasyDict | ||
import treetensor.torch as ttorch | ||
|
@@ -190,4 +191,64 @@ def __call__(self, ctx: "OnlineRLContext") -> None: | |
break | ||
|
||
|
||
class ChatCollector: | ||
""" | ||
Overview: | ||
The class of the collector running by steps, including model inference and transition \ | ||
process. Use the `__call__` method to execute the whole collection process. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why indent here |
||
""" | ||
|
||
def __new__(cls, *args, **kwargs): | ||
if task.router.is_active and not task.has_role(task.role.COLLECTOR): | ||
return task.void() | ||
return super(ChatCollector, cls).__new__(cls) | ||
|
||
def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll_len: int = 1) -> None: | ||
""" | ||
Arguments: | ||
- seed (:obj:`int`): Random seed. | ||
- policy (:obj:`Policy`): The policy to be collected. | ||
- env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ | ||
its derivatives are supported. | ||
""" | ||
self.env = env | ||
self.env.seed(seed) | ||
self.env.launch() | ||
self.env = self._envs[0] | ||
self.policy = policy | ||
self.n_sample = n_sample | ||
self.unroll_len = unroll_len | ||
|
||
def __call__(self, ctx: "OnlineRLContext") -> None: | ||
""" | ||
Overview: | ||
An encapsulation of inference and rollout middleware. Stop when completing \ | ||
the target number of steps. | ||
Input of ctx: | ||
- env_step (:obj:`int`): The env steps which will increase during collection. | ||
""" | ||
device = self.policy._device | ||
|
||
obs = ttorch.as_tensor(self.env.last_batch[0]['text_vec']) | ||
batch_size = obs.shape[0] | ||
obs = obs.to(device) | ||
|
||
total_action = [[] for _ in range(batch_size)] # [B, answers_per_question, T] | ||
for _ in range(self.policy._cfg.answers_per_question): | ||
_, inference_output = self.policy._model.actor.generate(obs, **ctx.collect_kwargs) | ||
for i in range(batch_size): | ||
total_action[i].append(copy.deepcopy(inference_output[i])) | ||
|
||
mask, resp, rew = self.env.step(total_action) | ||
ctx.env_step += 1 | ||
ctx.env_episode += 1 | ||
|
||
train_data = {} | ||
train_data['obs'] = resp # [B x answer-per-question, T] | ||
train_data['reward'] = rew # [B x answer-per-question, ] | ||
train_data['mask'] = mask # [B x answer-per-question, T] | ||
|
||
ctx.train_data = ttorch.as_tensor(train_data) | ||
|
||
|
||
# TODO battle collector |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,3 +21,20 @@ def create_model(cfg: EasyDict) -> torch.nn.Module: | |
import_module(cfg.pop('import_names', [])) | ||
# here we must use the pop opeartion to ensure compatibility | ||
return MODEL_REGISTRY.build(cfg.pop("type"), **cfg) | ||
|
||
|
||
def top_p_logits(logits, topp=0.9, filter_value=0, min_topk=1): | ||
""" | ||
Filter a distribution of logits using nucleus (top-p) filtering | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. polish comments add add unittest |
||
https://github.com/OpenLMLab/MOSS/blob/e088f438d1a95d424c6dffef0d73134ebe62cb72/models_jittor/generation.py#L146 | ||
""" | ||
cum_logits = logits.clone() | ||
if topp > 0: | ||
logits_sorted, inds = torch.sort(logits, dim=-1, descending=True) | ||
mask = (logits_sorted.cumsum(dim=-1) - logits_sorted) >= topp | ||
mask[:, :min_topk] = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
# Remove tokens with cumulative top_p above the threshold | ||
mask = torch.zeros_like(mask).to(torch.bool).scatter_(dim=-1, index=inds, src=mask) | ||
cum_logits[mask] = filter_value | ||
cum_logits.div_(cum_logits.sum(dim=-1, keepdim=True)) | ||
return cum_logits |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,7 @@ | ||
from typing import Union, Dict, Optional | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move these modifications to a new single file: |
||
from transformers import LlamaTokenizer | ||
from transformers.models.llama.modeling_llama import LlamaForCausalLM | ||
from easydict import EasyDict | ||
import torch | ||
import torch.nn as nn | ||
|
@@ -7,6 +10,8 @@ | |
from ..common import ReparameterizationHead, RegressionHead, DiscreteHead, MultiHead, \ | ||
FCEncoder, ConvEncoder, IMPALAConvEncoder | ||
from ding.torch_utils.network.dreamer import ActionHead, DenseHead | ||
from ..common.utils import top_p_logits | ||
from ding.reward_model import LlamaRewardModel | ||
|
||
|
||
@MODEL_REGISTRY.register('vac') | ||
|
@@ -425,3 +430,153 @@ def __init__( | |
outscale=0.0, | ||
device='cuda' if torch.cuda.is_available() else 'cpu', | ||
) | ||
|
||
|
||
class Llama(LlamaForCausalLM): | ||
|
||
def __init__(self, config, opt, tokenizer): | ||
super().__init__(config) | ||
self.opt = opt | ||
self.tokenizer = tokenizer | ||
|
||
def forward(self, decoder_input, incr_state=None): | ||
|
||
attention_mask = decoder_input.ne(self.tokenizer.pad_token_id) | ||
if incr_state is not None: | ||
decoder_input = decoder_input[:, -1:] | ||
|
||
output = super().forward( | ||
input_ids=decoder_input, | ||
attention_mask=attention_mask, | ||
past_key_values=incr_state, | ||
return_dict=True, | ||
use_cache=not self.training | ||
) | ||
|
||
logits = output.logits | ||
new_incr_states = output.past_key_values | ||
|
||
return logits, new_incr_states | ||
|
||
@torch.no_grad() | ||
def generate(self, batch, **kwargs): | ||
""" | ||
Generate response | ||
""" | ||
maxlen_res = kwargs.pop('maxlen_res', self.opt.maxlen_res) | ||
temperature = kwargs.pop('temperature', self.opt.temperature) | ||
repetition_penalty = kwargs.pop('repetition_penalty', self.opt.repetition_penalty) | ||
topp = kwargs.pop('topp', self.opt.topp) | ||
|
||
decoder_input: torch.LongTensor = batch # (bsz, ...) | ||
assert decoder_input[:, -1].ne( | ||
self.tokenizer.pad_token_id | ||
).all(), 'Last token should not be a padding token (you can use left padding instead).' | ||
|
||
dev = decoder_input.device | ||
bsz = decoder_input.size(0) | ||
|
||
scores = torch.zeros((bsz, ), device=dev, dtype=torch.float16) | ||
done = torch.zeros((bsz, ), device=dev).to(torch.bool) | ||
|
||
inds = torch.arange(bsz).to(dev).unsqueeze(1).view(-1) | ||
decoder_input = torch.index_select(decoder_input, 0, inds) | ||
init_length = decoder_input.size(1) | ||
|
||
incr_state = None | ||
for _token in range(maxlen_res): | ||
if done.all(): | ||
break | ||
score, incr_state, *_ = self.forward(decoder_input, incr_state) | ||
score = score.half() | ||
|
||
# now score is bs, len, vocab_size | ||
score = score[:, -1, :] | ||
|
||
# calculate repetition penalty | ||
if repetition_penalty > 1.: | ||
penalty_tokens = decoder_input[:, init_length:] | ||
penalty_scores = torch.gather(score, dim=1, index=penalty_tokens) | ||
penalty_scores = torch.where( | ||
penalty_scores < 0., penalty_scores * repetition_penalty, penalty_scores / repetition_penalty | ||
) | ||
score = score.scatter_(dim=1, index=penalty_tokens, src=penalty_scores) | ||
|
||
# nucleus sampling | ||
score = torch.softmax(score.div(temperature), dim=-1) | ||
probs = top_p_logits(score, topp=topp, filter_value=0) | ||
tok_ids = torch.multinomial(probs, 1)[:, 0] | ||
hyp_ids = torch.arange(probs.size(0), device=dev) | ||
scores = scores + probs[hyp_ids, tok_ids].log() * ~done | ||
|
||
tok_ids = torch.where(done, self.tokenizer.pad_token_id, tok_ids) | ||
decoder_input = torch.cat((decoder_input, tok_ids.unsqueeze(-1)), dim=-1) | ||
done = done | tok_ids.eq(self.tokenizer.eos_token_id) | ||
|
||
incr_state = self._reorder_cache(incr_state, hyp_ids) | ||
|
||
# get all finalized candidates for each sample | ||
decoder_input = decoder_input[:, init_length:] | ||
decoder_input = decoder_input.view(bsz, -1) | ||
scores = scores.view(bsz, ) | ||
|
||
lengths = decoder_input.ne(self.tokenizer.pad_token_id).sum(dim=-1) | ||
|
||
length_penalty = torch.pow(lengths, 1.0) | ||
scores /= length_penalty | ||
|
||
preds_scores = [] | ||
for i in range(bsz): | ||
seq: torch.LongTensor = decoder_input[i, :lengths[i, ]] | ||
res_scores = (float(scores[i, ]), seq.tolist()) | ||
preds_scores.append([res_scores]) | ||
|
||
best_preds_scores = [preds[0] for preds in preds_scores] | ||
return best_preds_scores, preds_scores | ||
|
||
|
||
@MODEL_REGISTRY.register('llamavac') | ||
class LlamaVAC(nn.Module): | ||
""" | ||
Overview: | ||
The neural network and computation graph of DreamerV3 (state) Value Actor-Critic (VAC). | ||
This model now supports discrete, continuous action space. | ||
Interfaces: | ||
``__init__``, ``forward``. | ||
""" | ||
mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] | ||
|
||
def __init__(self, actor_path: str, critic_path: str, tokenizer: LlamaTokenizer, opt: Dict) -> None: | ||
""" | ||
Overview: | ||
Initialize the ``DREAMERVAC`` model according to arguments. | ||
Arguments: | ||
- obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84]. | ||
- action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3]. | ||
""" | ||
super(LlamaVAC, self).__init__() | ||
self.actor = Llama.from_pretrained( | ||
actor_path, | ||
opt=opt, | ||
tokenizer=tokenizer, | ||
) | ||
self.critic = LlamaRewardModel.from_pretrained(critic_path, opt=opt, tokenizer=tokenizer) | ||
|
||
def forward(self, x: torch.Tensor, mode: str) -> Dict: | ||
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) | ||
return getattr(self, mode)(x) | ||
|
||
def compute_actor(self, x): | ||
policy_output = self.actor(decoder_input=x) | ||
policy_logit, *_ = policy_output | ||
return {"logit": policy_logit} | ||
|
||
def compute_critic(self, x): | ||
values = self.critic(decoder_input=x, only_last=False) | ||
return {"value": values} | ||
|
||
def compute_actor_critic(self, x): | ||
policy_output = self.actor(decoder_input=x) | ||
policy_logit, *_ = policy_output | ||
values = self.critic(decoder_input=x, only_last=False) | ||
return {"logit": policy_logit, "value": values} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
merge it into
ding.framework