From 7ba2125758467725edfb5d398d05f52a77091b79 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’>
Date: Tue, 24 Oct 2023 09:24:47 +0800
Subject: [PATCH 01/17] init commit
---
ding/bonus/config.py | 6 +
ding/bonus/ppof.py | 22 +-
ding/framework/middleware/collector.py | 60 ++++-
ding/model/common/utils.py | 16 ++
ding/model/template/vac.py | 150 +++++++++++++
ding/policy/ppof.py | 146 ++++++++-----
ding/reward_model/__init__.py | 2 +
ding/reward_model/language_reward_model.py | 27 +++
ding/rl_utils/gae.py | 20 ++
dizoo/chat/__init__.py | 1 +
dizoo/chat/env.py | 53 +++++
dizoo/chat/utils.py | 243 +++++++++++++++++++++
launch_ppof.py | 20 ++
13 files changed, 700 insertions(+), 66 deletions(-)
create mode 100644 ding/reward_model/language_reward_model.py
create mode 100644 dizoo/chat/__init__.py
create mode 100644 dizoo/chat/env.py
create mode 100644 dizoo/chat/utils.py
create mode 100644 launch_ppof.py
diff --git a/ding/bonus/config.py b/ding/bonus/config.py
index 285eff6586..3b5fe33463 100644
--- a/ding/bonus/config.py
+++ b/ding/bonus/config.py
@@ -167,6 +167,12 @@ def get_instance_config(env_id: str, algorithm: str) -> EasyDict:
cfg.batch_size = 320
cfg.epoch_per_collect = 10
cfg.learning_rate = 3e-4
+ elif env_id == 'chat':
+ cfg.epoch_per_collect = 1
+ cfg.batch_size = 2
+ cfg.learning_rate = 5e-7
+ cfg.answers_per_question = 3
+ cfg.kl_penalty_weight = 0.1
else:
raise KeyError("not supported env type: {}".format(env_id))
else:
diff --git a/ding/bonus/ppof.py b/ding/bonus/ppof.py
index bf6012240f..297e842844 100644
--- a/ding/bonus/ppof.py
+++ b/ding/bonus/ppof.py
@@ -1,3 +1,4 @@
+import copy
from typing import Optional, Union, List
from ditk import logging
from easydict import EasyDict
@@ -18,6 +19,7 @@
from .model import PPOFModel
from .config import get_instance_config, get_instance_env, get_hybrid_shape
from ding.bonus.common import TrainingReturn, EvalReturn
+from ..framework.middleware.collector import ChatCollector
class PPOF:
@@ -52,6 +54,8 @@ class PPOF:
'Hopper-v3',
'HalfCheetah-v3',
'Walker2d-v3',
+ # rlhf
+ 'chat'
]
def __init__(
@@ -129,7 +133,11 @@ def __init__(
popart_head=True,
**self.cfg.model
)
- self.policy = PPOFPolicy(self.cfg, model=model)
+ if self.cfg.chat_data:
+ orig_model = copy.deepcopy(model)
+ else:
+ orig_model = None
+ self.policy = PPOFPolicy(self.cfg, model=model, orig_model=orig_model)
if policy_state_dict is not None:
self.policy.load_state_dict(policy_state_dict)
self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
@@ -158,10 +166,14 @@ def train(
pass
with task.start(ctx=OnlineRLContext()):
- task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env))
- task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
- task.use(PPOFStepCollector(self.seed, self.policy, collector_env, self.cfg.n_sample))
- task.use(ppof_adv_estimator(self.policy))
+ if not self.policy._cfg.chat_data:
+ # task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env))
+ # task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
+ task.use(ChatCollector(self.seed, self.policy, collector_env, self.cfg.n_sample))
+ else:
+ task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env))
+ task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
+ task.use(PPOFStepCollector(self.seed, self.policy, collector_env, self.cfg.n_sample))
task.use(multistep_trainer(self.policy, log_freq=n_iter_log_show))
task.use(
wandb_online_logger(
diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py
index beb4894ad9..5e0c0811d4 100644
--- a/ding/framework/middleware/collector.py
+++ b/ding/framework/middleware/collector.py
@@ -1,3 +1,4 @@
+import copy
from typing import TYPE_CHECKING
from easydict import EasyDict
import treetensor.torch as ttorch
@@ -93,7 +94,8 @@ def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll
self.policy = policy
self.n_sample = n_sample
self.unroll_len = unroll_len
- self._transitions = TransitionList(self.env.env_num)
+ self._transitions = Transiti
+ onList(self.env.env_num)
self._env_episode_id = [_ for _ in range(env.env_num)]
self._current_id = env.env_num
@@ -190,4 +192,60 @@ def __call__(self, ctx: "OnlineRLContext") -> None:
break
+class ChatCollector:
+ """
+ Overview:
+ The class of the collector running by steps, including model inference and transition \
+ process. Use the `__call__` method to execute the whole collection process.
+ """
+
+ def __new__(cls, *args, **kwargs):
+ if task.router.is_active and not task.has_role(task.role.COLLECTOR):
+ return task.void()
+ return super(ChatCollector, cls).__new__(cls)
+
+ def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll_len: int = 1) -> None:
+ """
+ Arguments:
+ - seed (:obj:`int`): Random seed.
+ - policy (:obj:`Policy`): The policy to be collected.
+ - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \
+ its derivatives are supported.
+ """
+ self.env = env
+ self.env.seed(seed)
+ self.env.launch()
+ self.policy = policy
+ self.n_sample = n_sample
+ self.unroll_len = unroll_len
+
+ def __call__(self, ctx: "OnlineRLContext") -> None:
+ """
+ Overview:
+ An encapsulation of inference and rollout middleware. Stop when completing \
+ the target number of steps.
+ Input of ctx:
+ - env_step (:obj:`int`): The env steps which will increase during collection.
+ """
+ device = self.policy._device
+
+ obs = ttorch.as_tensor(self.env._env[0].last_batch)
+ obs = obs.to(device)
+
+ total_action = []
+ for _ in range(self.policy._cfg.answers_per_question):
+ _, inference_output = self.policy._model.actor.generate(obs, **ctx.collect_kwargs)
+ total_action.append(copy.deepcopy(inference_output))
+
+ mask, resp, rew = self.env.step(total_action)
+ ctx.env_step += 1
+ ctx.env_episode += 1
+
+ train_data = {}
+ train_data['obs'] = resp # [B x answer-per-question, max_len]
+ train_data['reward'] = rew # [B x answer-per-question, ]
+ train_data['mask'] = mask # [B x answer-per-question, max_len]
+
+ ctx.train_data = ttorch.as_tensor(train_data)
+
# TODO battle collector
diff --git a/ding/model/common/utils.py b/ding/model/common/utils.py
index 0f508de0b8..c1012276c6 100644
--- a/ding/model/common/utils.py
+++ b/ding/model/common/utils.py
@@ -21,3 +21,19 @@ def create_model(cfg: EasyDict) -> torch.nn.Module:
import_module(cfg.pop('import_names', []))
# here we must use the pop opeartion to ensure compatibility
return MODEL_REGISTRY.build(cfg.pop("type"), **cfg)
+
+
+def top_p_logits(logits, topp=0.9, filter_value=0, min_topk=1):
+ """
+ Filter a distribution of logits using nucleus (top-p) filtering
+ https://github.com/OpenLMLab/MOSS/blob/e088f438d1a95d424c6dffef0d73134ebe62cb72/models_jittor/generation.py#L146
+ """
+ cum_logits = logits.clone()
+ if topp > 0:
+ logits_sorted, inds = torch.sort(logits, dim=-1, descending=True)
+ mask = (logits_sorted.cumsum(dim=-1) - logits_sorted) >= topp
+ mask[:, :min_topk] = False
+ # Remove tokens with cumulative top_p above the threshold
+ mask = torch.zeros_like(mask).to(torch.bool).scatter_(dim=-1, index=inds, src=mask)
+ cum_logits[mask] = filter_value
+ cum_logits.div_(cum_logits.sum(dim=-1, keepdim=True))
diff --git a/ding/model/template/vac.py b/ding/model/template/vac.py
index 29363d3570..74bb4041b3 100644
--- a/ding/model/template/vac.py
+++ b/ding/model/template/vac.py
@@ -1,4 +1,5 @@
from typing import Union, Dict, Optional
+from transformers.models.llama.modeling_llama import LlamaForCausalLM
from easydict import EasyDict
import torch
import torch.nn as nn
@@ -7,6 +8,8 @@
from ..common import ReparameterizationHead, RegressionHead, DiscreteHead, MultiHead, \
FCEncoder, ConvEncoder, IMPALAConvEncoder
from ding.torch_utils.network.dreamer import ActionHead, DenseHead
+from ..common.utils import top_p_logits
+from ding.reward_model import LlamaRewardModel
@MODEL_REGISTRY.register('vac')
@@ -425,3 +428,150 @@ def __init__(
outscale=0.0,
device='cuda' if torch.cuda.is_available() else 'cpu',
)
+
+
+class Llama(LlamaForCausalLM):
+ def __init__(self, config, opt, tokenizer):
+ super().__init__(config)
+ self.opt = opt
+ self.tokenizer = tokenizer
+
+ def forward(self, decoder_input, incr_state=None):
+
+ attention_mask = decoder_input.ne(self.tokenizer.pad_token_id)
+ if incr_state is not None:
+ decoder_input = decoder_input[:, -1:]
+
+ output = super().forward(
+ input_ids=decoder_input,
+ attention_mask=attention_mask,
+ past_key_values=incr_state,
+ return_dict=True,
+ use_cache=not self.training
+ )
+
+ logits = output.logits
+ new_incr_states = output.past_key_values
+
+ return logits, new_incr_states
+
+ @torch.no_grad()
+ def generate(self, batch, **kwargs):
+ """
+ Generate response
+ """
+ maxlen_res = kwargs.pop('maxlen_res', self.opt.maxlen_res)
+ temperature = kwargs.pop('temperature', self.opt.temperature)
+ repetition_penalty = kwargs.pop('repetition_penalty', self.opt.repetition_penalty)
+ topp = kwargs.pop('topp', self.opt.topp)
+
+ decoder_input: torch.LongTensor = batch['text_vec'] # (bsz, ...)
+ assert decoder_input[:, -1].ne(
+ self.tokenizer.pad_token_id).all(), 'Last token should not be a padding token (you can use left padding instead).'
+
+ dev = decoder_input.device
+ bsz = decoder_input.size(0)
+
+ scores = torch.zeros((bsz,), device=dev, dtype=torch.float16)
+ done = torch.zeros((bsz,), device=dev).to(torch.bool)
+
+ inds = torch.arange(bsz).to(dev).unsqueeze(1).view(-1)
+ decoder_input = torch.index_select(decoder_input, 0, inds)
+ init_length = decoder_input.size(1)
+
+ incr_state = None
+ for _token in range(maxlen_res):
+ if done.all():
+ break
+ score, incr_state, *_ = self.forward(decoder_input, incr_state)
+ score = score.half()
+
+ # now score is bs, len, vocab_size
+ score = score[:, -1, :]
+
+ # calculate repetition penalty
+ if repetition_penalty > 1.:
+ penalty_tokens = decoder_input[:, init_length:]
+ penalty_scores = torch.gather(score, dim=1, index=penalty_tokens)
+ penalty_scores = torch.where(penalty_scores < 0., penalty_scores * repetition_penalty,
+ penalty_scores / repetition_penalty)
+ score = score.scatter_(dim=1, index=penalty_tokens, src=penalty_scores)
+
+ # nucleus sampling
+ score = torch.softmax(score.div(temperature), dim=-1)
+ probs = top_p_logits(score, topp=topp, filter_value=0)
+ tok_ids = torch.multinomial(probs, 1)[:, 0]
+ hyp_ids = torch.arange(probs.size(0), device=dev)
+ scores = scores + probs[hyp_ids, tok_ids].log() * ~done
+
+ tok_ids = torch.where(done, self.tokenizer.pad_token_id, tok_ids)
+ decoder_input = torch.cat((decoder_input, tok_ids.unsqueeze(-1)), dim=-1)
+ done = done | tok_ids.eq(self.tokenizer.eos_token_id)
+
+ incr_state = self._reorder_cache(incr_state, hyp_ids)
+
+ # get all finalized candidates for each sample
+ decoder_input = decoder_input[:, init_length:]
+ decoder_input = decoder_input.view(bsz, -1)
+ scores = scores.view(bsz, )
+
+ lengths = decoder_input.ne(self.tokenizer.pad_token_id).sum(dim=-1)
+
+ length_penalty = torch.pow(lengths, 1.0)
+ scores /= length_penalty
+
+ preds_scores = []
+ for i in range(bsz):
+ seq: torch.LongTensor = decoder_input[i, :lengths[i,]]
+ res_scores = (float(scores[i,]), seq.tolist())
+ preds_scores.append([res_scores])
+
+ best_preds_scores = [preds[0] for preds in preds_scores]
+ return best_preds_scores, preds_scores
+
+
+@MODEL_REGISTRY.register('llamavac')
+class LlamaVAC(nn.Module):
+ """
+ Overview:
+ The neural network and computation graph of DreamerV3 (state) Value Actor-Critic (VAC).
+ This model now supports discrete, continuous action space.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+ mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']
+
+ def __init__(
+ self,
+ actor_path: str,
+ critic_path: str
+ ) -> None:
+ """
+ Overview:
+ Initialize the ``DREAMERVAC`` model according to arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
+ - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
+ """
+ super(LlamaVAC, self).__init__()
+ self.actor = Llama.from_pretrained(actor_path)
+ self.critic = LlamaRewardModel.from_pretrained(critic_path)
+
+ def forward(self, x: torch.Tensor, mode: str) -> Dict:
+ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
+ return getattr(self, mode)(x)
+
+ def compute_actor(self, x):
+ policy_output = self.actor(decoder_input=x)
+ policy_logit, *_ = policy_output
+ return {"logit": policy_logit}
+
+ def compute_critic(self, x):
+ values = self.critic(decoder_input=x, only_last=False)
+ return {"value": values}
+
+ def compute_actor_critic(self, x):
+ policy_output = self.actor(decoder_input=x)
+ policy_logit, *_ = policy_output
+ values = self.critic(decoder_input=x, only_last=False)
+ return {"logit": policy_logit,"value": values}
diff --git a/ding/policy/ppof.py b/ding/policy/ppof.py
index 81e605384c..e746b9c617 100644
--- a/ding/policy/ppof.py
+++ b/ding/policy/ppof.py
@@ -1,4 +1,4 @@
-from typing import List, Dict, Any, Tuple, Union, Callable, Optional
+from typing import List, Dict, Any, Callable, Optional
from collections import namedtuple
from easydict import EasyDict
import copy
@@ -11,6 +11,7 @@
from ding.rl_utils import ppo_data, ppo_error, ppo_policy_error, ppo_policy_data, gae, gae_data, ppo_error_continuous, \
get_gae, ppo_policy_error_continuous, ArgmaxSampler, MultinomialSampler, ReparameterizationSampler, MuSampler, \
HybridStochasticSampler, HybridDeterminsticSampler, value_transform, value_inv_transform, symlog, inv_symlog
+from ding.rl_utils.gae import episodic_gae_data, episodic_gae
from ding.utils import POLICY_REGISTRY, RunningMeanStd
@@ -37,6 +38,7 @@ class PPOFPolicy:
value_norm='baseline',
ppo_param_init=True,
grad_norm=0.5,
+ chat_data=True,
# collect
n_sample=128,
unroll_len=1,
@@ -58,8 +60,9 @@ def default_model(cls: type) -> Callable:
from .model import PPOFModel
return PPOFModel
- def __init__(self, cfg: "EasyDict", model: torch.nn.Module, enable_mode: List[str] = None) -> None:
+ def __init__(self, cfg: "EasyDict", model: torch.nn.Module, enable_mode: List[str] = None, orig_model: torch.nn.Module = None) -> None:
self._cfg = cfg
+ self._orig_model = orig_model
if model is None:
self._model = self.default_model()
else:
@@ -151,63 +154,78 @@ def _model_param_init(self):
def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
return_infos = []
self._model.train()
- bs = self._cfg.batch_size
- data = data[:self._cfg.n_sample // bs * bs] # rounding
+ if not self._cfg.chat_data:
+ bs = self._cfg.batch_size
+ data = data[:self._cfg.n_sample // bs * bs] # rounding
# outer training loop
for epoch in range(self._cfg.epoch_per_collect):
# recompute adv
with torch.no_grad():
- # get the value dictionary
- # In popart, the dictionary has two keys: 'pred' and 'unnormalized_pred'
- value = self._model.compute_critic(data.obs)
- next_value = self._model.compute_critic(data.next_obs)
- reward = data.reward
-
- assert self._cfg.value_norm in ['popart', 'value_rescale', 'symlog', 'baseline'],\
- 'Not supported value normalization! Value normalization supported: \
- popart, value rescale, symlog, baseline'
-
- if self._cfg.value_norm == 'popart':
- unnormalized_value = value['unnormalized_pred']
- unnormalized_next_value = value['unnormalized_pred']
-
- mu = self._model.critic_head.popart.mu
- sigma = self._model.critic_head.popart.sigma
- reward = (reward - mu) / sigma
-
- value = value['pred']
- next_value = next_value['pred']
- elif self._cfg.value_norm == 'value_rescale':
- value = value_inv_transform(value['pred'])
- next_value = value_inv_transform(next_value['pred'])
- elif self._cfg.value_norm == 'symlog':
- value = inv_symlog(value['pred'])
- next_value = inv_symlog(next_value['pred'])
- elif self._cfg.value_norm == 'baseline':
- value = value['pred'] * self._running_mean_std.std
- next_value = next_value['pred'] * self._running_mean_std.std
-
- traj_flag = data.get('traj_flag', None) # traj_flag indicates termination of trajectory
- adv_data = gae_data(value, next_value, reward, data.done, traj_flag)
- data.adv = gae(adv_data, self._cfg.discount_factor, self._cfg.gae_lambda)
-
- unnormalized_returns = value + data.adv # In popart, this return is normalized
-
- if self._cfg.value_norm == 'popart':
- self._model.critic_head.popart.update_parameters((data.reward).unsqueeze(1))
- elif self._cfg.value_norm == 'value_rescale':
- value = value_transform(value)
- unnormalized_returns = value_transform(unnormalized_returns)
- elif self._cfg.value_norm == 'symlog':
- value = symlog(value)
- unnormalized_returns = symlog(unnormalized_returns)
- elif self._cfg.value_norm == 'baseline':
- value /= self._running_mean_std.std
- unnormalized_returns /= self._running_mean_std.std
- self._running_mean_std.update(unnormalized_returns.cpu().numpy())
- data.value = value
- data.return_ = unnormalized_returns
+ if self._cfg.chat_data:
+ # [B, T]
+ value = self._model.compute_critic(data.obs)
+ data.orig_logit = self._orig_model.compute_actor(data.obs)
+ data.value = value
+ reward = data.reward
+
+ traj_flag = data.get('traj_flag', None) # traj_flag indicates termination of trajectory
+ adv_data = episodic_gae_data(value, data.mask, reward, data.done, traj_flag)
+ data.adv = episodic_gae(adv_data, self._cfg.discount_factor, self._cfg.gae_lambda)
+
+ unnormalized_returns = data.value + data.adv
+ data.return_ = unnormalized_returns
+ else:
+ # get the value dictionary
+ # In popart, the dictionary has two keys: 'pred' and 'unnormalized_pred'
+ value = self._model.compute_critic(data.obs)
+ next_value = self._model.compute_critic(data.next_obs)
+ reward = data.reward
+
+ assert self._cfg.value_norm in ['popart', 'value_rescale', 'symlog', 'baseline'], \
+ 'Not supported value normalization! Value normalization supported: \
+ popart, value rescale, symlog, baseline'
+
+ if self._cfg.value_norm == 'popart':
+ unnormalized_value = value['unnormalized_pred']
+ unnormalized_next_value = value['unnormalized_pred']
+
+ mu = self._model.critic_head.popart.mu
+ sigma = self._model.critic_head.popart.sigma
+ reward = (reward - mu) / sigma
+
+ value = value['pred']
+ next_value = next_value['pred']
+ elif self._cfg.value_norm == 'value_rescale':
+ value = value_inv_transform(value['pred'])
+ next_value = value_inv_transform(next_value['pred'])
+ elif self._cfg.value_norm == 'symlog':
+ value = inv_symlog(value['pred'])
+ next_value = inv_symlog(next_value['pred'])
+ elif self._cfg.value_norm == 'baseline':
+ value = value['pred'] * self._running_mean_std.std
+ next_value = next_value['pred'] * self._running_mean_std.std
+
+ traj_flag = data.get('traj_flag', None) # traj_flag indicates termination of trajectory
+ adv_data = gae_data(value, next_value, reward, data.done, traj_flag)
+ data.adv = gae(adv_data, self._cfg.discount_factor, self._cfg.gae_lambda)
+
+ unnormalized_returns = value + data.adv # In popart, this return is normalized
+
+ if self._cfg.value_norm == 'popart':
+ self._model.critic_head.popart.update_parameters((data.reward).unsqueeze(1))
+ elif self._cfg.value_norm == 'value_rescale':
+ value = value_transform(value)
+ unnormalized_returns = value_transform(unnormalized_returns)
+ elif self._cfg.value_norm == 'symlog':
+ value = symlog(value)
+ unnormalized_returns = symlog(unnormalized_returns)
+ elif self._cfg.value_norm == 'baseline':
+ value /= self._running_mean_std.std
+ unnormalized_returns /= self._running_mean_std.std
+ self._running_mean_std.update(unnormalized_returns.cpu().numpy())
+ data.value = value
+ data.return_ = unnormalized_returns
# inner training loop
split_data = ttorch.split(data, self._cfg.batch_size)
@@ -215,6 +233,7 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
for batch in split_data:
output = self._model.compute_actor_critic(batch.obs)
adv = batch.adv
+ mask = batch.mask
if self._cfg.adv_norm:
# Normalize advantage in a train_batch
adv = (adv - adv.mean()) / (adv.std() + 1e-8)
@@ -226,10 +245,16 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
)
ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._cfg.clip_ratio)
elif self._action_space == 'discrete':
- ppo_batch = ppo_data(
- output.logit, batch.logit, batch.action, output.value, batch.value, adv, batch.return_, None
- )
- ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio)
+ if not self._cfg.chat_data:
+ ppo_batch = ppo_data(
+ output.logit, batch.logit, batch.action, output.value, batch.value, adv, batch.return_, mask
+ )
+ ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio)
+ else:
+ ppo_batch = ppo_data(
+ output.logit, batch.logit, batch.action, output.value, batch.value, adv, batch.return_, None
+ )
+ ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio)
elif self._action_space == 'hybrid':
# discrete part (discrete policy loss and entropy loss)
ppo_discrete_batch = ppo_policy_data(
@@ -253,8 +278,9 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
max(ppo_continuous_info.approx_kl, ppo_discrete_info.approx_kl),
max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac)
)
- wv, we = self._cfg.value_weight, self._cfg.entropy_weight
- total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss
+ wv, we, wk = self._cfg.value_weight, self._cfg.entropy_weight, self._cfg.kl_penalty_weight
+ kl_loss = (torch.nn.functional.kl_div(torch.softmax(output.logit, dim=-1), torch.softmax(data.orig_logit, dim=-1), reduction=None) * mask).mean()
+ total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + wk * kl_loss
self._optimizer.zero_grad()
total_loss.backward()
diff --git a/ding/reward_model/__init__.py b/ding/reward_model/__init__.py
index 4538102861..5b197af25d 100644
--- a/ding/reward_model/__init__.py
+++ b/ding/reward_model/__init__.py
@@ -13,3 +13,5 @@
from .guided_cost_reward_model import GuidedCostRewardModel
from .ngu_reward_model import RndNGURewardModel, EpisodicNGURewardModel
from .icm_reward_model import ICMRewardModel
+# RLHF
+from .language_reward_model import LlamaRewardModel
diff --git a/ding/reward_model/language_reward_model.py b/ding/reward_model/language_reward_model.py
new file mode 100644
index 0000000000..54c6cee23c
--- /dev/null
+++ b/ding/reward_model/language_reward_model.py
@@ -0,0 +1,27 @@
+import torch
+from utils import *
+from transformers.models.llama.modeling_llama import LlamaForCausalLM
+
+
+class LlamaRewardModel(LlamaForCausalLM):
+ def __init__(self, config, opt, tokenizer):
+ super().__init__(config)
+ self.opt = opt
+ self.tokenizer = tokenizer
+ self.reward_head = torch.nn.Linear(config.hidden_size, 1, bias=False)
+
+ def forward(self, decoder_input, only_last=True):
+ attention_mask = decoder_input.ne(self.tokenizer.pad_token_id)
+ output = self.model.forward(
+ input_ids=decoder_input,
+ attention_mask=attention_mask,
+ return_dict=True,
+ use_cache=False
+ )
+
+ if only_last:
+ logits = self.reward_head(output.last_hidden_state[:, -1, :]).squeeze(-1)
+ else:
+ logits = self.reward_head(output.last_hidden_state).squeeze(-1)
+
+ return (logits,)
diff --git a/ding/rl_utils/gae.py b/ding/rl_utils/gae.py
index 800fcae354..487f2c93ef 100644
--- a/ding/rl_utils/gae.py
+++ b/ding/rl_utils/gae.py
@@ -3,6 +3,7 @@
from ding.hpc_rl import hpc_wrapper
gae_data = namedtuple('gae_data', ['value', 'next_value', 'reward', 'done', 'traj_flag'])
+episodic_gae_data = namedtuple('episodic_gae_data', ['value', 'mask', 'reward', 'done', 'traj_flag'])
def shape_fn_gae(args, kwargs):
@@ -68,3 +69,22 @@ def gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97) -> torch.F
gae_item = delta[t] + factor[t] * gae_item
adv[t] = gae_item
return adv
+
+
+def episodic_gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97):
+ value, mask, reward, done, traj_flag = data
+ if done is None:
+ done = torch.zeros_like(value)
+ if traj_flag is None:
+ traj_flag = done
+ advs = []
+ bsz = value.shape[0]
+ for i in range(bsz):
+ val, mas, rew, don, traj = value[i], mask[i], reward[i], done[i], traj_flag[i]
+ assert val.shape[0] == rew.shape[0]
+ next_val = torch.zeros_like(val)
+ next_val[:-1] = val[1:]
+ gd = gae_data(val.unsqueeze(-1), next_val.unsqueeze(-1), rew.unsqueeze(-1), don.unsqueeze(-1),
+ traj.unsqueeze(-1))
+ advs.append(gae(gd, gamma, lambda_))
+ return torch.stack(advs, dim=0)
diff --git a/dizoo/chat/__init__.py b/dizoo/chat/__init__.py
new file mode 100644
index 0000000000..eb1eb48abb
--- /dev/null
+++ b/dizoo/chat/__init__.py
@@ -0,0 +1 @@
+from .env import ChatEnv
diff --git a/dizoo/chat/env.py b/dizoo/chat/env.py
new file mode 100644
index 0000000000..a7af5bfd28
--- /dev/null
+++ b/dizoo/chat/env.py
@@ -0,0 +1,53 @@
+import gym
+import torch
+
+from ding.reward_model import LlamaRewardModel
+from .utils import OnlyPromptDataset, concat_context_and_response, get_tokenizer, pad_sequences
+
+
+class ChatEnv(gym.Env):
+ def __init__(
+ self,
+ batch_size,
+ reward_model_path,
+ tokenizer_path,
+ data_path,
+ maxlen_prompt,
+ maxlen_res,
+ ):
+ self.batch_size = batch_size
+ self.rm = LlamaRewardModel.from_pretrained(reward_model_path)
+ self.tokenizer = get_tokenizer(tokenizer_path)
+
+ self.dataset = OnlyPromptDataset(
+ data_path=data_path,
+ tokenizer=self.tokenizer,
+ batch_size=batch_size,
+ maxlen_prompt=maxlen_prompt,
+ maxlen_res=maxlen_res,
+ mode='train',
+ )
+ self.generator = self.dataset.final_generator()
+ self.last_batch = None
+
+ def reset(self):
+ self.last_batch = next(self.generator)
+ return self.last_batch
+
+ def step(self, action):
+ """
+ For each step, this env will return a batch of prompts. These prompts a vectorized by using tokenizer, and are \
+ padded into the same length.
+ """
+ output_mask, output_vec = concat_context_and_response(self.tokenizer, self.last_batch['text_vec'].tolist(), action)
+ rm_input = torch.tensor(pad_sequences(output_vec, self.tokenizer.pad_token_id, padding='left'), dtype=torch.long)
+ output_mask = pad_sequences(output_mask, self.tokenizer.pad_token_id, padding='left')
+ with torch.no_grad():
+ rew, *_ = self.rm(rm_input)
+
+ self.last_batch = next(self.generator)
+ if self.last_batch is None:
+ self.generator = self.dataset.final_generator()
+ self.last_batch = next(self.generator)
+
+ return output_mask, output_vec, rew
diff --git a/dizoo/chat/utils.py b/dizoo/chat/utils.py
new file mode 100644
index 0000000000..574381fdc8
--- /dev/null
+++ b/dizoo/chat/utils.py
@@ -0,0 +1,243 @@
+import json
+import os
+from typing import List, Dict, Any, Tuple
+import warnings
+
+from transformers.models.llama.tokenization_llama import LlamaTokenizer
+from torch.utils.data.dataset import IterableDataset
+import torch
+import random
+
+
+# Prefix of human sentence and assistant sentence.
+HUMAN_PROMPT = "Human:"
+ASSISTANT_PROMPT = "Assistant:"
+
+
+def strip_pad_token_id(tokenizer: LlamaTokenizer, seq: List[int]):
+ """
+ Overview:
+ Remove ``pad_token_id`` in a sequence.
+ """
+ return [tok for tok in seq if tok != tokenizer.pad_token_id]
+
+
+def concat_context_and_response(
+ tokenizer: LlamaTokenizer,
+ context: List[List[int]],
+ responses: List[List[Tuple[float, List[int]]]]
+):
+ """
+ Overview:
+ Given the batched input prompts and responses, concatenate them together.
+ """
+ assert len(context) == len(responses), f'Size not match: {len(context)} and {len(responses)}'
+
+ total_context, total_response = [], []
+ total_context_mask, total_response_mask = [], []
+ for _context, _response in zip(context, responses):
+ # Each ``_context`` is a single input prompt.
+ _context = strip_pad_token_id(tokenizer, _context)
+ for _, resp in _response:
+ # Each ``resp`` is a single response.
+ resp = strip_pad_token_id(tokenizer, resp)
+ if resp[-1] != tokenizer.eos_token_id:
+ warnings.warn(
+ f'Generated response is too long: {tokenizer.decode(_context + resp, skip_special_tokens=False)}')
+
+ total_context.append(_context.copy())
+ total_context_mask.append([0] * len(_context))
+ total_response.append(resp)
+ total_response_mask.append([1] * len(resp))
+
+ total_gene_samples_vec = [c + r for c, r in zip(total_context, total_response)]
+ total_gene_samples_mask = [c + r for c, r in zip(total_context_mask, total_response_mask)]
+ return total_gene_samples_mask, total_gene_samples_vec
+
+
+def pad_sequences(
+ seqs: List[List[int]],
+ pad_value: int,
+ padding: str = 'right'):
+ """
+ Overview:
+ Padding sequence to the same length
+ """
+ max_len = max(len(seq) for seq in seqs)
+ if padding == 'right':
+ padded_seqs = [seq + [pad_value] * (max_len - len(seq)) for seq in seqs]
+ elif padding == 'left':
+ padded_seqs = [[pad_value] * (max_len - len(seq)) + seq for seq in seqs]
+ else:
+ raise ValueError
+ return padded_seqs
+
+
+def get_special_prompt(i: int):
+ return HUMAN_PROMPT if i % 2 == 0 else ASSISTANT_PROMPT
+
+
+def get_model_prompt(context: List[str], eos_token=""):
+ human_prompt, assistant_prompt = HUMAN_PROMPT, ASSISTANT_PROMPT
+ if context[-1].startswith(human_prompt):
+ end_prompt = assistant_prompt
+ elif context[-1].startswith(assistant_prompt):
+ end_prompt = human_prompt
+ else:
+ raise ValueError
+
+ context = eos_token.join(context)
+ return f'{context}{eos_token}{end_prompt}'
+
+
+def get_tokenizer(path: str):
+ """
+ Overview:
+ Return the pretrained tokenizer using the given path.
+ """
+ tokenizer = LlamaTokenizer.from_pretrained(path, trust_remote_code=True)
+ tokenizer.bos_token = ''
+ tokenizer.eos_token = ''
+ tokenizer.pad_token = ''
+ tokenizer.pad_token_id = 0
+ tokenizer.unk_token = tokenizer.pad_token
+ tokenizer.unk_token_id = tokenizer.pad_token_id
+
+ return tokenizer
+
+
+class OnlyPromptDataset(IterableDataset):
+ """
+ Overview:
+ Dataset that only contains the prompts of the raw data (no answer).
+ """
+ def __init__(
+ self,
+ data_path: os.PathLike,
+ tokenizer,
+ batch_size: int,
+ maxlen_prompt: int,
+ maxlen_res: int,
+ mode: str = 'train',
+ ) -> None:
+ super().__init__()
+ self.mode = mode
+ self.tokenizer = tokenizer
+ self.maxlen_prompt = maxlen_prompt
+ self.maxlen_res = maxlen_res
+ self.batch_size = batch_size
+
+ # Load data.
+ self.data = []
+ files = sorted([file for file in os.listdir(data_path) if file.endswith(f'{mode}.json')])
+ for file in files:
+ file_path = os.path.join(data_path, file)
+ tmp_data = []
+ try:
+ tmp_data = self.load_data(file_path)
+ except Exception as e:
+ pass
+ self.data.extend(tmp_data)
+
+ # Set the length of this dataset.
+ self.size = len(self.data)
+
+ def __len__(self):
+ return self.size
+
+ def load_data(self, file_path: str):
+ """
+ Overview:
+ Load raw data from given file_path.
+ """
+ with open(file_path, 'r') as f:
+ data: List[List[str]] = json.load(f)
+
+ output: List[List[str]] = [sample for sample in data if all(sample)]
+ del data
+
+ return output
+
+ def final_generator(self):
+ data_generator = self.batch_generator()
+ for batch_samples in data_generator:
+ batch = self.batchify(batch_samples)
+ yield batch
+
+ def __iter__(self):
+ return self.final_generator()
+
+ def format(self, sample: List[str]) -> Dict[str, Any]:
+ """
+ Overview:
+ Convert one data sample in to string.
+ """
+ context = sample
+ context = [get_special_prompt(i + (len(context) + 1) % 2) + s for i, s in enumerate(context)]
+ context_vec = self.tokenizer.encode(get_model_prompt(context, self.tokenizer.eos_token),
+ add_special_tokens=True)
+
+ # truncate to max_len
+ while len(context_vec) > self.maxlen_prompt - self.maxlen_res and len(context) > 1:
+ context = context[1:]
+ context_vec = self.tokenizer.encode(get_model_prompt(context, self.tokenizer.eos_token),
+ add_special_tokens=True)
+
+ output = {
+ 'text': self.tokenizer.decode(context_vec, skip_special_tokens=False),
+ 'text_vec': context_vec
+ }
+
+ return output
+
+ def batchify(self, batch_samples: List[Dict[str, Any]]) -> Dict[str, Any]:
+ """
+ Overview:
+ Batchify a list of ids by padding their shape to be the same.
+ """
+ batch_text_vec = torch.tensor(pad_sequences(
+ [sample['text_vec'] for sample in batch_samples], pad_value=self.tokenizer.pad_token_id, padding='left'
+ ), dtype=torch.long)
+ return {
+ 'text_vec': batch_text_vec,
+ 'text': [sample['text'] for sample in batch_samples]
+ }
+
+ def sample_generator(self):
+ """
+ Overview:
+ Generate a single data sample.
+ """
+ random.seed(None)
+ if self.mode == 'train':
+ random.shuffle(self.data)
+
+ for sample in self.data:
+ yield self.format(sample)
+
+ def _batch_generator(self):
+ """
+ Overview:
+ Generate a batch of samples.
+ """
+ batch = []
+ # Generate a sample.
+ for sample in self.sample_generator():
+ sample_len = len(sample['text_vec'])
+ if sample_len > self.maxlen_prompt:
+ continue
+
+ batch.append(sample)
+ if len(batch) >= self.batch_size:
+ yield batch[:self.batch_size]
+ batch = batch[self.batch_size:]
+ if batch:
+ yield batch
+
+ def batch_generator(self):
+ while True:
+ for batch in self._batch_generator():
+ if len(batch) == self.batch_size:
+ yield batch
+ if self.mode != 'train':
+ break
diff --git a/launch_ppof.py b/launch_ppof.py
new file mode 100644
index 0000000000..67dc26ad1e
--- /dev/null
+++ b/launch_ppof.py
@@ -0,0 +1,20 @@
+from ding.bonus.ppof import PPOF
+from ding.model.template.vac import LlamaVAC
+
+if __name__ == '__main__':
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--actor_path', type=str)
+ parser.add_argument('--critic_path', type=str)
+ args = parser.parse_args()
+ model = LlamaVAC(
+ actor_path=args.actor_path,
+ critic_path=args.critic_path
+ )
+
+ policy = PPOF(
+ env_id="prompt-generator",
+ exp_name="rlhf-ppo",
+ model=model
+ )
+ policy.train()
From eb5d8c83fdf14115799396611e1288302139433d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’>
Date: Wed, 25 Oct 2023 15:56:34 +0800
Subject: [PATCH 02/17] debug
---
ding/model/template/vac.py | 7 +++--
ding/reward_model/language_reward_model.py | 1 -
launch_ppof.py | 32 +++++++++++++++++++++-
3 files changed, 36 insertions(+), 4 deletions(-)
diff --git a/ding/model/template/vac.py b/ding/model/template/vac.py
index 74bb4041b3..ed704385ae 100644
--- a/ding/model/template/vac.py
+++ b/ding/model/template/vac.py
@@ -1,4 +1,6 @@
from typing import Union, Dict, Optional
+
+from transformers import LlamaTokenizer
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from easydict import EasyDict
import torch
@@ -544,7 +546,8 @@ class LlamaVAC(nn.Module):
def __init__(
self,
actor_path: str,
- critic_path: str
+ critic_path: str,
+ tokenizer: LlamaTokenizer
) -> None:
"""
Overview:
@@ -554,7 +557,7 @@ def __init__(
- action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
"""
super(LlamaVAC, self).__init__()
- self.actor = Llama.from_pretrained(actor_path)
+ self.actor = Llama.from_pretrained(actor_path, tokenizer=tokenizer)
self.critic = LlamaRewardModel.from_pretrained(critic_path)
def forward(self, x: torch.Tensor, mode: str) -> Dict:
diff --git a/ding/reward_model/language_reward_model.py b/ding/reward_model/language_reward_model.py
index 54c6cee23c..a2a5544b1d 100644
--- a/ding/reward_model/language_reward_model.py
+++ b/ding/reward_model/language_reward_model.py
@@ -1,5 +1,4 @@
import torch
-from utils import *
from transformers.models.llama.modeling_llama import LlamaForCausalLM
diff --git a/launch_ppof.py b/launch_ppof.py
index 67dc26ad1e..5192e8dbb2 100644
--- a/launch_ppof.py
+++ b/launch_ppof.py
@@ -1,15 +1,45 @@
+from easydict import EasyDict
+from transformers import LlamaTokenizer
+
from ding.bonus.ppof import PPOF
from ding.model.template.vac import LlamaVAC
+
+def get_tokenizer(path: str):
+ """
+ Overview:
+ Return the pretrained tokenizer using the given path.
+ """
+ tokenizer = LlamaTokenizer.from_pretrained(path, trust_remote_code=True)
+ tokenizer.bos_token = ''
+ tokenizer.eos_token = ''
+ tokenizer.pad_token = ''
+ tokenizer.pad_token_id = 0
+ tokenizer.unk_token = tokenizer.pad_token
+ tokenizer.unk_token_id = tokenizer.pad_token_id
+
+ return tokenizer
+
+
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--actor_path', type=str)
parser.add_argument('--critic_path', type=str)
args = parser.parse_args()
+
+ opt = EasyDict({
+ "maxlen_res": 512,
+ "temperature": 1,
+ "repetition_penalty": 1,
+ "topp": 0.8
+ })
+
model = LlamaVAC(
actor_path=args.actor_path,
- critic_path=args.critic_path
+ critic_path=args.critic_path,
+ tokenizer=get_tokenizer(args.actor_path),
+ opt=opt
)
policy = PPOF(
From f45b7427210b7e9d51700a286e76958f05c0aa6f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’>
Date: Wed, 25 Oct 2023 16:25:57 +0800
Subject: [PATCH 03/17] debug
---
ding/bonus/config.py | 10 ++++++++++
ding/model/template/vac.py | 9 +++++----
launch_ppof.py | 4 ++--
3 files changed, 17 insertions(+), 6 deletions(-)
diff --git a/ding/bonus/config.py b/ding/bonus/config.py
index 3b5fe33463..86f1d227da 100644
--- a/ding/bonus/config.py
+++ b/ding/bonus/config.py
@@ -321,6 +321,16 @@ def get_instance_env(env_id: str) -> BaseEnv:
)
cfg = EasyDict(cfg)
return DriveEnvWrapper(MetaDrivePPOOriginEnv(cfg))
+ elif env_id == 'chat':
+ from dizoo.chat.env import ChatEnv
+ return ChatEnv(
+ batch_size=2,
+ reward_model_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-rm-model-7B-en/recover",
+ tokenizer_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-rm-model-7B-en",
+ data_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/data/ppo_data",
+ maxlen_prompt=1024,
+ maxlen_res=512,
+ )
else:
raise KeyError("not supported env type: {}".format(env_id))
diff --git a/ding/model/template/vac.py b/ding/model/template/vac.py
index ed704385ae..ec4990b750 100644
--- a/ding/model/template/vac.py
+++ b/ding/model/template/vac.py
@@ -547,7 +547,8 @@ def __init__(
self,
actor_path: str,
critic_path: str,
- tokenizer: LlamaTokenizer
+ tokenizer: LlamaTokenizer,
+ opt: Dict
) -> None:
"""
Overview:
@@ -557,8 +558,8 @@ def __init__(
- action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
"""
super(LlamaVAC, self).__init__()
- self.actor = Llama.from_pretrained(actor_path, tokenizer=tokenizer)
- self.critic = LlamaRewardModel.from_pretrained(critic_path)
+ self.actor = Llama.from_pretrained(actor_path, opt=opt, tokenizer=tokenizer)
+ self.critic = LlamaRewardModel.from_pretrained(critic_path, opt=opt, tokenizer=tokenizer)
def forward(self, x: torch.Tensor, mode: str) -> Dict:
assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
@@ -577,4 +578,4 @@ def compute_actor_critic(self, x):
policy_output = self.actor(decoder_input=x)
policy_logit, *_ = policy_output
values = self.critic(decoder_input=x, only_last=False)
- return {"logit": policy_logit,"value": values}
+ return {"logit": policy_logit, "value": values}
diff --git a/launch_ppof.py b/launch_ppof.py
index 5192e8dbb2..a58ac10ae2 100644
--- a/launch_ppof.py
+++ b/launch_ppof.py
@@ -38,12 +38,12 @@ def get_tokenizer(path: str):
model = LlamaVAC(
actor_path=args.actor_path,
critic_path=args.critic_path,
- tokenizer=get_tokenizer(args.actor_path),
+ tokenizer=get_tokenizer("/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-sft-model-7B-en"),
opt=opt
)
policy = PPOF(
- env_id="prompt-generator",
+ env_id="chat",
exp_name="rlhf-ppo",
model=model
)
From 1ca316fbd40cce97d1a48a3fc4b3118945d5158b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’>
Date: Wed, 25 Oct 2023 16:42:54 +0800
Subject: [PATCH 04/17] debug
---
ding/bonus/config.py | 4 ++--
dizoo/chat/env.py | 2 +-
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/ding/bonus/config.py b/ding/bonus/config.py
index 86f1d227da..45b3624a6b 100644
--- a/ding/bonus/config.py
+++ b/ding/bonus/config.py
@@ -325,8 +325,8 @@ def get_instance_env(env_id: str) -> BaseEnv:
from dizoo.chat.env import ChatEnv
return ChatEnv(
batch_size=2,
- reward_model_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-rm-model-7B-en/recover",
- tokenizer_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-rm-model-7B-en",
+ reward_model_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en/recover",
+ tokenizer_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en",
data_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/data/ppo_data",
maxlen_prompt=1024,
maxlen_res=512,
diff --git a/dizoo/chat/env.py b/dizoo/chat/env.py
index a7af5bfd28..ec9f1aa48c 100644
--- a/dizoo/chat/env.py
+++ b/dizoo/chat/env.py
@@ -16,8 +16,8 @@ def __init__(
maxlen_res,
):
self.batch_size = batch_size
- self.rm = LlamaRewardModel.from_pretrained(reward_model_path)
self.tokenizer = get_tokenizer(tokenizer_path)
+ self.rm = LlamaRewardModel.from_pretrained(reward_model_path, tokenizer=self.tokenizer)
self.dataset = OnlyPromptDataset(
data_path=data_path,
From db72c2c0ebc104e423e9c5f009541bfae2c12f18 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’>
Date: Wed, 25 Oct 2023 17:35:45 +0800
Subject: [PATCH 05/17] debug
---
ding/bonus/ppof.py | 2 ++
dizoo/chat/env.py | 4 +++-
2 files changed, 5 insertions(+), 1 deletion(-)
diff --git a/ding/bonus/ppof.py b/ding/bonus/ppof.py
index 297e842844..3d10d32696 100644
--- a/ding/bonus/ppof.py
+++ b/ding/bonus/ppof.py
@@ -112,6 +112,8 @@ def __init__(
action_shape = int(action_space.n)
elif isinstance(action_space, (gym.spaces.Tuple, gymnasium.spaces.Tuple)):
action_shape = get_hybrid_shape(action_space)
+ elif action_space is None:
+ pass
else:
action_shape = action_space.shape
diff --git a/dizoo/chat/env.py b/dizoo/chat/env.py
index ec9f1aa48c..843c6bbd38 100644
--- a/dizoo/chat/env.py
+++ b/dizoo/chat/env.py
@@ -1,5 +1,6 @@
import gym
import torch
+from easydict import EasyDict
from ding.reward_model import LlamaRewardModel
from .utils import OnlyPromptDataset, concat_context_and_response, get_tokenizer, pad_sequences
@@ -17,7 +18,8 @@ def __init__(
):
self.batch_size = batch_size
self.tokenizer = get_tokenizer(tokenizer_path)
- self.rm = LlamaRewardModel.from_pretrained(reward_model_path, tokenizer=self.tokenizer)
+ self.rm = LlamaRewardModel.from_pretrained(reward_model_path, tokenizer=self.tokenizer, opt=None)
+ self.action_space = None
self.dataset = OnlyPromptDataset(
data_path=data_path,
From 3f1e47b84300495c882d547de2f012ad5a3361bb Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’>
Date: Thu, 26 Oct 2023 11:38:16 +0800
Subject: [PATCH 06/17] debug
---
ding/bonus/config.py | 1 +
dizoo/chat/env.py | 4 ++--
launch_ppof.py | 2 +-
3 files changed, 4 insertions(+), 3 deletions(-)
diff --git a/ding/bonus/config.py b/ding/bonus/config.py
index 45b3624a6b..113eaf0943 100644
--- a/ding/bonus/config.py
+++ b/ding/bonus/config.py
@@ -173,6 +173,7 @@ def get_instance_config(env_id: str, algorithm: str) -> EasyDict:
cfg.learning_rate = 5e-7
cfg.answers_per_question = 3
cfg.kl_penalty_weight = 0.1
+ cfg.ppo_param_init = False
else:
raise KeyError("not supported env type: {}".format(env_id))
else:
diff --git a/dizoo/chat/env.py b/dizoo/chat/env.py
index 843c6bbd38..3d03049ac7 100644
--- a/dizoo/chat/env.py
+++ b/dizoo/chat/env.py
@@ -1,12 +1,12 @@
import gym
import torch
-from easydict import EasyDict
+from ding.envs import BaseEnv
from ding.reward_model import LlamaRewardModel
from .utils import OnlyPromptDataset, concat_context_and_response, get_tokenizer, pad_sequences
-class ChatEnv(gym.Env):
+class ChatEnv(BaseEnv):
def __init__(
self,
batch_size,
diff --git a/launch_ppof.py b/launch_ppof.py
index a58ac10ae2..e360324f11 100644
--- a/launch_ppof.py
+++ b/launch_ppof.py
@@ -47,4 +47,4 @@ def get_tokenizer(path: str):
exp_name="rlhf-ppo",
model=model
)
- policy.train()
+ policy.train(collector_env_num=1, evaluator_env_num=1)
From dc4ceceaffc51b7d431b24b5516bc7ae9f1d16c5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’>
Date: Thu, 26 Oct 2023 11:46:47 +0800
Subject: [PATCH 07/17] debug
---
dizoo/chat/env.py | 12 ++++++++++++
1 file changed, 12 insertions(+)
diff --git a/dizoo/chat/env.py b/dizoo/chat/env.py
index 3d03049ac7..5fb5a1ebc9 100644
--- a/dizoo/chat/env.py
+++ b/dizoo/chat/env.py
@@ -20,6 +20,8 @@ def __init__(
self.tokenizer = get_tokenizer(tokenizer_path)
self.rm = LlamaRewardModel.from_pretrained(reward_model_path, tokenizer=self.tokenizer, opt=None)
self.action_space = None
+ self._init_flag = False
+ self._seed = None
self.dataset = OnlyPromptDataset(
data_path=data_path,
@@ -32,10 +34,20 @@ def __init__(
self.generator = self.dataset.final_generator()
self.last_batch = None
+ def close(self) -> None:
+ self._init_flag = False
+
def reset(self):
self.last_batch = next(self.generator)
+ self._init_flag = True
return self.last_batch
+ def __repr__(self) -> str:
+ return "DI-engine Chat Env"
+
+ def seed(self, seed):
+ self._seed = 0
+
def step(self, action):
"""
For each step, this env will return a batch of prompts. These prompts a vectorized by using tokenizer, and are \
From 30f29949b6d77c39338373d3b9c03bd19511a3fc Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’>
Date: Thu, 26 Oct 2023 13:37:56 +0800
Subject: [PATCH 08/17] debug
---
ding/bonus/ppof.py | 2 +-
ding/framework/middleware/collector.py | 5 ++---
dizoo/chat/env.py | 6 ++++++
3 files changed, 9 insertions(+), 4 deletions(-)
diff --git a/ding/bonus/ppof.py b/ding/bonus/ppof.py
index 3d10d32696..fc755bd25a 100644
--- a/ding/bonus/ppof.py
+++ b/ding/bonus/ppof.py
@@ -168,7 +168,7 @@ def train(
pass
with task.start(ctx=OnlineRLContext()):
- if not self.policy._cfg.chat_data:
+ if self.policy._cfg.chat_data:
# task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env))
# task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
task.use(ChatCollector(self.seed, self.policy, collector_env, self.cfg.n_sample))
diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py
index 5e0c0811d4..8d816eced1 100644
--- a/ding/framework/middleware/collector.py
+++ b/ding/framework/middleware/collector.py
@@ -94,8 +94,7 @@ def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll
self.policy = policy
self.n_sample = n_sample
self.unroll_len = unroll_len
- self._transitions = Transiti
- onList(self.env.env_num)
+ self._transitions = TransitionList(self.env.env_num)
self._env_episode_id = [_ for _ in range(env.env_num)]
self._current_id = env.env_num
@@ -214,7 +213,7 @@ def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll
"""
self.env = env
self.env.seed(seed)
- self.env.launch()
+ self.env.reset()
self.policy = policy
self.n_sample = n_sample
self.unroll_len = unroll_len
diff --git a/dizoo/chat/env.py b/dizoo/chat/env.py
index 5fb5a1ebc9..e92af3125a 100644
--- a/dizoo/chat/env.py
+++ b/dizoo/chat/env.py
@@ -20,6 +20,9 @@ def __init__(
self.tokenizer = get_tokenizer(tokenizer_path)
self.rm = LlamaRewardModel.from_pretrained(reward_model_path, tokenizer=self.tokenizer, opt=None)
self.action_space = None
+ self.observation_space = None
+ self.reward_space = None
+
self._init_flag = False
self._seed = None
@@ -48,6 +51,9 @@ def __repr__(self) -> str:
def seed(self, seed):
self._seed = 0
+ def clone(self, caller):
+ return self
+
def step(self, action):
"""
For each step, this env will return a batch of prompts. These prompts a vectorized by using tokenizer, and are \
From c1cc4545c80556e843a43a08f4a87d8f4bc62081 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’>
Date: Thu, 26 Oct 2023 14:08:55 +0800
Subject: [PATCH 09/17] debug
---
ding/bonus/config.py | 4 ++--
ding/framework/middleware/collector.py | 1 +
launch_ppof.py | 2 +-
3 files changed, 4 insertions(+), 3 deletions(-)
diff --git a/ding/bonus/config.py b/ding/bonus/config.py
index 113eaf0943..a55ec1bb44 100644
--- a/ding/bonus/config.py
+++ b/ding/bonus/config.py
@@ -324,14 +324,14 @@ def get_instance_env(env_id: str) -> BaseEnv:
return DriveEnvWrapper(MetaDrivePPOOriginEnv(cfg))
elif env_id == 'chat':
from dizoo.chat.env import ChatEnv
- return ChatEnv(
+ return DingEnvWrapper(ChatEnv(
batch_size=2,
reward_model_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en/recover",
tokenizer_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en",
data_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/data/ppo_data",
maxlen_prompt=1024,
maxlen_res=512,
- )
+ ))
else:
raise KeyError("not supported env type: {}".format(env_id))
diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py
index 8d816eced1..69e8bc04f3 100644
--- a/ding/framework/middleware/collector.py
+++ b/ding/framework/middleware/collector.py
@@ -213,6 +213,7 @@ def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll
"""
self.env = env
self.env.seed(seed)
+ self._closed = False
self.env.reset()
self.policy = policy
self.n_sample = n_sample
diff --git a/launch_ppof.py b/launch_ppof.py
index e360324f11..53825dcba2 100644
--- a/launch_ppof.py
+++ b/launch_ppof.py
@@ -47,4 +47,4 @@ def get_tokenizer(path: str):
exp_name="rlhf-ppo",
model=model
)
- policy.train(collector_env_num=1, evaluator_env_num=1)
+ policy.train(collector_env_num=1, evaluator_env_num=1, debug=True)
From 24de047f2831b09d90215c8283c245a014324d36 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’>
Date: Thu, 26 Oct 2023 14:19:52 +0800
Subject: [PATCH 10/17] debug
---
ding/bonus/config.py | 4 ++--
ding/framework/middleware/collector.py | 5 ++---
2 files changed, 4 insertions(+), 5 deletions(-)
diff --git a/ding/bonus/config.py b/ding/bonus/config.py
index a55ec1bb44..113eaf0943 100644
--- a/ding/bonus/config.py
+++ b/ding/bonus/config.py
@@ -324,14 +324,14 @@ def get_instance_env(env_id: str) -> BaseEnv:
return DriveEnvWrapper(MetaDrivePPOOriginEnv(cfg))
elif env_id == 'chat':
from dizoo.chat.env import ChatEnv
- return DingEnvWrapper(ChatEnv(
+ return ChatEnv(
batch_size=2,
reward_model_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en/recover",
tokenizer_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en",
data_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/data/ppo_data",
maxlen_prompt=1024,
maxlen_res=512,
- ))
+ )
else:
raise KeyError("not supported env type: {}".format(env_id))
diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py
index 69e8bc04f3..6a3e81735a 100644
--- a/ding/framework/middleware/collector.py
+++ b/ding/framework/middleware/collector.py
@@ -211,9 +211,8 @@ def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll
- env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \
its derivatives are supported.
"""
- self.env = env
+ self.env = env._env[0]
self.env.seed(seed)
- self._closed = False
self.env.reset()
self.policy = policy
self.n_sample = n_sample
@@ -229,7 +228,7 @@ def __call__(self, ctx: "OnlineRLContext") -> None:
"""
device = self.policy._device
- obs = ttorch.as_tensor(self.env._env[0].last_batch)
+ obs = ttorch.as_tensor(self.env.last_batch)
obs = obs.to(device)
total_action = []
From c9b71eec63ed66bd2b47a4ea8eb796c40dd7f452 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’>
Date: Mon, 6 Nov 2023 10:13:32 +0800
Subject: [PATCH 11/17] debug
---
ding/framework/middleware/collector.py | 17 ++++++++++-------
ding/model/common/utils.py | 1 +
ding/model/template/vac.py | 4 ++--
ding/policy/ppof.py | 15 ++++++++++-----
ding/rl_utils/gae.py | 2 +-
dizoo/chat/env.py | 4 ++--
dizoo/chat/utils.py | 3 ++-
7 files changed, 28 insertions(+), 18 deletions(-)
diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py
index 6a3e81735a..261e8fbe66 100644
--- a/ding/framework/middleware/collector.py
+++ b/ding/framework/middleware/collector.py
@@ -211,9 +211,10 @@ def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll
- env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \
its derivatives are supported.
"""
- self.env = env._env[0]
+ self.env = env
self.env.seed(seed)
- self.env.reset()
+ self.env.launch()
+ self.env = self._envs[0]
self.policy = policy
self.n_sample = n_sample
self.unroll_len = unroll_len
@@ -228,22 +229,24 @@ def __call__(self, ctx: "OnlineRLContext") -> None:
"""
device = self.policy._device
- obs = ttorch.as_tensor(self.env.last_batch)
+ obs = ttorch.as_tensor(self.env.last_batch[0]['text_vec'])
+ batch_size = obs.shape[0]
obs = obs.to(device)
- total_action = []
+ total_action = [[] for _ in range(batch_size)] # [B, answers_per_question, T]
for _ in range(self.policy._cfg.answers_per_question):
_, inference_output = self.policy._model.actor.generate(obs, **ctx.collect_kwargs)
- total_action.append(copy.deepcopy(inference_output))
+ for i in range(batch_size):
+ total_action[i].append(copy.deepcopy(inference_output[i]))
mask, resp, rew = self.env.step(total_action)
ctx.env_step += 1
ctx.env_episode += 1
train_data = {}
- train_data['obs'] = resp # [B x answer-per-question, max_len]
+ train_data['obs'] = resp # [B x answer-per-question, T]
train_data['reward'] = rew # [B x answer-per-question, ]
- train_data['mask'] = mask # [B x answer-per-question, max_len]
+ train_data['mask'] = mask # [B x answer-per-question, T]
ctx.train_data = ttorch.as_tensor(train_data)
diff --git a/ding/model/common/utils.py b/ding/model/common/utils.py
index c1012276c6..340749eb0c 100644
--- a/ding/model/common/utils.py
+++ b/ding/model/common/utils.py
@@ -37,3 +37,4 @@ def top_p_logits(logits, topp=0.9, filter_value=0, min_topk=1):
mask = torch.zeros_like(mask).to(torch.bool).scatter_(dim=-1, index=inds, src=mask)
cum_logits[mask] = filter_value
cum_logits.div_(cum_logits.sum(dim=-1, keepdim=True))
+ return cum_logits
diff --git a/ding/model/template/vac.py b/ding/model/template/vac.py
index ec4990b750..34e0f37198 100644
--- a/ding/model/template/vac.py
+++ b/ding/model/template/vac.py
@@ -467,7 +467,7 @@ def generate(self, batch, **kwargs):
repetition_penalty = kwargs.pop('repetition_penalty', self.opt.repetition_penalty)
topp = kwargs.pop('topp', self.opt.topp)
- decoder_input: torch.LongTensor = batch['text_vec'] # (bsz, ...)
+ decoder_input: torch.LongTensor = batch # (bsz, ...)
assert decoder_input[:, -1].ne(
self.tokenizer.pad_token_id).all(), 'Last token should not be a padding token (you can use left padding instead).'
@@ -558,7 +558,7 @@ def __init__(
- action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
"""
super(LlamaVAC, self).__init__()
- self.actor = Llama.from_pretrained(actor_path, opt=opt, tokenizer=tokenizer)
+ self.actor = Llama.from_pretrained(actor_path, opt=opt, tokenizer=tokenizer,)
self.critic = LlamaRewardModel.from_pretrained(critic_path, opt=opt, tokenizer=tokenizer)
def forward(self, x: torch.Tensor, mode: str) -> Dict:
diff --git a/ding/policy/ppof.py b/ding/policy/ppof.py
index e746b9c617..a445413ae1 100644
--- a/ding/policy/ppof.py
+++ b/ding/policy/ppof.py
@@ -164,13 +164,18 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
with torch.no_grad():
if self._cfg.chat_data:
# [B, T]
- value = self._model.compute_critic(data.obs)
- data.orig_logit = self._orig_model.compute_actor(data.obs)
+ value = self._model.compute_critic(data.obs)['value'][0]
+ self._model.cpu()
+ self._orig_model.cuda()
+ data.orig_logit = self._orig_model.compute_actor(data.obs)['logit']
+ self._orig_model.cpu()
+ self._model.cuda()
data.value = value
reward = data.reward
traj_flag = data.get('traj_flag', None) # traj_flag indicates termination of trajectory
- adv_data = episodic_gae_data(value, data.mask, reward, data.done, traj_flag)
+ done = data.get('done', None)
+ adv_data = episodic_gae_data(value, data.mask, reward, done, traj_flag)
data.adv = episodic_gae(adv_data, self._cfg.discount_factor, self._cfg.gae_lambda)
unnormalized_returns = data.value + data.adv
@@ -252,7 +257,7 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio)
else:
ppo_batch = ppo_data(
- output.logit, batch.logit, batch.action, output.value, batch.value, adv, batch.return_, None
+ output['logit'], batch.orig_logit, batch.obs, output['value'][0], batch.value, adv, batch.return_, None
)
ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio)
elif self._action_space == 'hybrid':
@@ -279,7 +284,7 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac)
)
wv, we, wk = self._cfg.value_weight, self._cfg.entropy_weight, self._cfg.kl_penalty_weight
- kl_loss = (torch.nn.functional.kl_div(torch.softmax(output.logit, dim=-1), torch.softmax(data.orig_logit, dim=-1), reduction=None) * mask).mean()
+ kl_loss = (torch.nn.functional.kl_div(torch.softmax(output.logit, dim=-1), torch.softmax(batch.orig_logit, dim=-1), reduction='none') * mask).mean()
total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + wk * kl_loss
self._optimizer.zero_grad()
diff --git a/ding/rl_utils/gae.py b/ding/rl_utils/gae.py
index 487f2c93ef..1d90a89d12 100644
--- a/ding/rl_utils/gae.py
+++ b/ding/rl_utils/gae.py
@@ -86,5 +86,5 @@ def episodic_gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97):
next_val[:-1] = val[1:]
gd = gae_data(val.unsqueeze(-1), next_val.unsqueeze(-1), rew.unsqueeze(-1), don.unsqueeze(-1),
traj.unsqueeze(-1))
- advs.append(gae(gd, gamma, lambda_))
+ advs.append(gae(gd, gamma, lambda_).squeeze(-1))
return torch.stack(advs, dim=0)
diff --git a/dizoo/chat/env.py b/dizoo/chat/env.py
index e92af3125a..01f0637a9b 100644
--- a/dizoo/chat/env.py
+++ b/dizoo/chat/env.py
@@ -1,4 +1,3 @@
-import gym
import torch
from ding.envs import BaseEnv
@@ -60,7 +59,8 @@ def step(self, action):
padded into the same length.
"""
output_mask, output_vec = concat_context_and_response(self.tokenizer, self.last_batch['text_vec'].tolist(), action)
- rm_input = torch.tensor(pad_sequences(output_vec, self.tokenizer.pad_token_id, padding='left'), dtype=torch.long)
+ output_vec = pad_sequences(output_vec, self.tokenizer.pad_token_id, padding='left')
+ rm_input = torch.tensor(output_vec, dtype=torch.long)
output_mask = pad_sequences(output_mask, self.tokenizer.pad_token_id, padding='left')
with torch.no_grad():
rew, *_ = self.rm(rm_input)
diff --git a/dizoo/chat/utils.py b/dizoo/chat/utils.py
index 574381fdc8..be05e76805 100644
--- a/dizoo/chat/utils.py
+++ b/dizoo/chat/utils.py
@@ -38,8 +38,9 @@ def concat_context_and_response(
for _context, _response in zip(context, responses):
# Each ``_context`` is a single input prompt.
_context = strip_pad_token_id(tokenizer, _context)
- for _, resp in _response:
+ for resp in _response:
# Each ``resp`` is a single response.
+ resp = resp[0][1]
resp = strip_pad_token_id(tokenizer, resp)
if resp[-1] != tokenizer.eos_token_id:
warnings.warn(
From 4418c68cc369c5f79b7a38c12e1fd145d46026f0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’>
Date: Mon, 6 Nov 2023 10:45:11 +0800
Subject: [PATCH 12/17] reformat
---
ding/framework/middleware/collector.py | 1 +
ding/model/template/vac.py | 31 +++++++++++-----------
ding/policy/ppof.py | 29 +++++++++++++++-----
ding/reward_model/language_reward_model.py | 8 +++---
ding/rl_utils/gae.py | 5 ++--
5 files changed, 46 insertions(+), 28 deletions(-)
diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py
index 261e8fbe66..7ad8650dbe 100644
--- a/ding/framework/middleware/collector.py
+++ b/ding/framework/middleware/collector.py
@@ -250,4 +250,5 @@ def __call__(self, ctx: "OnlineRLContext") -> None:
ctx.train_data = ttorch.as_tensor(train_data)
+
# TODO battle collector
diff --git a/ding/model/template/vac.py b/ding/model/template/vac.py
index 34e0f37198..60cb596e55 100644
--- a/ding/model/template/vac.py
+++ b/ding/model/template/vac.py
@@ -433,6 +433,7 @@ def __init__(
class Llama(LlamaForCausalLM):
+
def __init__(self, config, opt, tokenizer):
super().__init__(config)
self.opt = opt
@@ -469,13 +470,14 @@ def generate(self, batch, **kwargs):
decoder_input: torch.LongTensor = batch # (bsz, ...)
assert decoder_input[:, -1].ne(
- self.tokenizer.pad_token_id).all(), 'Last token should not be a padding token (you can use left padding instead).'
+ self.tokenizer.pad_token_id
+ ).all(), 'Last token should not be a padding token (you can use left padding instead).'
dev = decoder_input.device
bsz = decoder_input.size(0)
- scores = torch.zeros((bsz,), device=dev, dtype=torch.float16)
- done = torch.zeros((bsz,), device=dev).to(torch.bool)
+ scores = torch.zeros((bsz, ), device=dev, dtype=torch.float16)
+ done = torch.zeros((bsz, ), device=dev).to(torch.bool)
inds = torch.arange(bsz).to(dev).unsqueeze(1).view(-1)
decoder_input = torch.index_select(decoder_input, 0, inds)
@@ -495,8 +497,9 @@ def generate(self, batch, **kwargs):
if repetition_penalty > 1.:
penalty_tokens = decoder_input[:, init_length:]
penalty_scores = torch.gather(score, dim=1, index=penalty_tokens)
- penalty_scores = torch.where(penalty_scores < 0., penalty_scores * repetition_penalty,
- penalty_scores / repetition_penalty)
+ penalty_scores = torch.where(
+ penalty_scores < 0., penalty_scores * repetition_penalty, penalty_scores / repetition_penalty
+ )
score = score.scatter_(dim=1, index=penalty_tokens, src=penalty_scores)
# nucleus sampling
@@ -524,8 +527,8 @@ def generate(self, batch, **kwargs):
preds_scores = []
for i in range(bsz):
- seq: torch.LongTensor = decoder_input[i, :lengths[i,]]
- res_scores = (float(scores[i,]), seq.tolist())
+ seq: torch.LongTensor = decoder_input[i, :lengths[i, ]]
+ res_scores = (float(scores[i, ]), seq.tolist())
preds_scores.append([res_scores])
best_preds_scores = [preds[0] for preds in preds_scores]
@@ -543,13 +546,7 @@ class LlamaVAC(nn.Module):
"""
mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']
- def __init__(
- self,
- actor_path: str,
- critic_path: str,
- tokenizer: LlamaTokenizer,
- opt: Dict
- ) -> None:
+ def __init__(self, actor_path: str, critic_path: str, tokenizer: LlamaTokenizer, opt: Dict) -> None:
"""
Overview:
Initialize the ``DREAMERVAC`` model according to arguments.
@@ -558,7 +555,11 @@ def __init__(
- action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
"""
super(LlamaVAC, self).__init__()
- self.actor = Llama.from_pretrained(actor_path, opt=opt, tokenizer=tokenizer,)
+ self.actor = Llama.from_pretrained(
+ actor_path,
+ opt=opt,
+ tokenizer=tokenizer,
+ )
self.critic = LlamaRewardModel.from_pretrained(critic_path, opt=opt, tokenizer=tokenizer)
def forward(self, x: torch.Tensor, mode: str) -> Dict:
diff --git a/ding/policy/ppof.py b/ding/policy/ppof.py
index a445413ae1..10b7876494 100644
--- a/ding/policy/ppof.py
+++ b/ding/policy/ppof.py
@@ -60,7 +60,13 @@ def default_model(cls: type) -> Callable:
from .model import PPOFModel
return PPOFModel
- def __init__(self, cfg: "EasyDict", model: torch.nn.Module, enable_mode: List[str] = None, orig_model: torch.nn.Module = None) -> None:
+ def __init__(
+ self,
+ cfg: "EasyDict",
+ model: torch.nn.Module,
+ enable_mode: List[str] = None,
+ orig_model: torch.nn.Module = None
+ ) -> None:
self._cfg = cfg
self._orig_model = orig_model
if model is None:
@@ -238,7 +244,6 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
for batch in split_data:
output = self._model.compute_actor_critic(batch.obs)
adv = batch.adv
- mask = batch.mask
if self._cfg.adv_norm:
# Normalize advantage in a train_batch
adv = (adv - adv.mean()) / (adv.std() + 1e-8)
@@ -256,8 +261,10 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
)
ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio)
else:
+ mask = batch.mask
ppo_batch = ppo_data(
- output['logit'], batch.orig_logit, batch.obs, output['value'][0], batch.value, adv, batch.return_, None
+ output['logit'], batch.orig_logit, batch.obs, output['value'][0], batch.value, adv,
+ batch.return_, None
)
ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio)
elif self._action_space == 'hybrid':
@@ -283,9 +290,19 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
max(ppo_continuous_info.approx_kl, ppo_discrete_info.approx_kl),
max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac)
)
- wv, we, wk = self._cfg.value_weight, self._cfg.entropy_weight, self._cfg.kl_penalty_weight
- kl_loss = (torch.nn.functional.kl_div(torch.softmax(output.logit, dim=-1), torch.softmax(batch.orig_logit, dim=-1), reduction='none') * mask).mean()
- total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + wk * kl_loss
+ if not self._cfg.chat_data:
+ wv, we = self._cfg.value_weight, self._cfg.entropy_weight
+ total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss
+ else:
+ wv, we, wk = self._cfg.value_weight, self._cfg.entropy_weight, self._cfg.kl_penalty_weight
+ kl_loss = (
+ torch.nn.functional.kl_div(
+ torch.softmax(output.logit, dim=-1),
+ torch.softmax(batch.orig_logit, dim=-1),
+ reduction='none'
+ ) * mask
+ ).mean()
+ total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + wk * kl_loss
self._optimizer.zero_grad()
total_loss.backward()
diff --git a/ding/reward_model/language_reward_model.py b/ding/reward_model/language_reward_model.py
index a2a5544b1d..595348bfa3 100644
--- a/ding/reward_model/language_reward_model.py
+++ b/ding/reward_model/language_reward_model.py
@@ -3,6 +3,7 @@
class LlamaRewardModel(LlamaForCausalLM):
+
def __init__(self, config, opt, tokenizer):
super().__init__(config)
self.opt = opt
@@ -12,10 +13,7 @@ def __init__(self, config, opt, tokenizer):
def forward(self, decoder_input, only_last=True):
attention_mask = decoder_input.ne(self.tokenizer.pad_token_id)
output = self.model.forward(
- input_ids=decoder_input,
- attention_mask=attention_mask,
- return_dict=True,
- use_cache=False
+ input_ids=decoder_input, attention_mask=attention_mask, return_dict=True, use_cache=False
)
if only_last:
@@ -23,4 +21,4 @@ def forward(self, decoder_input, only_last=True):
else:
logits = self.reward_head(output.last_hidden_state).squeeze(-1)
- return (logits,)
+ return (logits, )
diff --git a/ding/rl_utils/gae.py b/ding/rl_utils/gae.py
index 1d90a89d12..d7081ad969 100644
--- a/ding/rl_utils/gae.py
+++ b/ding/rl_utils/gae.py
@@ -84,7 +84,8 @@ def episodic_gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97):
assert val.shape[0] == rew.shape[0]
next_val = torch.zeros_like(val)
next_val[:-1] = val[1:]
- gd = gae_data(val.unsqueeze(-1), next_val.unsqueeze(-1), rew.unsqueeze(-1), don.unsqueeze(-1),
- traj.unsqueeze(-1))
+ gd = gae_data(
+ val.unsqueeze(-1), next_val.unsqueeze(-1), rew.unsqueeze(-1), don.unsqueeze(-1), traj.unsqueeze(-1)
+ )
advs.append(gae(gd, gamma, lambda_).squeeze(-1))
return torch.stack(advs, dim=0)
From d5f931be577ca119f0ca24cd406a0a60454cb06a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’>
Date: Fri, 24 Nov 2023 15:46:07 +0800
Subject: [PATCH 13/17] polish
---
ding/bonus/ppof.py | 3 +-
ding/framework/middleware/__init__.py | 2 +-
ding/framework/middleware/collector.py | 2 +-
ding/model/common/__init__.py | 2 +-
ding/model/common/tests/test_utils.py | 21 +++
ding/model/common/utils.py | 17 +-
ding/model/template/__init__.py | 1 +
ding/model/template/lm_vac.py | 184 +++++++++++++++++++++
ding/model/template/vac.py | 155 -----------------
ding/policy/ppof.py | 2 +-
ding/reward_model/language_reward_model.py | 11 +-
dizoo/chat/entry.py | 34 ++++
dizoo/chat/env.py | 22 ++-
13 files changed, 278 insertions(+), 178 deletions(-)
create mode 100644 ding/model/common/tests/test_utils.py
create mode 100644 ding/model/template/lm_vac.py
create mode 100644 dizoo/chat/entry.py
diff --git a/ding/bonus/ppof.py b/ding/bonus/ppof.py
index fc755bd25a..63b0334024 100644
--- a/ding/bonus/ppof.py
+++ b/ding/bonus/ppof.py
@@ -10,7 +10,7 @@
import torch
from ding.framework import task, OnlineRLContext
from ding.framework.middleware import interaction_evaluator_ttorch, PPOFStepCollector, multistep_trainer, CkptSaver, \
- wandb_online_logger, offline_data_saver, termination_checker, ppof_adv_estimator
+ wandb_online_logger, offline_data_saver, termination_checker, ppof_adv_estimator, ChatCollector
from ding.envs import BaseEnv, BaseEnvManagerV2, SubprocessEnvManagerV2
from ding.policy import PPOFPolicy, single_env_forward_wrapper_ttorch
from ding.utils import set_pkg_seed
@@ -19,7 +19,6 @@
from .model import PPOFModel
from .config import get_instance_config, get_instance_env, get_hybrid_shape
from ding.bonus.common import TrainingReturn, EvalReturn
-from ..framework.middleware.collector import ChatCollector
class PPOF:
diff --git a/ding/framework/middleware/__init__.py b/ding/framework/middleware/__init__.py
index b9e3c5005d..43cea6883b 100644
--- a/ding/framework/middleware/__init__.py
+++ b/ding/framework/middleware/__init__.py
@@ -1,5 +1,5 @@
from .functional import *
-from .collector import StepCollector, EpisodeCollector, PPOFStepCollector
+from .collector import StepCollector, EpisodeCollector, PPOFStepCollector, ChatCollector
from .learner import OffPolicyLearner, HERLearner
from .ckpt_handler import CkptSaver
from .distributer import ContextExchanger, ModelExchanger, PeriodicalModelExchanger
diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py
index 7ad8650dbe..24533e5983 100644
--- a/ding/framework/middleware/collector.py
+++ b/ding/framework/middleware/collector.py
@@ -195,7 +195,7 @@ class ChatCollector:
"""
Overview:
The class of the collector running by steps, including model inference and transition \
- process. Use the `__call__` method to execute the whole collection process.
+ process. Use the `__call__` method to execute the whole collection process.
"""
def __new__(cls, *args, **kwargs):
diff --git a/ding/model/common/__init__.py b/ding/model/common/__init__.py
index 4bf7d8be5a..5bf1fba4d5 100755
--- a/ding/model/common/__init__.py
+++ b/ding/model/common/__init__.py
@@ -2,4 +2,4 @@
QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, BranchingHead, head_cls_map, \
independent_normal_dist, AttentionPolicyHead, PopArtVHead, EnsembleHead
from .encoder import ConvEncoder, FCEncoder, IMPALAConvEncoder
-from .utils import create_model
+from .utils import create_model, top_p_logits
diff --git a/ding/model/common/tests/test_utils.py b/ding/model/common/tests/test_utils.py
new file mode 100644
index 0000000000..b12e529e9b
--- /dev/null
+++ b/ding/model/common/tests/test_utils.py
@@ -0,0 +1,21 @@
+import pytest
+import torch
+from ding.model.common.utils import top_p_logits
+
+
+@pytest.mark.unittest
+class TestUtils:
+
+ def test_top_p_logits(self):
+ test_logit = torch.Tensor([
+ [0., 0.91, 0.05, 0.04],
+ [0.04, 0.46, 0.46, 0.04]
+ ])
+
+ gt_logit = torch.Tensor([
+ [0., 1., 0., 0.],
+ [0., 0.5, 0.5, 0.]
+ ])
+
+ pred_logit = top_p_logits(test_logit)
+ assert torch.sum((gt_logit - pred_logit)**2).item() < 1e-8
diff --git a/ding/model/common/utils.py b/ding/model/common/utils.py
index 340749eb0c..1b4b41e59a 100644
--- a/ding/model/common/utils.py
+++ b/ding/model/common/utils.py
@@ -23,16 +23,25 @@ def create_model(cfg: EasyDict) -> torch.nn.Module:
return MODEL_REGISTRY.build(cfg.pop("type"), **cfg)
-def top_p_logits(logits, topp=0.9, filter_value=0, min_topk=1):
+def top_p_logits(logits: torch.Tensor, topp: float = 0.9, filter_value: float = 0, min_topk: int = 1):
"""
- Filter a distribution of logits using nucleus (top-p) filtering
- https://github.com/OpenLMLab/MOSS/blob/e088f438d1a95d424c6dffef0d73134ebe62cb72/models_jittor/generation.py#L146
+ Overview:
+ Filter a distribution of logits using nucleus (top-p) filtering. The output is also logit tensors but some \
+ values are masked.
+ Arguments:
+ - logits (:obj:`torch.Tensor`): The input logits for top-p sampling.
+ - topp (:obj:`float`): The top-p value, such as 0.9.
+ - filter_value (:obj:`float`): The value for masked logits in output, default as 0.
+ - min_topk (:obj:`int`): The min number of sampled logit, default as 1 (which means that at least one sample \
+ will not be masked.)
+ Returns:
+ - cum_logits (:obj:`torch.Tensor`): The output logits after masking.
"""
cum_logits = logits.clone()
if topp > 0:
logits_sorted, inds = torch.sort(logits, dim=-1, descending=True)
mask = (logits_sorted.cumsum(dim=-1) - logits_sorted) >= topp
- mask[:, :min_topk] = False
+ mask[..., :min_topk] = False
# Remove tokens with cumulative top_p above the threshold
mask = torch.zeros_like(mask).to(torch.bool).scatter_(dim=-1, index=inds, src=mask)
cum_logits[mask] = filter_value
diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py
index b2dd815287..cbf31e72ae 100755
--- a/ding/model/template/__init__.py
+++ b/ding/model/template/__init__.py
@@ -5,6 +5,7 @@
from .vac import VAC, DREAMERVAC
from .bc import DiscreteBC, ContinuousBC
from .language_transformer import LanguageTransformer
+from .lm_vac import LlamaVAC
# algorithm-specific
from .pg import PG
from .ppg import PPG
diff --git a/ding/model/template/lm_vac.py b/ding/model/template/lm_vac.py
new file mode 100644
index 0000000000..61ef4f7856
--- /dev/null
+++ b/ding/model/template/lm_vac.py
@@ -0,0 +1,184 @@
+from typing import Dict
+import torch
+import torch.nn as nn
+try:
+ from transformers import LlamaTokenizer
+ from transformers.models.llama.modeling_llama import LlamaForCausalLM
+except ImportError:
+ from ditk import logging
+ logging.warning("Not found transformer, please install it using: pip install transformers")
+
+from ding.model.common import top_p_logits
+from ding.reward_model import LlamaRewardModel
+from ding.utils import MODEL_REGISTRY
+
+
+def get_tokenizer(path: str):
+ """
+ Overview:
+ Return the pretrained tokenizer using the given path.
+ """
+ tokenizer = LlamaTokenizer.from_pretrained(path, trust_remote_code=True)
+ tokenizer.bos_token = ''
+ tokenizer.eos_token = ''
+ tokenizer.pad_token = ''
+ tokenizer.pad_token_id = 0
+ tokenizer.unk_token = tokenizer.pad_token
+ tokenizer.unk_token_id = tokenizer.pad_token_id
+
+ return tokenizer
+
+
+class Llama(LlamaForCausalLM):
+
+ def __init__(self, config, opt, tokenizer):
+ super().__init__(config)
+ self.opt = opt
+ self.tokenizer = tokenizer
+
+ def forward(self, decoder_input, incr_state=None):
+
+ attention_mask = decoder_input.ne(self.tokenizer.pad_token_id)
+ if incr_state is not None:
+ decoder_input = decoder_input[:, -1:]
+
+ output = super().forward(
+ input_ids=decoder_input,
+ attention_mask=attention_mask,
+ past_key_values=incr_state,
+ return_dict=True,
+ use_cache=not self.training
+ )
+
+ logits = output.logits
+ new_incr_states = output.past_key_values
+
+ return logits, new_incr_states
+
+ @torch.no_grad()
+ def generate(self, batch, **kwargs):
+ """
+ Generate response
+ """
+ maxlen_res = kwargs.pop('maxlen_res', self.opt.maxlen_res)
+ temperature = kwargs.pop('temperature', self.opt.temperature)
+ repetition_penalty = kwargs.pop('repetition_penalty', self.opt.repetition_penalty)
+ topp = kwargs.pop('topp', self.opt.topp)
+
+ decoder_input: torch.LongTensor = batch # (bsz, ...)
+ assert decoder_input[:, -1].ne(
+ self.tokenizer.pad_token_id
+ ).all(), 'Last token should not be a padding token (you can use left padding instead).'
+
+ dev = decoder_input.device
+ bsz = decoder_input.size(0)
+
+ scores = torch.zeros((bsz, ), device=dev, dtype=torch.float16)
+ done = torch.zeros((bsz, ), device=dev).to(torch.bool)
+
+ inds = torch.arange(bsz).to(dev).unsqueeze(1).view(-1)
+ decoder_input = torch.index_select(decoder_input, 0, inds)
+ init_length = decoder_input.size(1)
+
+ incr_state = None
+ for _token in range(maxlen_res):
+ if done.all():
+ break
+ score, incr_state, *_ = self.forward(decoder_input, incr_state)
+ score = score.half()
+
+ # now score is bs, len, vocab_size
+ score = score[:, -1, :]
+
+ # calculate repetition penalty
+ if repetition_penalty > 1.:
+ penalty_tokens = decoder_input[:, init_length:]
+ penalty_scores = torch.gather(score, dim=1, index=penalty_tokens)
+ penalty_scores = torch.where(
+ penalty_scores < 0., penalty_scores * repetition_penalty, penalty_scores / repetition_penalty
+ )
+ score = score.scatter_(dim=1, index=penalty_tokens, src=penalty_scores)
+
+ # nucleus sampling
+ score = torch.softmax(score.div(temperature), dim=-1)
+ probs = top_p_logits(score, topp=topp, filter_value=0)
+ tok_ids = torch.multinomial(probs, 1)[:, 0]
+ hyp_ids = torch.arange(probs.size(0), device=dev)
+ scores = scores + probs[hyp_ids, tok_ids].log() * ~done
+
+ tok_ids = torch.where(done, self.tokenizer.pad_token_id, tok_ids)
+ decoder_input = torch.cat((decoder_input, tok_ids.unsqueeze(-1)), dim=-1)
+ done = done | tok_ids.eq(self.tokenizer.eos_token_id)
+
+ incr_state = self._reorder_cache(incr_state, hyp_ids)
+
+ # get all finalized candidates for each sample
+ decoder_input = decoder_input[:, init_length:]
+ decoder_input = decoder_input.view(bsz, -1)
+ scores = scores.view(bsz, )
+
+ lengths = decoder_input.ne(self.tokenizer.pad_token_id).sum(dim=-1)
+
+ length_penalty = torch.pow(lengths, 1.0)
+ scores /= length_penalty
+
+ preds_scores = []
+ for i in range(bsz):
+ seq: torch.LongTensor = decoder_input[i, :lengths[i, ]]
+ res_scores = (float(scores[i, ]), seq.tolist())
+ preds_scores.append([res_scores])
+
+ best_preds_scores = [preds[0] for preds in preds_scores]
+ return best_preds_scores, preds_scores
+
+
+@MODEL_REGISTRY.register('llamavac')
+class LlamaVAC(nn.Module):
+ """
+ Overview:
+ The neural network and computation graph of Llama VAC. The actor and critic of this model are respectively \
+ a Llama Pretrained Model.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+ mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']
+
+ def __init__(self, actor_path: str, critic_path: str, opt: Dict, enable_checkpointing: bool = True) -> None:
+ """
+ Overview:
+ Initialize the ``LlamaVAC`` model according to arguments.
+ Arguments:
+ - actor_path (:obj:`str`): Pretrained model path for actor.
+ - critic_path (:obj:`str`): Pretrained model path for critic.
+ - opt (:obj:`Dict`): Options for this model.
+ """
+ super(LlamaVAC, self).__init__()
+ tokenizer = get_tokenizer(actor_path)
+ self.actor = Llama.from_pretrained(
+ actor_path,
+ opt=opt,
+ tokenizer=tokenizer,
+ )
+ self.critic = LlamaRewardModel.from_pretrained(critic_path, tokenizer=tokenizer)
+ if enable_checkpointing:
+ self.actor.gradient_checkpointing_enable()
+ self.critic.gradient_checkpointing_enable()
+
+ def forward(self, x: torch.Tensor, mode: str) -> Dict:
+ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
+ return getattr(self, mode)(x)
+
+ def compute_actor(self, x):
+ policy_output = self.actor(decoder_input=x)
+ policy_logit, *_ = policy_output
+ return {"logit": policy_logit}
+
+ def compute_critic(self, x):
+ values = self.critic(decoder_input=x, only_last=False)
+ return {"value": values}
+
+ def compute_actor_critic(self, x):
+ policy_output = self.actor(decoder_input=x)
+ policy_logit, *_ = policy_output
+ values = self.critic(decoder_input=x, only_last=False)
+ return {"logit": policy_logit, "value": values}
diff --git a/ding/model/template/vac.py b/ding/model/template/vac.py
index 60cb596e55..29363d3570 100644
--- a/ding/model/template/vac.py
+++ b/ding/model/template/vac.py
@@ -1,7 +1,4 @@
from typing import Union, Dict, Optional
-
-from transformers import LlamaTokenizer
-from transformers.models.llama.modeling_llama import LlamaForCausalLM
from easydict import EasyDict
import torch
import torch.nn as nn
@@ -10,8 +7,6 @@
from ..common import ReparameterizationHead, RegressionHead, DiscreteHead, MultiHead, \
FCEncoder, ConvEncoder, IMPALAConvEncoder
from ding.torch_utils.network.dreamer import ActionHead, DenseHead
-from ..common.utils import top_p_logits
-from ding.reward_model import LlamaRewardModel
@MODEL_REGISTRY.register('vac')
@@ -430,153 +425,3 @@ def __init__(
outscale=0.0,
device='cuda' if torch.cuda.is_available() else 'cpu',
)
-
-
-class Llama(LlamaForCausalLM):
-
- def __init__(self, config, opt, tokenizer):
- super().__init__(config)
- self.opt = opt
- self.tokenizer = tokenizer
-
- def forward(self, decoder_input, incr_state=None):
-
- attention_mask = decoder_input.ne(self.tokenizer.pad_token_id)
- if incr_state is not None:
- decoder_input = decoder_input[:, -1:]
-
- output = super().forward(
- input_ids=decoder_input,
- attention_mask=attention_mask,
- past_key_values=incr_state,
- return_dict=True,
- use_cache=not self.training
- )
-
- logits = output.logits
- new_incr_states = output.past_key_values
-
- return logits, new_incr_states
-
- @torch.no_grad()
- def generate(self, batch, **kwargs):
- """
- Generate response
- """
- maxlen_res = kwargs.pop('maxlen_res', self.opt.maxlen_res)
- temperature = kwargs.pop('temperature', self.opt.temperature)
- repetition_penalty = kwargs.pop('repetition_penalty', self.opt.repetition_penalty)
- topp = kwargs.pop('topp', self.opt.topp)
-
- decoder_input: torch.LongTensor = batch # (bsz, ...)
- assert decoder_input[:, -1].ne(
- self.tokenizer.pad_token_id
- ).all(), 'Last token should not be a padding token (you can use left padding instead).'
-
- dev = decoder_input.device
- bsz = decoder_input.size(0)
-
- scores = torch.zeros((bsz, ), device=dev, dtype=torch.float16)
- done = torch.zeros((bsz, ), device=dev).to(torch.bool)
-
- inds = torch.arange(bsz).to(dev).unsqueeze(1).view(-1)
- decoder_input = torch.index_select(decoder_input, 0, inds)
- init_length = decoder_input.size(1)
-
- incr_state = None
- for _token in range(maxlen_res):
- if done.all():
- break
- score, incr_state, *_ = self.forward(decoder_input, incr_state)
- score = score.half()
-
- # now score is bs, len, vocab_size
- score = score[:, -1, :]
-
- # calculate repetition penalty
- if repetition_penalty > 1.:
- penalty_tokens = decoder_input[:, init_length:]
- penalty_scores = torch.gather(score, dim=1, index=penalty_tokens)
- penalty_scores = torch.where(
- penalty_scores < 0., penalty_scores * repetition_penalty, penalty_scores / repetition_penalty
- )
- score = score.scatter_(dim=1, index=penalty_tokens, src=penalty_scores)
-
- # nucleus sampling
- score = torch.softmax(score.div(temperature), dim=-1)
- probs = top_p_logits(score, topp=topp, filter_value=0)
- tok_ids = torch.multinomial(probs, 1)[:, 0]
- hyp_ids = torch.arange(probs.size(0), device=dev)
- scores = scores + probs[hyp_ids, tok_ids].log() * ~done
-
- tok_ids = torch.where(done, self.tokenizer.pad_token_id, tok_ids)
- decoder_input = torch.cat((decoder_input, tok_ids.unsqueeze(-1)), dim=-1)
- done = done | tok_ids.eq(self.tokenizer.eos_token_id)
-
- incr_state = self._reorder_cache(incr_state, hyp_ids)
-
- # get all finalized candidates for each sample
- decoder_input = decoder_input[:, init_length:]
- decoder_input = decoder_input.view(bsz, -1)
- scores = scores.view(bsz, )
-
- lengths = decoder_input.ne(self.tokenizer.pad_token_id).sum(dim=-1)
-
- length_penalty = torch.pow(lengths, 1.0)
- scores /= length_penalty
-
- preds_scores = []
- for i in range(bsz):
- seq: torch.LongTensor = decoder_input[i, :lengths[i, ]]
- res_scores = (float(scores[i, ]), seq.tolist())
- preds_scores.append([res_scores])
-
- best_preds_scores = [preds[0] for preds in preds_scores]
- return best_preds_scores, preds_scores
-
-
-@MODEL_REGISTRY.register('llamavac')
-class LlamaVAC(nn.Module):
- """
- Overview:
- The neural network and computation graph of DreamerV3 (state) Value Actor-Critic (VAC).
- This model now supports discrete, continuous action space.
- Interfaces:
- ``__init__``, ``forward``.
- """
- mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']
-
- def __init__(self, actor_path: str, critic_path: str, tokenizer: LlamaTokenizer, opt: Dict) -> None:
- """
- Overview:
- Initialize the ``DREAMERVAC`` model according to arguments.
- Arguments:
- - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
- - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
- """
- super(LlamaVAC, self).__init__()
- self.actor = Llama.from_pretrained(
- actor_path,
- opt=opt,
- tokenizer=tokenizer,
- )
- self.critic = LlamaRewardModel.from_pretrained(critic_path, opt=opt, tokenizer=tokenizer)
-
- def forward(self, x: torch.Tensor, mode: str) -> Dict:
- assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
- return getattr(self, mode)(x)
-
- def compute_actor(self, x):
- policy_output = self.actor(decoder_input=x)
- policy_logit, *_ = policy_output
- return {"logit": policy_logit}
-
- def compute_critic(self, x):
- values = self.critic(decoder_input=x, only_last=False)
- return {"value": values}
-
- def compute_actor_critic(self, x):
- policy_output = self.actor(decoder_input=x)
- policy_logit, *_ = policy_output
- values = self.critic(decoder_input=x, only_last=False)
- return {"logit": policy_logit, "value": values}
diff --git a/ding/policy/ppof.py b/ding/policy/ppof.py
index 10b7876494..d519853835 100644
--- a/ding/policy/ppof.py
+++ b/ding/policy/ppof.py
@@ -263,7 +263,7 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
else:
mask = batch.mask
ppo_batch = ppo_data(
- output['logit'], batch.orig_logit, batch.obs, output['value'][0], batch.value, adv,
+ output['logit'], batch.orig_logit, batch.obs, output['value'], batch.value, adv,
batch.return_, None
)
ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio)
diff --git a/ding/reward_model/language_reward_model.py b/ding/reward_model/language_reward_model.py
index 595348bfa3..79d3d0c149 100644
--- a/ding/reward_model/language_reward_model.py
+++ b/ding/reward_model/language_reward_model.py
@@ -1,12 +1,15 @@
import torch
-from transformers.models.llama.modeling_llama import LlamaForCausalLM
+try:
+ from transformers.models.llama.modeling_llama import LlamaForCausalLM
+except ImportError:
+ from ditk import logging
+ logging.warning("Not found transformer, please install it using: pip install transformers")
class LlamaRewardModel(LlamaForCausalLM):
- def __init__(self, config, opt, tokenizer):
+ def __init__(self, config, tokenizer):
super().__init__(config)
- self.opt = opt
self.tokenizer = tokenizer
self.reward_head = torch.nn.Linear(config.hidden_size, 1, bias=False)
@@ -21,4 +24,4 @@ def forward(self, decoder_input, only_last=True):
else:
logits = self.reward_head(output.last_hidden_state).squeeze(-1)
- return (logits, )
+ return logits
diff --git a/dizoo/chat/entry.py b/dizoo/chat/entry.py
new file mode 100644
index 0000000000..ecb47c53f9
--- /dev/null
+++ b/dizoo/chat/entry.py
@@ -0,0 +1,34 @@
+from easydict import EasyDict
+
+from ding.bonus.ppof import PPOF
+from ding.model.template import LlamaVAC
+
+
+
+if __name__ == '__main__':
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--actor_path', type=str)
+ parser.add_argument('--critic_path', type=str)
+ args = parser.parse_args()
+
+ opt = EasyDict({
+ "maxlen_res": 512,
+ "temperature": 1,
+ "repetition_penalty": 1,
+ "topp": 0.8
+ })
+
+ model = LlamaVAC(
+ actor_path=args.actor_path,
+ critic_path=args.critic_path,
+ tokenizer=get_tokenizer("/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-sft-model-7B-en"),
+ opt=opt
+ )
+
+ policy = PPOF(
+ env_id="chat",
+ exp_name="rlhf-ppo",
+ model=model
+ )
+ policy.train(collector_env_num=1, evaluator_env_num=1, debug=True)
diff --git a/dizoo/chat/env.py b/dizoo/chat/env.py
index 01f0637a9b..01d0f183cb 100644
--- a/dizoo/chat/env.py
+++ b/dizoo/chat/env.py
@@ -8,16 +8,16 @@
class ChatEnv(BaseEnv):
def __init__(
self,
- batch_size,
- reward_model_path,
- tokenizer_path,
- data_path,
- maxlen_prompt,
- maxlen_res,
+ batch_size: int,
+ reward_model_path: str,
+ tokenizer_path: str,
+ data_path: str,
+ maxlen_prompt: int,
+ maxlen_res: int,
):
self.batch_size = batch_size
self.tokenizer = get_tokenizer(tokenizer_path)
- self.rm = LlamaRewardModel.from_pretrained(reward_model_path, tokenizer=self.tokenizer, opt=None)
+ self.rm = LlamaRewardModel.from_pretrained(reward_model_path, tokenizer=self.tokenizer)
self.action_space = None
self.observation_space = None
self.reward_space = None
@@ -41,6 +41,9 @@ def close(self) -> None:
def reset(self):
self.last_batch = next(self.generator)
+ if self.last_batch is None:
+ self.generator = self.dataset.final_generator()
+ self.last_batch = next(self.generator)
self._init_flag = True
return self.last_batch
@@ -48,9 +51,10 @@ def __repr__(self) -> str:
return "DI-engine Chat Env"
def seed(self, seed):
- self._seed = 0
+ self._seed = seed
def clone(self, caller):
+ # It should not create a new copy, since the language model is initialized.
return self
def step(self, action):
@@ -63,7 +67,7 @@ def step(self, action):
rm_input = torch.tensor(output_vec, dtype=torch.long)
output_mask = pad_sequences(output_mask, self.tokenizer.pad_token_id, padding='left')
with torch.no_grad():
- rew, *_ = self.rm(rm_input)
+ rew = self.rm(rm_input)
self.last_batch = next(self.generator)
if self.last_batch is None:
From 11d3c480689ac39978c08b577e88191cd0c1e462 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’>
Date: Mon, 4 Dec 2023 22:18:41 +0800
Subject: [PATCH 14/17] add mix precision
---
ding/bonus/config.py | 8 ++--
ding/framework/middleware/collector.py | 4 +-
ding/model/template/lm_vac.py | 8 +++-
ding/policy/ppof.py | 55 ++++++++++++++--------
ding/reward_model/language_reward_model.py | 7 +--
ding/rl_utils/gae.py | 5 +-
dizoo/chat/entry.py | 4 +-
7 files changed, 56 insertions(+), 35 deletions(-)
diff --git a/ding/bonus/config.py b/ding/bonus/config.py
index 113eaf0943..ddb73fb235 100644
--- a/ding/bonus/config.py
+++ b/ding/bonus/config.py
@@ -169,7 +169,7 @@ def get_instance_config(env_id: str, algorithm: str) -> EasyDict:
cfg.learning_rate = 3e-4
elif env_id == 'chat':
cfg.epoch_per_collect = 1
- cfg.batch_size = 2
+ cfg.batch_size = 1
cfg.learning_rate = 5e-7
cfg.answers_per_question = 3
cfg.kl_penalty_weight = 0.1
@@ -325,12 +325,12 @@ def get_instance_env(env_id: str) -> BaseEnv:
elif env_id == 'chat':
from dizoo.chat.env import ChatEnv
return ChatEnv(
- batch_size=2,
+ batch_size=1,
reward_model_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en/recover",
tokenizer_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-reward-model-7B-en",
data_path="/mnt/nfs/whl/rlhf/MOSS-RLHF/data/ppo_data",
- maxlen_prompt=1024,
- maxlen_res=512,
+ maxlen_prompt=128,
+ maxlen_res=128,
)
else:
raise KeyError("not supported env type: {}".format(env_id))
diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py
index 24533e5983..6017940396 100644
--- a/ding/framework/middleware/collector.py
+++ b/ding/framework/middleware/collector.py
@@ -214,7 +214,7 @@ def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll
self.env = env
self.env.seed(seed)
self.env.launch()
- self.env = self._envs[0]
+ self.env = self.env._envs[0]
self.policy = policy
self.n_sample = n_sample
self.unroll_len = unroll_len
@@ -229,7 +229,7 @@ def __call__(self, ctx: "OnlineRLContext") -> None:
"""
device = self.policy._device
- obs = ttorch.as_tensor(self.env.last_batch[0]['text_vec'])
+ obs = ttorch.as_tensor(self.env.last_batch['text_vec'])
batch_size = obs.shape[0]
obs = obs.to(device)
diff --git a/ding/model/template/lm_vac.py b/ding/model/template/lm_vac.py
index 61ef4f7856..e2099e705a 100644
--- a/ding/model/template/lm_vac.py
+++ b/ding/model/template/lm_vac.py
@@ -143,7 +143,8 @@ class LlamaVAC(nn.Module):
"""
mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']
- def __init__(self, actor_path: str, critic_path: str, opt: Dict, enable_checkpointing: bool = True) -> None:
+ def __init__(self, actor_path: str, critic_path: str,
+ tokenizer_path: str, opt: Dict, enable_checkpointing: bool = True) -> None:
"""
Overview:
Initialize the ``LlamaVAC`` model according to arguments.
@@ -153,13 +154,16 @@ def __init__(self, actor_path: str, critic_path: str, opt: Dict, enable_checkpoi
- opt (:obj:`Dict`): Options for this model.
"""
super(LlamaVAC, self).__init__()
- tokenizer = get_tokenizer(actor_path)
+ tokenizer = get_tokenizer(tokenizer_path)
+
self.actor = Llama.from_pretrained(
actor_path,
opt=opt,
tokenizer=tokenizer,
)
+
self.critic = LlamaRewardModel.from_pretrained(critic_path, tokenizer=tokenizer)
+
if enable_checkpointing:
self.actor.gradient_checkpointing_enable()
self.critic.gradient_checkpointing_enable()
diff --git a/ding/policy/ppof.py b/ding/policy/ppof.py
index d519853835..ad1afb1c1d 100644
--- a/ding/policy/ppof.py
+++ b/ding/policy/ppof.py
@@ -7,6 +7,8 @@
import torch
import treetensor.torch as ttorch
from torch.optim import AdamW
+from torch.cuda.amp import GradScaler
+from torch import autocast
from ding.rl_utils import ppo_data, ppo_error, ppo_policy_error, ppo_policy_data, gae, gae_data, ppo_error_continuous, \
get_gae, ppo_policy_error_continuous, ArgmaxSampler, MultinomialSampler, ReparameterizationSampler, MuSampler, \
@@ -69,6 +71,8 @@ def __init__(
) -> None:
self._cfg = cfg
self._orig_model = orig_model
+ if self._orig_model is not None:
+ self.scalar = GradScaler()
if model is None:
self._model = self.default_model()
else:
@@ -170,7 +174,7 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
with torch.no_grad():
if self._cfg.chat_data:
# [B, T]
- value = self._model.compute_critic(data.obs)['value'][0]
+ value = self._model.compute_critic(data.obs)['value']
self._model.cpu()
self._orig_model.cuda()
data.orig_logit = self._orig_model.compute_actor(data.obs)['logit']
@@ -242,7 +246,8 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
split_data = ttorch.split(data, self._cfg.batch_size)
random.shuffle(list(split_data))
for batch in split_data:
- output = self._model.compute_actor_critic(batch.obs)
+ if not self._cfg.chat_data:
+ output = self._model.compute_actor_critic(batch.obs)
adv = batch.adv
if self._cfg.adv_norm:
# Normalize advantage in a train_batch
@@ -261,12 +266,21 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
)
ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio)
else:
- mask = batch.mask
- ppo_batch = ppo_data(
- output['logit'], batch.orig_logit, batch.obs, output['value'], batch.value, adv,
- batch.return_, None
- )
- ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio)
+ with autocast(device_type='cuda', dtype=torch.float16):
+ output = self._model.compute_actor_critic(batch.obs)
+ mask = batch.mask
+ ppo_batch = ppo_data(
+ output['logit'], batch.orig_logit, batch.obs, output['value'], batch.value, adv,
+ batch.return_, None
+ )
+ ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio)
+ kl_loss = (
+ torch.nn.functional.kl_div(
+ torch.softmax(output["logit"], dim=-1),
+ torch.softmax(batch.orig_logit, dim=-1),
+ reduction='none'
+ ) * mask.unsqueeze(-1)
+ ).mean()
elif self._action_space == 'hybrid':
# discrete part (discrete policy loss and entropy loss)
ppo_discrete_batch = ppo_policy_data(
@@ -293,21 +307,22 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
if not self._cfg.chat_data:
wv, we = self._cfg.value_weight, self._cfg.entropy_weight
total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss
+ self._optimizer.zero_grad()
+ total_loss.backward()
+ torch.nn.utils.clip_grad_norm_(self._model.parameters(), self._cfg.grad_norm)
+ self._optimizer.step()
else:
wv, we, wk = self._cfg.value_weight, self._cfg.entropy_weight, self._cfg.kl_penalty_weight
- kl_loss = (
- torch.nn.functional.kl_div(
- torch.softmax(output.logit, dim=-1),
- torch.softmax(batch.orig_logit, dim=-1),
- reduction='none'
- ) * mask
- ).mean()
total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + wk * kl_loss
-
- self._optimizer.zero_grad()
- total_loss.backward()
- torch.nn.utils.clip_grad_norm_(self._model.parameters(), self._cfg.grad_norm)
- self._optimizer.step()
+ output = ttorch.as_tensor(output)
+ self._optimizer.zero_grad()
+ self.scaler.scale(total_loss).backward()
+ # scaler.step() first unscales the gradients of the optimizer's assigned params.
+ # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
+ # otherwise, optimizer.step() is skipped.
+ scaler.step(self._optimizer)
+ # Updates the scale for next iteration.
+ scaler.update()
return_info = {
'cur_lr': self._optimizer.defaults['lr'],
diff --git a/ding/reward_model/language_reward_model.py b/ding/reward_model/language_reward_model.py
index 79d3d0c149..ff2558099a 100644
--- a/ding/reward_model/language_reward_model.py
+++ b/ding/reward_model/language_reward_model.py
@@ -15,9 +15,10 @@ def __init__(self, config, tokenizer):
def forward(self, decoder_input, only_last=True):
attention_mask = decoder_input.ne(self.tokenizer.pad_token_id)
- output = self.model.forward(
- input_ids=decoder_input, attention_mask=attention_mask, return_dict=True, use_cache=False
- )
+ with torch.no_grad():
+ output = self.model.forward(
+ input_ids=decoder_input, attention_mask=attention_mask, return_dict=True, use_cache=False
+ )
if only_last:
logits = self.reward_head(output.last_hidden_state[:, -1, :]).squeeze(-1)
diff --git a/ding/rl_utils/gae.py b/ding/rl_utils/gae.py
index d7081ad969..16b8313750 100644
--- a/ding/rl_utils/gae.py
+++ b/ding/rl_utils/gae.py
@@ -81,11 +81,12 @@ def episodic_gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97):
bsz = value.shape[0]
for i in range(bsz):
val, mas, rew, don, traj = value[i], mask[i], reward[i], done[i], traj_flag[i]
- assert val.shape[0] == rew.shape[0]
next_val = torch.zeros_like(val)
next_val[:-1] = val[1:]
+ reward = torch.zeros_like(val)
+ reward[-1] = rew
gd = gae_data(
- val.unsqueeze(-1), next_val.unsqueeze(-1), rew.unsqueeze(-1), don.unsqueeze(-1), traj.unsqueeze(-1)
+ val.unsqueeze(-1), next_val.unsqueeze(-1), reward.unsqueeze(-1), don.unsqueeze(-1), traj.unsqueeze(-1)
)
advs.append(gae(gd, gamma, lambda_).squeeze(-1))
return torch.stack(advs, dim=0)
diff --git a/dizoo/chat/entry.py b/dizoo/chat/entry.py
index ecb47c53f9..5da6ac620d 100644
--- a/dizoo/chat/entry.py
+++ b/dizoo/chat/entry.py
@@ -4,12 +4,12 @@
from ding.model.template import LlamaVAC
-
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--actor_path', type=str)
parser.add_argument('--critic_path', type=str)
+ parser.add_argument('--tokenizer_path', type=str)
args = parser.parse_args()
opt = EasyDict({
@@ -22,7 +22,7 @@
model = LlamaVAC(
actor_path=args.actor_path,
critic_path=args.critic_path,
- tokenizer=get_tokenizer("/mnt/nfs/whl/rlhf/MOSS-RLHF/models/moss-rlhf-sft-model-7B-en"),
+ tokenizer_path=args.tokenizer_path,
opt=opt
)
From f95529ffc6ab22714f3b113273541ddc11430280 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’>
Date: Wed, 6 Dec 2023 17:47:04 +0800
Subject: [PATCH 15/17] fix quant bug
---
ding/model/common/tests/test_utils.py | 12 +++---------
ding/model/template/lm_vac.py | 24 +++++++++++++-----------
ding/policy/ppof.py | 21 +++++++--------------
3 files changed, 23 insertions(+), 34 deletions(-)
diff --git a/ding/model/common/tests/test_utils.py b/ding/model/common/tests/test_utils.py
index b12e529e9b..9a6f688e2d 100644
--- a/ding/model/common/tests/test_utils.py
+++ b/ding/model/common/tests/test_utils.py
@@ -7,15 +7,9 @@
class TestUtils:
def test_top_p_logits(self):
- test_logit = torch.Tensor([
- [0., 0.91, 0.05, 0.04],
- [0.04, 0.46, 0.46, 0.04]
- ])
+ test_logit = torch.Tensor([[0., 0.91, 0.05, 0.04], [0.04, 0.46, 0.46, 0.04]])
- gt_logit = torch.Tensor([
- [0., 1., 0., 0.],
- [0., 0.5, 0.5, 0.]
- ])
+ gt_logit = torch.Tensor([[0., 1., 0., 0.], [0., 0.5, 0.5, 0.]])
pred_logit = top_p_logits(test_logit)
- assert torch.sum((gt_logit - pred_logit)**2).item() < 1e-8
+ assert torch.sum((gt_logit - pred_logit) ** 2).item() < 1e-8
diff --git a/ding/model/template/lm_vac.py b/ding/model/template/lm_vac.py
index e2099e705a..2dd96c6a0a 100644
--- a/ding/model/template/lm_vac.py
+++ b/ding/model/template/lm_vac.py
@@ -36,7 +36,7 @@ def __init__(self, config, opt, tokenizer):
self.opt = opt
self.tokenizer = tokenizer
- def forward(self, decoder_input, incr_state=None):
+ def forward(self, decoder_input, incr_state=None, is_train=True):
attention_mask = decoder_input.ne(self.tokenizer.pad_token_id)
if incr_state is not None:
@@ -47,7 +47,7 @@ def forward(self, decoder_input, incr_state=None):
attention_mask=attention_mask,
past_key_values=incr_state,
return_dict=True,
- use_cache=not self.training
+ use_cache=not is_train
)
logits = output.logits
@@ -84,7 +84,7 @@ def generate(self, batch, **kwargs):
for _token in range(maxlen_res):
if done.all():
break
- score, incr_state, *_ = self.forward(decoder_input, incr_state)
+ score, incr_state, *_ = self.forward(decoder_input, incr_state, is_train=False)
score = score.half()
# now score is bs, len, vocab_size
@@ -143,8 +143,14 @@ class LlamaVAC(nn.Module):
"""
mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']
- def __init__(self, actor_path: str, critic_path: str,
- tokenizer_path: str, opt: Dict, enable_checkpointing: bool = True) -> None:
+ def __init__(
+ self,
+ actor_path: str,
+ critic_path: str,
+ tokenizer_path: str,
+ opt: Dict,
+ enable_checkpointing: bool = True
+ ) -> None:
"""
Overview:
Initialize the ``LlamaVAC`` model according to arguments.
@@ -156,13 +162,9 @@ def __init__(self, actor_path: str, critic_path: str,
super(LlamaVAC, self).__init__()
tokenizer = get_tokenizer(tokenizer_path)
- self.actor = Llama.from_pretrained(
- actor_path,
- opt=opt,
- tokenizer=tokenizer,
- )
+ self.actor = Llama.from_pretrained(actor_path, opt=opt, tokenizer=tokenizer, torch_dtype=torch.bfloat16)
- self.critic = LlamaRewardModel.from_pretrained(critic_path, tokenizer=tokenizer)
+ self.critic = LlamaRewardModel.from_pretrained(critic_path, tokenizer=tokenizer, torch_dtype=torch.bfloat16)
if enable_checkpointing:
self.actor.gradient_checkpointing_enable()
diff --git a/ding/policy/ppof.py b/ding/policy/ppof.py
index ad1afb1c1d..d2da40772d 100644
--- a/ding/policy/ppof.py
+++ b/ding/policy/ppof.py
@@ -71,8 +71,6 @@ def __init__(
) -> None:
self._cfg = cfg
self._orig_model = orig_model
- if self._orig_model is not None:
- self.scalar = GradScaler()
if model is None:
self._model = self.default_model()
else:
@@ -275,11 +273,11 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
)
ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio)
kl_loss = (
- torch.nn.functional.kl_div(
- torch.softmax(output["logit"], dim=-1),
- torch.softmax(batch.orig_logit, dim=-1),
- reduction='none'
- ) * mask.unsqueeze(-1)
+ torch.nn.functional.kl_div(
+ torch.softmax(output["logit"], dim=-1),
+ torch.softmax(batch.orig_logit, dim=-1),
+ reduction='none'
+ ) * mask.unsqueeze(-1)
).mean()
elif self._action_space == 'hybrid':
# discrete part (discrete policy loss and entropy loss)
@@ -316,13 +314,8 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss + wk * kl_loss
output = ttorch.as_tensor(output)
self._optimizer.zero_grad()
- self.scaler.scale(total_loss).backward()
- # scaler.step() first unscales the gradients of the optimizer's assigned params.
- # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
- # otherwise, optimizer.step() is skipped.
- scaler.step(self._optimizer)
- # Updates the scale for next iteration.
- scaler.update()
+ total_loss.backward()
+ self._optimizer.step()
return_info = {
'cur_lr': self._optimizer.defaults['lr'],
From 31a61914c16f13a184cf341fecbdf59752dd98d8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’>
Date: Wed, 6 Dec 2023 21:30:18 +0800
Subject: [PATCH 16/17] fix imcopatible problem
---
ding/model/template/lm_vac.py | 11 +++++++++--
1 file changed, 9 insertions(+), 2 deletions(-)
diff --git a/ding/model/template/lm_vac.py b/ding/model/template/lm_vac.py
index 2dd96c6a0a..2931fc3fdf 100644
--- a/ding/model/template/lm_vac.py
+++ b/ding/model/template/lm_vac.py
@@ -31,10 +31,11 @@ def get_tokenizer(path: str):
class Llama(LlamaForCausalLM):
- def __init__(self, config, opt, tokenizer):
+ def __init__(self, config, opt, tokenizer, enable_checkpointing):
super().__init__(config)
self.opt = opt
self.tokenizer = tokenizer
+ self.enable_checkpointing = enable_checkpointing
def forward(self, decoder_input, incr_state=None, is_train=True):
@@ -60,6 +61,8 @@ def generate(self, batch, **kwargs):
"""
Generate response
"""
+ if self.enable_checkpointing:
+ self.gradient_checkpointing_disable()
maxlen_res = kwargs.pop('maxlen_res', self.opt.maxlen_res)
temperature = kwargs.pop('temperature', self.opt.temperature)
repetition_penalty = kwargs.pop('repetition_penalty', self.opt.repetition_penalty)
@@ -129,6 +132,8 @@ def generate(self, batch, **kwargs):
preds_scores.append([res_scores])
best_preds_scores = [preds[0] for preds in preds_scores]
+ if self.enable_checkpointing:
+ self.gradient_checkpointing_enable()
return best_preds_scores, preds_scores
@@ -161,8 +166,10 @@ def __init__(
"""
super(LlamaVAC, self).__init__()
tokenizer = get_tokenizer(tokenizer_path)
+ self.enable_checkpointing = enable_checkpointing
- self.actor = Llama.from_pretrained(actor_path, opt=opt, tokenizer=tokenizer, torch_dtype=torch.bfloat16)
+ self.actor = Llama.from_pretrained(actor_path, opt=opt, tokenizer=tokenizer, torch_dtype=torch.bfloat16,
+ enable_checkpointing=enable_checkpointing)
self.critic = LlamaRewardModel.from_pretrained(critic_path, tokenizer=tokenizer, torch_dtype=torch.bfloat16)
From 538ddd27a7e241880b4f30727cff1903c6ae6329 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’>
Date: Wed, 3 Jan 2024 15:40:35 +0800
Subject: [PATCH 17/17] reformat
---
ding/model/template/lm_vac.py | 9 +++++++--
ding/policy/ppof.py | 1 +
2 files changed, 8 insertions(+), 2 deletions(-)
diff --git a/ding/model/template/lm_vac.py b/ding/model/template/lm_vac.py
index 2931fc3fdf..747b1d1d8f 100644
--- a/ding/model/template/lm_vac.py
+++ b/ding/model/template/lm_vac.py
@@ -168,8 +168,13 @@ def __init__(
tokenizer = get_tokenizer(tokenizer_path)
self.enable_checkpointing = enable_checkpointing
- self.actor = Llama.from_pretrained(actor_path, opt=opt, tokenizer=tokenizer, torch_dtype=torch.bfloat16,
- enable_checkpointing=enable_checkpointing)
+ self.actor = Llama.from_pretrained(
+ actor_path,
+ opt=opt,
+ tokenizer=tokenizer,
+ torch_dtype=torch.bfloat16,
+ enable_checkpointing=enable_checkpointing
+ )
self.critic = LlamaRewardModel.from_pretrained(critic_path, tokenizer=tokenizer, torch_dtype=torch.bfloat16)
diff --git a/ding/policy/ppof.py b/ding/policy/ppof.py
index d2da40772d..3dbfb05581 100644
--- a/ding/policy/ppof.py
+++ b/ding/policy/ppof.py
@@ -323,6 +323,7 @@ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
'policy_loss': ppo_loss.policy_loss.item(),
'value_loss': ppo_loss.value_loss.item(),
'entropy_loss': ppo_loss.entropy_loss.item(),
+ 'kl_loss': kl_loss.item(),
'adv_max': adv.max().item(),
'adv_mean': adv.mean().item(),
'value_mean': output.value.mean().item(),