Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(whl): add rlhf pipeline. #748

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions ding/bonus/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,13 @@ def get_instance_config(env_id: str, algorithm: str) -> EasyDict:
cfg.batch_size = 320
cfg.epoch_per_collect = 10
cfg.learning_rate = 3e-4
elif env_id == 'chat':
cfg.epoch_per_collect = 1
cfg.batch_size = 2
cfg.learning_rate = 5e-7
cfg.answers_per_question = 3
cfg.kl_penalty_weight = 0.1
cfg.ppo_param_init = False
else:
raise KeyError("not supported env type: {}".format(env_id))
else:
Expand Down Expand Up @@ -315,6 +322,16 @@ def get_instance_env(env_id: str) -> BaseEnv:
)
cfg = EasyDict(cfg)
return DriveEnvWrapper(MetaDrivePPOOriginEnv(cfg))
elif env_id == 'chat':
from dizoo.chat.env import ChatEnv
return ChatEnv(
batch_size=2,
reward_model_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en/recover",
tokenizer_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en",
data_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/data/ppo_data",
maxlen_prompt=1024,
maxlen_res=512,
)
else:
raise KeyError("not supported env type: {}".format(env_id))

Expand Down
24 changes: 19 additions & 5 deletions ding/bonus/ppof.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from typing import Optional, Union, List
from ditk import logging
from easydict import EasyDict
Expand All @@ -18,6 +19,7 @@
from .model import PPOFModel
from .config import get_instance_config, get_instance_env, get_hybrid_shape
from ding.bonus.common import TrainingReturn, EvalReturn
from ..framework.middleware.collector import ChatCollector
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge it into ding.framework



class PPOF:
Expand Down Expand Up @@ -52,6 +54,8 @@ class PPOF:
'Hopper-v3',
'HalfCheetah-v3',
'Walker2d-v3',
# rlhf
'chat'
]

def __init__(
Expand Down Expand Up @@ -108,6 +112,8 @@ def __init__(
action_shape = int(action_space.n)
elif isinstance(action_space, (gym.spaces.Tuple, gymnasium.spaces.Tuple)):
action_shape = get_hybrid_shape(action_space)
elif action_space is None:
pass
else:
action_shape = action_space.shape

Expand All @@ -129,7 +135,11 @@ def __init__(
popart_head=True,
**self.cfg.model
)
self.policy = PPOFPolicy(self.cfg, model=model)
if self.cfg.chat_data:
orig_model = copy.deepcopy(model)
else:
orig_model = None
self.policy = PPOFPolicy(self.cfg, model=model, orig_model=orig_model)
if policy_state_dict is not None:
self.policy.load_state_dict(policy_state_dict)
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
Expand Down Expand Up @@ -158,10 +168,14 @@ def train(
pass

with task.start(ctx=OnlineRLContext()):
task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env))
task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
task.use(PPOFStepCollector(self.seed, self.policy, collector_env, self.cfg.n_sample))
task.use(ppof_adv_estimator(self.policy))
if self.policy._cfg.chat_data:
# task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env))
# task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
task.use(ChatCollector(self.seed, self.policy, collector_env, self.cfg.n_sample))
else:
task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env))
task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
task.use(PPOFStepCollector(self.seed, self.policy, collector_env, self.cfg.n_sample))
task.use(multistep_trainer(self.policy, log_freq=n_iter_log_show))
task.use(
wandb_online_logger(
Expand Down
61 changes: 61 additions & 0 deletions ding/framework/middleware/collector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from typing import TYPE_CHECKING
from easydict import EasyDict
import treetensor.torch as ttorch
Expand Down Expand Up @@ -190,4 +191,64 @@ def __call__(self, ctx: "OnlineRLContext") -> None:
break


class ChatCollector:
"""
Overview:
The class of the collector running by steps, including model inference and transition \
process. Use the `__call__` method to execute the whole collection process.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why indent here

"""

def __new__(cls, *args, **kwargs):
if task.router.is_active and not task.has_role(task.role.COLLECTOR):
return task.void()
return super(ChatCollector, cls).__new__(cls)

def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll_len: int = 1) -> None:
"""
Arguments:
- seed (:obj:`int`): Random seed.
- policy (:obj:`Policy`): The policy to be collected.
- env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \
its derivatives are supported.
"""
self.env = env
self.env.seed(seed)
self.env.launch()
self.env = self._envs[0]
self.policy = policy
self.n_sample = n_sample
self.unroll_len = unroll_len

def __call__(self, ctx: "OnlineRLContext") -> None:
"""
Overview:
An encapsulation of inference and rollout middleware. Stop when completing \
the target number of steps.
Input of ctx:
- env_step (:obj:`int`): The env steps which will increase during collection.
"""
device = self.policy._device

obs = ttorch.as_tensor(self.env.last_batch[0]['text_vec'])
batch_size = obs.shape[0]
obs = obs.to(device)

total_action = [[] for _ in range(batch_size)] # [B, answers_per_question, T]
for _ in range(self.policy._cfg.answers_per_question):
_, inference_output = self.policy._model.actor.generate(obs, **ctx.collect_kwargs)
for i in range(batch_size):
total_action[i].append(copy.deepcopy(inference_output[i]))

mask, resp, rew = self.env.step(total_action)
ctx.env_step += 1
ctx.env_episode += 1

train_data = {}
train_data['obs'] = resp # [B x answer-per-question, T]
train_data['reward'] = rew # [B x answer-per-question, ]
train_data['mask'] = mask # [B x answer-per-question, T]

ctx.train_data = ttorch.as_tensor(train_data)


# TODO battle collector
17 changes: 17 additions & 0 deletions ding/model/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,20 @@ def create_model(cfg: EasyDict) -> torch.nn.Module:
import_module(cfg.pop('import_names', []))
# here we must use the pop opeartion to ensure compatibility
return MODEL_REGISTRY.build(cfg.pop("type"), **cfg)


def top_p_logits(logits, topp=0.9, filter_value=0, min_topk=1):
"""
Filter a distribution of logits using nucleus (top-p) filtering
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

polish comments add add unittest

https://github.com/OpenLMLab/MOSS/blob/e088f438d1a95d424c6dffef0d73134ebe62cb72/models_jittor/generation.py#L146
"""
cum_logits = logits.clone()
if topp > 0:
logits_sorted, inds = torch.sort(logits, dim=-1, descending=True)
mask = (logits_sorted.cumsum(dim=-1) - logits_sorted) >= topp
mask[:, :min_topk] = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

..., :min_topk

# Remove tokens with cumulative top_p above the threshold
mask = torch.zeros_like(mask).to(torch.bool).scatter_(dim=-1, index=inds, src=mask)
cum_logits[mask] = filter_value
cum_logits.div_(cum_logits.sum(dim=-1, keepdim=True))
return cum_logits
155 changes: 155 additions & 0 deletions ding/model/template/vac.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Union, Dict, Optional

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move these modifications to a new single file: lm_vac.py

from transformers import LlamaTokenizer
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from easydict import EasyDict
import torch
import torch.nn as nn
Expand All @@ -7,6 +10,8 @@
from ..common import ReparameterizationHead, RegressionHead, DiscreteHead, MultiHead, \
FCEncoder, ConvEncoder, IMPALAConvEncoder
from ding.torch_utils.network.dreamer import ActionHead, DenseHead
from ..common.utils import top_p_logits
from ding.reward_model import LlamaRewardModel


@MODEL_REGISTRY.register('vac')
Expand Down Expand Up @@ -425,3 +430,153 @@ def __init__(
outscale=0.0,
device='cuda' if torch.cuda.is_available() else 'cpu',
)


class Llama(LlamaForCausalLM):

def __init__(self, config, opt, tokenizer):
super().__init__(config)
self.opt = opt
self.tokenizer = tokenizer

def forward(self, decoder_input, incr_state=None):

attention_mask = decoder_input.ne(self.tokenizer.pad_token_id)
if incr_state is not None:
decoder_input = decoder_input[:, -1:]

output = super().forward(
input_ids=decoder_input,
attention_mask=attention_mask,
past_key_values=incr_state,
return_dict=True,
use_cache=not self.training
)

logits = output.logits
new_incr_states = output.past_key_values

return logits, new_incr_states

@torch.no_grad()
def generate(self, batch, **kwargs):
"""
Generate response
"""
maxlen_res = kwargs.pop('maxlen_res', self.opt.maxlen_res)
temperature = kwargs.pop('temperature', self.opt.temperature)
repetition_penalty = kwargs.pop('repetition_penalty', self.opt.repetition_penalty)
topp = kwargs.pop('topp', self.opt.topp)

decoder_input: torch.LongTensor = batch # (bsz, ...)
assert decoder_input[:, -1].ne(
self.tokenizer.pad_token_id
).all(), 'Last token should not be a padding token (you can use left padding instead).'

dev = decoder_input.device
bsz = decoder_input.size(0)

scores = torch.zeros((bsz, ), device=dev, dtype=torch.float16)
done = torch.zeros((bsz, ), device=dev).to(torch.bool)

inds = torch.arange(bsz).to(dev).unsqueeze(1).view(-1)
decoder_input = torch.index_select(decoder_input, 0, inds)
init_length = decoder_input.size(1)

incr_state = None
for _token in range(maxlen_res):
if done.all():
break
score, incr_state, *_ = self.forward(decoder_input, incr_state)
score = score.half()

# now score is bs, len, vocab_size
score = score[:, -1, :]

# calculate repetition penalty
if repetition_penalty > 1.:
penalty_tokens = decoder_input[:, init_length:]
penalty_scores = torch.gather(score, dim=1, index=penalty_tokens)
penalty_scores = torch.where(
penalty_scores < 0., penalty_scores * repetition_penalty, penalty_scores / repetition_penalty
)
score = score.scatter_(dim=1, index=penalty_tokens, src=penalty_scores)

# nucleus sampling
score = torch.softmax(score.div(temperature), dim=-1)
probs = top_p_logits(score, topp=topp, filter_value=0)
tok_ids = torch.multinomial(probs, 1)[:, 0]
hyp_ids = torch.arange(probs.size(0), device=dev)
scores = scores + probs[hyp_ids, tok_ids].log() * ~done

tok_ids = torch.where(done, self.tokenizer.pad_token_id, tok_ids)
decoder_input = torch.cat((decoder_input, tok_ids.unsqueeze(-1)), dim=-1)
done = done | tok_ids.eq(self.tokenizer.eos_token_id)

incr_state = self._reorder_cache(incr_state, hyp_ids)

# get all finalized candidates for each sample
decoder_input = decoder_input[:, init_length:]
decoder_input = decoder_input.view(bsz, -1)
scores = scores.view(bsz, )

lengths = decoder_input.ne(self.tokenizer.pad_token_id).sum(dim=-1)

length_penalty = torch.pow(lengths, 1.0)
scores /= length_penalty

preds_scores = []
for i in range(bsz):
seq: torch.LongTensor = decoder_input[i, :lengths[i, ]]
res_scores = (float(scores[i, ]), seq.tolist())
preds_scores.append([res_scores])

best_preds_scores = [preds[0] for preds in preds_scores]
return best_preds_scores, preds_scores


@MODEL_REGISTRY.register('llamavac')
class LlamaVAC(nn.Module):
"""
Overview:
The neural network and computation graph of DreamerV3 (state) Value Actor-Critic (VAC).
This model now supports discrete, continuous action space.
Interfaces:
``__init__``, ``forward``.
"""
mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']

def __init__(self, actor_path: str, critic_path: str, tokenizer: LlamaTokenizer, opt: Dict) -> None:
"""
Overview:
Initialize the ``DREAMERVAC`` model according to arguments.
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
- action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
"""
super(LlamaVAC, self).__init__()
self.actor = Llama.from_pretrained(
actor_path,
opt=opt,
tokenizer=tokenizer,
)
self.critic = LlamaRewardModel.from_pretrained(critic_path, opt=opt, tokenizer=tokenizer)

def forward(self, x: torch.Tensor, mode: str) -> Dict:
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
return getattr(self, mode)(x)

def compute_actor(self, x):
policy_output = self.actor(decoder_input=x)
policy_logit, *_ = policy_output
return {"logit": policy_logit}

def compute_critic(self, x):
values = self.critic(decoder_input=x, only_last=False)
return {"value": values}

def compute_actor_critic(self, x):
policy_output = self.actor(decoder_input=x)
policy_logit, *_ = policy_output
values = self.critic(decoder_input=x, only_last=False)
return {"logit": policy_logit, "value": values}
Loading
Loading