From 7ba2125758467725edfb5d398d05f52a77091b79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Tue, 24 Oct 2023 09:24:47 +0800 Subject: [PATCH 01/17] init commit --- ding/bonus/config.py | 6 + ding/bonus/ppof.py | 22 +- ding/framework/middleware/collector.py | 60 ++++- ding/model/common/utils.py | 16 ++ ding/model/template/vac.py | 150 +++++++++++++ ding/policy/ppof.py | 146 ++++++++----- ding/reward_model/__init__.py | 2 + ding/reward_model/language_reward_model.py | 27 +++ ding/rl_utils/gae.py | 20 ++ dizoo/chat/__init__.py | 1 + dizoo/chat/env.py | 53 +++++ dizoo/chat/utils.py | 243 +++++++++++++++++++++ launch_ppof.py | 20 ++ 13 files changed, 700 insertions(+), 66 deletions(-) create mode 100644 ding/reward_model/language_reward_model.py create mode 100644 dizoo/chat/__init__.py create mode 100644 dizoo/chat/env.py create mode 100644 dizoo/chat/utils.py create mode 100644 launch_ppof.py diff --git a/ding/bonus/config.py b/ding/bonus/config.py index 285eff6586..3b5fe33463 100644 --- a/ding/bonus/config.py +++ b/ding/bonus/config.py @@ -167,6 +167,12 @@ def get_instance_config(env_id: str, algorithm: str) -> EasyDict: cfg.batch_size = 320 cfg.epoch_per_collect = 10 cfg.learning_rate = 3e-4 + elif env_id == 'chat': + cfg.epoch_per_collect = 1 + cfg.batch_size = 2 + cfg.learning_rate = 5e-7 + cfg.answers_per_question = 3 + cfg.kl_penalty_weight = 0.1 else: raise KeyError("not supported env type: {}".format(env_id)) else: diff --git a/ding/bonus/ppof.py b/ding/bonus/ppof.py index bf6012240f..297e842844 100644 --- a/ding/bonus/ppof.py +++ b/ding/bonus/ppof.py @@ -1,3 +1,4 @@ +import copy from typing import Optional, Union, List from ditk import logging from easydict import EasyDict @@ -18,6 +19,7 @@ from .model import PPOFModel from .config import get_instance_config, get_instance_env, get_hybrid_shape from ding.bonus.common import TrainingReturn, EvalReturn +from ..framework.middleware.collector import ChatCollector class PPOF: @@ -52,6 +54,8 @@ class PPOF: 'Hopper-v3', 'HalfCheetah-v3', 'Walker2d-v3', + # rlhf + 'chat' ] def __init__( @@ -129,7 +133,11 @@ def __init__( popart_head=True, **self.cfg.model ) - self.policy = PPOFPolicy(self.cfg, model=model) + if self.cfg.chat_data: + orig_model = copy.deepcopy(model) + else: + orig_model = None + self.policy = PPOFPolicy(self.cfg, model=model, orig_model=orig_model) if policy_state_dict is not None: self.policy.load_state_dict(policy_state_dict) self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt") @@ -158,10 +166,14 @@ def train( pass with task.start(ctx=OnlineRLContext()): - task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env)) - task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt)) - task.use(PPOFStepCollector(self.seed, self.policy, collector_env, self.cfg.n_sample)) - task.use(ppof_adv_estimator(self.policy)) + if not self.policy._cfg.chat_data: + # task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env)) + # task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt)) + task.use(ChatCollector(self.seed, self.policy, collector_env, self.cfg.n_sample)) + else: + task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env)) + task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt)) + task.use(PPOFStepCollector(self.seed, self.policy, collector_env, self.cfg.n_sample)) task.use(multistep_trainer(self.policy, log_freq=n_iter_log_show)) task.use( wandb_online_logger( diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py index beb4894ad9..5e0c0811d4 100644 --- a/ding/framework/middleware/collector.py +++ b/ding/framework/middleware/collector.py @@ -1,3 +1,4 @@ +import copy from typing import TYPE_CHECKING from easydict import EasyDict import treetensor.torch as ttorch @@ -93,7 +94,8 @@ def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll self.policy = policy self.n_sample = n_sample self.unroll_len = unroll_len - self._transitions = TransitionList(self.env.env_num) + self._transitions = Transiti + onList(self.env.env_num) self._env_episode_id = [_ for _ in range(env.env_num)] self._current_id = env.env_num @@ -190,4 +192,60 @@ def __call__(self, ctx: "OnlineRLContext") -> None: break +class ChatCollector: + """ + Overview: + The class of the collector running by steps, including model inference and transition \ + process. Use the `__call__` method to execute the whole collection process. + """ + + def __new__(cls, *args, **kwargs): + if task.router.is_active and not task.has_role(task.role.COLLECTOR): + return task.void() + return super(ChatCollector, cls).__new__(cls) + + def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll_len: int = 1) -> None: + """ + Arguments: + - seed (:obj:`int`): Random seed. + - policy (:obj:`Policy`): The policy to be collected. + - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ + its derivatives are supported. + """ + self.env = env + self.env.seed(seed) + self.env.launch() + self.policy = policy + self.n_sample = n_sample + self.unroll_len = unroll_len + + def __call__(self, ctx: "OnlineRLContext") -> None: + """ + Overview: + An encapsulation of inference and rollout middleware. Stop when completing \ + the target number of steps. + Input of ctx: + - env_step (:obj:`int`): The env steps which will increase during collection. + """ + device = self.policy._device + + obs = ttorch.as_tensor(self.env._env[0].last_batch) + obs = obs.to(device) + + total_action = [] + for _ in range(self.policy._cfg.answers_per_question): + _, inference_output = self.policy._model.actor.generate(obs, **ctx.collect_kwargs) + total_action.append(copy.deepcopy(inference_output)) + + mask, resp, rew = self.env.step(total_action) + ctx.env_step += 1 + ctx.env_episode += 1 + + train_data = {} + train_data['obs'] = resp # [B x answer-per-question, max_len] + train_data['reward'] = rew # [B x answer-per-question, ] + train_data['mask'] = mask # [B x answer-per-question, max_len] + + ctx.train_data = ttorch.as_tensor(train_data) + # TODO battle collector diff --git a/ding/model/common/utils.py b/ding/model/common/utils.py index 0f508de0b8..c1012276c6 100644 --- a/ding/model/common/utils.py +++ b/ding/model/common/utils.py @@ -21,3 +21,19 @@ def create_model(cfg: EasyDict) -> torch.nn.Module: import_module(cfg.pop('import_names', [])) # here we must use the pop opeartion to ensure compatibility return MODEL_REGISTRY.build(cfg.pop("type"), **cfg) + + +def top_p_logits(logits, topp=0.9, filter_value=0, min_topk=1): + """ + Filter a distribution of logits using nucleus (top-p) filtering + https://github.com/OpenLMLab/MOSS/blob/e088f438d1a95d424c6dffef0d73134ebe62cb72/models_jittor/generation.py#L146 + """ + cum_logits = logits.clone() + if topp > 0: + logits_sorted, inds = torch.sort(logits, dim=-1, descending=True) + mask = (logits_sorted.cumsum(dim=-1) - logits_sorted) >= topp + mask[:, :min_topk] = False + # Remove tokens with cumulative top_p above the threshold + mask = torch.zeros_like(mask).to(torch.bool).scatter_(dim=-1, index=inds, src=mask) + cum_logits[mask] = filter_value + cum_logits.div_(cum_logits.sum(dim=-1, keepdim=True)) diff --git a/ding/model/template/vac.py b/ding/model/template/vac.py index 29363d3570..74bb4041b3 100644 --- a/ding/model/template/vac.py +++ b/ding/model/template/vac.py @@ -1,4 +1,5 @@ from typing import Union, Dict, Optional +from transformers.models.llama.modeling_llama import LlamaForCausalLM from easydict import EasyDict import torch import torch.nn as nn @@ -7,6 +8,8 @@ from ..common import ReparameterizationHead, RegressionHead, DiscreteHead, MultiHead, \ FCEncoder, ConvEncoder, IMPALAConvEncoder from ding.torch_utils.network.dreamer import ActionHead, DenseHead +from ..common.utils import top_p_logits +from ding.reward_model import LlamaRewardModel @MODEL_REGISTRY.register('vac') @@ -425,3 +428,150 @@ def __init__( outscale=0.0, device='cuda' if torch.cuda.is_available() else 'cpu', ) + + +class Llama(LlamaForCausalLM): + def __init__(self, config, opt, tokenizer): + super().__init__(config) + self.opt = opt + self.tokenizer = tokenizer + + def forward(self, decoder_input, incr_state=None): + + attention_mask = decoder_input.ne(self.tokenizer.pad_token_id) + if incr_state is not None: + decoder_input = decoder_input[:, -1:] + + output = super().forward( + input_ids=decoder_input, + attention_mask=attention_mask, + past_key_values=incr_state, + return_dict=True, + use_cache=not self.training + ) + + logits = output.logits + new_incr_states = output.past_key_values + + return logits, new_incr_states + + @torch.no_grad() + def generate(self, batch, **kwargs): + """ + Generate response + """ + maxlen_res = kwargs.pop('maxlen_res', self.opt.maxlen_res) + temperature = kwargs.pop('temperature', self.opt.temperature) + repetition_penalty = kwargs.pop('repetition_penalty', self.opt.repetition_penalty) + topp = kwargs.pop('topp', self.opt.topp) + + decoder_input: torch.LongTensor = batch['text_vec'] # (bsz, ...) + assert decoder_input[:, -1].ne( + self.tokenizer.pad_token_id).all(), 'Last token should not be a padding token (you can use left padding instead).' + + dev = decoder_input.device + bsz = decoder_input.size(0) + + scores = torch.zeros((bsz,), device=dev, dtype=torch.float16) + done = torch.zeros((bsz,), device=dev).to(torch.bool) + + inds = torch.arange(bsz).to(dev).unsqueeze(1).view(-1) + decoder_input = torch.index_select(decoder_input, 0, inds) + init_length = decoder_input.size(1) + + incr_state = None + for _token in range(maxlen_res): + if done.all(): + break + score, incr_state, *_ = self.forward(decoder_input, incr_state) + score = score.half() + + # now score is bs, len, vocab_size + score = score[:, -1, :] + + # calculate repetition penalty + if repetition_penalty > 1.: + penalty_tokens = decoder_input[:, init_length:] + penalty_scores = torch.gather(score, dim=1, index=penalty_tokens) + penalty_scores = torch.where(penalty_scores < 0., penalty_scores * repetition_penalty, + penalty_scores / repetition_penalty) + score = score.scatter_(dim=1, index=penalty_tokens, src=penalty_scores) + + # nucleus sampling + score = torch.softmax(score.div(temperature), dim=-1) + probs = top_p_logits(score, topp=topp, filter_value=0) + tok_ids = torch.multinomial(probs, 1)[:, 0] + hyp_ids = torch.arange(probs.size(0), device=dev) + scores = scores + probs[hyp_ids, tok_ids].log() * ~done + + tok_ids = torch.where(done, self.tokenizer.pad_token_id, tok_ids) + decoder_input = torch.cat((decoder_input, tok_ids.unsqueeze(-1)), dim=-1) + done = done | tok_ids.eq(self.tokenizer.eos_token_id) + + incr_state = self._reorder_cache(incr_state, hyp_ids) + + # get all finalized candidates for each sample + decoder_input = decoder_input[:, init_length:] + decoder_input = decoder_input.view(bsz, -1) + scores = scores.view(bsz, ) + + lengths = decoder_input.ne(self.tokenizer.pad_token_id).sum(dim=-1) + + length_penalty = torch.pow(lengths, 1.0) + scores /= length_penalty + + preds_scores = [] + for i in range(bsz): + seq: torch.LongTensor = decoder_input[i, :lengths[i,]] + res_scores = (float(scores[i,]), seq.tolist()) + preds_scores.append([res_scores]) + + best_preds_scores = [preds[0] for preds in preds_scores] + return best_preds_scores, preds_scores + + +@MODEL_REGISTRY.register('llamavac') +class LlamaVAC(nn.Module): + """ + Overview: + The neural network and computation graph of DreamerV3 (state) Value Actor-Critic (VAC). + This model now supports discrete, continuous action space. + Interfaces: + ``__init__``, ``forward``. + """ + mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] + + def __init__( + self, + actor_path: str, + critic_path: str + ) -> None: + """ + Overview: + Initialize the ``DREAMERVAC`` model according to arguments. + Arguments: + - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84]. + - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3]. + """ + super(LlamaVAC, self).__init__() + self.actor = Llama.from_pretrained(actor_path) + self.critic = LlamaRewardModel.from_pretrained(critic_path) + + def forward(self, x: torch.Tensor, mode: str) -> Dict: + assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) + return getattr(self, mode)(x) + + def compute_actor(self, x): + policy_output = self.actor(decoder_input=x) + policy_logit, *_ = policy_output + return {"logit": policy_logit} + + def compute_critic(self, x): + values = self.critic(decoder_input=x, only_last=False) + return {"value": values} + + def compute_actor_critic(self, x): + policy_output = self.actor(decoder_input=x) + policy_logit, *_ = policy_output + values = self.critic(decoder_input=x, only_last=False) + return {"logit": policy_logit,"value": values} diff --git a/ding/policy/ppof.py b/ding/policy/ppof.py index 81e605384c..e746b9c617 100644 --- a/ding/policy/ppof.py +++ b/ding/policy/ppof.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Any, Tuple, Union, Callable, Optional +from typing import List, Dict, Any, Callable, Optional from collections import namedtuple from easydict import EasyDict import copy @@ -11,6 +11,7 @@ from ding.rl_utils import ppo_data, ppo_error, ppo_policy_error, ppo_policy_data, gae, gae_data, ppo_error_continuous, \ get_gae, ppo_policy_error_continuous, ArgmaxSampler, MultinomialSampler, ReparameterizationSampler, MuSampler, \ HybridStochasticSampler, HybridDeterminsticSampler, value_transform, value_inv_transform, symlog, inv_symlog +from ding.rl_utils.gae import episodic_gae_data, episodic_gae from ding.utils import POLICY_REGISTRY, RunningMeanStd @@ -37,6 +38,7 @@ class PPOFPolicy: value_norm='baseline', ppo_param_init=True, grad_norm=0.5, + chat_data=True, # collect n_sample=128, unroll_len=1, @@ -58,8 +60,9 @@ def default_model(cls: type) -> Callable: from .model import PPOFModel return PPOFModel - def __init__(self, cfg: "EasyDict", model: torch.nn.Module, enable_mode: List[str] = None) -> None: + def __init__(self, cfg: "EasyDict", model: torch.nn.Module, enable_mode: List[str] = None, orig_model: torch.nn.Module = None) -> None: self._cfg = cfg + self._orig_model = orig_model if model is None: self._model = self.default_model() else: @@ -151,63 +154,78 @@ def _model_param_init(self): def forward(self, data: ttorch.Tensor) -> Dict[str, Any]: return_infos = [] self._model.train() - bs = self._cfg.batch_size - data = data[:self._cfg.n_sample // bs * bs] # rounding + if not self._cfg.chat_data: + bs = self._cfg.batch_size + data = data[:self._cfg.n_sample // bs * bs] # rounding # outer training loop for epoch in range(self._cfg.epoch_per_collect): # recompute adv with torch.no_grad(): - # get the value dictionary - # In popart, the dictionary has two keys: 'pred' and 'unnormalized_pred' - value = self._model.compute_critic(data.obs) - next_value = self._model.compute_critic(data.next_obs) - reward = data.reward - - assert self._cfg.value_norm in ['popart', 'value_rescale', 'symlog', 'baseline'],\ - 'Not supported value normalization! Value normalization supported: \ - popart, value rescale, symlog, baseline' - - if self._cfg.value_norm == 'popart': - unnormalized_value = value['unnormalized_pred'] - unnormalized_next_value = value['unnormalized_pred'] - - mu = self._model.critic_head.popart.mu - sigma = self._model.critic_head.popart.sigma - reward = (reward - mu) / sigma - - value = value['pred'] - next_value = next_value['pred'] - elif self._cfg.value_norm == 'value_rescale': - value = value_inv_transform(value['pred']) - next_value = value_inv_transform(next_value['pred']) - elif self._cfg.value_norm == 'symlog': - value = inv_symlog(value['pred']) - next_value = inv_symlog(next_value['pred']) - elif self._cfg.value_norm == 'baseline': - value = value['pred'] * self._running_mean_std.std - next_value = next_value['pred'] * self._running_mean_std.std - - traj_flag = data.get('traj_flag', None) # traj_flag indicates termination of trajectory - adv_data = gae_data(value, next_value, reward, data.done, traj_flag) - data.adv = gae(adv_data, self._cfg.discount_factor, self._cfg.gae_lambda) - - unnormalized_returns = value + data.adv # In popart, this return is normalized - - if self._cfg.value_norm == 'popart': - self._model.critic_head.popart.update_parameters((data.reward).unsqueeze(1)) - elif self._cfg.value_norm == 'value_rescale': - value = value_transform(value) - unnormalized_returns = value_transform(unnormalized_returns) - elif self._cfg.value_norm == 'symlog': - value = symlog(value) - unnormalized_returns = symlog(unnormalized_returns) - elif self._cfg.value_norm == 'baseline': - value /= self._running_mean_std.std - unnormalized_returns /= self._running_mean_std.std - self._running_mean_std.update(unnormalized_returns.cpu().numpy()) - data.value = value - data.return_ = unnormalized_returns + if self._cfg.chat_data: + # [B, T] + value = self._model.compute_critic(data.obs) + data.orig_logit = self._orig_model.compute_actor(data.obs) + data.value = value + reward = data.reward + + traj_flag = data.get('traj_flag', None) # traj_flag indicates termination of trajectory + adv_data = episodic_gae_data(value, data.mask, reward, data.done, traj_flag) + data.adv = episodic_gae(adv_data, self._cfg.discount_factor, self._cfg.gae_lambda) + + unnormalized_returns = data.value + data.adv + data.return_ = unnormalized_returns + else: + # get the value dictionary + # In popart, the dictionary has two keys: 'pred' and 'unnormalized_pred' + value = self._model.compute_critic(data.obs) + next_value = self._model.compute_critic(data.next_obs) + reward = data.reward + + assert self._cfg.value_norm in ['popart', 'value_rescale', 'symlog', 'baseline'], \ + 'Not supported value normalization! Value normalization supported: \ + popart, value rescale, symlog, baseline' + + if self._cfg.value_norm == 'popart': + unnormalized_value = value['unnormalized_pred'] + unnormalized_next_value = value['unnormalized_pred'] + + mu = self._model.critic_head.popart.mu + sigma = self._model.critic_head.popart.sigma + reward = (reward - mu) / sigma + + value = value['pred'] + next_value = next_value['pred'] + elif self._cfg.value_norm == 'value_rescale': + value = value_inv_transform(value['pred']) + next_value = value_inv_transform(next_value['pred']) + elif self._cfg.value_norm == 'symlog': + value = inv_symlog(value['pred']) + next_value = inv_symlog(next_value['pred']) + elif self._cfg.value_norm == 'baseline': + value = value['pred'] * self._running_mean_std.std + next_value = next_value['pred'] * self._running_mean_std.std + + traj_flag = data.get('traj_flag', None) # traj_flag indicates termination of trajectory + adv_data = gae_data(value, next_value, reward, data.done, traj_flag) + data.adv = gae(adv_data, self._cfg.discount_factor, self._cfg.gae_lambda) + + unnormalized_returns = value + data.adv # In popart, this return is normalized + + if self._cfg.value_norm == 'popart': + self._model.critic_head.popart.update_parameters((data.reward).unsqueeze(1)) + elif self._cfg.value_norm == 'value_rescale': + value = value_transform(value) + unnormalized_returns = value_transform(unnormalized_returns) + elif self._cfg.value_norm == 'symlog': + value = symlog(value) + unnormalized_returns = symlog(unnormalized_returns) + elif self._cfg.value_norm == 'baseline': + value /= self._running_mean_std.std + unnormalized_returns /= self._running_mean_std.std + self._running_mean_std.update(unnormalized_returns.cpu().numpy()) + data.value = value + data.return_ = unnormalized_returns # inner training loop split_data = ttorch.split(data, self._cfg.batch_size) @@ -215,6 +233,7 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]: for batch in split_data: output = self._model.compute_actor_critic(batch.obs) adv = batch.adv + mask = batch.mask if self._cfg.adv_norm: # Normalize advantage in a train_batch adv = (adv - adv.mean()) / (adv.std() + 1e-8) @@ -226,10 +245,16 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]: ) ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._cfg.clip_ratio) elif self._action_space == 'discrete': - ppo_batch = ppo_data( - output.logit, batch.logit, batch.action, output.value, batch.value, adv, batch.return_, None - ) - ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio) + if not self._cfg.chat_data: + ppo_batch = ppo_data( + output.logit, batch.logit, batch.action, output.value, batch.value, adv, batch.return_, mask + ) + ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio) + else: + ppo_batch = ppo_data( + output.logit, batch.logit, batch.action, output.value, batch.value, adv, batch.return_, None + ) + ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio) elif self._action_space == 'hybrid': # discrete part (discrete policy loss and entropy loss) ppo_discrete_batch = ppo_policy_data( @@ -253,8 +278,9 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]: max(ppo_continuous_info.approx_kl, ppo_discrete_info.approx_kl), max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac) ) - wv, we = self._cfg.value_weight, self._cfg.entropy_weight - total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + wv, we, wk = self._cfg.value_weight, self._cfg.entropy_weight, self._cfg.kl_penalty_weight + kl_loss = (torch.nn.functional.kl_div(torch.softmax(output.logit, dim=-1), torch.softmax(data.orig_logit, dim=-1), reduction=None) * mask).mean() + total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + wk * kl_loss self._optimizer.zero_grad() total_loss.backward() diff --git a/ding/reward_model/__init__.py b/ding/reward_model/__init__.py index 4538102861..5b197af25d 100644 --- a/ding/reward_model/__init__.py +++ b/ding/reward_model/__init__.py @@ -13,3 +13,5 @@ from .guided_cost_reward_model import GuidedCostRewardModel from .ngu_reward_model import RndNGURewardModel, EpisodicNGURewardModel from .icm_reward_model import ICMRewardModel +# RLHF +from .language_reward_model import LlamaRewardModel diff --git a/ding/reward_model/language_reward_model.py b/ding/reward_model/language_reward_model.py new file mode 100644 index 0000000000..54c6cee23c --- /dev/null +++ b/ding/reward_model/language_reward_model.py @@ -0,0 +1,27 @@ +import torch +from utils import * +from transformers.models.llama.modeling_llama import LlamaForCausalLM + + +class LlamaRewardModel(LlamaForCausalLM): + def __init__(self, config, opt, tokenizer): + super().__init__(config) + self.opt = opt + self.tokenizer = tokenizer + self.reward_head = torch.nn.Linear(config.hidden_size, 1, bias=False) + + def forward(self, decoder_input, only_last=True): + attention_mask = decoder_input.ne(self.tokenizer.pad_token_id) + output = self.model.forward( + input_ids=decoder_input, + attention_mask=attention_mask, + return_dict=True, + use_cache=False + ) + + if only_last: + logits = self.reward_head(output.last_hidden_state[:, -1, :]).squeeze(-1) + else: + logits = self.reward_head(output.last_hidden_state).squeeze(-1) + + return (logits,) diff --git a/ding/rl_utils/gae.py b/ding/rl_utils/gae.py index 800fcae354..487f2c93ef 100644 --- a/ding/rl_utils/gae.py +++ b/ding/rl_utils/gae.py @@ -3,6 +3,7 @@ from ding.hpc_rl import hpc_wrapper gae_data = namedtuple('gae_data', ['value', 'next_value', 'reward', 'done', 'traj_flag']) +episodic_gae_data = namedtuple('episodic_gae_data', ['value', 'mask', 'reward', 'done', 'traj_flag']) def shape_fn_gae(args, kwargs): @@ -68,3 +69,22 @@ def gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97) -> torch.F gae_item = delta[t] + factor[t] * gae_item adv[t] = gae_item return adv + + +def episodic_gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97): + value, mask, reward, done, traj_flag = data + if done is None: + done = torch.zeros_like(value) + if traj_flag is None: + traj_flag = done + advs = [] + bsz = value.shape[0] + for i in range(bsz): + val, mas, rew, don, traj = value[i], mask[i], reward[i], done[i], traj_flag[i] + assert val.shape[0] == rew.shape[0] + next_val = torch.zeros_like(val) + next_val[:-1] = val[1:] + gd = gae_data(val.unsqueeze(-1), next_val.unsqueeze(-1), rew.unsqueeze(-1), don.unsqueeze(-1), + traj.unsqueeze(-1)) + advs.append(gae(gd, gamma, lambda_)) + return torch.stack(advs, dim=0) diff --git a/dizoo/chat/__init__.py b/dizoo/chat/__init__.py new file mode 100644 index 0000000000..eb1eb48abb --- /dev/null +++ b/dizoo/chat/__init__.py @@ -0,0 +1 @@ +from .env import ChatEnv diff --git a/dizoo/chat/env.py b/dizoo/chat/env.py new file mode 100644 index 0000000000..a7af5bfd28 --- /dev/null +++ b/dizoo/chat/env.py @@ -0,0 +1,53 @@ +import gym +import torch + +from ding.reward_model import LlamaRewardModel +from .utils import OnlyPromptDataset, concat_context_and_response, get_tokenizer, pad_sequences + + +class ChatEnv(gym.Env): + def __init__( + self, + batch_size, + reward_model_path, + tokenizer_path, + data_path, + maxlen_prompt, + maxlen_res, + ): + self.batch_size = batch_size + self.rm = LlamaRewardModel.from_pretrained(reward_model_path) + self.tokenizer = get_tokenizer(tokenizer_path) + + self.dataset = OnlyPromptDataset( + data_path=data_path, + tokenizer=self.tokenizer, + batch_size=batch_size, + maxlen_prompt=maxlen_prompt, + maxlen_res=maxlen_res, + mode='train', + ) + self.generator = self.dataset.final_generator() + self.last_batch = None + + def reset(self): + self.last_batch = next(self.generator) + return self.last_batch + + def step(self, action): + """ + For each step, this env will return a batch of prompts. These prompts a vectorized by using tokenizer, and are \ + padded into the same length. + """ + output_mask, output_vec = concat_context_and_response(self.tokenizer, self.last_batch['text_vec'].tolist(), action) + rm_input = torch.tensor(pad_sequences(output_vec, self.tokenizer.pad_token_id, padding='left'), dtype=torch.long) + output_mask = pad_sequences(output_mask, self.tokenizer.pad_token_id, padding='left') + with torch.no_grad(): + rew, *_ = self.rm(rm_input) + + self.last_batch = next(self.generator) + if self.last_batch is None: + self.generator = self.dataset.final_generator() + self.last_batch = next(self.generator) + + return output_mask, output_vec, rew diff --git a/dizoo/chat/utils.py b/dizoo/chat/utils.py new file mode 100644 index 0000000000..574381fdc8 --- /dev/null +++ b/dizoo/chat/utils.py @@ -0,0 +1,243 @@ +import json +import os +from typing import List, Dict, Any, Tuple +import warnings + +from transformers.models.llama.tokenization_llama import LlamaTokenizer +from torch.utils.data.dataset import IterableDataset +import torch +import random + + +# Prefix of human sentence and assistant sentence. +HUMAN_PROMPT = "Human:" +ASSISTANT_PROMPT = "Assistant:" + + +def strip_pad_token_id(tokenizer: LlamaTokenizer, seq: List[int]): + """ + Overview: + Remove ``pad_token_id`` in a sequence. + """ + return [tok for tok in seq if tok != tokenizer.pad_token_id] + + +def concat_context_and_response( + tokenizer: LlamaTokenizer, + context: List[List[int]], + responses: List[List[Tuple[float, List[int]]]] +): + """ + Overview: + Given the batched input prompts and responses, concatenate them together. + """ + assert len(context) == len(responses), f'Size not match: {len(context)} and {len(responses)}' + + total_context, total_response = [], [] + total_context_mask, total_response_mask = [], [] + for _context, _response in zip(context, responses): + # Each ``_context`` is a single input prompt. + _context = strip_pad_token_id(tokenizer, _context) + for _, resp in _response: + # Each ``resp`` is a single response. + resp = strip_pad_token_id(tokenizer, resp) + if resp[-1] != tokenizer.eos_token_id: + warnings.warn( + f'Generated response is too long: {tokenizer.decode(_context + resp, skip_special_tokens=False)}') + + total_context.append(_context.copy()) + total_context_mask.append([0] * len(_context)) + total_response.append(resp) + total_response_mask.append([1] * len(resp)) + + total_gene_samples_vec = [c + r for c, r in zip(total_context, total_response)] + total_gene_samples_mask = [c + r for c, r in zip(total_context_mask, total_response_mask)] + return total_gene_samples_mask, total_gene_samples_vec + + +def pad_sequences( + seqs: List[List[int]], + pad_value: int, + padding: str = 'right'): + """ + Overview: + Padding sequence to the same length + """ + max_len = max(len(seq) for seq in seqs) + if padding == 'right': + padded_seqs = [seq + [pad_value] * (max_len - len(seq)) for seq in seqs] + elif padding == 'left': + padded_seqs = [[pad_value] * (max_len - len(seq)) + seq for seq in seqs] + else: + raise ValueError + return padded_seqs + + +def get_special_prompt(i: int): + return HUMAN_PROMPT if i % 2 == 0 else ASSISTANT_PROMPT + + +def get_model_prompt(context: List[str], eos_token=""): + human_prompt, assistant_prompt = HUMAN_PROMPT, ASSISTANT_PROMPT + if context[-1].startswith(human_prompt): + end_prompt = assistant_prompt + elif context[-1].startswith(assistant_prompt): + end_prompt = human_prompt + else: + raise ValueError + + context = eos_token.join(context) + return f'{context}{eos_token}{end_prompt}' + + +def get_tokenizer(path: str): + """ + Overview: + Return the pretrained tokenizer using the given path. + """ + tokenizer = LlamaTokenizer.from_pretrained(path, trust_remote_code=True) + tokenizer.bos_token = '' + tokenizer.eos_token = '' + tokenizer.pad_token = '' + tokenizer.pad_token_id = 0 + tokenizer.unk_token = tokenizer.pad_token + tokenizer.unk_token_id = tokenizer.pad_token_id + + return tokenizer + + +class OnlyPromptDataset(IterableDataset): + """ + Overview: + Dataset that only contains the prompts of the raw data (no answer). + """ + def __init__( + self, + data_path: os.PathLike, + tokenizer, + batch_size: int, + maxlen_prompt: int, + maxlen_res: int, + mode: str = 'train', + ) -> None: + super().__init__() + self.mode = mode + self.tokenizer = tokenizer + self.maxlen_prompt = maxlen_prompt + self.maxlen_res = maxlen_res + self.batch_size = batch_size + + # Load data. + self.data = [] + files = sorted([file for file in os.listdir(data_path) if file.endswith(f'{mode}.json')]) + for file in files: + file_path = os.path.join(data_path, file) + tmp_data = [] + try: + tmp_data = self.load_data(file_path) + except Exception as e: + pass + self.data.extend(tmp_data) + + # Set the length of this dataset. + self.size = len(self.data) + + def __len__(self): + return self.size + + def load_data(self, file_path: str): + """ + Overview: + Load raw data from given file_path. + """ + with open(file_path, 'r') as f: + data: List[List[str]] = json.load(f) + + output: List[List[str]] = [sample for sample in data if all(sample)] + del data + + return output + + def final_generator(self): + data_generator = self.batch_generator() + for batch_samples in data_generator: + batch = self.batchify(batch_samples) + yield batch + + def __iter__(self): + return self.final_generator() + + def format(self, sample: List[str]) -> Dict[str, Any]: + """ + Overview: + Convert one data sample in to string. + """ + context = sample + context = [get_special_prompt(i + (len(context) + 1) % 2) + s for i, s in enumerate(context)] + context_vec = self.tokenizer.encode(get_model_prompt(context, self.tokenizer.eos_token), + add_special_tokens=True) + + # truncate to max_len + while len(context_vec) > self.maxlen_prompt - self.maxlen_res and len(context) > 1: + context = context[1:] + context_vec = self.tokenizer.encode(get_model_prompt(context, self.tokenizer.eos_token), + add_special_tokens=True) + + output = { + 'text': self.tokenizer.decode(context_vec, skip_special_tokens=False), + 'text_vec': context_vec + } + + return output + + def batchify(self, batch_samples: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Overview: + Batchify a list of ids by padding their shape to be the same. + """ + batch_text_vec = torch.tensor(pad_sequences( + [sample['text_vec'] for sample in batch_samples], pad_value=self.tokenizer.pad_token_id, padding='left' + ), dtype=torch.long) + return { + 'text_vec': batch_text_vec, + 'text': [sample['text'] for sample in batch_samples] + } + + def sample_generator(self): + """ + Overview: + Generate a single data sample. + """ + random.seed(None) + if self.mode == 'train': + random.shuffle(self.data) + + for sample in self.data: + yield self.format(sample) + + def _batch_generator(self): + """ + Overview: + Generate a batch of samples. + """ + batch = [] + # Generate a sample. + for sample in self.sample_generator(): + sample_len = len(sample['text_vec']) + if sample_len > self.maxlen_prompt: + continue + + batch.append(sample) + if len(batch) >= self.batch_size: + yield batch[:self.batch_size] + batch = batch[self.batch_size:] + if batch: + yield batch + + def batch_generator(self): + while True: + for batch in self._batch_generator(): + if len(batch) == self.batch_size: + yield batch + if self.mode != 'train': + break diff --git a/launch_ppof.py b/launch_ppof.py new file mode 100644 index 0000000000..67dc26ad1e --- /dev/null +++ b/launch_ppof.py @@ -0,0 +1,20 @@ +from ding.bonus.ppof import PPOF +from ding.model.template.vac import LlamaVAC + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--actor_path', type=str) + parser.add_argument('--critic_path', type=str) + args = parser.parse_args() + model = LlamaVAC( + actor_path=args.actor_path, + critic_path=args.critic_path + ) + + policy = PPOF( + env_id="prompt-generator", + exp_name="rlhf-ppo", + model=model + ) + policy.train() From eb5d8c83fdf14115799396611e1288302139433d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Wed, 25 Oct 2023 15:56:34 +0800 Subject: [PATCH 02/17] debug --- ding/model/template/vac.py | 7 +++-- ding/reward_model/language_reward_model.py | 1 - launch_ppof.py | 32 +++++++++++++++++++++- 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/ding/model/template/vac.py b/ding/model/template/vac.py index 74bb4041b3..ed704385ae 100644 --- a/ding/model/template/vac.py +++ b/ding/model/template/vac.py @@ -1,4 +1,6 @@ from typing import Union, Dict, Optional + +from transformers import LlamaTokenizer from transformers.models.llama.modeling_llama import LlamaForCausalLM from easydict import EasyDict import torch @@ -544,7 +546,8 @@ class LlamaVAC(nn.Module): def __init__( self, actor_path: str, - critic_path: str + critic_path: str, + tokenizer: LlamaTokenizer ) -> None: """ Overview: @@ -554,7 +557,7 @@ def __init__( - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3]. """ super(LlamaVAC, self).__init__() - self.actor = Llama.from_pretrained(actor_path) + self.actor = Llama.from_pretrained(actor_path, tokenizer=tokenizer) self.critic = LlamaRewardModel.from_pretrained(critic_path) def forward(self, x: torch.Tensor, mode: str) -> Dict: diff --git a/ding/reward_model/language_reward_model.py b/ding/reward_model/language_reward_model.py index 54c6cee23c..a2a5544b1d 100644 --- a/ding/reward_model/language_reward_model.py +++ b/ding/reward_model/language_reward_model.py @@ -1,5 +1,4 @@ import torch -from utils import * from transformers.models.llama.modeling_llama import LlamaForCausalLM diff --git a/launch_ppof.py b/launch_ppof.py index 67dc26ad1e..5192e8dbb2 100644 --- a/launch_ppof.py +++ b/launch_ppof.py @@ -1,15 +1,45 @@ +from easydict import EasyDict +from transformers import LlamaTokenizer + from ding.bonus.ppof import PPOF from ding.model.template.vac import LlamaVAC + +def get_tokenizer(path: str): + """ + Overview: + Return the pretrained tokenizer using the given path. + """ + tokenizer = LlamaTokenizer.from_pretrained(path, trust_remote_code=True) + tokenizer.bos_token = '' + tokenizer.eos_token = '' + tokenizer.pad_token = '' + tokenizer.pad_token_id = 0 + tokenizer.unk_token = tokenizer.pad_token + tokenizer.unk_token_id = tokenizer.pad_token_id + + return tokenizer + + if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--actor_path', type=str) parser.add_argument('--critic_path', type=str) args = parser.parse_args() + + opt = EasyDict({ + "maxlen_res": 512, + "temperature": 1, + "repetition_penalty": 1, + "topp": 0.8 + }) + model = LlamaVAC( actor_path=args.actor_path, - critic_path=args.critic_path + critic_path=args.critic_path, + tokenizer=get_tokenizer(args.actor_path), + opt=opt ) policy = PPOF( From f45b7427210b7e9d51700a286e76958f05c0aa6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Wed, 25 Oct 2023 16:25:57 +0800 Subject: [PATCH 03/17] debug --- ding/bonus/config.py | 10 ++++++++++ ding/model/template/vac.py | 9 +++++---- launch_ppof.py | 4 ++-- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/ding/bonus/config.py b/ding/bonus/config.py index 3b5fe33463..86f1d227da 100644 --- a/ding/bonus/config.py +++ b/ding/bonus/config.py @@ -321,6 +321,16 @@ def get_instance_env(env_id: str) -> BaseEnv: ) cfg = EasyDict(cfg) return DriveEnvWrapper(MetaDrivePPOOriginEnv(cfg)) + elif env_id == 'chat': + from dizoo.chat.env import ChatEnv + return ChatEnv( + batch_size=2, + reward_model_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-rm-model-7B-en/recover", + tokenizer_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-rm-model-7B-en", + data_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/data/ppo_data", + maxlen_prompt=1024, + maxlen_res=512, + ) else: raise KeyError("not supported env type: {}".format(env_id)) diff --git a/ding/model/template/vac.py b/ding/model/template/vac.py index ed704385ae..ec4990b750 100644 --- a/ding/model/template/vac.py +++ b/ding/model/template/vac.py @@ -547,7 +547,8 @@ def __init__( self, actor_path: str, critic_path: str, - tokenizer: LlamaTokenizer + tokenizer: LlamaTokenizer, + opt: Dict ) -> None: """ Overview: @@ -557,8 +558,8 @@ def __init__( - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3]. """ super(LlamaVAC, self).__init__() - self.actor = Llama.from_pretrained(actor_path, tokenizer=tokenizer) - self.critic = LlamaRewardModel.from_pretrained(critic_path) + self.actor = Llama.from_pretrained(actor_path, opt=opt, tokenizer=tokenizer) + self.critic = LlamaRewardModel.from_pretrained(critic_path, opt=opt, tokenizer=tokenizer) def forward(self, x: torch.Tensor, mode: str) -> Dict: assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) @@ -577,4 +578,4 @@ def compute_actor_critic(self, x): policy_output = self.actor(decoder_input=x) policy_logit, *_ = policy_output values = self.critic(decoder_input=x, only_last=False) - return {"logit": policy_logit,"value": values} + return {"logit": policy_logit, "value": values} diff --git a/launch_ppof.py b/launch_ppof.py index 5192e8dbb2..a58ac10ae2 100644 --- a/launch_ppof.py +++ b/launch_ppof.py @@ -38,12 +38,12 @@ def get_tokenizer(path: str): model = LlamaVAC( actor_path=args.actor_path, critic_path=args.critic_path, - tokenizer=get_tokenizer(args.actor_path), + tokenizer=get_tokenizer("/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-sft-model-7B-en"), opt=opt ) policy = PPOF( - env_id="prompt-generator", + env_id="chat", exp_name="rlhf-ppo", model=model ) From 1ca316fbd40cce97d1a48a3fc4b3118945d5158b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Wed, 25 Oct 2023 16:42:54 +0800 Subject: [PATCH 04/17] debug --- ding/bonus/config.py | 4 ++-- dizoo/chat/env.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ding/bonus/config.py b/ding/bonus/config.py index 86f1d227da..45b3624a6b 100644 --- a/ding/bonus/config.py +++ b/ding/bonus/config.py @@ -325,8 +325,8 @@ def get_instance_env(env_id: str) -> BaseEnv: from dizoo.chat.env import ChatEnv return ChatEnv( batch_size=2, - reward_model_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-rm-model-7B-en/recover", - tokenizer_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-rm-model-7B-en", + reward_model_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en/recover", + tokenizer_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en", data_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/data/ppo_data", maxlen_prompt=1024, maxlen_res=512, diff --git a/dizoo/chat/env.py b/dizoo/chat/env.py index a7af5bfd28..ec9f1aa48c 100644 --- a/dizoo/chat/env.py +++ b/dizoo/chat/env.py @@ -16,8 +16,8 @@ def __init__( maxlen_res, ): self.batch_size = batch_size - self.rm = LlamaRewardModel.from_pretrained(reward_model_path) self.tokenizer = get_tokenizer(tokenizer_path) + self.rm = LlamaRewardModel.from_pretrained(reward_model_path, tokenizer=self.tokenizer) self.dataset = OnlyPromptDataset( data_path=data_path, From db72c2c0ebc104e423e9c5f009541bfae2c12f18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Wed, 25 Oct 2023 17:35:45 +0800 Subject: [PATCH 05/17] debug --- ding/bonus/ppof.py | 2 ++ dizoo/chat/env.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/ding/bonus/ppof.py b/ding/bonus/ppof.py index 297e842844..3d10d32696 100644 --- a/ding/bonus/ppof.py +++ b/ding/bonus/ppof.py @@ -112,6 +112,8 @@ def __init__( action_shape = int(action_space.n) elif isinstance(action_space, (gym.spaces.Tuple, gymnasium.spaces.Tuple)): action_shape = get_hybrid_shape(action_space) + elif action_space is None: + pass else: action_shape = action_space.shape diff --git a/dizoo/chat/env.py b/dizoo/chat/env.py index ec9f1aa48c..843c6bbd38 100644 --- a/dizoo/chat/env.py +++ b/dizoo/chat/env.py @@ -1,5 +1,6 @@ import gym import torch +from easydict import EasyDict from ding.reward_model import LlamaRewardModel from .utils import OnlyPromptDataset, concat_context_and_response, get_tokenizer, pad_sequences @@ -17,7 +18,8 @@ def __init__( ): self.batch_size = batch_size self.tokenizer = get_tokenizer(tokenizer_path) - self.rm = LlamaRewardModel.from_pretrained(reward_model_path, tokenizer=self.tokenizer) + self.rm = LlamaRewardModel.from_pretrained(reward_model_path, tokenizer=self.tokenizer, opt=None) + self.action_space = None self.dataset = OnlyPromptDataset( data_path=data_path, From 3f1e47b84300495c882d547de2f012ad5a3361bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Thu, 26 Oct 2023 11:38:16 +0800 Subject: [PATCH 06/17] debug --- ding/bonus/config.py | 1 + dizoo/chat/env.py | 4 ++-- launch_ppof.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ding/bonus/config.py b/ding/bonus/config.py index 45b3624a6b..113eaf0943 100644 --- a/ding/bonus/config.py +++ b/ding/bonus/config.py @@ -173,6 +173,7 @@ def get_instance_config(env_id: str, algorithm: str) -> EasyDict: cfg.learning_rate = 5e-7 cfg.answers_per_question = 3 cfg.kl_penalty_weight = 0.1 + cfg.ppo_param_init = False else: raise KeyError("not supported env type: {}".format(env_id)) else: diff --git a/dizoo/chat/env.py b/dizoo/chat/env.py index 843c6bbd38..3d03049ac7 100644 --- a/dizoo/chat/env.py +++ b/dizoo/chat/env.py @@ -1,12 +1,12 @@ import gym import torch -from easydict import EasyDict +from ding.envs import BaseEnv from ding.reward_model import LlamaRewardModel from .utils import OnlyPromptDataset, concat_context_and_response, get_tokenizer, pad_sequences -class ChatEnv(gym.Env): +class ChatEnv(BaseEnv): def __init__( self, batch_size, diff --git a/launch_ppof.py b/launch_ppof.py index a58ac10ae2..e360324f11 100644 --- a/launch_ppof.py +++ b/launch_ppof.py @@ -47,4 +47,4 @@ def get_tokenizer(path: str): exp_name="rlhf-ppo", model=model ) - policy.train() + policy.train(collector_env_num=1, evaluator_env_num=1) From dc4ceceaffc51b7d431b24b5516bc7ae9f1d16c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Thu, 26 Oct 2023 11:46:47 +0800 Subject: [PATCH 07/17] debug --- dizoo/chat/env.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/dizoo/chat/env.py b/dizoo/chat/env.py index 3d03049ac7..5fb5a1ebc9 100644 --- a/dizoo/chat/env.py +++ b/dizoo/chat/env.py @@ -20,6 +20,8 @@ def __init__( self.tokenizer = get_tokenizer(tokenizer_path) self.rm = LlamaRewardModel.from_pretrained(reward_model_path, tokenizer=self.tokenizer, opt=None) self.action_space = None + self._init_flag = False + self._seed = None self.dataset = OnlyPromptDataset( data_path=data_path, @@ -32,10 +34,20 @@ def __init__( self.generator = self.dataset.final_generator() self.last_batch = None + def close(self) -> None: + self._init_flag = False + def reset(self): self.last_batch = next(self.generator) + self._init_flag = True return self.last_batch + def __repr__(self) -> str: + return "DI-engine Chat Env" + + def seed(self, seed): + self._seed = 0 + def step(self, action): """ For each step, this env will return a batch of prompts. These prompts a vectorized by using tokenizer, and are \ From 30f29949b6d77c39338373d3b9c03bd19511a3fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Thu, 26 Oct 2023 13:37:56 +0800 Subject: [PATCH 08/17] debug --- ding/bonus/ppof.py | 2 +- ding/framework/middleware/collector.py | 5 ++--- dizoo/chat/env.py | 6 ++++++ 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/ding/bonus/ppof.py b/ding/bonus/ppof.py index 3d10d32696..fc755bd25a 100644 --- a/ding/bonus/ppof.py +++ b/ding/bonus/ppof.py @@ -168,7 +168,7 @@ def train( pass with task.start(ctx=OnlineRLContext()): - if not self.policy._cfg.chat_data: + if self.policy._cfg.chat_data: # task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env)) # task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt)) task.use(ChatCollector(self.seed, self.policy, collector_env, self.cfg.n_sample)) diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py index 5e0c0811d4..8d816eced1 100644 --- a/ding/framework/middleware/collector.py +++ b/ding/framework/middleware/collector.py @@ -94,8 +94,7 @@ def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll self.policy = policy self.n_sample = n_sample self.unroll_len = unroll_len - self._transitions = Transiti - onList(self.env.env_num) + self._transitions = TransitionList(self.env.env_num) self._env_episode_id = [_ for _ in range(env.env_num)] self._current_id = env.env_num @@ -214,7 +213,7 @@ def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll """ self.env = env self.env.seed(seed) - self.env.launch() + self.env.reset() self.policy = policy self.n_sample = n_sample self.unroll_len = unroll_len diff --git a/dizoo/chat/env.py b/dizoo/chat/env.py index 5fb5a1ebc9..e92af3125a 100644 --- a/dizoo/chat/env.py +++ b/dizoo/chat/env.py @@ -20,6 +20,9 @@ def __init__( self.tokenizer = get_tokenizer(tokenizer_path) self.rm = LlamaRewardModel.from_pretrained(reward_model_path, tokenizer=self.tokenizer, opt=None) self.action_space = None + self.observation_space = None + self.reward_space = None + self._init_flag = False self._seed = None @@ -48,6 +51,9 @@ def __repr__(self) -> str: def seed(self, seed): self._seed = 0 + def clone(self, caller): + return self + def step(self, action): """ For each step, this env will return a batch of prompts. These prompts a vectorized by using tokenizer, and are \ From c1cc4545c80556e843a43a08f4a87d8f4bc62081 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Thu, 26 Oct 2023 14:08:55 +0800 Subject: [PATCH 09/17] debug --- ding/bonus/config.py | 4 ++-- ding/framework/middleware/collector.py | 1 + launch_ppof.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ding/bonus/config.py b/ding/bonus/config.py index 113eaf0943..a55ec1bb44 100644 --- a/ding/bonus/config.py +++ b/ding/bonus/config.py @@ -324,14 +324,14 @@ def get_instance_env(env_id: str) -> BaseEnv: return DriveEnvWrapper(MetaDrivePPOOriginEnv(cfg)) elif env_id == 'chat': from dizoo.chat.env import ChatEnv - return ChatEnv( + return DingEnvWrapper(ChatEnv( batch_size=2, reward_model_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en/recover", tokenizer_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en", data_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/data/ppo_data", maxlen_prompt=1024, maxlen_res=512, - ) + )) else: raise KeyError("not supported env type: {}".format(env_id)) diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py index 8d816eced1..69e8bc04f3 100644 --- a/ding/framework/middleware/collector.py +++ b/ding/framework/middleware/collector.py @@ -213,6 +213,7 @@ def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll """ self.env = env self.env.seed(seed) + self._closed = False self.env.reset() self.policy = policy self.n_sample = n_sample diff --git a/launch_ppof.py b/launch_ppof.py index e360324f11..53825dcba2 100644 --- a/launch_ppof.py +++ b/launch_ppof.py @@ -47,4 +47,4 @@ def get_tokenizer(path: str): exp_name="rlhf-ppo", model=model ) - policy.train(collector_env_num=1, evaluator_env_num=1) + policy.train(collector_env_num=1, evaluator_env_num=1, debug=True) From 24de047f2831b09d90215c8283c245a014324d36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Thu, 26 Oct 2023 14:19:52 +0800 Subject: [PATCH 10/17] debug --- ding/bonus/config.py | 4 ++-- ding/framework/middleware/collector.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/ding/bonus/config.py b/ding/bonus/config.py index a55ec1bb44..113eaf0943 100644 --- a/ding/bonus/config.py +++ b/ding/bonus/config.py @@ -324,14 +324,14 @@ def get_instance_env(env_id: str) -> BaseEnv: return DriveEnvWrapper(MetaDrivePPOOriginEnv(cfg)) elif env_id == 'chat': from dizoo.chat.env import ChatEnv - return DingEnvWrapper(ChatEnv( + return ChatEnv( batch_size=2, reward_model_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en/recover", tokenizer_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en", data_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/data/ppo_data", maxlen_prompt=1024, maxlen_res=512, - )) + ) else: raise KeyError("not supported env type: {}".format(env_id)) diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py index 69e8bc04f3..6a3e81735a 100644 --- a/ding/framework/middleware/collector.py +++ b/ding/framework/middleware/collector.py @@ -211,9 +211,8 @@ def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ its derivatives are supported. """ - self.env = env + self.env = env._env[0] self.env.seed(seed) - self._closed = False self.env.reset() self.policy = policy self.n_sample = n_sample @@ -229,7 +228,7 @@ def __call__(self, ctx: "OnlineRLContext") -> None: """ device = self.policy._device - obs = ttorch.as_tensor(self.env._env[0].last_batch) + obs = ttorch.as_tensor(self.env.last_batch) obs = obs.to(device) total_action = [] From c9b71eec63ed66bd2b47a4ea8eb796c40dd7f452 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Mon, 6 Nov 2023 10:13:32 +0800 Subject: [PATCH 11/17] debug --- ding/framework/middleware/collector.py | 17 ++++++++++------- ding/model/common/utils.py | 1 + ding/model/template/vac.py | 4 ++-- ding/policy/ppof.py | 15 ++++++++++----- ding/rl_utils/gae.py | 2 +- dizoo/chat/env.py | 4 ++-- dizoo/chat/utils.py | 3 ++- 7 files changed, 28 insertions(+), 18 deletions(-) diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py index 6a3e81735a..261e8fbe66 100644 --- a/ding/framework/middleware/collector.py +++ b/ding/framework/middleware/collector.py @@ -211,9 +211,10 @@ def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ its derivatives are supported. """ - self.env = env._env[0] + self.env = env self.env.seed(seed) - self.env.reset() + self.env.launch() + self.env = self._envs[0] self.policy = policy self.n_sample = n_sample self.unroll_len = unroll_len @@ -228,22 +229,24 @@ def __call__(self, ctx: "OnlineRLContext") -> None: """ device = self.policy._device - obs = ttorch.as_tensor(self.env.last_batch) + obs = ttorch.as_tensor(self.env.last_batch[0]['text_vec']) + batch_size = obs.shape[0] obs = obs.to(device) - total_action = [] + total_action = [[] for _ in range(batch_size)] # [B, answers_per_question, T] for _ in range(self.policy._cfg.answers_per_question): _, inference_output = self.policy._model.actor.generate(obs, **ctx.collect_kwargs) - total_action.append(copy.deepcopy(inference_output)) + for i in range(batch_size): + total_action[i].append(copy.deepcopy(inference_output[i])) mask, resp, rew = self.env.step(total_action) ctx.env_step += 1 ctx.env_episode += 1 train_data = {} - train_data['obs'] = resp # [B x answer-per-question, max_len] + train_data['obs'] = resp # [B x answer-per-question, T] train_data['reward'] = rew # [B x answer-per-question, ] - train_data['mask'] = mask # [B x answer-per-question, max_len] + train_data['mask'] = mask # [B x answer-per-question, T] ctx.train_data = ttorch.as_tensor(train_data) diff --git a/ding/model/common/utils.py b/ding/model/common/utils.py index c1012276c6..340749eb0c 100644 --- a/ding/model/common/utils.py +++ b/ding/model/common/utils.py @@ -37,3 +37,4 @@ def top_p_logits(logits, topp=0.9, filter_value=0, min_topk=1): mask = torch.zeros_like(mask).to(torch.bool).scatter_(dim=-1, index=inds, src=mask) cum_logits[mask] = filter_value cum_logits.div_(cum_logits.sum(dim=-1, keepdim=True)) + return cum_logits diff --git a/ding/model/template/vac.py b/ding/model/template/vac.py index ec4990b750..34e0f37198 100644 --- a/ding/model/template/vac.py +++ b/ding/model/template/vac.py @@ -467,7 +467,7 @@ def generate(self, batch, **kwargs): repetition_penalty = kwargs.pop('repetition_penalty', self.opt.repetition_penalty) topp = kwargs.pop('topp', self.opt.topp) - decoder_input: torch.LongTensor = batch['text_vec'] # (bsz, ...) + decoder_input: torch.LongTensor = batch # (bsz, ...) assert decoder_input[:, -1].ne( self.tokenizer.pad_token_id).all(), 'Last token should not be a padding token (you can use left padding instead).' @@ -558,7 +558,7 @@ def __init__( - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3]. """ super(LlamaVAC, self).__init__() - self.actor = Llama.from_pretrained(actor_path, opt=opt, tokenizer=tokenizer) + self.actor = Llama.from_pretrained(actor_path, opt=opt, tokenizer=tokenizer,) self.critic = LlamaRewardModel.from_pretrained(critic_path, opt=opt, tokenizer=tokenizer) def forward(self, x: torch.Tensor, mode: str) -> Dict: diff --git a/ding/policy/ppof.py b/ding/policy/ppof.py index e746b9c617..a445413ae1 100644 --- a/ding/policy/ppof.py +++ b/ding/policy/ppof.py @@ -164,13 +164,18 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]: with torch.no_grad(): if self._cfg.chat_data: # [B, T] - value = self._model.compute_critic(data.obs) - data.orig_logit = self._orig_model.compute_actor(data.obs) + value = self._model.compute_critic(data.obs)['value'][0] + self._model.cpu() + self._orig_model.cuda() + data.orig_logit = self._orig_model.compute_actor(data.obs)['logit'] + self._orig_model.cpu() + self._model.cuda() data.value = value reward = data.reward traj_flag = data.get('traj_flag', None) # traj_flag indicates termination of trajectory - adv_data = episodic_gae_data(value, data.mask, reward, data.done, traj_flag) + done = data.get('done', None) + adv_data = episodic_gae_data(value, data.mask, reward, done, traj_flag) data.adv = episodic_gae(adv_data, self._cfg.discount_factor, self._cfg.gae_lambda) unnormalized_returns = data.value + data.adv @@ -252,7 +257,7 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]: ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio) else: ppo_batch = ppo_data( - output.logit, batch.logit, batch.action, output.value, batch.value, adv, batch.return_, None + output['logit'], batch.orig_logit, batch.obs, output['value'][0], batch.value, adv, batch.return_, None ) ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio) elif self._action_space == 'hybrid': @@ -279,7 +284,7 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]: max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac) ) wv, we, wk = self._cfg.value_weight, self._cfg.entropy_weight, self._cfg.kl_penalty_weight - kl_loss = (torch.nn.functional.kl_div(torch.softmax(output.logit, dim=-1), torch.softmax(data.orig_logit, dim=-1), reduction=None) * mask).mean() + kl_loss = (torch.nn.functional.kl_div(torch.softmax(output.logit, dim=-1), torch.softmax(batch.orig_logit, dim=-1), reduction='none') * mask).mean() total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + wk * kl_loss self._optimizer.zero_grad() diff --git a/ding/rl_utils/gae.py b/ding/rl_utils/gae.py index 487f2c93ef..1d90a89d12 100644 --- a/ding/rl_utils/gae.py +++ b/ding/rl_utils/gae.py @@ -86,5 +86,5 @@ def episodic_gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97): next_val[:-1] = val[1:] gd = gae_data(val.unsqueeze(-1), next_val.unsqueeze(-1), rew.unsqueeze(-1), don.unsqueeze(-1), traj.unsqueeze(-1)) - advs.append(gae(gd, gamma, lambda_)) + advs.append(gae(gd, gamma, lambda_).squeeze(-1)) return torch.stack(advs, dim=0) diff --git a/dizoo/chat/env.py b/dizoo/chat/env.py index e92af3125a..01f0637a9b 100644 --- a/dizoo/chat/env.py +++ b/dizoo/chat/env.py @@ -1,4 +1,3 @@ -import gym import torch from ding.envs import BaseEnv @@ -60,7 +59,8 @@ def step(self, action): padded into the same length. """ output_mask, output_vec = concat_context_and_response(self.tokenizer, self.last_batch['text_vec'].tolist(), action) - rm_input = torch.tensor(pad_sequences(output_vec, self.tokenizer.pad_token_id, padding='left'), dtype=torch.long) + output_vec = pad_sequences(output_vec, self.tokenizer.pad_token_id, padding='left') + rm_input = torch.tensor(output_vec, dtype=torch.long) output_mask = pad_sequences(output_mask, self.tokenizer.pad_token_id, padding='left') with torch.no_grad(): rew, *_ = self.rm(rm_input) diff --git a/dizoo/chat/utils.py b/dizoo/chat/utils.py index 574381fdc8..be05e76805 100644 --- a/dizoo/chat/utils.py +++ b/dizoo/chat/utils.py @@ -38,8 +38,9 @@ def concat_context_and_response( for _context, _response in zip(context, responses): # Each ``_context`` is a single input prompt. _context = strip_pad_token_id(tokenizer, _context) - for _, resp in _response: + for resp in _response: # Each ``resp`` is a single response. + resp = resp[0][1] resp = strip_pad_token_id(tokenizer, resp) if resp[-1] != tokenizer.eos_token_id: warnings.warn( From 4418c68cc369c5f79b7a38c12e1fd145d46026f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Mon, 6 Nov 2023 10:45:11 +0800 Subject: [PATCH 12/17] reformat --- ding/framework/middleware/collector.py | 1 + ding/model/template/vac.py | 31 +++++++++++----------- ding/policy/ppof.py | 29 +++++++++++++++----- ding/reward_model/language_reward_model.py | 8 +++--- ding/rl_utils/gae.py | 5 ++-- 5 files changed, 46 insertions(+), 28 deletions(-) diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py index 261e8fbe66..7ad8650dbe 100644 --- a/ding/framework/middleware/collector.py +++ b/ding/framework/middleware/collector.py @@ -250,4 +250,5 @@ def __call__(self, ctx: "OnlineRLContext") -> None: ctx.train_data = ttorch.as_tensor(train_data) + # TODO battle collector diff --git a/ding/model/template/vac.py b/ding/model/template/vac.py index 34e0f37198..60cb596e55 100644 --- a/ding/model/template/vac.py +++ b/ding/model/template/vac.py @@ -433,6 +433,7 @@ def __init__( class Llama(LlamaForCausalLM): + def __init__(self, config, opt, tokenizer): super().__init__(config) self.opt = opt @@ -469,13 +470,14 @@ def generate(self, batch, **kwargs): decoder_input: torch.LongTensor = batch # (bsz, ...) assert decoder_input[:, -1].ne( - self.tokenizer.pad_token_id).all(), 'Last token should not be a padding token (you can use left padding instead).' + self.tokenizer.pad_token_id + ).all(), 'Last token should not be a padding token (you can use left padding instead).' dev = decoder_input.device bsz = decoder_input.size(0) - scores = torch.zeros((bsz,), device=dev, dtype=torch.float16) - done = torch.zeros((bsz,), device=dev).to(torch.bool) + scores = torch.zeros((bsz, ), device=dev, dtype=torch.float16) + done = torch.zeros((bsz, ), device=dev).to(torch.bool) inds = torch.arange(bsz).to(dev).unsqueeze(1).view(-1) decoder_input = torch.index_select(decoder_input, 0, inds) @@ -495,8 +497,9 @@ def generate(self, batch, **kwargs): if repetition_penalty > 1.: penalty_tokens = decoder_input[:, init_length:] penalty_scores = torch.gather(score, dim=1, index=penalty_tokens) - penalty_scores = torch.where(penalty_scores < 0., penalty_scores * repetition_penalty, - penalty_scores / repetition_penalty) + penalty_scores = torch.where( + penalty_scores < 0., penalty_scores * repetition_penalty, penalty_scores / repetition_penalty + ) score = score.scatter_(dim=1, index=penalty_tokens, src=penalty_scores) # nucleus sampling @@ -524,8 +527,8 @@ def generate(self, batch, **kwargs): preds_scores = [] for i in range(bsz): - seq: torch.LongTensor = decoder_input[i, :lengths[i,]] - res_scores = (float(scores[i,]), seq.tolist()) + seq: torch.LongTensor = decoder_input[i, :lengths[i, ]] + res_scores = (float(scores[i, ]), seq.tolist()) preds_scores.append([res_scores]) best_preds_scores = [preds[0] for preds in preds_scores] @@ -543,13 +546,7 @@ class LlamaVAC(nn.Module): """ mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] - def __init__( - self, - actor_path: str, - critic_path: str, - tokenizer: LlamaTokenizer, - opt: Dict - ) -> None: + def __init__(self, actor_path: str, critic_path: str, tokenizer: LlamaTokenizer, opt: Dict) -> None: """ Overview: Initialize the ``DREAMERVAC`` model according to arguments. @@ -558,7 +555,11 @@ def __init__( - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3]. """ super(LlamaVAC, self).__init__() - self.actor = Llama.from_pretrained(actor_path, opt=opt, tokenizer=tokenizer,) + self.actor = Llama.from_pretrained( + actor_path, + opt=opt, + tokenizer=tokenizer, + ) self.critic = LlamaRewardModel.from_pretrained(critic_path, opt=opt, tokenizer=tokenizer) def forward(self, x: torch.Tensor, mode: str) -> Dict: diff --git a/ding/policy/ppof.py b/ding/policy/ppof.py index a445413ae1..10b7876494 100644 --- a/ding/policy/ppof.py +++ b/ding/policy/ppof.py @@ -60,7 +60,13 @@ def default_model(cls: type) -> Callable: from .model import PPOFModel return PPOFModel - def __init__(self, cfg: "EasyDict", model: torch.nn.Module, enable_mode: List[str] = None, orig_model: torch.nn.Module = None) -> None: + def __init__( + self, + cfg: "EasyDict", + model: torch.nn.Module, + enable_mode: List[str] = None, + orig_model: torch.nn.Module = None + ) -> None: self._cfg = cfg self._orig_model = orig_model if model is None: @@ -238,7 +244,6 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]: for batch in split_data: output = self._model.compute_actor_critic(batch.obs) adv = batch.adv - mask = batch.mask if self._cfg.adv_norm: # Normalize advantage in a train_batch adv = (adv - adv.mean()) / (adv.std() + 1e-8) @@ -256,8 +261,10 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]: ) ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio) else: + mask = batch.mask ppo_batch = ppo_data( - output['logit'], batch.orig_logit, batch.obs, output['value'][0], batch.value, adv, batch.return_, None + output['logit'], batch.orig_logit, batch.obs, output['value'][0], batch.value, adv, + batch.return_, None ) ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio) elif self._action_space == 'hybrid': @@ -283,9 +290,19 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]: max(ppo_continuous_info.approx_kl, ppo_discrete_info.approx_kl), max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac) ) - wv, we, wk = self._cfg.value_weight, self._cfg.entropy_weight, self._cfg.kl_penalty_weight - kl_loss = (torch.nn.functional.kl_div(torch.softmax(output.logit, dim=-1), torch.softmax(batch.orig_logit, dim=-1), reduction='none') * mask).mean() - total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + wk * kl_loss + if not self._cfg.chat_data: + wv, we = self._cfg.value_weight, self._cfg.entropy_weight + total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + else: + wv, we, wk = self._cfg.value_weight, self._cfg.entropy_weight, self._cfg.kl_penalty_weight + kl_loss = ( + torch.nn.functional.kl_div( + torch.softmax(output.logit, dim=-1), + torch.softmax(batch.orig_logit, dim=-1), + reduction='none' + ) * mask + ).mean() + total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + wk * kl_loss self._optimizer.zero_grad() total_loss.backward() diff --git a/ding/reward_model/language_reward_model.py b/ding/reward_model/language_reward_model.py index a2a5544b1d..595348bfa3 100644 --- a/ding/reward_model/language_reward_model.py +++ b/ding/reward_model/language_reward_model.py @@ -3,6 +3,7 @@ class LlamaRewardModel(LlamaForCausalLM): + def __init__(self, config, opt, tokenizer): super().__init__(config) self.opt = opt @@ -12,10 +13,7 @@ def __init__(self, config, opt, tokenizer): def forward(self, decoder_input, only_last=True): attention_mask = decoder_input.ne(self.tokenizer.pad_token_id) output = self.model.forward( - input_ids=decoder_input, - attention_mask=attention_mask, - return_dict=True, - use_cache=False + input_ids=decoder_input, attention_mask=attention_mask, return_dict=True, use_cache=False ) if only_last: @@ -23,4 +21,4 @@ def forward(self, decoder_input, only_last=True): else: logits = self.reward_head(output.last_hidden_state).squeeze(-1) - return (logits,) + return (logits, ) diff --git a/ding/rl_utils/gae.py b/ding/rl_utils/gae.py index 1d90a89d12..d7081ad969 100644 --- a/ding/rl_utils/gae.py +++ b/ding/rl_utils/gae.py @@ -84,7 +84,8 @@ def episodic_gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97): assert val.shape[0] == rew.shape[0] next_val = torch.zeros_like(val) next_val[:-1] = val[1:] - gd = gae_data(val.unsqueeze(-1), next_val.unsqueeze(-1), rew.unsqueeze(-1), don.unsqueeze(-1), - traj.unsqueeze(-1)) + gd = gae_data( + val.unsqueeze(-1), next_val.unsqueeze(-1), rew.unsqueeze(-1), don.unsqueeze(-1), traj.unsqueeze(-1) + ) advs.append(gae(gd, gamma, lambda_).squeeze(-1)) return torch.stack(advs, dim=0) From d5f931be577ca119f0ca24cd406a0a60454cb06a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Fri, 24 Nov 2023 15:46:07 +0800 Subject: [PATCH 13/17] polish --- ding/bonus/ppof.py | 3 +- ding/framework/middleware/__init__.py | 2 +- ding/framework/middleware/collector.py | 2 +- ding/model/common/__init__.py | 2 +- ding/model/common/tests/test_utils.py | 21 +++ ding/model/common/utils.py | 17 +- ding/model/template/__init__.py | 1 + ding/model/template/lm_vac.py | 184 +++++++++++++++++++++ ding/model/template/vac.py | 155 ----------------- ding/policy/ppof.py | 2 +- ding/reward_model/language_reward_model.py | 11 +- dizoo/chat/entry.py | 34 ++++ dizoo/chat/env.py | 22 ++- 13 files changed, 278 insertions(+), 178 deletions(-) create mode 100644 ding/model/common/tests/test_utils.py create mode 100644 ding/model/template/lm_vac.py create mode 100644 dizoo/chat/entry.py diff --git a/ding/bonus/ppof.py b/ding/bonus/ppof.py index fc755bd25a..63b0334024 100644 --- a/ding/bonus/ppof.py +++ b/ding/bonus/ppof.py @@ -10,7 +10,7 @@ import torch from ding.framework import task, OnlineRLContext from ding.framework.middleware import interaction_evaluator_ttorch, PPOFStepCollector, multistep_trainer, CkptSaver, \ - wandb_online_logger, offline_data_saver, termination_checker, ppof_adv_estimator + wandb_online_logger, offline_data_saver, termination_checker, ppof_adv_estimator, ChatCollector from ding.envs import BaseEnv, BaseEnvManagerV2, SubprocessEnvManagerV2 from ding.policy import PPOFPolicy, single_env_forward_wrapper_ttorch from ding.utils import set_pkg_seed @@ -19,7 +19,6 @@ from .model import PPOFModel from .config import get_instance_config, get_instance_env, get_hybrid_shape from ding.bonus.common import TrainingReturn, EvalReturn -from ..framework.middleware.collector import ChatCollector class PPOF: diff --git a/ding/framework/middleware/__init__.py b/ding/framework/middleware/__init__.py index b9e3c5005d..43cea6883b 100644 --- a/ding/framework/middleware/__init__.py +++ b/ding/framework/middleware/__init__.py @@ -1,5 +1,5 @@ from .functional import * -from .collector import StepCollector, EpisodeCollector, PPOFStepCollector +from .collector import StepCollector, EpisodeCollector, PPOFStepCollector, ChatCollector from .learner import OffPolicyLearner, HERLearner from .ckpt_handler import CkptSaver from .distributer import ContextExchanger, ModelExchanger, PeriodicalModelExchanger diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py index 7ad8650dbe..24533e5983 100644 --- a/ding/framework/middleware/collector.py +++ b/ding/framework/middleware/collector.py @@ -195,7 +195,7 @@ class ChatCollector: """ Overview: The class of the collector running by steps, including model inference and transition \ - process. Use the `__call__` method to execute the whole collection process. + process. Use the `__call__` method to execute the whole collection process. """ def __new__(cls, *args, **kwargs): diff --git a/ding/model/common/__init__.py b/ding/model/common/__init__.py index 4bf7d8be5a..5bf1fba4d5 100755 --- a/ding/model/common/__init__.py +++ b/ding/model/common/__init__.py @@ -2,4 +2,4 @@ QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, BranchingHead, head_cls_map, \ independent_normal_dist, AttentionPolicyHead, PopArtVHead, EnsembleHead from .encoder import ConvEncoder, FCEncoder, IMPALAConvEncoder -from .utils import create_model +from .utils import create_model, top_p_logits diff --git a/ding/model/common/tests/test_utils.py b/ding/model/common/tests/test_utils.py new file mode 100644 index 0000000000..b12e529e9b --- /dev/null +++ b/ding/model/common/tests/test_utils.py @@ -0,0 +1,21 @@ +import pytest +import torch +from ding.model.common.utils import top_p_logits + + +@pytest.mark.unittest +class TestUtils: + + def test_top_p_logits(self): + test_logit = torch.Tensor([ + [0., 0.91, 0.05, 0.04], + [0.04, 0.46, 0.46, 0.04] + ]) + + gt_logit = torch.Tensor([ + [0., 1., 0., 0.], + [0., 0.5, 0.5, 0.] + ]) + + pred_logit = top_p_logits(test_logit) + assert torch.sum((gt_logit - pred_logit)**2).item() < 1e-8 diff --git a/ding/model/common/utils.py b/ding/model/common/utils.py index 340749eb0c..1b4b41e59a 100644 --- a/ding/model/common/utils.py +++ b/ding/model/common/utils.py @@ -23,16 +23,25 @@ def create_model(cfg: EasyDict) -> torch.nn.Module: return MODEL_REGISTRY.build(cfg.pop("type"), **cfg) -def top_p_logits(logits, topp=0.9, filter_value=0, min_topk=1): +def top_p_logits(logits: torch.Tensor, topp: float = 0.9, filter_value: float = 0, min_topk: int = 1): """ - Filter a distribution of logits using nucleus (top-p) filtering - https://github.com/OpenLMLab/MOSS/blob/e088f438d1a95d424c6dffef0d73134ebe62cb72/models_jittor/generation.py#L146 + Overview: + Filter a distribution of logits using nucleus (top-p) filtering. The output is also logit tensors but some \ + values are masked. + Arguments: + - logits (:obj:`torch.Tensor`): The input logits for top-p sampling. + - topp (:obj:`float`): The top-p value, such as 0.9. + - filter_value (:obj:`float`): The value for masked logits in output, default as 0. + - min_topk (:obj:`int`): The min number of sampled logit, default as 1 (which means that at least one sample \ + will not be masked.) + Returns: + - cum_logits (:obj:`torch.Tensor`): The output logits after masking. """ cum_logits = logits.clone() if topp > 0: logits_sorted, inds = torch.sort(logits, dim=-1, descending=True) mask = (logits_sorted.cumsum(dim=-1) - logits_sorted) >= topp - mask[:, :min_topk] = False + mask[..., :min_topk] = False # Remove tokens with cumulative top_p above the threshold mask = torch.zeros_like(mask).to(torch.bool).scatter_(dim=-1, index=inds, src=mask) cum_logits[mask] = filter_value diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py index b2dd815287..cbf31e72ae 100755 --- a/ding/model/template/__init__.py +++ b/ding/model/template/__init__.py @@ -5,6 +5,7 @@ from .vac import VAC, DREAMERVAC from .bc import DiscreteBC, ContinuousBC from .language_transformer import LanguageTransformer +from .lm_vac import LlamaVAC # algorithm-specific from .pg import PG from .ppg import PPG diff --git a/ding/model/template/lm_vac.py b/ding/model/template/lm_vac.py new file mode 100644 index 0000000000..61ef4f7856 --- /dev/null +++ b/ding/model/template/lm_vac.py @@ -0,0 +1,184 @@ +from typing import Dict +import torch +import torch.nn as nn +try: + from transformers import LlamaTokenizer + from transformers.models.llama.modeling_llama import LlamaForCausalLM +except ImportError: + from ditk import logging + logging.warning("Not found transformer, please install it using: pip install transformers") + +from ding.model.common import top_p_logits +from ding.reward_model import LlamaRewardModel +from ding.utils import MODEL_REGISTRY + + +def get_tokenizer(path: str): + """ + Overview: + Return the pretrained tokenizer using the given path. + """ + tokenizer = LlamaTokenizer.from_pretrained(path, trust_remote_code=True) + tokenizer.bos_token = '' + tokenizer.eos_token = '' + tokenizer.pad_token = '' + tokenizer.pad_token_id = 0 + tokenizer.unk_token = tokenizer.pad_token + tokenizer.unk_token_id = tokenizer.pad_token_id + + return tokenizer + + +class Llama(LlamaForCausalLM): + + def __init__(self, config, opt, tokenizer): + super().__init__(config) + self.opt = opt + self.tokenizer = tokenizer + + def forward(self, decoder_input, incr_state=None): + + attention_mask = decoder_input.ne(self.tokenizer.pad_token_id) + if incr_state is not None: + decoder_input = decoder_input[:, -1:] + + output = super().forward( + input_ids=decoder_input, + attention_mask=attention_mask, + past_key_values=incr_state, + return_dict=True, + use_cache=not self.training + ) + + logits = output.logits + new_incr_states = output.past_key_values + + return logits, new_incr_states + + @torch.no_grad() + def generate(self, batch, **kwargs): + """ + Generate response + """ + maxlen_res = kwargs.pop('maxlen_res', self.opt.maxlen_res) + temperature = kwargs.pop('temperature', self.opt.temperature) + repetition_penalty = kwargs.pop('repetition_penalty', self.opt.repetition_penalty) + topp = kwargs.pop('topp', self.opt.topp) + + decoder_input: torch.LongTensor = batch # (bsz, ...) + assert decoder_input[:, -1].ne( + self.tokenizer.pad_token_id + ).all(), 'Last token should not be a padding token (you can use left padding instead).' + + dev = decoder_input.device + bsz = decoder_input.size(0) + + scores = torch.zeros((bsz, ), device=dev, dtype=torch.float16) + done = torch.zeros((bsz, ), device=dev).to(torch.bool) + + inds = torch.arange(bsz).to(dev).unsqueeze(1).view(-1) + decoder_input = torch.index_select(decoder_input, 0, inds) + init_length = decoder_input.size(1) + + incr_state = None + for _token in range(maxlen_res): + if done.all(): + break + score, incr_state, *_ = self.forward(decoder_input, incr_state) + score = score.half() + + # now score is bs, len, vocab_size + score = score[:, -1, :] + + # calculate repetition penalty + if repetition_penalty > 1.: + penalty_tokens = decoder_input[:, init_length:] + penalty_scores = torch.gather(score, dim=1, index=penalty_tokens) + penalty_scores = torch.where( + penalty_scores < 0., penalty_scores * repetition_penalty, penalty_scores / repetition_penalty + ) + score = score.scatter_(dim=1, index=penalty_tokens, src=penalty_scores) + + # nucleus sampling + score = torch.softmax(score.div(temperature), dim=-1) + probs = top_p_logits(score, topp=topp, filter_value=0) + tok_ids = torch.multinomial(probs, 1)[:, 0] + hyp_ids = torch.arange(probs.size(0), device=dev) + scores = scores + probs[hyp_ids, tok_ids].log() * ~done + + tok_ids = torch.where(done, self.tokenizer.pad_token_id, tok_ids) + decoder_input = torch.cat((decoder_input, tok_ids.unsqueeze(-1)), dim=-1) + done = done | tok_ids.eq(self.tokenizer.eos_token_id) + + incr_state = self._reorder_cache(incr_state, hyp_ids) + + # get all finalized candidates for each sample + decoder_input = decoder_input[:, init_length:] + decoder_input = decoder_input.view(bsz, -1) + scores = scores.view(bsz, ) + + lengths = decoder_input.ne(self.tokenizer.pad_token_id).sum(dim=-1) + + length_penalty = torch.pow(lengths, 1.0) + scores /= length_penalty + + preds_scores = [] + for i in range(bsz): + seq: torch.LongTensor = decoder_input[i, :lengths[i, ]] + res_scores = (float(scores[i, ]), seq.tolist()) + preds_scores.append([res_scores]) + + best_preds_scores = [preds[0] for preds in preds_scores] + return best_preds_scores, preds_scores + + +@MODEL_REGISTRY.register('llamavac') +class LlamaVAC(nn.Module): + """ + Overview: + The neural network and computation graph of Llama VAC. The actor and critic of this model are respectively \ + a Llama Pretrained Model. + Interfaces: + ``__init__``, ``forward``. + """ + mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] + + def __init__(self, actor_path: str, critic_path: str, opt: Dict, enable_checkpointing: bool = True) -> None: + """ + Overview: + Initialize the ``LlamaVAC`` model according to arguments. + Arguments: + - actor_path (:obj:`str`): Pretrained model path for actor. + - critic_path (:obj:`str`): Pretrained model path for critic. + - opt (:obj:`Dict`): Options for this model. + """ + super(LlamaVAC, self).__init__() + tokenizer = get_tokenizer(actor_path) + self.actor = Llama.from_pretrained( + actor_path, + opt=opt, + tokenizer=tokenizer, + ) + self.critic = LlamaRewardModel.from_pretrained(critic_path, tokenizer=tokenizer) + if enable_checkpointing: + self.actor.gradient_checkpointing_enable() + self.critic.gradient_checkpointing_enable() + + def forward(self, x: torch.Tensor, mode: str) -> Dict: + assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) + return getattr(self, mode)(x) + + def compute_actor(self, x): + policy_output = self.actor(decoder_input=x) + policy_logit, *_ = policy_output + return {"logit": policy_logit} + + def compute_critic(self, x): + values = self.critic(decoder_input=x, only_last=False) + return {"value": values} + + def compute_actor_critic(self, x): + policy_output = self.actor(decoder_input=x) + policy_logit, *_ = policy_output + values = self.critic(decoder_input=x, only_last=False) + return {"logit": policy_logit, "value": values} diff --git a/ding/model/template/vac.py b/ding/model/template/vac.py index 60cb596e55..29363d3570 100644 --- a/ding/model/template/vac.py +++ b/ding/model/template/vac.py @@ -1,7 +1,4 @@ from typing import Union, Dict, Optional - -from transformers import LlamaTokenizer -from transformers.models.llama.modeling_llama import LlamaForCausalLM from easydict import EasyDict import torch import torch.nn as nn @@ -10,8 +7,6 @@ from ..common import ReparameterizationHead, RegressionHead, DiscreteHead, MultiHead, \ FCEncoder, ConvEncoder, IMPALAConvEncoder from ding.torch_utils.network.dreamer import ActionHead, DenseHead -from ..common.utils import top_p_logits -from ding.reward_model import LlamaRewardModel @MODEL_REGISTRY.register('vac') @@ -430,153 +425,3 @@ def __init__( outscale=0.0, device='cuda' if torch.cuda.is_available() else 'cpu', ) - - -class Llama(LlamaForCausalLM): - - def __init__(self, config, opt, tokenizer): - super().__init__(config) - self.opt = opt - self.tokenizer = tokenizer - - def forward(self, decoder_input, incr_state=None): - - attention_mask = decoder_input.ne(self.tokenizer.pad_token_id) - if incr_state is not None: - decoder_input = decoder_input[:, -1:] - - output = super().forward( - input_ids=decoder_input, - attention_mask=attention_mask, - past_key_values=incr_state, - return_dict=True, - use_cache=not self.training - ) - - logits = output.logits - new_incr_states = output.past_key_values - - return logits, new_incr_states - - @torch.no_grad() - def generate(self, batch, **kwargs): - """ - Generate response - """ - maxlen_res = kwargs.pop('maxlen_res', self.opt.maxlen_res) - temperature = kwargs.pop('temperature', self.opt.temperature) - repetition_penalty = kwargs.pop('repetition_penalty', self.opt.repetition_penalty) - topp = kwargs.pop('topp', self.opt.topp) - - decoder_input: torch.LongTensor = batch # (bsz, ...) - assert decoder_input[:, -1].ne( - self.tokenizer.pad_token_id - ).all(), 'Last token should not be a padding token (you can use left padding instead).' - - dev = decoder_input.device - bsz = decoder_input.size(0) - - scores = torch.zeros((bsz, ), device=dev, dtype=torch.float16) - done = torch.zeros((bsz, ), device=dev).to(torch.bool) - - inds = torch.arange(bsz).to(dev).unsqueeze(1).view(-1) - decoder_input = torch.index_select(decoder_input, 0, inds) - init_length = decoder_input.size(1) - - incr_state = None - for _token in range(maxlen_res): - if done.all(): - break - score, incr_state, *_ = self.forward(decoder_input, incr_state) - score = score.half() - - # now score is bs, len, vocab_size - score = score[:, -1, :] - - # calculate repetition penalty - if repetition_penalty > 1.: - penalty_tokens = decoder_input[:, init_length:] - penalty_scores = torch.gather(score, dim=1, index=penalty_tokens) - penalty_scores = torch.where( - penalty_scores < 0., penalty_scores * repetition_penalty, penalty_scores / repetition_penalty - ) - score = score.scatter_(dim=1, index=penalty_tokens, src=penalty_scores) - - # nucleus sampling - score = torch.softmax(score.div(temperature), dim=-1) - probs = top_p_logits(score, topp=topp, filter_value=0) - tok_ids = torch.multinomial(probs, 1)[:, 0] - hyp_ids = torch.arange(probs.size(0), device=dev) - scores = scores + probs[hyp_ids, tok_ids].log() * ~done - - tok_ids = torch.where(done, self.tokenizer.pad_token_id, tok_ids) - decoder_input = torch.cat((decoder_input, tok_ids.unsqueeze(-1)), dim=-1) - done = done | tok_ids.eq(self.tokenizer.eos_token_id) - - incr_state = self._reorder_cache(incr_state, hyp_ids) - - # get all finalized candidates for each sample - decoder_input = decoder_input[:, init_length:] - decoder_input = decoder_input.view(bsz, -1) - scores = scores.view(bsz, ) - - lengths = decoder_input.ne(self.tokenizer.pad_token_id).sum(dim=-1) - - length_penalty = torch.pow(lengths, 1.0) - scores /= length_penalty - - preds_scores = [] - for i in range(bsz): - seq: torch.LongTensor = decoder_input[i, :lengths[i, ]] - res_scores = (float(scores[i, ]), seq.tolist()) - preds_scores.append([res_scores]) - - best_preds_scores = [preds[0] for preds in preds_scores] - return best_preds_scores, preds_scores - - -@MODEL_REGISTRY.register('llamavac') -class LlamaVAC(nn.Module): - """ - Overview: - The neural network and computation graph of DreamerV3 (state) Value Actor-Critic (VAC). - This model now supports discrete, continuous action space. - Interfaces: - ``__init__``, ``forward``. - """ - mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] - - def __init__(self, actor_path: str, critic_path: str, tokenizer: LlamaTokenizer, opt: Dict) -> None: - """ - Overview: - Initialize the ``DREAMERVAC`` model according to arguments. - Arguments: - - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84]. - - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3]. - """ - super(LlamaVAC, self).__init__() - self.actor = Llama.from_pretrained( - actor_path, - opt=opt, - tokenizer=tokenizer, - ) - self.critic = LlamaRewardModel.from_pretrained(critic_path, opt=opt, tokenizer=tokenizer) - - def forward(self, x: torch.Tensor, mode: str) -> Dict: - assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) - return getattr(self, mode)(x) - - def compute_actor(self, x): - policy_output = self.actor(decoder_input=x) - policy_logit, *_ = policy_output - return {"logit": policy_logit} - - def compute_critic(self, x): - values = self.critic(decoder_input=x, only_last=False) - return {"value": values} - - def compute_actor_critic(self, x): - policy_output = self.actor(decoder_input=x) - policy_logit, *_ = policy_output - values = self.critic(decoder_input=x, only_last=False) - return {"logit": policy_logit, "value": values} diff --git a/ding/policy/ppof.py b/ding/policy/ppof.py index 10b7876494..d519853835 100644 --- a/ding/policy/ppof.py +++ b/ding/policy/ppof.py @@ -263,7 +263,7 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]: else: mask = batch.mask ppo_batch = ppo_data( - output['logit'], batch.orig_logit, batch.obs, output['value'][0], batch.value, adv, + output['logit'], batch.orig_logit, batch.obs, output['value'], batch.value, adv, batch.return_, None ) ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio) diff --git a/ding/reward_model/language_reward_model.py b/ding/reward_model/language_reward_model.py index 595348bfa3..79d3d0c149 100644 --- a/ding/reward_model/language_reward_model.py +++ b/ding/reward_model/language_reward_model.py @@ -1,12 +1,15 @@ import torch -from transformers.models.llama.modeling_llama import LlamaForCausalLM +try: + from transformers.models.llama.modeling_llama import LlamaForCausalLM +except ImportError: + from ditk import logging + logging.warning("Not found transformer, please install it using: pip install transformers") class LlamaRewardModel(LlamaForCausalLM): - def __init__(self, config, opt, tokenizer): + def __init__(self, config, tokenizer): super().__init__(config) - self.opt = opt self.tokenizer = tokenizer self.reward_head = torch.nn.Linear(config.hidden_size, 1, bias=False) @@ -21,4 +24,4 @@ def forward(self, decoder_input, only_last=True): else: logits = self.reward_head(output.last_hidden_state).squeeze(-1) - return (logits, ) + return logits diff --git a/dizoo/chat/entry.py b/dizoo/chat/entry.py new file mode 100644 index 0000000000..ecb47c53f9 --- /dev/null +++ b/dizoo/chat/entry.py @@ -0,0 +1,34 @@ +from easydict import EasyDict + +from ding.bonus.ppof import PPOF +from ding.model.template import LlamaVAC + + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--actor_path', type=str) + parser.add_argument('--critic_path', type=str) + args = parser.parse_args() + + opt = EasyDict({ + "maxlen_res": 512, + "temperature": 1, + "repetition_penalty": 1, + "topp": 0.8 + }) + + model = LlamaVAC( + actor_path=args.actor_path, + critic_path=args.critic_path, + tokenizer=get_tokenizer("/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-sft-model-7B-en"), + opt=opt + ) + + policy = PPOF( + env_id="chat", + exp_name="rlhf-ppo", + model=model + ) + policy.train(collector_env_num=1, evaluator_env_num=1, debug=True) diff --git a/dizoo/chat/env.py b/dizoo/chat/env.py index 01f0637a9b..01d0f183cb 100644 --- a/dizoo/chat/env.py +++ b/dizoo/chat/env.py @@ -8,16 +8,16 @@ class ChatEnv(BaseEnv): def __init__( self, - batch_size, - reward_model_path, - tokenizer_path, - data_path, - maxlen_prompt, - maxlen_res, + batch_size: int, + reward_model_path: str, + tokenizer_path: str, + data_path: str, + maxlen_prompt: int, + maxlen_res: int, ): self.batch_size = batch_size self.tokenizer = get_tokenizer(tokenizer_path) - self.rm = LlamaRewardModel.from_pretrained(reward_model_path, tokenizer=self.tokenizer, opt=None) + self.rm = LlamaRewardModel.from_pretrained(reward_model_path, tokenizer=self.tokenizer) self.action_space = None self.observation_space = None self.reward_space = None @@ -41,6 +41,9 @@ def close(self) -> None: def reset(self): self.last_batch = next(self.generator) + if self.last_batch is None: + self.generator = self.dataset.final_generator() + self.last_batch = next(self.generator) self._init_flag = True return self.last_batch @@ -48,9 +51,10 @@ def __repr__(self) -> str: return "DI-engine Chat Env" def seed(self, seed): - self._seed = 0 + self._seed = seed def clone(self, caller): + # It should not create a new copy, since the language model is initialized. return self def step(self, action): @@ -63,7 +67,7 @@ def step(self, action): rm_input = torch.tensor(output_vec, dtype=torch.long) output_mask = pad_sequences(output_mask, self.tokenizer.pad_token_id, padding='left') with torch.no_grad(): - rew, *_ = self.rm(rm_input) + rew = self.rm(rm_input) self.last_batch = next(self.generator) if self.last_batch is None: From 11d3c480689ac39978c08b577e88191cd0c1e462 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Mon, 4 Dec 2023 22:18:41 +0800 Subject: [PATCH 14/17] add mix precision --- ding/bonus/config.py | 8 ++-- ding/framework/middleware/collector.py | 4 +- ding/model/template/lm_vac.py | 8 +++- ding/policy/ppof.py | 55 ++++++++++++++-------- ding/reward_model/language_reward_model.py | 7 +-- ding/rl_utils/gae.py | 5 +- dizoo/chat/entry.py | 4 +- 7 files changed, 56 insertions(+), 35 deletions(-) diff --git a/ding/bonus/config.py b/ding/bonus/config.py index 113eaf0943..ddb73fb235 100644 --- a/ding/bonus/config.py +++ b/ding/bonus/config.py @@ -169,7 +169,7 @@ def get_instance_config(env_id: str, algorithm: str) -> EasyDict: cfg.learning_rate = 3e-4 elif env_id == 'chat': cfg.epoch_per_collect = 1 - cfg.batch_size = 2 + cfg.batch_size = 1 cfg.learning_rate = 5e-7 cfg.answers_per_question = 3 cfg.kl_penalty_weight = 0.1 @@ -325,12 +325,12 @@ def get_instance_env(env_id: str) -> BaseEnv: elif env_id == 'chat': from dizoo.chat.env import ChatEnv return ChatEnv( - batch_size=2, + batch_size=1, reward_model_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en/recover", tokenizer_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en", data_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/data/ppo_data", - maxlen_prompt=1024, - maxlen_res=512, + maxlen_prompt=128, + maxlen_res=128, ) else: raise KeyError("not supported env type: {}".format(env_id)) diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py index 24533e5983..6017940396 100644 --- a/ding/framework/middleware/collector.py +++ b/ding/framework/middleware/collector.py @@ -214,7 +214,7 @@ def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll self.env = env self.env.seed(seed) self.env.launch() - self.env = self._envs[0] + self.env = self.env._envs[0] self.policy = policy self.n_sample = n_sample self.unroll_len = unroll_len @@ -229,7 +229,7 @@ def __call__(self, ctx: "OnlineRLContext") -> None: """ device = self.policy._device - obs = ttorch.as_tensor(self.env.last_batch[0]['text_vec']) + obs = ttorch.as_tensor(self.env.last_batch['text_vec']) batch_size = obs.shape[0] obs = obs.to(device) diff --git a/ding/model/template/lm_vac.py b/ding/model/template/lm_vac.py index 61ef4f7856..e2099e705a 100644 --- a/ding/model/template/lm_vac.py +++ b/ding/model/template/lm_vac.py @@ -143,7 +143,8 @@ class LlamaVAC(nn.Module): """ mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] - def __init__(self, actor_path: str, critic_path: str, opt: Dict, enable_checkpointing: bool = True) -> None: + def __init__(self, actor_path: str, critic_path: str, + tokenizer_path: str, opt: Dict, enable_checkpointing: bool = True) -> None: """ Overview: Initialize the ``LlamaVAC`` model according to arguments. @@ -153,13 +154,16 @@ def __init__(self, actor_path: str, critic_path: str, opt: Dict, enable_checkpoi - opt (:obj:`Dict`): Options for this model. """ super(LlamaVAC, self).__init__() - tokenizer = get_tokenizer(actor_path) + tokenizer = get_tokenizer(tokenizer_path) + self.actor = Llama.from_pretrained( actor_path, opt=opt, tokenizer=tokenizer, ) + self.critic = LlamaRewardModel.from_pretrained(critic_path, tokenizer=tokenizer) + if enable_checkpointing: self.actor.gradient_checkpointing_enable() self.critic.gradient_checkpointing_enable() diff --git a/ding/policy/ppof.py b/ding/policy/ppof.py index d519853835..ad1afb1c1d 100644 --- a/ding/policy/ppof.py +++ b/ding/policy/ppof.py @@ -7,6 +7,8 @@ import torch import treetensor.torch as ttorch from torch.optim import AdamW +from torch.cuda.amp import GradScaler +from torch import autocast from ding.rl_utils import ppo_data, ppo_error, ppo_policy_error, ppo_policy_data, gae, gae_data, ppo_error_continuous, \ get_gae, ppo_policy_error_continuous, ArgmaxSampler, MultinomialSampler, ReparameterizationSampler, MuSampler, \ @@ -69,6 +71,8 @@ def __init__( ) -> None: self._cfg = cfg self._orig_model = orig_model + if self._orig_model is not None: + self.scalar = GradScaler() if model is None: self._model = self.default_model() else: @@ -170,7 +174,7 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]: with torch.no_grad(): if self._cfg.chat_data: # [B, T] - value = self._model.compute_critic(data.obs)['value'][0] + value = self._model.compute_critic(data.obs)['value'] self._model.cpu() self._orig_model.cuda() data.orig_logit = self._orig_model.compute_actor(data.obs)['logit'] @@ -242,7 +246,8 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]: split_data = ttorch.split(data, self._cfg.batch_size) random.shuffle(list(split_data)) for batch in split_data: - output = self._model.compute_actor_critic(batch.obs) + if not self._cfg.chat_data: + output = self._model.compute_actor_critic(batch.obs) adv = batch.adv if self._cfg.adv_norm: # Normalize advantage in a train_batch @@ -261,12 +266,21 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]: ) ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio) else: - mask = batch.mask - ppo_batch = ppo_data( - output['logit'], batch.orig_logit, batch.obs, output['value'], batch.value, adv, - batch.return_, None - ) - ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio) + with autocast(device_type='cuda', dtype=torch.float16): + output = self._model.compute_actor_critic(batch.obs) + mask = batch.mask + ppo_batch = ppo_data( + output['logit'], batch.orig_logit, batch.obs, output['value'], batch.value, adv, + batch.return_, None + ) + ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio) + kl_loss = ( + torch.nn.functional.kl_div( + torch.softmax(output["logit"], dim=-1), + torch.softmax(batch.orig_logit, dim=-1), + reduction='none' + ) * mask.unsqueeze(-1) + ).mean() elif self._action_space == 'hybrid': # discrete part (discrete policy loss and entropy loss) ppo_discrete_batch = ppo_policy_data( @@ -293,21 +307,22 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]: if not self._cfg.chat_data: wv, we = self._cfg.value_weight, self._cfg.entropy_weight total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + self._optimizer.zero_grad() + total_loss.backward() + torch.nn.utils.clip_grad_norm_(self._model.parameters(), self._cfg.grad_norm) + self._optimizer.step() else: wv, we, wk = self._cfg.value_weight, self._cfg.entropy_weight, self._cfg.kl_penalty_weight - kl_loss = ( - torch.nn.functional.kl_div( - torch.softmax(output.logit, dim=-1), - torch.softmax(batch.orig_logit, dim=-1), - reduction='none' - ) * mask - ).mean() total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + wk * kl_loss - - self._optimizer.zero_grad() - total_loss.backward() - torch.nn.utils.clip_grad_norm_(self._model.parameters(), self._cfg.grad_norm) - self._optimizer.step() + output = ttorch.as_tensor(output) + self._optimizer.zero_grad() + self.scaler.scale(total_loss).backward() + # scaler.step() first unscales the gradients of the optimizer's assigned params. + # If these gradients do not contain infs or NaNs, optimizer.step() is then called, + # otherwise, optimizer.step() is skipped. + scaler.step(self._optimizer) + # Updates the scale for next iteration. + scaler.update() return_info = { 'cur_lr': self._optimizer.defaults['lr'], diff --git a/ding/reward_model/language_reward_model.py b/ding/reward_model/language_reward_model.py index 79d3d0c149..ff2558099a 100644 --- a/ding/reward_model/language_reward_model.py +++ b/ding/reward_model/language_reward_model.py @@ -15,9 +15,10 @@ def __init__(self, config, tokenizer): def forward(self, decoder_input, only_last=True): attention_mask = decoder_input.ne(self.tokenizer.pad_token_id) - output = self.model.forward( - input_ids=decoder_input, attention_mask=attention_mask, return_dict=True, use_cache=False - ) + with torch.no_grad(): + output = self.model.forward( + input_ids=decoder_input, attention_mask=attention_mask, return_dict=True, use_cache=False + ) if only_last: logits = self.reward_head(output.last_hidden_state[:, -1, :]).squeeze(-1) diff --git a/ding/rl_utils/gae.py b/ding/rl_utils/gae.py index d7081ad969..16b8313750 100644 --- a/ding/rl_utils/gae.py +++ b/ding/rl_utils/gae.py @@ -81,11 +81,12 @@ def episodic_gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97): bsz = value.shape[0] for i in range(bsz): val, mas, rew, don, traj = value[i], mask[i], reward[i], done[i], traj_flag[i] - assert val.shape[0] == rew.shape[0] next_val = torch.zeros_like(val) next_val[:-1] = val[1:] + reward = torch.zeros_like(val) + reward[-1] = rew gd = gae_data( - val.unsqueeze(-1), next_val.unsqueeze(-1), rew.unsqueeze(-1), don.unsqueeze(-1), traj.unsqueeze(-1) + val.unsqueeze(-1), next_val.unsqueeze(-1), reward.unsqueeze(-1), don.unsqueeze(-1), traj.unsqueeze(-1) ) advs.append(gae(gd, gamma, lambda_).squeeze(-1)) return torch.stack(advs, dim=0) diff --git a/dizoo/chat/entry.py b/dizoo/chat/entry.py index ecb47c53f9..5da6ac620d 100644 --- a/dizoo/chat/entry.py +++ b/dizoo/chat/entry.py @@ -4,12 +4,12 @@ from ding.model.template import LlamaVAC - if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--actor_path', type=str) parser.add_argument('--critic_path', type=str) + parser.add_argument('--tokenizer_path', type=str) args = parser.parse_args() opt = EasyDict({ @@ -22,7 +22,7 @@ model = LlamaVAC( actor_path=args.actor_path, critic_path=args.critic_path, - tokenizer=get_tokenizer("/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-sft-model-7B-en"), + tokenizer_path=args.tokenizer_path, opt=opt ) From f95529ffc6ab22714f3b113273541ddc11430280 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Wed, 6 Dec 2023 17:47:04 +0800 Subject: [PATCH 15/17] fix quant bug --- ding/model/common/tests/test_utils.py | 12 +++--------- ding/model/template/lm_vac.py | 24 +++++++++++++----------- ding/policy/ppof.py | 21 +++++++-------------- 3 files changed, 23 insertions(+), 34 deletions(-) diff --git a/ding/model/common/tests/test_utils.py b/ding/model/common/tests/test_utils.py index b12e529e9b..9a6f688e2d 100644 --- a/ding/model/common/tests/test_utils.py +++ b/ding/model/common/tests/test_utils.py @@ -7,15 +7,9 @@ class TestUtils: def test_top_p_logits(self): - test_logit = torch.Tensor([ - [0., 0.91, 0.05, 0.04], - [0.04, 0.46, 0.46, 0.04] - ]) + test_logit = torch.Tensor([[0., 0.91, 0.05, 0.04], [0.04, 0.46, 0.46, 0.04]]) - gt_logit = torch.Tensor([ - [0., 1., 0., 0.], - [0., 0.5, 0.5, 0.] - ]) + gt_logit = torch.Tensor([[0., 1., 0., 0.], [0., 0.5, 0.5, 0.]]) pred_logit = top_p_logits(test_logit) - assert torch.sum((gt_logit - pred_logit)**2).item() < 1e-8 + assert torch.sum((gt_logit - pred_logit) ** 2).item() < 1e-8 diff --git a/ding/model/template/lm_vac.py b/ding/model/template/lm_vac.py index e2099e705a..2dd96c6a0a 100644 --- a/ding/model/template/lm_vac.py +++ b/ding/model/template/lm_vac.py @@ -36,7 +36,7 @@ def __init__(self, config, opt, tokenizer): self.opt = opt self.tokenizer = tokenizer - def forward(self, decoder_input, incr_state=None): + def forward(self, decoder_input, incr_state=None, is_train=True): attention_mask = decoder_input.ne(self.tokenizer.pad_token_id) if incr_state is not None: @@ -47,7 +47,7 @@ def forward(self, decoder_input, incr_state=None): attention_mask=attention_mask, past_key_values=incr_state, return_dict=True, - use_cache=not self.training + use_cache=not is_train ) logits = output.logits @@ -84,7 +84,7 @@ def generate(self, batch, **kwargs): for _token in range(maxlen_res): if done.all(): break - score, incr_state, *_ = self.forward(decoder_input, incr_state) + score, incr_state, *_ = self.forward(decoder_input, incr_state, is_train=False) score = score.half() # now score is bs, len, vocab_size @@ -143,8 +143,14 @@ class LlamaVAC(nn.Module): """ mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] - def __init__(self, actor_path: str, critic_path: str, - tokenizer_path: str, opt: Dict, enable_checkpointing: bool = True) -> None: + def __init__( + self, + actor_path: str, + critic_path: str, + tokenizer_path: str, + opt: Dict, + enable_checkpointing: bool = True + ) -> None: """ Overview: Initialize the ``LlamaVAC`` model according to arguments. @@ -156,13 +162,9 @@ def __init__(self, actor_path: str, critic_path: str, super(LlamaVAC, self).__init__() tokenizer = get_tokenizer(tokenizer_path) - self.actor = Llama.from_pretrained( - actor_path, - opt=opt, - tokenizer=tokenizer, - ) + self.actor = Llama.from_pretrained(actor_path, opt=opt, tokenizer=tokenizer, torch_dtype=torch.bfloat16) - self.critic = LlamaRewardModel.from_pretrained(critic_path, tokenizer=tokenizer) + self.critic = LlamaRewardModel.from_pretrained(critic_path, tokenizer=tokenizer, torch_dtype=torch.bfloat16) if enable_checkpointing: self.actor.gradient_checkpointing_enable() diff --git a/ding/policy/ppof.py b/ding/policy/ppof.py index ad1afb1c1d..d2da40772d 100644 --- a/ding/policy/ppof.py +++ b/ding/policy/ppof.py @@ -71,8 +71,6 @@ def __init__( ) -> None: self._cfg = cfg self._orig_model = orig_model - if self._orig_model is not None: - self.scalar = GradScaler() if model is None: self._model = self.default_model() else: @@ -275,11 +273,11 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]: ) ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio) kl_loss = ( - torch.nn.functional.kl_div( - torch.softmax(output["logit"], dim=-1), - torch.softmax(batch.orig_logit, dim=-1), - reduction='none' - ) * mask.unsqueeze(-1) + torch.nn.functional.kl_div( + torch.softmax(output["logit"], dim=-1), + torch.softmax(batch.orig_logit, dim=-1), + reduction='none' + ) * mask.unsqueeze(-1) ).mean() elif self._action_space == 'hybrid': # discrete part (discrete policy loss and entropy loss) @@ -316,13 +314,8 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]: total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + wk * kl_loss output = ttorch.as_tensor(output) self._optimizer.zero_grad() - self.scaler.scale(total_loss).backward() - # scaler.step() first unscales the gradients of the optimizer's assigned params. - # If these gradients do not contain infs or NaNs, optimizer.step() is then called, - # otherwise, optimizer.step() is skipped. - scaler.step(self._optimizer) - # Updates the scale for next iteration. - scaler.update() + total_loss.backward() + self._optimizer.step() return_info = { 'cur_lr': self._optimizer.defaults['lr'], From 31a61914c16f13a184cf341fecbdf59752dd98d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Wed, 6 Dec 2023 21:30:18 +0800 Subject: [PATCH 16/17] fix imcopatible problem --- ding/model/template/lm_vac.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/ding/model/template/lm_vac.py b/ding/model/template/lm_vac.py index 2dd96c6a0a..2931fc3fdf 100644 --- a/ding/model/template/lm_vac.py +++ b/ding/model/template/lm_vac.py @@ -31,10 +31,11 @@ def get_tokenizer(path: str): class Llama(LlamaForCausalLM): - def __init__(self, config, opt, tokenizer): + def __init__(self, config, opt, tokenizer, enable_checkpointing): super().__init__(config) self.opt = opt self.tokenizer = tokenizer + self.enable_checkpointing = enable_checkpointing def forward(self, decoder_input, incr_state=None, is_train=True): @@ -60,6 +61,8 @@ def generate(self, batch, **kwargs): """ Generate response """ + if self.enable_checkpointing: + self.gradient_checkpointing_disable() maxlen_res = kwargs.pop('maxlen_res', self.opt.maxlen_res) temperature = kwargs.pop('temperature', self.opt.temperature) repetition_penalty = kwargs.pop('repetition_penalty', self.opt.repetition_penalty) @@ -129,6 +132,8 @@ def generate(self, batch, **kwargs): preds_scores.append([res_scores]) best_preds_scores = [preds[0] for preds in preds_scores] + if self.enable_checkpointing: + self.gradient_checkpointing_enable() return best_preds_scores, preds_scores @@ -161,8 +166,10 @@ def __init__( """ super(LlamaVAC, self).__init__() tokenizer = get_tokenizer(tokenizer_path) + self.enable_checkpointing = enable_checkpointing - self.actor = Llama.from_pretrained(actor_path, opt=opt, tokenizer=tokenizer, torch_dtype=torch.bfloat16) + self.actor = Llama.from_pretrained(actor_path, opt=opt, tokenizer=tokenizer, torch_dtype=torch.bfloat16, + enable_checkpointing=enable_checkpointing) self.critic = LlamaRewardModel.from_pretrained(critic_path, tokenizer=tokenizer, torch_dtype=torch.bfloat16) From 538ddd27a7e241880b4f30727cff1903c6ae6329 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Wed, 3 Jan 2024 15:40:35 +0800 Subject: [PATCH 17/17] reformat --- ding/model/template/lm_vac.py | 9 +++++++-- ding/policy/ppof.py | 1 + 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/ding/model/template/lm_vac.py b/ding/model/template/lm_vac.py index 2931fc3fdf..747b1d1d8f 100644 --- a/ding/model/template/lm_vac.py +++ b/ding/model/template/lm_vac.py @@ -168,8 +168,13 @@ def __init__( tokenizer = get_tokenizer(tokenizer_path) self.enable_checkpointing = enable_checkpointing - self.actor = Llama.from_pretrained(actor_path, opt=opt, tokenizer=tokenizer, torch_dtype=torch.bfloat16, - enable_checkpointing=enable_checkpointing) + self.actor = Llama.from_pretrained( + actor_path, + opt=opt, + tokenizer=tokenizer, + torch_dtype=torch.bfloat16, + enable_checkpointing=enable_checkpointing + ) self.critic = LlamaRewardModel.from_pretrained(critic_path, tokenizer=tokenizer, torch_dtype=torch.bfloat16) diff --git a/ding/policy/ppof.py b/ding/policy/ppof.py index d2da40772d..3dbfb05581 100644 --- a/ding/policy/ppof.py +++ b/ding/policy/ppof.py @@ -323,6 +323,7 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]: 'policy_loss': ppo_loss.policy_loss.item(), 'value_loss': ppo_loss.value_loss.item(), 'entropy_loss': ppo_loss.entropy_loss.item(), + 'kl_loss': kl_loss.item(), 'adv_max': adv.max().item(), 'adv_mean': adv.mean().item(), 'value_mean': output.value.mean().item(),