diff --git a/intel_extension_for_transformers/neural_chat/examples/finetuning/ppo_pipeline/README.md b/intel_extension_for_transformers/neural_chat/examples/finetuning/ppo_pipeline/README.md index ebc2b2d08b2..27b3eded2f2 100644 --- a/intel_extension_for_transformers/neural_chat/examples/finetuning/ppo_pipeline/README.md +++ b/intel_extension_for_transformers/neural_chat/examples/finetuning/ppo_pipeline/README.md @@ -43,3 +43,25 @@ multi card finetunes ``` python ../instruction/gaudi_spawn.py --world_size 8 --use_mpi reward_modeling.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir --log_level info --num_train_epochs 1 --use_habana --use_lazy_mode --hf_access_token xxxxxx --ddp_find_unused_parameters True ``` + +## 5. Reinforcement Fine-tuning + +### Training on CUDA +``` +accelerate launch --multi_gpu --num_machines 1 --num_processes 8 rl_training.py --log_with=wandb --model_name=meta-llama/Llama-2-7b-hf --reward_model_name=output_se --adafactor=False --tokenizer_name=meta-llama/Llama-2-7b-hf --save_freq=100 --output_max_length=128 --batch_size=8 --gradient_accumulation_steps=8 --batched_gen=True --ppo_epochs=4 --seed=0 --learning_rate=1.4e-5 --early_stopping=True --output_dir=llama-se-rl-finetune-128-8-8-1.4e-5_adam --hf_access_token xxxxxx +``` + +### Training on Habana + +Follow install guidance in [optimum-habana](https://github.com/huggingface/optimum-habana) + +single card finetune + +``` +python3 rl_training.py --model_name=meta-llama/Llama-2-7b-hf --reward_model_name= --adafactor=False --tokenizer_name=meta-llama/Llama-2-7b-hf --save_freq=100 --output_max_length=128 --batch_size=8 --mini_batch_size=1 --gradient_accumulation_steps=8 --batched_gen=True --ppo_epochs=4 --seed=0 --learning_rate=1.4e-5 --early_stopping=True --output_dir=llama-se-rl-finetune-128-8-8-1.4e-5_adam --hf_access_token xxxxxx --use_habana +``` + +multi card finetunes +``` +python3 ../instruction/gaudi_spawn.py --world_size 8 --use_mpi rl_training.py --model_name=meta-llama/Llama-2-7b-hf --reward_model_name= --adafactor=False --tokenizer_name=meta-llama/Llama-2-7b-hf --save_freq=100 --output_max_length=128 --batch_size=8 --mini_batch_size=1 --gradient_accumulation_steps=8 --batched_gen=True --ppo_epochs=4 --seed=0 --learning_rate=1.4e-5 --early_stopping=True --output_dir=llama-se-rl-finetune-128-8-8-1.4e-5_adam --hf_access_token xxxxxx --use_habana +``` diff --git a/intel_extension_for_transformers/neural_chat/examples/finetuning/ppo_pipeline/requirements.txt b/intel_extension_for_transformers/neural_chat/examples/finetuning/ppo_pipeline/requirements.txt index ba2126416cd..ffaf73f5112 100644 --- a/intel_extension_for_transformers/neural_chat/examples/finetuning/ppo_pipeline/requirements.txt +++ b/intel_extension_for_transformers/neural_chat/examples/finetuning/ppo_pipeline/requirements.txt @@ -5,3 +5,5 @@ datasets bitsandbytes evaluate scikit-learn +intel-extension-for-transformers +tyro diff --git a/intel_extension_for_transformers/neural_chat/examples/finetuning/ppo_pipeline/reward_modeling.py b/intel_extension_for_transformers/neural_chat/examples/finetuning/ppo_pipeline/reward_modeling.py index 39931d918e6..29fe3c2d087 100644 --- a/intel_extension_for_transformers/neural_chat/examples/finetuning/ppo_pipeline/reward_modeling.py +++ b/intel_extension_for_transformers/neural_chat/examples/finetuning/ppo_pipeline/reward_modeling.py @@ -199,14 +199,14 @@ def preprocess_function(examples): "input_ids_k": [], "attention_mask_k": [], } - for question, response_j, response_k in zip( - examples["question"], examples["chatgpt"], examples["llama2-13b-chat"] + for system, question, response_j, response_k in zip( + examples["system"], examples["question"], examples["chatgpt"], examples["llama2-13b-chat"] ): tokenized_j = tokenizer( - "Question: " + question + "\n\nAnswer: " + response_j, truncation=True + system + question + response_j, truncation=True ) tokenized_k = tokenizer( - "Question: " + question + "\n\nAnswer: " + response_k, truncation=True + system + question + response_k, truncation=True ) new_examples["input_ids_j"].append(tokenized_j["input_ids"]) diff --git a/intel_extension_for_transformers/neural_chat/examples/finetuning/ppo_pipeline/rl_training.py b/intel_extension_for_transformers/neural_chat/examples/finetuning/ppo_pipeline/rl_training.py new file mode 100644 index 00000000000..7dca27c963d --- /dev/null +++ b/intel_extension_for_transformers/neural_chat/examples/finetuning/ppo_pipeline/rl_training.py @@ -0,0 +1,393 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +import torch +from datasets import load_dataset +from peft import LoraConfig +from tqdm import tqdm +from transformers import ( + Adafactor, + AutoTokenizer, + HfArgumentParser, + pipeline, + AutoModelForSequenceClassification, +) + +from intel_extension_for_transformers.transformers.ppo_core import ( + LengthSampler, + set_seed, +) +from intel_extension_for_transformers.transformers.ppo_config import PPOConfig +from intel_extension_for_transformers.transformers.ppo_trainer import PPOTrainer +from intel_extension_for_transformers.transformers.modeling.trl_models import ( + AutoModelForCausalLMWithValueHead, +) + +import sys +import logging + +logger = logging.getLogger(__name__) +# Setup logging +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], +) + + +def build_dataset( + tokenizer, + dataset_name, +): + """ + Build dataset for training. This builds the dataset from `load_dataset`, one should + customize this function to train the model on its own dataset. + + Args: + dataset_name (`str`): + The name of the dataset to be loaded. + + Returns: + dataloader (`torch.utils.data.DataLoader`): + The dataloader for the dataset. + """ + + # load imdb with datasets + ds = load_dataset(dataset_name, split="train") + original_columns = ds.column_names + num_proc = 24 + + def preprocess_function(examples): + new_examples = { + "query": [], + "input_ids": [], + } + for system, question in zip(examples["system"], examples["question"]): + query = system + question + tokenized_question = tokenizer(query, truncation=True) + new_examples["query"].append(query) + new_examples["input_ids"].append(tokenized_question["input_ids"]) + + return new_examples + + ds = ds.map( + preprocess_function, + batched=True, + num_proc=num_proc, + remove_columns=original_columns, + ) + ds = ds.filter(lambda x: len(x["input_ids"]) <= 512, batched=False) + + ds.set_format(type="torch") + return ds + + +def collator(data): + return dict((key, [d[key] for d in data]) for key in data[0]) + + +@dataclass +class ScriptArguments: + """ + The name of the Casual LM model we wish to fine with PPO + """ + + # NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode + # models like gpt-neo* models are more suitable. + model_name: Optional[str] = field(default="", metadata={"help": "the model name"}) + tokenizer_name: Optional[str] = field( + default="", metadata={"help": "the tokenizer name"} + ) + reward_model_name: Optional[str] = field( + default="", metadata={"help": "the reward model name"} + ) + log_with: Optional[str] = field( + default=None, metadata={"help": "use 'wandb' to log with wandb"} + ) + learning_rate: Optional[float] = field( + default=1.41e-5, metadata={"help": "the learning rate"} + ) + output_max_length: Optional[int] = field( + default=128, metadata={"help": "maximum length for generation"} + ) + mini_batch_size: Optional[int] = field( + default=1, metadata={"help": "the PPO minibatch size"} + ) + batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"}) + ppo_epochs: Optional[int] = field( + default=4, metadata={"help": "the number of ppo epochs"} + ) + gradient_accumulation_steps: Optional[int] = field( + default=4, metadata={"help": "the number of gradient accumulation steps"} + ) + adafactor: Optional[bool] = field( + default=False, metadata={"help": "whether to use the adafactor optimizer"} + ) + early_stopping: Optional[bool] = field( + default=False, metadata={"help": "whether to early stop"} + ) + target_kl: Optional[float] = field( + default=0.1, metadata={"help": "kl target for early stopping"} + ) + reward_baseline: Optional[float] = field( + default=0.0, + metadata={"help": "a baseline value that is subtracted from the reward"}, + ) + batched_gen: Optional[bool] = field( + default=False, metadata={"help": "whether to use the batched text gen"} + ) + save_freq: Optional[int] = field( + default=None, metadata={"help": "n steps to save the model"} + ) + output_dir: Optional[str] = field( + default="runs/", metadata={"help": "n steps to save the model"} + ) + seed: Optional[int] = field(default=0, metadata={"help": "the seed"}) + steps: Optional[int] = field(default=20000, metadata={"help": "number of epochs"}) + init_kl_coef: Optional[float] = field( + default=0.2, + metadata={ + "help": "Initial KL penalty coefficient (used for adaptive and linear control)" + }, + ) + + adap_kl_ctrl: Optional[bool] = field( + default=True, metadata={"help": "Use adaptive KL control, otherwise linear"} + ) + + hf_access_token: Optional[str] = field( + default=None, + metadata={"help": "Huggingface token to access model."}, + ) + + dataset_name: Optional[str] = field( + default="Intel/orca_dpo_pairs", + metadata={"help": "The name of the dataset to use (via the datasets library)."}, + ) + lora_rank: Optional[int] = field( + default=16, + metadata={"help": "Rank parameter in the LoRA method."}, + ) + lora_alpha: Optional[int] = field( + default=32, + metadata={"help": "Alpha parameter in the LoRA method."}, + ) + lora_dropout: Optional[float] = field( + default=0.05, + metadata={"help": "Dropout parameter in the LoRA method."}, + ) + use_habana: Optional[bool] = field( + default=False, metadata={"help": "use habana for RL training"} + ) + + +if __name__ == "__main__": + parser = HfArgumentParser(ScriptArguments) + script_args: ScriptArguments = parser.parse_args_into_dataclasses()[0] + reward_model_name = script_args.reward_model_name + config = PPOConfig( + steps=script_args.steps, + model_name=script_args.model_name, + learning_rate=script_args.learning_rate, + log_with=script_args.log_with, + batch_size=script_args.batch_size, + mini_batch_size=script_args.mini_batch_size, + gradient_accumulation_steps=script_args.gradient_accumulation_steps, + optimize_device_cache=True, + early_stopping=script_args.early_stopping, + target_kl=script_args.target_kl, + ppo_epochs=script_args.ppo_epochs, + seed=script_args.seed, + init_kl_coef=script_args.init_kl_coef, + adap_kl_ctrl=script_args.adap_kl_ctrl, + use_habana=script_args.use_habana, + pad_for_acceleration=script_args.use_habana, + pad_max_len=512 + script_args.output_max_length, + pad_max_input_len=512, + ) + + # We then define the arguments to pass to the sentiment analysis pipeline. + # We set `return_all_scores` to True to get the sentiment score for each token. + sent_kwargs = { + "return_all_scores": True, + "function_to_apply": "none", + "batch_size": 16, + "truncation": True, + } + + if config.pad_for_acceleration: + sent_kwargs["padding"] = "max_length" + # is 1024 enough? + sent_kwargs["max_length"] = 1024 + + tokenizer = AutoTokenizer.from_pretrained( + script_args.tokenizer_name, token=script_args.hf_access_token + ) + + if getattr(tokenizer, "pad_token", None) is None: + tokenizer.pad_token = tokenizer.eos_token + + # We retrieve the dataloader by calling the `build_dataset` function. + dataset = build_dataset(tokenizer, dataset_name=script_args.dataset_name) + + # set seed before initializing value head for deterministic eval + set_seed(config.seed) + + lora_config = LoraConfig( + r=script_args.lora_rank, + lora_alpha=script_args.lora_alpha, + lora_dropout=script_args.lora_dropout, + bias="none", + task_type="CAUSAL_LM", + ) + model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + peft_config=lora_config, + token=script_args.hf_access_token, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + ) + model = model.to(torch.bfloat16) + + if script_args.use_habana: + ref_model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + token=script_args.hf_access_token, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + ) + else: + ref_model = None + + optimizer = None + if script_args.adafactor: + optimizer = Adafactor( + filter(lambda p: p.requires_grad, model.parameters()), + scale_parameter=False, + relative_step=False, + warmup_init=False, + lr=config.learning_rate, + ) + # We then build the PPOTrainer, passing the model, the reference model, the tokenizer + ppo_trainer = PPOTrainer( + config, + model, + ref_model=ref_model, + tokenizer=tokenizer, + dataset=dataset, + data_collator=collator, + optimizer=optimizer, + ) + + # We then build the sentiment analysis pipeline using our reward model, passing the + # model name and the sentiment analysis pipeline arguments. Let's also make sure to + # set the device to the same device as the PPOTrainer. + device = ppo_trainer.accelerator.device + if ppo_trainer.accelerator.num_processes == 1: + if torch.cuda.is_available(): + device = 0 + + reward_model = AutoModelForSequenceClassification.from_pretrained( + reward_model_name, + num_labels=1, + low_cpu_mem_usage=True, + torch_dtype=torch.bfloat16, + token=script_args.hf_access_token, + ) + + if config.use_habana: + from habana_frameworks.torch.hpu import ( + wrap_in_hpu_graph, + ) # pylint: disable=E0611, E0401 + + reward_model = wrap_in_hpu_graph(reward_model) + + sentiment_pipe = pipeline( + "sentiment-analysis", + model=reward_model, + tokenizer=tokenizer, + return_token_type_ids=False, + device=device, + model_kwargs={ + "use_auth_token": script_args.hf_access_token, + "low_cpu_mem_usage": True, + "torch_dtype": torch.bfloat16, + }, + ) + + # Some tokenizers like GPT-2's don't have a padding token by default, so we set one here. + if sentiment_pipe.tokenizer.pad_token_id is None: + sentiment_pipe.tokenizer.pad_token_id = tokenizer.pad_token_id + + if sentiment_pipe.model.config.pad_token_id is None: + sentiment_pipe.model.config.pad_token_id = tokenizer.pad_token_id + + # We then define the arguments to pass to the `generate` function. These arguments + # are passed to the `generate` function of the PPOTrainer, which is a wrapper around + # the `generate` function of the trained model. + generation_kwargs = { + # "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.pad_token_id, + "eos_token_id": 100_000, + } + output_min_length = 32 + output_max_length = script_args.output_max_length + if not config.pad_for_acceleration: + output_length_sampler = LengthSampler(output_min_length, output_max_length) + else: + output_length_sampler = LengthSampler(output_max_length, output_max_length + 1) + + epochs = tqdm( + enumerate(ppo_trainer.dataloader), + total=len(ppo_trainer.dataloader), + desc="rl progress", + ) + for epoch, batch in epochs: + if epoch >= config.total_ppo_epochs: + break + + question_tensors = batch["input_ids"] + + response_tensors = ppo_trainer.generate( + question_tensors, + return_prompt=False, + length_sampler=output_length_sampler, + **generation_kwargs, + ) + batch["response"] = tokenizer.batch_decode( + response_tensors, skip_special_tokens=True + ) + # Compute reward score (using the sentiment analysis pipeline) + texts = [q + r for q, r in zip(batch["query"], batch["response"])] + pipe_outputs = sentiment_pipe(texts, **sent_kwargs) + rewards = [ + torch.tensor(output[0]["score"] - script_args.reward_baseline) + for output in pipe_outputs + ] + # Run PPO step + stats = ppo_trainer.step(question_tensors, response_tensors, rewards) + ppo_trainer.log_stats(stats, batch, rewards) + epochs.update(1) + + if script_args.save_freq and epoch and epoch % script_args.save_freq == 0: + ppo_trainer.save_pretrained(script_args.output_dir + f"step_{epoch}") diff --git a/intel_extension_for_transformers/transformers/modeling/trl_models/__init__.py b/intel_extension_for_transformers/transformers/modeling/trl_models/__init__.py new file mode 100644 index 00000000000..64f2865cf59 --- /dev/null +++ b/intel_extension_for_transformers/transformers/modeling/trl_models/__init__.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .modeling_base import PreTrainedModelWrapper, create_reference_model +from .modeling_value_head import ( + AutoModelForCausalLMWithValueHead, + AutoModelForSeq2SeqLMWithValueHead, +) + + +SUPPORTED_ARCHITECTURES = ( + AutoModelForCausalLMWithValueHead, + AutoModelForSeq2SeqLMWithValueHead, +) diff --git a/intel_extension_for_transformers/transformers/modeling/trl_models/modeling_base.py b/intel_extension_for_transformers/transformers/modeling/trl_models/modeling_base.py new file mode 100644 index 00000000000..88a74bf27ea --- /dev/null +++ b/intel_extension_for_transformers/transformers/modeling/trl_models/modeling_base.py @@ -0,0 +1,742 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +from copy import deepcopy + +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import ( + EntryNotFoundError, + HFValidationError, + LocalEntryNotFoundError, +) +from safetensors.torch import load_file as safe_load_file +from transformers import PreTrainedModel +import importlib + + +def is_peft_available(): + return importlib.util.find_spec("peft") is not None + + +def is_transformers_greater_than(version: str) -> bool: + _transformers_version = importlib.metadata.version("transformers") + return _transformers_version > version + + +def is_optimum_habana_available(): + import importlib + from transformers.utils.import_utils import is_optimum_available + + return is_optimum_available() and importlib.util.find_spec("optimum.habana") != None + + +if is_optimum_habana_available(): + from optimum.habana.accelerate import GaudiAccelerator as Accelerator # pylint: disable=E0611, E0401 +else: + from accelerate import Accelerator + +if is_peft_available(): + from peft import ( + LoraConfig, + PeftConfig, + PeftModel, + PeftModelForCausalLM, + PeftModelForSeq2SeqLM, + PromptLearningConfig, + get_peft_model, + prepare_model_for_kbit_training, + ) + from peft.peft_model import set_peft_model_state_dict + +if is_transformers_greater_than("4.33.0"): + from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled # pylint: disable=E0611, E0401 +else: + from transformers.deepspeed import is_deepspeed_zero3_enabled + +LAYER_PATTERNS = [ + "transformer.h.{layer}", + "model.decoder.layers.{layer}", + "gpt_neox.layers.{layer}", + "model.layers.{layer}", +] + + +class PreTrainedModelWrapper(nn.Module): + r""" + A wrapper class around a (`transformers.PreTrainedModel`) to be compatible with the + (`~transformers.PreTrained`) class in order to keep some attributes and methods of the + (`~transformers.PreTrainedModel`) class. + + Attributes: + pretrained_model: (`transformers.PreTrainedModel`) + The model to be wrapped. + parent_class: (`transformers.PreTrainedModel`) + The parent class of the model to be wrapped. + supported_args: (`list`) + The list of arguments that are supported by the wrapper class. + """ + transformers_parent_class = None + supported_args = None + supported_modules = ("v_head",) + supported_rm_modules = ("score",) + supported_pretrained_model_architectures = ( + (PreTrainedModel) + if not is_peft_available() + else (PreTrainedModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM) + ) + + def __init__(self, pretrained_model=None, **kwargs): + super().__init__() + self.pretrained_model = pretrained_model + + self.config = pretrained_model.config + self.prepare_inputs_for_generation = ( + pretrained_model.prepare_inputs_for_generation + ) + self.is_loaded_in_8bit = getattr(pretrained_model, "is_loaded_in_8bit", False) + self.is_loaded_in_4bit = getattr(pretrained_model, "is_loaded_in_4bit", False) + self.is_sequential_parallel = False + + if hasattr(pretrained_model, "gradient_checkpointing_disable"): + self.gradient_checkpointing_disable = ( + pretrained_model.gradient_checkpointing_disable + ) + + if hasattr(pretrained_model, "gradient_checkpointing_enable"): + self.gradient_checkpointing_enable = ( + pretrained_model.gradient_checkpointing_enable + ) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r""" + Instantiates a new model from a pretrained model from `transformers`. The + pretrained model is loaded using the `from_pretrained` method of the + `transformers.PreTrainedModel` class. The arguments that are specific to the + `transformers.PreTrainedModel` class are passed along this method and filtered + out from the `kwargs` argument. + + + Args: + pretrained_model_name_or_path (`str` or `transformers.PreTrainedModel`): + The path to the pretrained model or its name. + *model_args (`list`, *optional*)): + Additional positional arguments passed along to the underlying model's + `from_pretrained` method. + **kwargs (`dict`, *optional*): + Additional keyword arguments passed along to the underlying model's + `from_pretrained` method. We also pre-process the kwargs to extract + the arguments that are specific to the `transformers.PreTrainedModel` + class and the arguments that are specific to trl models. The kwargs + also support `prepare_model_for_kbit_training` arguments from + `peft` library. + """ + if kwargs is not None: + peft_config = kwargs.pop("peft_config", None) + reward_adapter = kwargs.pop("reward_adapter", None) + is_trainable = kwargs.pop("is_trainable", False) + ( + trl_model_args, + pretrained_kwargs, + peft_quantization_kwargs, + ) = cls._split_kwargs(kwargs) + token = pretrained_kwargs.get("token", None) + else: + peft_config = None + is_trainable = False + trl_model_args = {} + pretrained_kwargs = {} + peft_quantization_kwargs = {} + token = None + + if reward_adapter is not None and not isinstance(reward_adapter, str): + raise ValueError( + "The `reward_adapter` argument should be a string representing the name of local path or the Hub id to" + "the Reward Modeling adapter." + ) + + is_peft_model = False + + current_device = cls._get_current_device() + if isinstance(pretrained_model_name_or_path, str): + is_loaded_in_8bit = ( + pretrained_kwargs["load_in_8bit"] + if "load_in_8bit" in pretrained_kwargs + else False + ) + is_loaded_in_4bit = ( + pretrained_kwargs["load_in_4bit"] + if "load_in_4bit" in pretrained_kwargs + else False + ) + else: + is_loaded_in_8bit = getattr( + pretrained_model_name_or_path, "is_loaded_in_8bit", False + ) + is_loaded_in_4bit = getattr( + pretrained_model_name_or_path, "is_loaded_in_4bit", False + ) + + if ( + is_loaded_in_8bit or is_loaded_in_4bit + ) and "device_map" not in pretrained_kwargs: + # warn users + logging.warning( + "The `device_map` argument is not provided. We will override the device_map argument." + " to set the entire" + " model on the current device. If you want to set the model on multiple devices, please provide" + " a custom `device_map` argument." + ) + pretrained_kwargs["device_map"] = {"": current_device} + + if ( + is_peft_available() + and peft_config is not None + and not isinstance(peft_config, PeftConfig) + ): + raise ValueError( + "The `peft_config` argument should be an instance of `peft.PeftConfig` class." + ) + + # First, load the pre-trained model using the parent-class + # either `AutoModelForCausalLM` or `AutoModelForSeq2SeqLM` + if isinstance(pretrained_model_name_or_path, str): + if is_peft_available(): + try: + # If there is a trained peft adapter in the hub, load its config. + remote_adapter_config = hf_hub_download( + pretrained_model_name_or_path, + "adapter_config.json", + token=token, + ) + except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError): + remote_adapter_config = None + else: + remote_adapter_config = None + + local_adapter_present = os.path.exists( + os.path.join(pretrained_model_name_or_path, "adapter_config.json") + ) + + if ( + local_adapter_present or remote_adapter_config is not None + ) and is_peft_available(): + if peft_config is not None: + logging.warning( + "`peft_config` argument ignored since a peft config file was found in " + f"{pretrained_model_name_or_path}" + ) + + # Load the trained peft adapter config + if local_adapter_present: + trained_adapter_config = PeftConfig.from_pretrained( + pretrained_model_name_or_path + ) + else: + remote_adapter_dir = os.path.dirname(remote_adapter_config) + trained_adapter_config = PeftConfig.from_pretrained( + remote_adapter_dir + ) + + # Load the pretrained base model + pretrained_model = cls.transformers_parent_class.from_pretrained( + trained_adapter_config.base_model_name_or_path, + *model_args, + **pretrained_kwargs, + ) + + # Wrap the pretrained model with the trained peft adapter + pretrained_model = PeftModel.from_pretrained( + pretrained_model, + pretrained_model_name_or_path, + is_trainable=is_trainable, + ) + logging.info("Trained peft adapter loaded") + else: + pretrained_model = cls.transformers_parent_class.from_pretrained( + pretrained_model_name_or_path, *model_args, **pretrained_kwargs + ) + + if peft_config is not None: + # Initialize a new peft adapter with the given config + if is_loaded_in_8bit or is_loaded_in_4bit: + pretrained_model = prepare_model_for_kbit_training( + pretrained_model, + **peft_quantization_kwargs, + ) + pretrained_model = get_peft_model(pretrained_model, peft_config) + logging.info("peft adapter initialised") + + elif isinstance( + pretrained_model_name_or_path, cls.supported_pretrained_model_architectures + ): + pretrained_model = pretrained_model_name_or_path + + if peft_config is not None and isinstance( + pretrained_model, PreTrainedModel + ): + # Initialize a new peft adapter with the given config + if is_loaded_in_8bit or is_loaded_in_4bit: + pretrained_model = prepare_model_for_kbit_training( + pretrained_model, + **peft_quantization_kwargs, + ) + pretrained_model = get_peft_model(pretrained_model, peft_config) + logging.info("peft adapter initialised") + else: + raise ValueError( + "pretrained_model_name_or_path should be a string or a PreTrainedModel, " + f"but is {type(pretrained_model_name_or_path)}" + ) + + if is_peft_available(): + if isinstance(pretrained_model, PeftModel): + is_peft_model = True + # for backward compatibility + if hasattr(pretrained_model, "active_peft_config") and isinstance( + pretrained_model.active_peft_config, PromptLearningConfig + ): + raise ValueError( + "PromptLearningConfig is not supported for PPO training." + ) + # Then, create the full model by instantiating the wrapper class + model = cls(pretrained_model, **trl_model_args) + + # if resume_training, load the state_dict again - this is ok since the + # state_dict is removed from the model after loading it. + is_resuming_training = True + if isinstance(pretrained_model_name_or_path, str): + safe_filename = os.path.join( + pretrained_model_name_or_path, "model.safetensors" + ) + filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") + + sharded_index_filename = os.path.join( + pretrained_model_name_or_path, "pytorch_model.bin.index.json" + ) + safe_sharded_index_filename = os.path.join( + pretrained_model_name_or_path, "model.safetensors.index.json" + ) + is_sharded = False + use_safe = os.path.exists(safe_filename) + + if not (os.path.exists(filename) or os.path.exists(safe_filename)): + # Try with `pytorch_model.bin` + ( + filename, + files_to_download, + is_sharded, + is_resuming_training, + ) = cls._get_checkpoint_from_hub( + pretrained_model, + pretrained_model_name_or_path, + sharded_index_filename, + token=token, + ) + # Try with safetensors + if filename is None and files_to_download is None: + ( + safe_filename, + files_to_download, + is_sharded, + is_resuming_training, + ) = cls._get_checkpoint_from_hub( + pretrained_model, + pretrained_model_name_or_path, + safe_sharded_index_filename, + token=token, + model_name="model.safetensors", + model_index_name="model.safetensors.index.json", + ) + use_safe = True + else: + use_safe = False + + loading_func = safe_load_file if use_safe else torch.load + load_kwargs = {} if use_safe else {"map_location": "cpu"} + + if is_resuming_training: + if is_sharded: + # download each file and add it to the state_dict + state_dict = {} + + for shard_file in files_to_download: + filename = hf_hub_download( + pretrained_model_name_or_path, + shard_file, + token=token, + ) + state_dict.update(loading_func(filename, **load_kwargs)) + else: + state_dict = loading_func( + filename if not use_safe else safe_filename, **load_kwargs + ) + + else: + state_dict = pretrained_model_name_or_path.state_dict() + + model.is_peft_model = is_peft_model + model.current_device = current_device + + if is_resuming_training: + model.post_init(state_dict=state_dict) + + if not is_peft_model and reward_adapter is not None: + raise ValueError("reward_adapter can only be used with a PeftModel. ") + elif is_peft_model and reward_adapter is not None: + model.add_and_load_reward_modeling_adapter(reward_adapter, token=token) + model.supports_rm_adapter = True + else: + model.supports_rm_adapter = False + + return model + + @classmethod + def _get_checkpoint_from_hub( + cls, + pretrained_model, + pretrained_model_name_or_path, + index_filename, + token=None, + model_name="pytorch_model.bin", + model_index_name="pytorch_model.bin.index.json", + ): + files_to_download = None + filename = None + is_resuming_training = True + is_sharded = False + + try: + filename = hf_hub_download( + pretrained_model_name_or_path, + model_name, + token=token, + ) + # sharded + except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError): + if os.path.exists(index_filename): + index_file_name = index_filename + else: + try: + index_file_name = hf_hub_download( + pretrained_model_name_or_path, + model_index_name, + token=token, + ) + except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError): + # not continue training, do not have v_head weight + is_resuming_training = False + logging.warning( + f"A {type(pretrained_model)} model is loaded from '{pretrained_model_name_or_path}', " + f"and no v_head weight is found. This IS expected if you are not resuming PPO training." + ) + # load json + if is_resuming_training: + with open(index_file_name, "r") as f: + index = json.load(f) + # check filename with `v_head` or any known extra module: + files_to_download = set() + for k, v in index["weight_map"].items(): + if any([module in k for module in cls.supported_modules]): + files_to_download.add(v) + is_sharded = True + + return filename, files_to_download, is_sharded, is_resuming_training + + @classmethod + def _get_current_device(cls): + r""" + Get the current device. For GPU, we return the local process index using the `Accelerator` + object to handle corner cases when running scripts in distributed environments. + + Returns: + current_device (`Union[int, str]`): + The current device. + """ + dummy_accelerator = Accelerator() + return ( + dummy_accelerator.local_process_index + if torch.cuda.is_available() + else "cpu" + ) + + @classmethod + def _split_kwargs(cls, kwargs): + """ + Separate the kwargs from the arguments that we support inside + `supported_args` and the ones that we don't. + """ + check_peft_kwargs = False + + if is_peft_available(): + from peft import prepare_model_for_kbit_training + + check_peft_kwargs = True + + supported_kwargs = {} + unsupported_kwargs = {} + peft_kwargs = {} + + for key, value in kwargs.items(): + if key in cls.supported_args: + supported_kwargs[key] = value + else: + unsupported_kwargs[key] = value + + if check_peft_kwargs: + if key in prepare_model_for_kbit_training.__code__.co_varnames: + peft_kwargs[key] = value + if key in unsupported_kwargs: + unsupported_kwargs.pop(key) + + return supported_kwargs, unsupported_kwargs, peft_kwargs + + def push_to_hub(self, *args, **kwargs): + r""" + Push the pretrained model to the hub. This method is a wrapper around + `transformers.PreTrainedModel.push_to_hub`. Please refer to the documentation + of `transformers.PreTrainedModel.push_to_hub` for more information. + + Args: + *args (`list`, *optional*): + Positional arguments passed along to the underlying model's + `push_to_hub` method. + **kwargs (`dict`, *optional*): + Keyword arguments passed along to the underlying model's + `push_to_hub` method. + """ + raise NotImplementedError + + def save_pretrained(self, *args, **kwargs): + r""" + Save the pretrained model to a directory. This method is a wrapper around + `transformers.PreTrainedModel.save_pretrained`. Please refer to the documentation + of `transformers.PreTrainedModel.save_pretrained` for more information. + + Args: + *args (`list`, *optional*): + Positional arguments passed along to the underlying model's + `save_pretrained` method. + **kwargs (`dict`, *optional*): + Keyword arguments passed along to the underlying model's + `save_pretrained` method. + """ + state_dict = kwargs.get("state_dict") + if state_dict is None: + state_dict = self.state_dict() + kwargs["state_dict"] = state_dict + + # if it is a peft model only save the `v_head` state_dict and + # pop the `state_dict` from the kwargs to avoid slient bugs with `peft` + if self.is_peft_model: + save_path = args[0] + save_path = os.path.join(save_path, "pytorch_model.bin") + torch.save(state_dict, save_path) + _ = kwargs.pop("state_dict", None) + + return self.pretrained_model.save_pretrained(*args, **kwargs) + + def state_dict(self, *args, **kwargs): + r""" + Return the state_dict of the pretrained model. + """ + raise NotImplementedError + + def post_init(self, *args, **kwargs): + r""" + Post initialization method. This method is called after the model is + instantiated and loaded from a checkpoint. It can be used to perform + additional operations such as loading the state_dict. + """ + raise NotImplementedError + + def add_and_load_reward_modeling_adapter( + self, adapter_model_id, adapter_name="reward_model_adapter", token=None + ): + r""" + Add and load a reward modeling adapter. This method can only be used if the + model is a `PeftModel` and if you have initialized the model with the `reward_modeling_adapter_id` + argument, pointing to the id of the reward modeling adapter. The latest needs also to contain the + score head in order to produce the reward. + """ + filename = os.path.join(adapter_model_id, "adapter_model.bin") + if not os.path.exists(filename): + try: + local_filename = hf_hub_download( + adapter_model_id, + "adapter_model.bin", + token=token, + ) + except: # noqa + raise ValueError( + "Could not find adapter model in the Hub, make sure you have the correct adapter model id." + ) + else: + local_filename = filename + + adapter_state_dict = torch.load(local_filename, map_location="cpu") + rm_adapter_peft_config = LoraConfig.from_pretrained(adapter_model_id) + + for score_name_candidate in self.supported_rm_modules: + if any( + [score_name_candidate in name for name in adapter_state_dict.keys()] + ): + score_name = score_name_candidate + # we have found the correct head name and can break + break + + score_dict = {} + copy_adapter_state_dict = adapter_state_dict.copy() + + for name, _ in copy_adapter_state_dict.items(): + if score_name in name: + key_name = ".".join(name.split(".")[-1:]) + score_dict[key_name] = adapter_state_dict.pop(name).to( + self._get_current_device() + ) + + self.pretrained_model.add_adapter(adapter_name, rm_adapter_peft_config) + self.rm_adapter_name = adapter_name + + num_labels, hidden_dim = score_dict["weight"].shape + has_bias = any(["bias" in name for name in adapter_state_dict.keys()]) + + self.score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to( + device=self._get_current_device(), + dtype=self.pretrained_model.dtype, + ) + self.score.load_state_dict(score_dict) + + # load the adapter to the model + set_peft_model_state_dict( + self.pretrained_model, adapter_state_dict, adapter_name=adapter_name + ) + + def compute_reward_score( + self, input_ids, attention_mask=None, ppo_adapter_name="default", **kwargs + ): + r""" + Computes the reward score for a given input. The method has first to enable the adapter + and then compute the reward score. After that the model disables the reward modeling + adapter and enables the default ppo adapter again. + """ + if not self.supports_rm_adapter: + raise ValueError("This model does not support reward modeling adapter.") + + # enable rm adapter + self.pretrained_model.set_adapter(self.rm_adapter_name) + self.pretrained_model.eval() + + base_model_output = self.pretrained_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + return_dict=True, + **kwargs, + ) + + last_hidden_states = base_model_output.hidden_states[-1] + scores = self.score(last_hidden_states) # pylint: disable=E1102 + + self.pretrained_model.set_adapter(ppo_adapter_name) + self.pretrained_model.train() + + return scores + + +def create_reference_model( + model: PreTrainedModelWrapper, num_shared_layers: int = None, pattern: str = None +) -> PreTrainedModelWrapper: + """ + Creates a static reference copy of a model. Note that model will be in `.eval()` mode. + + Args: + model (`PreTrainedModelWrapper`): The model to be copied. + num_shared_layers (`int`, *optional*): The number of initial layers that are shared between both models and + kept frozen. + pattern (`str`, *optional*): The shared layers are selected with a string pattern + (e.g. "transformer.h.{layer}" for GPT2) and if a custom pattern is necessary it can be passed here. + + Returns + `PreTrainedModelWrapper` + """ + if is_deepspeed_zero3_enabled(): + raise ValueError( + "DeepSpeed ZeRO-3 is enabled and is not compatible with `create_reference_model()`. Please instantiate " + "your reference model directly with `AutoCausalLM.from_pretrained()`." + ) + + parameter_names = [n for n, _ in model.named_parameters()] + ref_model = deepcopy(model) + + # if no layers are shared, return copy of model + if num_shared_layers is None: + for param_name in parameter_names: + param = ref_model.get_parameter(param_name) + param.requires_grad = False + return ref_model.eval() + + # identify layer name pattern + if pattern is not None: + pattern = pattern.format(layer=num_shared_layers) + else: + for pattern_candidate in LAYER_PATTERNS: + pattern_candidate = pattern_candidate.format(layer=num_shared_layers) + if any([pattern_candidate in name for name in parameter_names]): + pattern = pattern_candidate + break + + if pattern is None: + raise ValueError("Layer pattern could not be matched.") + + # divide parameters in shared and unshared parameter lists + shared_param_list = [] + unshared_param_list = [] + + shared_parameter = True + for name, param in model.named_parameters(): + if pattern in name: + shared_parameter = False + if shared_parameter: + shared_param_list.append(name) + else: + unshared_param_list.append(name) + + # create reference of the original parameter if they are shared + for param_name in shared_param_list: + param = model.get_parameter(param_name) + param.requires_grad = False + + ref_param = ref_model.get_parameter(param_name) # noqa + ref_param = param # noqa + + # for all other parameters just make sure they don't use gradients + for param_name in unshared_param_list: + param = ref_model.get_parameter(param_name) + param.requires_grad = False + + if pattern is not None and len(unshared_param_list) == 0: + logging.warning( + "Pattern passed or found, but no layers matched in the model. Check for a typo." + ) + + return ref_model.eval() diff --git a/intel_extension_for_transformers/transformers/modeling/trl_models/modeling_value_head.py b/intel_extension_for_transformers/transformers/modeling/trl_models/modeling_value_head.py new file mode 100644 index 00000000000..34be5246348 --- /dev/null +++ b/intel_extension_for_transformers/transformers/modeling/trl_models/modeling_value_head.py @@ -0,0 +1,452 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM + +from .modeling_base import PreTrainedModelWrapper + + +class ValueHead(nn.Module): + r""" + The ValueHead class implements a head for GPT2 that returns a scalar for each output token. + """ + + def __init__(self, config, **kwargs): + super().__init__() + if not hasattr(config, "summary_dropout_prob"): + summary_dropout_prob = kwargs.pop("summary_dropout_prob", 0.1) + else: + summary_dropout_prob = config.summary_dropout_prob + + self.dropout = ( + nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity() + ) + + # some models such as OPT have a projection layer before the word embeddings - e.g. OPT-350m + if hasattr(config, "word_embed_proj_dim"): + hidden_size = config.word_embed_proj_dim + else: + hidden_size = config.hidden_size + + self.summary = nn.Linear(hidden_size, 1) + + self.flatten = nn.Flatten() + + def forward(self, hidden_states): + output = self.dropout(hidden_states) + + # For now force upcast in fp32 if needed. Let's keep the + # output in fp32 for numerical stability. + if output.dtype != self.summary.weight.dtype: + output = output.to(self.summary.weight.dtype) + + output = self.summary(output) + return output + + +class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper): + r""" + An autoregressive model with a value head in addition to the language model head. + This class inherits from `~trl.PreTrainedModelWrapper` and wraps a + `transformers.PreTrainedModel` class. The wrapper class supports classic functions + such as `from_pretrained`, `push_to_hub` and `generate`. To call a method of the wrapped + model, simply manipulate the `pretrained_model` attribute of this class. + + Class attributes: + - **transformers_parent_class** (`transformers.PreTrainedModel`) -- The parent class of the wrapped model. This + should be set to `transformers.AutoModelForCausalLM` for this class. + - **lm_head_namings** (`tuple`) -- A tuple of strings that are used to identify the language model head of the + wrapped model. This is set to `("lm_head", "embed_out")` for this class but can be changed for other models + in the future + - **supported_args** (`tuple`) -- A tuple of strings that are used to identify the arguments that are supported + by the `ValueHead` class. Currently, the supported args are: + - **summary_dropout_prob** (`float`, `optional`, defaults to `None`) -- The dropout probability for the + `ValueHead` class. + - **v_head_initializer_range** (`float`, `optional`, defaults to `0.2`) -- The initializer range for the + `ValueHead` if a specific initialization strategy is selected. + - **v_head_init_strategy** (`str`, `optional`, defaults to `None`) -- The initialization strategy for the + `ValueHead`. Currently, the supported strategies are: + - **`None`** -- Initializes the weights of the `ValueHead` with a random distribution. This is the + default strategy. + - **"normal"** -- Initializes the weights of the `ValueHead` with a normal distribution. + + """ + transformers_parent_class = AutoModelForCausalLM + lm_head_namings = ["lm_head", "embed_out"] + supported_args = ( + "summary_dropout_prob", + "v_head_initializer_range", + "v_head_init_strategy", + ) + + def __init__(self, pretrained_model, **kwargs): + r""" + Initializes the model. + + Args: + pretrained_model (`transformers.PreTrainedModel`): + The model to wrap. It should be a causal language model such as GPT2. + or any model mapped inside the `AutoModelForCausalLM` class. + kwargs (`dict`, `optional`): + Additional keyword arguments, that are passed to the `ValueHead` class. + """ + super().__init__(pretrained_model) + v_head_kwargs, _, _ = self._split_kwargs(kwargs) + + if not any( + hasattr(self.pretrained_model, attribute) + for attribute in self.lm_head_namings + ): + raise ValueError( + "The model does not have a language model head, please use a model that has one." + ) + + self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs) + + self._init_weights(**v_head_kwargs) + + def _init_weights(self, **kwargs): + r""" + Initializes the weights of the value head. The default initialization strategy is random. + Users can pass a different initialization strategy by passing the `v_head_init_strategy` argument + when calling `.from_pretrained`. Supported strategies are: + - `normal`: initializes the weights with a normal distribution. + + Args: + **kwargs (`dict`, `optional`): + Additional keyword arguments, that are passed to the `ValueHead` class. These arguments + can contain the `v_head_init_strategy` argument as well as the `v_head_initializer_range` + argument. + """ + initializer_range = kwargs.pop("v_head_initializer_range", 0.2) + # random init by default + init_strategy = kwargs.pop("v_head_init_strategy", None) + if init_strategy is None: + # do nothing + pass + elif init_strategy == "normal": + self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range) + self.v_head.summary.bias.data.zero_() + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + **kwargs, + ): + r""" + Applies a forward pass to the wrapped model and returns the logits of the value head. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + past_key_values (`tuple(tuple(torch.FloatTensor))`, `optional`): + Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model + (see `past_key_values` input) to speed up sequential decoding. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + kwargs (`dict`, `optional`): + Additional keyword arguments, that are passed to the wrapped model. + """ + kwargs[ + "output_hidden_states" + ] = True # this had already been set in the LORA / PEFT examples + kwargs["past_key_values"] = past_key_values + + if ( + self.is_peft_model + and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING" + ): + kwargs.pop("past_key_values") + + base_model_output = self.pretrained_model( + input_ids=input_ids, + attention_mask=attention_mask, + **kwargs, + ) + + last_hidden_state = base_model_output.hidden_states[-1] + lm_logits = base_model_output.logits + loss = base_model_output.loss + + if last_hidden_state.device != self.v_head.summary.weight.device: + last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device) + + value = self.v_head(last_hidden_state).squeeze(-1) + + # force upcast in fp32 if logits are in half-precision + if lm_logits.dtype != torch.float32: + lm_logits = lm_logits.float() + + return (lm_logits, loss, value) + + def generate(self, *args, **kwargs): + r""" + A simple wrapper around the `generate` method of the wrapped model. + Please refer to the [`generate`](https://huggingface.co/docs/transformers/internal/generation_utils) + method of the wrapped model for more information about the supported arguments. + + Args: + *args (`list`, *optional*): + Positional arguments passed to the `generate` method of the wrapped model. + **kwargs (`dict`, *optional*): + Keyword arguments passed to the `generate` method of the wrapped model. + """ + return self.pretrained_model.generate(*args, **kwargs) + + def state_dict(self, *args, **kwargs): + r""" + Returns the state dictionary of the model. We add the state dictionary of the value head + to the state dictionary of the wrapped model by prepending the key with `v_head.`. + """ + if not self.is_peft_model: + pretrained_model_state_dict = self.pretrained_model.state_dict( + *args, **kwargs + ) + else: + # if it is a peft model, only save the v_head + pretrained_model_state_dict = {} + + v_head_state_dict = self.v_head.state_dict(*args, **kwargs) + for k, v in v_head_state_dict.items(): + pretrained_model_state_dict[f"v_head.{k}"] = v + return pretrained_model_state_dict + + def push_to_hub(self, *args, **kwargs): + setattr(self.pretrained_model, "v_head", self.v_head) + + return self.pretrained_model.push_to_hub(*args, **kwargs) + + def post_init(self, state_dict): + r""" + We add the state dictionary of the value head to the state dictionary of the wrapped model + by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the + keys of the value head state dictionary. + """ + for k in list(state_dict.keys()): + if "v_head." in k: + state_dict[k.replace("v_head.", "")] = state_dict.pop(k) + self.v_head.load_state_dict(state_dict, strict=False) + del state_dict + + if hasattr(self.pretrained_model, "hf_device_map"): + if ( + "cpu" in self.pretrained_model.hf_device_map.values() + or "disk" in self.pretrained_model.hf_device_map.values() + ): + raise ValueError( + "The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead" + " models." + ) + + first_device = list(set(self.pretrained_model.hf_device_map.values()))[0] + + self.v_head = self.v_head.to(first_device) + + def set_device_hook(module, input, outputs): + new_output = () + for output in outputs: + if isinstance(output, torch.Tensor): + new_output += (output.to(first_device),) + else: + new_output += (output,) + return new_output + + self.register_forward_hook(set_device_hook) + + self.is_sequential_parallel = True + + +class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper): + r""" + A seq2seq model with a value head in addition to the language model head. + This class inherits from `~trl.PreTrainedModelWrapper` and wraps a + `transformers.PreTrainedModel` class. The wrapper class supports classic functions + such as `from_pretrained` and `push_to_hub` and also provides some additional + functionalities such as `generate`. + + Args: + pretrained_model (`transformers.PreTrainedModel`): + The model to wrap. It should be a causal language model such as GPT2. + or any model mapped inside the `AutoModelForSeq2SeqLM` class. + kwargs: + Additional keyword arguments passed along to the `ValueHead` class. + """ + transformers_parent_class = AutoModelForSeq2SeqLM + lm_head_namings = ["lm_head", "embed_out", "output_projection"] + supported_args = ( + "summary_dropout_prob", + "v_head_initializer_range", + "v_head_init_strategy", + ) + + def __init__(self, pretrained_model, **kwargs): + super().__init__(pretrained_model) + v_head_kwargs, _, _ = self._split_kwargs(kwargs) + self.is_encoder_decoder = True + + if not self._has_lm_head(): + raise ValueError( + "The model does not have a language model head, please use a model that has one." + ) + + self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs) + + self._init_weights(**v_head_kwargs) + + def _has_lm_head(self): + # check module names of all modules inside `pretrained_model` to find the language model head + for name, module in self.pretrained_model.named_modules(): + if any(attribute in name for attribute in self.lm_head_namings): + return True + return False + + def post_init(self, state_dict): + r""" + We add the state dictionary of the value head to the state dictionary of the wrapped model + by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the + keys of the value head state dictionary. + """ + for k in list(state_dict.keys()): + if "v_head." in k: + state_dict[k.replace("v_head.", "")] = state_dict.pop(k) + self.v_head.load_state_dict(state_dict, strict=False) + del state_dict + + if hasattr(self.pretrained_model, "hf_device_map"): + if ( + "cpu" in self.pretrained_model.hf_device_map.values() + or "disk" in self.pretrained_model.hf_device_map.values() + ): + raise ValueError( + "The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead" + " models." + ) + + # get the lm_head device + for name, module in self.pretrained_model.named_modules(): + if any(attribute in name for attribute in self.lm_head_namings): + lm_head_device = module.weight.device + break + + # put v_head on the same device as the lm_head to avoid issues + self.v_head = self.v_head.to(lm_head_device) + + def set_device_hook(module, input, outputs): + r""" + A hook that sets the device of the output of the model to the device of the first + parameter of the model. + + Args: + module (`nn.Module`): + The module to which the hook is attached. + input (`tuple`): + The input to the module. + outputs (`tuple`): + The output of the module. + """ + new_output = () + for output in outputs: + if isinstance(output, torch.Tensor): + new_output += (output.to(lm_head_device),) + else: + new_output += (output,) + return new_output + + self.register_forward_hook(set_device_hook) + self.is_sequential_parallel = True + + def state_dict(self, *args, **kwargs): + r""" + Returns the state dictionary of the model. We add the state dictionary of the value head + to the state dictionary of the wrapped model by prepending the key with `v_head.`. + """ + if not self.is_peft_model: + pretrained_model_state_dict = self.pretrained_model.state_dict( + *args, **kwargs + ) + else: + # if it is a peft model, only save the v_head + pretrained_model_state_dict = {} + + v_head_state_dict = self.v_head.state_dict(*args, **kwargs) + for k, v in v_head_state_dict.items(): + pretrained_model_state_dict[f"v_head.{k}"] = v + return pretrained_model_state_dict + + def push_to_hub(self, *args, **kwargs): + setattr(self.pretrained_model, "v_head", self.v_head) + + return self.pretrained_model.push_to_hub(*args, **kwargs) + + def _init_weights(self, **kwargs): + r""" + We initialize the weights of the value head. + """ + initializer_range = kwargs.pop("v_head_initializer_range", 0.2) + # random init by default + init_strategy = kwargs.pop("v_head_init_strategy", None) + if init_strategy is None: + # do nothing + pass + elif init_strategy == "normal": + self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range) + self.v_head.summary.bias.data.zero_() + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + **kwargs, + ): + kwargs["past_key_values"] = past_key_values + if ( + self.is_peft_model + and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING" + ): + kwargs.pop("past_key_values") + + base_model_output = self.pretrained_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, # We force the model to output hidden states + **kwargs, + ) + + last_hidden_state = base_model_output.decoder_hidden_states[-1] + lm_logits = base_model_output.logits + loss = base_model_output.loss + + value = self.v_head(last_hidden_state).squeeze(-1) + + # force upcast in fp32 if logits are in half-precision + if lm_logits.dtype != torch.float32: + lm_logits = lm_logits.float() + + return (lm_logits, loss, value) + + def generate(self, *args, **kwargs): + r""" + We call `generate` on the wrapped model. + """ + return self.pretrained_model.generate(*args, **kwargs) diff --git a/intel_extension_for_transformers/transformers/ppo_config.py b/intel_extension_for_transformers/transformers/ppo_config.py new file mode 100644 index 00000000000..a0c6198ef6a --- /dev/null +++ b/intel_extension_for_transformers/transformers/ppo_config.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +import sys +import warnings +from dataclasses import dataclass, field +from typing import Literal, Optional + +import numpy as np +import tyro # pylint: disable=E0611, E0401 +from typing_extensions import Annotated +from .ppo_core import flatten_dict + + +def exact_div(a, b, a_str, b_str, custom_error_message=""): + q = a // b + if a != q * b: + raise ValueError( + f"{custom_error_message}, {a_str}={a}, {b_str}={b}, inexact division: {a} / {b} = {a / b}" + ) + return q + + +JSONDict = Annotated[ + Optional[dict], tyro.conf.arg(metavar="JSON", constructor=json.loads) +] + + +@dataclass +class PPOConfig: + """ + Configuration class for PPOTrainer + """ + + # common parameters + exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")] + """the name of this experiment (by default is the file name without the extension name)""" + seed: int = 0 + """Seed value for random generations""" + log_with: Optional[Literal["wandb", "tensorboard"]] = None + """Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking + for more details""" + task_name: Optional[str] = None + """Name of task to use - used only for tracking purposes""" + model_name: Optional[str] = None + """Name of model to use - used only for tracking purposes""" + query_dataset: Optional[str] = None + """Name of dataset to query - used only for tracking purposes""" + reward_model: Optional[str] = None + """The reward model to use - used only for tracking purposes""" + remove_unused_columns: bool = True + """Remove unused columns from the dataset if `datasets.Dataset` is used""" + tracker_kwargs: JSONDict = field(default_factory=dict) + """Keyword arguments for the tracker (e.g. python ppo.py --ppo_config.tracker_kwargs='{"wandb": {"entity": + "my_wandb_entity", "name": "my_exp_name"}}'""" + accelerator_kwargs: JSONDict = field(default_factory=dict) + """Keyword arguments for the accelerator""" + project_kwargs: JSONDict = field(default_factory=dict) + """Keyword arguments for the accelerator project config (e.g. `logging_dir`)""" + tracker_project_name: str = "trl" + """Name of project to use for tracking""" + push_to_hub_if_best_kwargs: JSONDict = field(default_factory=dict) + """Keyword arguments for pushing model to the hub during training (e.g. repo_id)""" + + # hyperparameters + steps: int = 20000 + """Number of training steps""" + learning_rate: float = 1e-5 + """Adam learning rate""" + adap_kl_ctrl: bool = True + """Use adaptive KL control, otherwise linear""" + init_kl_coef: Optional[float] = 0.2 + """Initial KL penalty coefficient (used for adaptive and linear control)""" + kl_penalty: Literal["kl", "abs", "mse", "full"] = "kl" + """kl penalty options: 'kl': model_logp - ref_logp, 'abs': abs(kl), 'mse': mean squared error mse(kl) and 'full': + the actual kl for all tokens in the distribution""" + target: Optional[float] = 6 + """Target KL value for adaptive KL control""" + horizon: Optional[float] = 10000 + """Horizon for adaptive KL control""" + gamma: float = 1 + """Gamma parameter for advantage calculation""" + lam: float = 0.95 + """Lambda parameter for advantage calculation""" + cliprange: float = 0.2 + """Range for clipping in PPO policy gradient loss""" + cliprange_value: float = 0.2 + """Range for clipping values in loss calculation""" + vf_coef: float = 0.1 + """Scaling factor for value loss""" + batch_size: int = 256 + """Number of samples per optimisation step""" + mini_batch_size: int = 1 + """Number of samples optimized in each mini batch""" + gradient_accumulation_steps: int = 1 + """The number of gradient accumulation steps""" + world_size: tyro.conf.Suppress[int] = None + """The world size for distributed training""" + ppo_epochs: int = 4 + """Number of optimisation epochs per batch of samples""" + max_grad_norm: Optional[float] = None + """Maximum gradient norm for gradient clipping""" + optimize_device_cache: Optional[bool] = False + """Optimize device cache for slightly more memory-efficient training""" + early_stopping: bool = False + """Whether to stop the PPO optimization loop early is the KL too high""" + target_kl: float = 1 + """Stop early if we exceed this value by over 50%""" + compare_steps: int = 1 + """Number of steps between comparison of the current reward with the best seen so far""" + ratio_threshold: float = 10.0 + """Skip mini-batches with high PPO ratios that can cause loss spikes""" + use_score_scaling: bool = False + """Use score scaling""" + use_score_norm: bool = False + """Use score normalization. Only applicable if use_score_scaling is True""" + score_clip: Optional[float] = None + """Score clipping""" + whiten_rewards: bool = False + """Whiten the rewards before compute advantages""" + + # computed hyperparameters at runtime; we use `tyro.conf.Suppress` to hide them from the help text + is_encoder_decoder: Optional[tyro.conf.Suppress[bool]] = None + """TO BE FILLED In RUNTIME: Whether the model is an encoder-decoder model""" + is_peft_model: Optional[tyro.conf.Suppress[bool]] = None + """TO BE FILLED In RUNTIME: Whether the model is a PEFT model""" + backward_batch_size: tyro.conf.Suppress[int] = None + """TO BE FILLED In RUNTIME: Number of samples optimized in an `optimizer.step()` call""" + global_backward_batch_size: tyro.conf.Suppress[int] = None + """TO BE FILLED In RUNTIME: the effective `backward_batch_size` across all processes""" + global_batch_size: tyro.conf.Suppress[int] = None + """TO BE FILLED In RUNTIME: the effective `batch_size` across all processes""" + + use_habana: bool = False + """Use habana. Only applicable if use_habana is True""" + pad_for_acceleration: bool = False + """Use pad_for_acceleration. Only applicable if pad_for_acceleration is True""" + pad_max_len: int = 0 + """Use pad_for_acceleration. Only applicable if pad_for_acceleration is True""" + pad_max_input_len: int = 0 + + def __post_init__(self): + self.backward_batch_size = ( + self.mini_batch_size * self.gradient_accumulation_steps + ) + exact_div( + self.batch_size, + self.backward_batch_size, + "`batch_size`", + "`mini_batch_size * gradient_accumulation_steps`", + "`batch_size` must be a multiple of `mini_batch_size * gradient_accumulation_steps`", + ) + self.total_ppo_epochs = int(np.ceil(self.steps / self.batch_size)) + + if self.pad_for_acceleration: + if self.pad_max_input_len == 0: + raise AssertionError("pad_max_input_len ({self.pad_max_input_len}) must be set for pad input ") + if self.pad_max_input_len >= self.pad_max_len: + raise AssertionError("pad_max_input_len ({self.pad_max_input_len}) must be smaller " + " then pad_max_len ({self.pad_max_len})") + + if self.use_habana: + from optimum.habana.transformers.modeling_utils import ( # pylint: disable=E0611, E0401 + adapt_transformers_to_gaudi, + ) + + adapt_transformers_to_gaudi() + + assert self.kl_penalty in ["kl", "abs", "mse", "full"] + + def to_dict(self): + output_dict = {} + for key, value in self.__dict__.items(): + output_dict[key] = value + return flatten_dict(output_dict) diff --git a/intel_extension_for_transformers/transformers/ppo_core.py b/intel_extension_for_transformers/transformers/ppo_core.py new file mode 100644 index 00000000000..8e24f3ef43c --- /dev/null +++ b/intel_extension_for_transformers/transformers/ppo_core.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import warnings +from contextlib import contextmanager +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence +from transformers import top_k_top_p_filtering +import importlib +from transformers.utils.import_utils import is_optimum_available + +try: + from collections.abc import Mapping +except ImportError: + from collections import Mapping + + +WANDB_PADDING = -1 + + +def flatten_dict(nested, sep="/"): + """Flatten dictionary and concatenate nested keys with separator.""" + + def rec(nest, prefix, into): + for k, v in nest.items(): + if sep in k: + raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'") + if isinstance(v, Mapping): + rec(v, prefix + k + sep, into) + else: + into[prefix + k] = v + + flat = {} + rec(nested, "", flat) + return flat + + +def convert_to_scalar(stats): + """ + Converts the stats from a flattened dict to single scalar dicts + """ + tensorboard_stats = {} + for k, v in stats.items(): + # for tensorboard compatibility - arrays and tensors are ignored with tensorboard + # therefore we convert single element tensors to scalars + if (isinstance(v, torch.Tensor) or isinstance(v, np.ndarray)) and ( + len(v.shape) == 0 or (len(v.shape) == 1 and v.shape[0] == 1) + ): + v = v.item() + tensorboard_stats[k] = v + return tensorboard_stats + + +def stack_dicts(stats_dicts): + """Stack the values of a dict.""" + results = dict() + for k in stats_dicts[0]: + stats_list = [torch.flatten(d[k]) for d in stats_dicts] + results[k] = pad_sequence( + stats_list, batch_first=True, padding_value=WANDB_PADDING + ) + return results + + +def logprobs_from_logits(logits, labels, gather=True): + """ + See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591 + """ + logp = F.log_softmax(logits, dim=2) + + if not gather: + return logp + logpy = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1) + return logpy + + +def masked_mean(values, mask, axis=None): + """Compute mean of tensor with a masked values.""" + if axis is not None: + return (values * mask).sum(axis=axis) / mask.sum(axis=axis) + else: + return (values * mask).sum() / mask.sum() + + +def masked_var(values, mask, unbiased=True): + """Compute variance of tensor with masked values.""" + mean = masked_mean(values, mask) + centered_values = values - mean + variance = masked_mean(centered_values**2, mask) + if unbiased: + mask_sum = mask.sum() + if mask_sum == 0: + raise ValueError( + "The sum of the mask is zero, which can happen when `mini_batch_size=1`;" + "try increase the `mini_batch_size` or `gradient_accumulation_steps`" + ) + # note that if mask_sum == 1, then there is a division by zero issue + # to avoid it you just need to use a larger minibatch_size + bessel_correction = mask_sum / (mask_sum - 1) + variance = variance * bessel_correction + return variance + + +def masked_whiten(values, mask, shift_mean=True): + """Whiten values with masked values.""" + mean, var = masked_mean(values, mask), masked_var(values, mask) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +def clip_by_value(x, tensor_min, tensor_max): + """ + Tensor extenstion to torch.clamp + https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713 + """ + clipped = torch.max(torch.min(x, tensor_max), tensor_min) + return clipped + + +def entropy_from_logits(logits): + """Calculate entropy from logits.""" + pd = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, axis=-1) - torch.sum(pd * logits, axis=-1) + return entropy + +def stats_to_np(stats_dict): + """Cast all torch.tensors in dict to numpy arrays.""" + new_dict = dict() + for k, v in stats_dict.items(): + if isinstance(v, torch.Tensor): + new_dict[k] = v.detach().cpu() + if new_dict[k].dtype == torch.bfloat16: + new_dict[k] = new_dict[k].float() + new_dict[k] = new_dict[k].numpy() + else: + new_dict[k] = v + if np.isscalar(new_dict[k]): + new_dict[k] = float(new_dict[k]) + return new_dict + + +def set_seed(seed: int): + """ + Helper function for reproducible behavior to set the seed in `random`, `numpy`, and `torch`. + + Args: + seed (`int`): The seed to set. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + if is_optimum_available() and importlib.util.find_spec("optimum.habana") != None: # pragma: no cover + from habana_frameworks.torch.hpu import random as hpu_random # pylint: disable=E0611, E0401 + + hpu_random.manual_seed_all(seed) + + +class LengthSampler: + """ + Samples a length + """ + + def __init__(self, min_value, max_value): + self.values = list(range(min_value, max_value)) + + def __call__(self): + return np.random.choice(self.values) + + +class PPODecorators(object): + optimize_device_cache = False + + @classmethod + @contextmanager + def empty_device_cache(cls): + yield + if cls.optimize_device_cache and torch.cuda.is_available(): + gc.collect() + torch.cuda.empty_cache() + gc.collect() diff --git a/intel_extension_for_transformers/transformers/ppo_trainer.py b/intel_extension_for_transformers/transformers/ppo_trainer.py new file mode 100644 index 00000000000..413e3f0e8cb --- /dev/null +++ b/intel_extension_for_transformers/transformers/ppo_trainer.py @@ -0,0 +1,1861 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +import os +import time +import typing +import warnings +from contextlib import nullcontext +from typing import Callable, List, Optional, Union + +import datasets +import numpy as np +import torch +import torch.nn.functional as F +from accelerate.utils import ProjectConfiguration, is_deepspeed_available +from datasets import Dataset +from huggingface_hub import whoami +from packaging import version +from torch.optim import Adam +from transformers import ( + DataCollatorForLanguageModeling, + PreTrainedTokenizer, + PreTrainedTokenizerBase, + PreTrainedTokenizerFast, +) + +from .ppo_core import ( + WANDB_PADDING, + PPODecorators, + clip_by_value, + convert_to_scalar, + entropy_from_logits, + flatten_dict, + logprobs_from_logits, + masked_mean, + masked_var, + masked_whiten, + set_seed, + stack_dicts, + stats_to_np, +) +from .modeling.trl_models import ( + SUPPORTED_ARCHITECTURES, + PreTrainedModelWrapper, + create_reference_model, +) +from .ppo_config import PPOConfig +from huggingface_hub import PyTorchModelHubMixin +from typing import List, Optional, Tuple, Union +import importlib, sys + +import logging + +logger = logging.getLogger(__name__) + + +@torch.no_grad() +def get_global_statistics( + accelerator, xs: torch.Tensor, mask=None, device="cpu" +) -> Tuple[float, float, int]: # pragma: no cover + """ + Computes element-wise mean and variance of the tensor across processes. Reference: + https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L57C1-L73C75 + """ + xs = xs.to(accelerator.device) + sum_and_count = torch.tensor( + [xs.sum(), (xs.numel() if mask is None else mask.sum())], device=xs.device + ) + sum_and_count = accelerator.reduce(sum_and_count) + global_sum, count = sum_and_count + global_mean = global_sum / count + + sum_var = torch.sum(((xs - global_mean) ** 2).mul(1 if mask is None else mask)) + sum_var = accelerator.reduce(sum_var) + global_var = sum_var / count + + return global_mean.to(device), global_var.to(device), count.to(device) + + +class RunningMoments: + def __init__(self, accelerator): + """ + Calculates the running mean and standard deviation of a data stream. Reference: + https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L75 + """ + self.mean = 0 + self.std = 1 + self.var = 1 + self.count = 1e-24 + self.accelerator = accelerator + + @torch.no_grad() + def update(self, xs: torch.Tensor) -> Tuple[float, float]: + """ + Updates running moments from batch's moments computed across ranks + """ + if self.accelerator.use_distributed: # pragma: no cover + xs_mean, xs_var, xs_count = get_global_statistics(self.accelerator, xs) + else: + xs_count = xs.numel() + xs_var, xs_mean = torch.var_mean(xs, unbiased=False) + xs_mean, xs_var = xs_mean.float(), xs_var.float() + + delta = xs_mean - self.mean + tot_count = self.count + xs_count + + new_sum = xs_var * xs_count + # correct old_sum deviation accounting for the new mean + old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count + tot_sum = old_sum + new_sum + + self.mean += delta * xs_count / tot_count + self.var = tot_sum / tot_count + self.std = (self.var * tot_count / (tot_count - 1)).float().sqrt() + self.count = tot_count + + return ( + xs_mean.item(), + (xs_var * xs_count / (xs_count - 1)).float().sqrt().item(), + ) + + +class AdaptiveKLController: + """ + Adaptive KL controller described in the paper: + https://arxiv.org/pdf/1909.08593.pdf + """ + + def __init__(self, init_kl_coef, target, horizon): + self.value = init_kl_coef + self.target = target + self.horizon = horizon + + def update(self, current, n_steps): + target = self.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.horizon + self.value *= mult + + +class FixedKLController: + """Fixed KL controller.""" + + def __init__(self, kl_coef): + self.value = kl_coef + + def update(self, current, n_steps): + pass + + +def is_torch_greater_2_0() -> bool: + if sys.version_info < (3, 8): + _is_python_greater_3_8 = False + else: + _is_python_greater_3_8 = True + if _is_python_greater_3_8: + from importlib.metadata import version + + torch_version = version("torch") + else: + import pkg_resources + + torch_version = pkg_resources.get_distribution("torch").version + return torch_version >= "2.0" + + +if is_deepspeed_available(): # pragma: no cover + import deepspeed # pylint: disable=E0611, E0401 + +MODEL_CARD_TEMPLATE = """--- +license: apache-2.0 +tags: +- trl +- transformers +- reinforcement-learning +--- + +# {model_name} + +This is a [TRL language model](https://github.com/huggingface/trl) that has been fine-tuned with reinforcement learning to + guide the model outputs according to a value, function, or human feedback. The model can be used for text generation. + +## Usage + +To use this model for inference, first install the TRL library: + +```bash +python -m pip install trl +``` + +You can then generate text as follows: + +```python +from transformers import pipeline + +generator = pipeline("text-generation", model="{model_id}") +outputs = generator("Hello, my llama is cute") +``` + +If you want to use the model for training or to obtain the outputs from the value head, load the model as follows: + +```python +from transformers import AutoTokenizer +from trl import AutoModelForCausalLMWithValueHead + +tokenizer = AutoTokenizer.from_pretrained("{model_id}") +model = AutoModelForCausalLMWithValueHead.from_pretrained("{model_id}") + +inputs = tokenizer("Hello, my llama is cute", return_tensors="pt") +outputs = model(**inputs, labels=inputs["input_ids"]) +``` +""" + + +def disable_dropout_in_model(model: torch.nn.Module) -> None: # pragma: no cover + for module in model.modules(): + if isinstance(module, torch.nn.Dropout): + module.p = 0 + + +class PPOTrainer(PyTorchModelHubMixin): + """ + Initialize PPOTrainer, refer: https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py + The PPOTrainer uses Proximal Policy Optimization to optimise language models. + Note, this trainer is heavily inspired by the original OpenAI learning to summarize work here: + https://github.com/openai/summarize-from-feedback + + Attributes: + **config** (`PPOConfig`) -- Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more + details. + **model** (`PreTrainedModelWrapper`) -- Model to be optimized, Hugging Face transformer model with a value head. + Check the documentation of `PreTrainedModelWrapper` for more details. + **ref_model** (`PreTrainedModelWrapper`, *optional*) -- Reference model to be used for KL penalty, Hugging Face + transformer model with a casual language modelling head. Check the documentation of `PreTrainedModelWrapper` + for more details. If no reference model is provided, the trainer will create a reference model with the same + architecture as the model to be optimized with shared layers. + **tokenizer** (`PreTrainedTokenizerBase`) -- Tokenizer to be used for encoding the + data. Check the documentation of `transformers.PreTrainedTokenizer` and + `transformers.PreTrainedTokenizerFast` for more details. + **dataset** (Union[`torch.utils.data.Dataset`, `datasets.Dataset`], *optional*) -- PyTorch dataset or Hugging + Face dataset. This is used to create a PyTorch dataloader. If no dataset is provided, the dataloader must be + created outside the trainer users needs to design their own dataloader and make sure the batch + size that is used is the same as the one specified in the configuration object. + **optimizer** (`torch.optim.Optimizer`, *optional*) -- Optimizer to be used for training. If no optimizer is + provided, the trainer will create an Adam optimizer with the learning rate specified in the configuration + object. + **data_collator** (DataCollatorForLanguageModeling, *optional*) -- Data collator to be used for training and + passed along the dataloader + **num_shared_layers** (int, *optional*) -- Number of layers to be shared between the model and the reference + model, if no reference model is passed. If no number is provided, all the layers will be shared. + **lr_scheduler** (`torch.optim.lr_scheduler`, *optional*) -- Learning rate scheduler to be used for training. + """ + + def __init__( + self, + config: PPOConfig = None, + model: PreTrainedModelWrapper = None, + ref_model: Optional[PreTrainedModelWrapper] = None, + tokenizer: PreTrainedTokenizerBase = None, + dataset: Optional[Union[torch.utils.data.Dataset, Dataset]] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + data_collator: Optional[typing.Callable] = None, + num_shared_layers: Optional[int] = None, + lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + ): + """ + Initialize PPOTrainer. + + Args: + config (`PPOConfig`): + Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more details. + model (`PreTrainedModelWrapper`): + Hugging Face transformer model with a value head. + ref_model (`PreTrainedModelWrapper`): + Hugging Face transformer model with a casual language modelling head. Used for KL penalty + tokenizer (`transformers.PreTrainedTokenizerBase`): + Hugging Face tokenizer + dataset (Optional[Union[`torch.utils.data.Dataset`, `datasets.Dataset`]]): + PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset + will be preprocessed by removing the columns that are not used by the model. If none is passed, + a warning will be raised in a multi-GPU setting. + optimizer (Optional[`torch.optim.Optimizer`]): + Optimizer used for training. If `None`, the `Adam` is used as default. + data_collator (Optional[function]): + Data collator function. + num_shared_layers (Optional[int]): + Number of shared layers between the model and the reference model. If `None`, all layers are shared. + used only if `ref_model` is `None`. + lr_scheduler (Optional[`torch.optim.lr_scheduler`]): + Learning rate scheduler used for training. + """ + self.config = config + + # initial seed for reproducible experiments + set_seed(config.seed) + + # Step 0: check positional arguments validity + if not isinstance(config, PPOConfig): + raise ValueError(f"config must be a PPOConfig, got {type(config)}") + if not isinstance(tokenizer, (PreTrainedTokenizerBase)): + raise ValueError( + f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast" + " got {type(tokenizer)}" + ) + if not isinstance(model, (SUPPORTED_ARCHITECTURES)): + raise ValueError( + f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: " + "{SUPPORTED_ARCHITECTURES}" + ) + # Step 1: Initialize Accelerator + if config.use_habana: # pragma: no cover + from optimum.habana.accelerate import GaudiAccelerator as Accelerator # pylint: disable=E0611, E0401 + else: + from accelerate import Accelerator + self.accelerator = Accelerator( + log_with=config.log_with, + gradient_accumulation_steps=config.gradient_accumulation_steps, + project_config=ProjectConfiguration(**config.project_kwargs), + **config.accelerator_kwargs, + ) + + # Step 1.1 Runtime variables filled by the accelerator + config.world_size = self.accelerator.num_processes + config.global_backward_batch_size = ( + config.backward_batch_size * config.world_size + ) + config.global_batch_size = config.batch_size * config.world_size + self.model = model.to(self.accelerator.device.type) + self.model_params = filter(lambda p: p.requires_grad, self.model.parameters()) + self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder") + self.is_peft_model = getattr(self.model, "is_peft_model", False) + config.is_encoder_decoder = self.is_encoder_decoder + config.is_peft_model = self.is_peft_model + + is_using_tensorboard = ( + config.log_with is not None and config.log_with == "tensorboard" + ) + self.accelerator.init_trackers( + config.tracker_project_name, + config=dict(trl_ppo_trainer_config=config.to_dict()) + if not is_using_tensorboard + else config.to_dict(), + init_kwargs=config.tracker_kwargs, + ) + self.is_using_text_environment = getattr(config, "use_text_environment", False) + if isinstance(ref_model, SUPPORTED_ARCHITECTURES): + self.ref_model = ref_model.to(self.accelerator.device.type) + if num_shared_layers is not None: + warnings.warn( + "num_shared_layers is ignored when ref_model is provided. Two different models are used for the " + "model and the reference model and no layers are shared.", + UserWarning, + ) + elif ref_model is None and not self.is_peft_model: + self.ref_model = create_reference_model( + self.model, num_shared_layers=num_shared_layers + ) + elif self.is_peft_model: + self.ref_model = None + else: + raise ValueError( + f"ref_model must be a PreTrainedModelWrapper or `None`, got {type(ref_model)} - supported " + f"architectures are: {SUPPORTED_ARCHITECTURES} " + ) + self.optional_peft_ctx = ( + self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter + if self.is_peft_model + else nullcontext + ) + + if not ( + isinstance(tokenizer, PreTrainedTokenizer) + or isinstance(tokenizer, PreTrainedTokenizerFast) + ): + raise ValueError( + "tokenizer must be a transformers.PreTrainedTokenizer or transformers.PreTrainedTokenizerFast" + ) + self.tokenizer = tokenizer + + if dataset is not None and not ( + isinstance(dataset, torch.utils.data.Dataset) + or isinstance(dataset, Dataset) + ): + raise ValueError( + "dataset must be a torch.utils.data.Dataset or datasets.Dataset" + ) + elif dataset is None: + warnings.warn( + "No dataset is provided. Make sure to set config.batch_size to the correct value before training.", + UserWarning, + ) + self.dataset = dataset + self._signature_columns = None + if self.dataset is not None: + self.dataloader = self.prepare_dataloader(self.dataset, data_collator) + elif self.dataset is None and self.accelerator.num_processes > 1: + warnings.warn( + "No dataset is provided. In a multi-GPU setting, this will lead to an error. You should" + " prepare your dataloader yourself with `dataloader = ppo_trainer.accelerator.prepare(dataloader)`" + " and using `torch.utils.data.DataLoader`, or pass a dataset to the `PPOTrainer`. Please " + " refer to the documentation for more details.", + UserWarning, + ) + self.dataloader = None + else: + self.dataloader = None + + # Step 3: Initialize optimizer and data collator + self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) + if optimizer is None: + self.optimizer = Adam( + filter(lambda p: p.requires_grad, self.model.parameters()), + lr=self.config.learning_rate, + ) + else: + self.optimizer = optimizer + + self.lr_scheduler = lr_scheduler + if self.lr_scheduler is not None: + lr_scheduler_class = ( + torch.optim.lr_scheduler._LRScheduler + if not is_torch_greater_2_0() + else torch.optim.lr_scheduler.LRScheduler + ) + if not isinstance(self.lr_scheduler, lr_scheduler_class): + raise ValueError( + "lr_scheduler must be a torch.optim.lr_scheduler._LRScheduler or torch.optim.lr_scheduler." + " LRScheduler (for torch >= 2.0)" + ) + + if self.config.adap_kl_ctrl: + self.kl_ctl = AdaptiveKLController( + self.config.init_kl_coef, self.config.target, self.config.horizon + ) + else: + self.kl_ctl = FixedKLController(self.config.init_kl_coef) + + if self.accelerator.distributed_type == "MULTI_HPU": # pragma: no cover + from accelerate.utils import DistributedDataParallelKwargs + + kwargs = {} + kwargs["find_unused_parameters"] = True + kwargs["gradient_as_bucket_view"] = True + self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs) + + # Safety checkers for DS integration + is_deepspeed_used = ( + self.accelerator.distributed_type == "DEEPSPEED" + and hasattr(self.accelerator.state, "deepspeed_plugin") + ) + if self.accelerator.device.type == "hpu": # pragma: no cover + # WA for Gaudi + disable_dropout_in_model(self.model) + + ( + self.model, + self.optimizer, + self.data_collator, + self.dataloader, + self.lr_scheduler, + ) = self.accelerator.prepare( + self.model, + self.optimizer, + self.data_collator, + self.dataloader, + self.lr_scheduler, + ) + if is_deepspeed_used: # pragma: no cover + # Quantized models are already set on the correct device + if not self.is_peft_model and not ( + getattr(self.ref_model.pretrained_model, "is_loaded_in_8bit", False) + or getattr(self.ref_model.pretrained_model, "is_loaded_in_4bit", False) + ): + self.ref_model = self._prepare_deepspeed(self.ref_model) + else: + self.ref_model = self.accelerator.prepare(self.ref_model) + + # In a distributed setup, only logging needs to be performed on the main process + # check: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html + # or: https://discuss.pytorch.org/t/use-distributed-data-parallel-correctly/82500/11 + self.is_distributed = self.accelerator.distributed_type in [ + "MULTI_GPU", + "MULTI_HPU", + ] + + # init the current step + self.current_step = 0 + + # init variables for pushing model to hub + if config.push_to_hub_if_best_kwargs: # pragma: no cover + if "repo_id" not in config.push_to_hub_if_best_kwargs: + raise ValueError( + "You have to specify repo_id in order to push the model to the hub!" + ) + self.push_to_hub_kwargs = config.push_to_hub_if_best_kwargs + self.compare_step = 0 + self.highest_reward = torch.tensor(-float("inf")) + + # post process for PP + if not getattr(self.model, "is_sequential_parallel", False): + self.current_device = self.accelerator.device + else: + self.current_device = torch.device("cuda:0") + + PPODecorators.optimize_device_cache = self.config.optimize_device_cache + + self.running = RunningMoments(self.accelerator) + if config.use_habana: # pragma: no cover + import habana_frameworks.torch.core as htcore # pylint: disable=E0611, E0401 + + self.htcore = htcore + + def _filter_kwargs(self, kwargs, target_func): + """ + filter the keyword arguments that are supported by the target function. + + Args: + kwargs (dict): + Keyword arguments + target_func (function): + Target function + """ + return { + k: v + for k, v in kwargs.items() + if k in inspect.signature(target_func).parameters.keys() + } + + def prepare_dataloader( + self, dataset: Union[torch.utils.data.Dataset, Dataset], data_collator=None + ): + """ + Prepare the dataloader for training. + + Args: + dataset (Union[`torch.utils.data.Dataset`, `datasets.Dataset`]): + PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset + will be preprocessed by removing the columns that are not used by the model. + data_collator (Optional[function]): + Data collator function. + + Returns: + `torch.utils.data.DataLoader`: PyTorch dataloader + """ + if isinstance(dataset, Dataset): # pragma: no cover + dataset = self._remove_unused_columns(dataset) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=self.config.batch_size, + collate_fn=data_collator, + shuffle=True, + drop_last=True, + ) + return dataloader + + # Adapted from transformers.Trainer._set_signature_columns_if_needed + def _set_signature_columns_if_needed(self): # pragma: no cover + if self._signature_columns is None: + # Inspect model forward signature to keep only the arguments it accepts. + signature = inspect.signature(self.model.forward) + self._signature_columns = list(signature.parameters.keys()) + # label => sentiment | we need query and response for logging purpose + self._signature_columns += ["label", "query", "response"] + + # Adapted from transformers.Trainer._remove_unused_columns + def _remove_unused_columns(self, dataset: "Dataset"): # pragma: no cover + if not self.config.remove_unused_columns: + return dataset + self._set_signature_columns_if_needed() + signature_columns = self._signature_columns + + ignored_columns = list(set(dataset.column_names) - set(signature_columns)) + + columns = [k for k in signature_columns if k in dataset.column_names] + + if version.parse(datasets.__version__) < version.parse("1.4.0"): + dataset.set_format( + type=dataset.format["type"], + columns=columns, + format_kwargs=dataset.format["format_kwargs"], + ) + return dataset + else: + return dataset.remove_columns(ignored_columns) + + def generate( + self, + query_tensor: Union[torch.Tensor, List[torch.Tensor]], + length_sampler: Callable = None, + batch_size: int = 4, + return_prompt: bool = True, + generate_ref_response: bool = False, + **generation_kwargs, + ): + """ + Generate response with the model given the query tensor. + call the `generate` method of the model. + + Args: + query_tensor (`torch.LongTensor`): + A tensor of shape (`seq_len`) containing query tokens or a list of tensors of shape (`seq_len`). + generation_kwargs (dict[str, Any]): + Keyword arguments for generation. + length_sampler (`Callable`, *optional*): + Callable that returns the number of newly generated tokens. + batch_size (`int`, *optional): + Batch size used for generation, defaults to `4`. + return_prompt (`bool`, *optional*): + If set to `False` the prompt is not returned but only the newly generated tokens, defaults to `True`. + generate_ref_response (`bool`, *optional*): + If set to `True` the reference response is also generated, defaults to `False`. + + Returns: + `torch.LongTensor`: A tensor of shape (`batch_size`, `gen_len`) containing response tokens. + """ + if generate_ref_response: + ref_model = self.ref_model if self.ref_model is not None else self.model + if isinstance(query_tensor, List): + if self.config.use_habana: # pragma: no cover + self.wrap_generation_for_hpu_graph_mode(self.model) + response = self._generate_batched( + self.model, + query_tensor, + length_sampler=length_sampler, + batch_size=batch_size, + return_prompt=return_prompt, + **generation_kwargs, + ) + if generate_ref_response: + with self.optional_peft_ctx(): + if self.config.use_habana: # pragma: no cover + self.wrap_generation_for_hpu_graph_mode(ref_model) + ref_response = self._generate_batched( + ref_model, + query_tensor, + length_sampler=length_sampler, + batch_size=batch_size, + return_prompt=return_prompt, + **generation_kwargs, + ) + + else: + if len(query_tensor.shape) == 2: + raise ValueError( + "query_tensor must be a tensor of shape (`seq_len`) or a list of tensors of shape (`seq_len`)" + ) + + if length_sampler is not None: + generation_kwargs["max_new_tokens"] = length_sampler() + if self.config.use_habana: # pragma: no cover + self.wrap_generation_for_hpu_graph_mode(self.model) + response = self.accelerator.unwrap_model(self.model).generate( + input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs + ) + if generate_ref_response: + with self.optional_peft_ctx(): + if self.config.use_habana: # pragma: no cover + self.wrap_generation_for_hpu_graph_mode(ref_model) + ref_response = ref_model.generate( + input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs + ) + + if not return_prompt and not self.is_encoder_decoder: + response = response[:, query_tensor.shape[0] :] + if generate_ref_response: + ref_response = ref_response[:, query_tensor.shape[0] :] + + if generate_ref_response: + return response, ref_response + return response + + def _generate_batched( + self, + model: PreTrainedModelWrapper, + query_tensors: List[torch.Tensor], + length_sampler: Callable = None, + batch_size: int = 4, + return_prompt: bool = True, + pad_to_multiple_of: int = None, + remove_padding: bool = True, + **generation_kwargs, + ): + outputs = [] + + padding_side_default = self.tokenizer.padding_side + if not self.is_encoder_decoder: + self.tokenizer.padding_side = "left" + + # in case we have fewer examples than bs + batch_size = min(len(query_tensors), batch_size) + + for i in range(0, len(query_tensors), batch_size): + if length_sampler is not None: + generation_kwargs["max_new_tokens"] = length_sampler() + + # prevent overflow if query tensors are not even multiple of bs + end_index = min(len(query_tensors), i + batch_size) + + batch = query_tensors[i:end_index] + batch_mask = [torch.ones_like(element) for element in batch] + inputs = {"input_ids": batch, "attention_mask": batch_mask} + if self.config.pad_for_acceleration and self.config.pad_max_input_len > 0: + padded_inputs = self.tokenizer.pad( + inputs, + padding="max_length", + max_length=self.config.pad_max_input_len, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors="pt", + ).to(self.current_device) + else: + padded_inputs = self.tokenizer.pad( + inputs, + padding=True, + max_length=None, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors="pt", + ).to(self.current_device) + + if self.config.use_habana: # pragma: no cover + generation_kwargs["ignore_eos"] = False + generation_kwargs["lazy_mode"] = True + generation_kwargs["hpu_graphs"] = True + + generations = self.accelerator.unwrap_model(model).generate( + **padded_inputs, **generation_kwargs + ) + + for generation, mask in zip(generations, padded_inputs["attention_mask"]): + if not self.is_encoder_decoder: + output = generation[(1 - mask).sum() :] # remove padding + else: + output = generation + + if not return_prompt and not self.is_encoder_decoder: + output = output[(mask).sum() :] # remove prompt + + if remove_padding and self.tokenizer.eos_token_id in output: + pad_mask = output == self.tokenizer.eos_token_id + pad_start = torch.nonzero(pad_mask, as_tuple=False)[0, 0].item() + output = output[: pad_start + 1] # keep the eos token at the end + outputs.append(output.clone()) + + self.tokenizer.padding_side = padding_side_default + return outputs + + def _step_safety_checker( + self, + batch_size: int, + queries: List[torch.LongTensor], + responses: List[torch.LongTensor], + scores: List[torch.FloatTensor], + masks: Optional[List[torch.LongTensor]] = None, + ): + """ + Check if the input data is valid for training. + + Args: + batch_size (int): + Batch size from the config file. + queries (List[`torch.LongTensor`]): + List of tensors containing the encoded queries of shape (`query_length`) + responses (List[`torch.LongTensor`]): + List of tensors containing the encoded responses of shape (`response_length`) + scores (List[`torch.FloatTensor`]): + List of tensors containing the scores. + masks (List[`torch.LongTensor`], *optional*): + list of optional tensors containing the masks of shape (`query_length` + `response_length`) + Returns: + `tuple`: The input processed data. + """ + for name, tensor_list in zip( + ["queries", "responses", "scores"], [queries, responses, scores] + ): + if not isinstance(tensor_list, list): + raise ValueError( + f"{name} must be a list of tensors - got {type(tensor_list)}" + ) + if not isinstance(tensor_list[0], torch.Tensor): + raise ValueError( + f"Elements in {name} must be tensors - got {type(tensor_list[0])}" + ) + if batch_size is not None and len(tensor_list) != batch_size: + raise ValueError( + f"Batch size ({batch_size}) does not match number of examples - but got {len(tensor_list)} for: " + "{name}" + ) + + # add queries, scores and responses on the correct device + queries = [tensor.to(self.current_device) for tensor in queries] + responses = [tensor.to(self.current_device) for tensor in responses] + scores = [tensor.to(self.current_device) for tensor in scores] + masks = ( + [tensor.to(self.current_device) for tensor in masks] + if masks is not None + else None + ) + + # squeeze scores if needed + for i, score in enumerate(scores): + if score.dim() > 1: + raise ValueError( + f"Scores must be 1-dimensional - got {score.dim()} for {score}" + ) + elif score.dim() == 1: + scores[i] = score.squeeze() + + return queries, responses, scores, masks + + @PPODecorators.empty_device_cache() + def step( + self, + queries: List[torch.LongTensor], + responses: List[torch.LongTensor], + scores: List[torch.FloatTensor], + response_masks: Optional[List[torch.LongTensor]] = None, + ): + """ + Run a PPO optimisation step given a list of queries, model responses, and rewards. + + Args: + queries (List[`torch.LongTensor`]): + List of tensors containing the encoded queries of shape (`query_length`) + responses (List[`torch.LongTensor`]): + List of tensors containing the encoded responses of shape (`response_length`) + scores (List[`torch.FloatTensor`]): + List of tensors containing the scores. + response_masks (List[`torch.FloatTensor`], *optional*)): + List of tensors containing masks of the response tokens. + + Returns: + `dict[str, Any]`: A summary of the training statistics + """ + bs = self.config.batch_size + + queries, responses, scores, response_masks = self._step_safety_checker( + bs, queries, responses, scores, response_masks + ) + scores = torch.tensor(scores, device=self.current_device) + if self.config.use_score_scaling: + # Score scaling + scores_mean, scores_std = self.running.update(scores) + tensor_to_kwargs = dict(dtype=scores.dtype, device=scores.device) + score_scaling_factor = ( + self.running.std.to(**tensor_to_kwargs) + torch.finfo(scores.dtype).eps + ) + if self.config.use_score_norm: + scores = ( + scores - self.running.mean.to(**tensor_to_kwargs) + ) / score_scaling_factor + else: + scores /= score_scaling_factor + + if self.config.score_clip is not None: + # Score clipping + scores_dtype = scores.dtype + scores = torch.clip( + scores.float(), -self.config.score_clip, self.config.score_clip + ).to(dtype=scores_dtype) + + # if we want to push best model to the hub + if hasattr(self, "highest_reward"): # pragma: no cover + if self.compare_step % self.config.compare_steps == 0: + curr_mean_reward = scores.mean() + # if the best reward ever seen + if curr_mean_reward > self.highest_reward: + self.highest_reward = curr_mean_reward + # push model to hub + self.push_to_hub(**self.push_to_hub_kwargs) + self.compare_step += 1 + + timing = dict() + t0 = time.time() + + t = time.time() + + model_inputs = self.prepare_model_inputs(queries, responses) + + if self.is_distributed and not self.config.pad_for_acceleration: # pragma: no cover + pad_first = self.tokenizer.padding_side == "left" + + model_inputs["input_ids"] = self.accelerator.pad_across_processes( + model_inputs["input_ids"], + dim=1, + pad_index=self.tokenizer.pad_token_id, + pad_first=pad_first, + ) + model_inputs["attention_mask"] = self.accelerator.pad_across_processes( + model_inputs["attention_mask"], dim=1, pad_index=0, pad_first=pad_first + ) + if self.is_encoder_decoder: + model_inputs[ + "decoder_input_ids" + ] = self.accelerator.pad_across_processes( + model_inputs["decoder_input_ids"], + dim=1, + pad_index=self.tokenizer.pad_token_id, + pad_first=pad_first, + ) + model_inputs[ + "decoder_attention_mask" + ] = self.accelerator.pad_across_processes( + model_inputs["decoder_attention_mask"], + dim=1, + pad_index=0, + pad_first=pad_first, + ) + + model_inputs_names = list(model_inputs.keys()) + + full_kl_penalty = self.config.kl_penalty == "full" + + with torch.no_grad(): + if self.config.use_habana: # pragma: no cover + self.unwrap_generation_for_hpu_graph_mode(self.model) + self.wrap_fw_for_hpu_graph_mode(self.model) + if self.ref_model is not None: + self.unwrap_generation_for_hpu_graph_mode(self.ref_model) + self.wrap_fw_for_hpu_graph_mode(self.ref_model) + + all_logprobs, logits_or_none, values, masks = self.batched_forward_pass( + self.model, + queries, + responses, + model_inputs, + response_masks=response_masks, + return_logits=full_kl_penalty, + ) + with self.optional_peft_ctx(): + ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass( + self.ref_model if self.ref_model is not None else self.model, + queries, + responses, + model_inputs, + return_logits=full_kl_penalty, + ) + + timing["time/ppo/forward_pass"] = time.time() - t + + with torch.no_grad(): + t = time.time() + if full_kl_penalty: + active_full_logprobs = logprobs_from_logits( + logits_or_none, None, gather=False + ) + ref_full_logprobs = logprobs_from_logits( + ref_logits_or_none, None, gather=False + ) + + rewards, non_score_reward = self.compute_rewards( + scores, active_full_logprobs, ref_full_logprobs, masks + ) + else: + rewards, non_score_reward = self.compute_rewards( + scores, all_logprobs, ref_logprobs, masks + ) + timing["time/ppo/compute_rewards"] = time.time() - t + + t = time.time() + values, advantages, returns = self.compute_advantages( + values, rewards, masks + ) + timing["time/ppo/compute_advantages"] = time.time() - t + + # upcast to float32 to avoid dataset issues + batch_dict = { + "queries": queries, + "responses": responses, + "logprobs": all_logprobs.to(torch.float32), + "values": values.to(torch.float32), + "masks": masks, + "advantages": advantages, + "returns": returns, + } + batch_dict.update(model_inputs) + + t = time.time() + all_stats = [] + early_stop = False + if self.config.use_habana: # pragma: no cover + self.unwrap_fw_for_hpu_graph_mode(self.model) + for _ in range(self.config.ppo_epochs): + if early_stop: + break + b_inds = np.random.permutation(bs) + for backward_batch_start in range(0, bs, self.config.backward_batch_size): + backward_batch_end = ( + backward_batch_start + self.config.backward_batch_size + ) + backward_batch_inds = b_inds[backward_batch_start:backward_batch_end] + + for mini_batch_start in range( + 0, self.config.backward_batch_size, self.config.mini_batch_size + ): + mini_batch_end = mini_batch_start + self.config.mini_batch_size + mini_batch_inds = backward_batch_inds[ + mini_batch_start:mini_batch_end + ] + mini_batch_dict = { + "logprobs": batch_dict["logprobs"][mini_batch_inds], + "values": batch_dict["values"][mini_batch_inds], + "masks": batch_dict["masks"][mini_batch_inds], + # hacks: the queries and responses are ragged. + "queries": [batch_dict["queries"][i] for i in mini_batch_inds], + "responses": [ + batch_dict["responses"][i] for i in mini_batch_inds + ], + "advantages": batch_dict["advantages"][mini_batch_inds], + "returns": batch_dict["returns"][mini_batch_inds], + } + for k in model_inputs_names: + mini_batch_dict[k] = batch_dict[k][mini_batch_inds] + with self.accelerator.accumulate(self.model): + model_inputs = { + k: mini_batch_dict[k] for k in model_inputs_names + } + logprobs, logits, vpreds, _ = self.batched_forward_pass( + self.model, + mini_batch_dict["queries"], + mini_batch_dict["responses"], + model_inputs, + return_logits=True, + ) + train_stats = self.train_minibatch( + mini_batch_dict["logprobs"], + mini_batch_dict["values"], + logprobs, + logits, + vpreds, + mini_batch_dict["masks"], + mini_batch_dict["advantages"], + mini_batch_dict["returns"], + ) + all_stats.append(train_stats) + + # typically, early stopping is done at the epoch level + if self.config.early_stopping: + policykl = train_stats["policy/policykl"] + early_stop = self._early_stop(policykl) + if early_stop: + break + + timing["time/ppo/optimize_step"] = time.time() - t + + t = time.time() + train_stats = stack_dicts(all_stats) + + # reshape advantages/ratios such that they are not averaged. + train_stats["policy/advantages"] = torch.flatten( + train_stats["policy/advantages"] + ).unsqueeze(0) + train_stats["policy/advantages"] = torch.nan_to_num( + train_stats["policy/advantages"], WANDB_PADDING + ) + train_stats["policy/ratio"] = torch.flatten( + train_stats["policy/ratio"] + ).unsqueeze(0) + + stats = self.record_step_stats( + scores=scores, + logprobs=all_logprobs, + ref_logprobs=ref_logprobs, + non_score_reward=non_score_reward, + train_stats=train_stats, + kl_coef=self.kl_ctl.value, + masks=masks, + queries=queries, + responses=responses, + ) + # Gather/Reduce stats from all processes + if self.is_distributed: # pragma: no cover + stats = self.gather_stats(stats) + stats = stats_to_np(stats) + timing["time/ppo/calc_stats"] = time.time() - t + stats["ppo/learning_rate"] = self.optimizer.param_groups[0]["lr"] + + # Update the KL control - multiply the batch_size by the number of processes + self.kl_ctl.update( + stats["objective/kl"], + self.config.batch_size * self.accelerator.num_processes, + ) + + # Log the total ppo time + timing["time/ppo/total"] = time.time() - t0 + stats.update(timing) + + # post-process stats for tensorboard and other loggers + if self.config.log_with != "wandb": + stats = convert_to_scalar(stats) + + if self.lr_scheduler is not None: + self.lr_scheduler.step() + + return stats + + def _early_stop(self, policykl): + r""" + Handles the early stopping logic. If the policy KL is greater than the target KL, then the gradient is zeroed + and the optimization step is skipped. + This also handles the multi-gpu case where the policy KL is averaged across all processes. + + Args: + policy_kl (torch.Tensor): + the policy KL + + Returns: + `bool`: whether to early stop or not + """ + early_stop = False + if not self.config.early_stopping: + return early_stop + + if not self.is_distributed and policykl > 1.5 * self.config.target_kl: + self.optimizer.zero_grad() + early_stop = True + elif self.is_distributed: # pragma: no cover + import torch.distributed as dist + + # Wait for all processes to finish + dist.barrier() + + # all gather the policykl + dist.all_reduce(policykl, dist.ReduceOp.SUM) + policykl /= self.accelerator.num_processes + + if policykl > 1.5 * self.config.target_kl: + self.optimizer.zero_grad() + early_stop = True + return early_stop + + def gather_stats(self, stats): # pragma: no cover + """ + Gather stats from all processes. Useful in the context of distributed training. + + Args: + stats (dict[str, Any]): + a dictionary of stats to be gathered. The stats should contain torch tensors. + + Returns: + `dict[str, Any]`: A dictionary of stats with the tensors gathered. + """ + import torch.distributed as dist + + # Wait for all processes to finish + dist.barrier() + + for k, v in stats.items(): + if isinstance(v, torch.Tensor): + dist.all_reduce(v.to(self.accelerator.device), dist.ReduceOp.SUM) + v /= self.accelerator.num_processes + stats[k] = v + return stats + + def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor): + if self.is_encoder_decoder: # pragma: no cover + input_data = self.data_collator( + [ + {"input_ids": q, "attention_mask": torch.ones_like(q)} + for q in queries + ] + ).to(self.current_device) + + decoder_inputs = self.data_collator( + [ + {"input_ids": r, "attention_mask": torch.ones_like(r)} + for r in responses + ] + ).to(self.current_device) + + input_data["decoder_input_ids"] = decoder_inputs["input_ids"] + input_data["decoder_attention_mask"] = decoder_inputs["attention_mask"] + if self.config.pad_for_acceleration: + input_data["input_ids"] = torch.nn.functional.pad( + input_data["input_ids"], + (0, self.config.pad_max_len - input_data["input_ids"].shape[1]), + value=self.tokenizer.pad_token_id, + ) + input_data["attention_mask"] = torch.nn.functional.pad( + input_data["attention_mask"], + ( + 0, + self.config.pad_max_len - input_data["attention_mask"].shape[1], + ), + value=0, + ) + input_data["decoder_input_ids"] = torch.nn.functional.pad( + input_data["decoder_input_ids"], + ( + 0, + self.config.pad_max_len + - input_data["decoder_input_ids"].shape[1], + ), + value=self.tokenizer.pad_token_id, + ) + input_data["decoder_attention_mask"] = torch.nn.functional.pad( + input_data["decoder_attention_mask"], + ( + 0, + self.config.pad_max_len + - input_data["decoder_attention_mask"].shape[1], + ), + value=0, + ) + else: + input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)] + input_data = self.data_collator( + [ + {"input_ids": ids, "attention_mask": torch.ones_like(ids)} + for ids in input_ids + ] + ).to(self.current_device) + + if self.config.pad_for_acceleration: + input_data["input_ids"] = torch.nn.functional.pad( + input_data["input_ids"], + (0, self.config.pad_max_len - input_data["input_ids"].shape[1]), + value=self.tokenizer.pad_token_id, + ) + input_data["attention_mask"] = torch.nn.functional.pad( + input_data["attention_mask"], + ( + 0, + self.config.pad_max_len - input_data["attention_mask"].shape[1], + ), + value=0, + ) + input_data.pop("labels", None) # we don't want to compute LM losses + return input_data + + @PPODecorators.empty_device_cache() + def batched_forward_pass( + self, + model: PreTrainedModelWrapper, + queries: torch.Tensor, + responses: torch.Tensor, + model_inputs: dict, + return_logits: bool = False, + response_masks: Optional[torch.Tensor] = None, + ): + """ + Calculate model outputs in multiple batches. + + Args: + queries (`torch.LongTensor`): + List of tensors containing the encoded queries, shape (`batch_size`, `query_length`) + responses (`torch.LongTensor`): + List of tensors containing the encoded responses, shape (`batch_size`, `response_length`) + return_logits (`bool`, *optional*, defaults to `False`): + Whether to return all_logits. Set to `False` if logits are not needed to reduce memory consumption. + Returns: + (tuple): + - all_logprobs (`torch.FloatTensor`): Log probabilities of the responses, + shape (`batch_size`, `response_length`) + - all_ref_logprobs (`torch.FloatTensor`): Log probabilities of the responses, + shape (`batch_size`, `response_length`) + - all_values (`torch.FloatTensor`): Values of the responses, shape (`batch_size`, `response_length`) + """ + bs = len(queries) + fbs = self.config.mini_batch_size + all_logprobs = [] + all_logits = [] + all_masks = [] + all_values = [] + + model.eval() + + for i in range(math.ceil(bs / fbs)): + input_kwargs = { + key: value[i * fbs : (i + 1) * fbs].clone() + for key, value in model_inputs.items() + } + query_batch = queries[i * fbs : (i + 1) * fbs] + response_batch = responses[i * fbs : (i + 1) * fbs] + if response_masks is not None: + response_masks_batch = response_masks[i * fbs : (i + 1) * fbs] + logits, _, values = model(**input_kwargs) + if self.is_encoder_decoder: # pragma: no cover + input_ids = input_kwargs["decoder_input_ids"] + attention_mask = input_kwargs["decoder_attention_mask"] + else: + input_ids = input_kwargs["input_ids"] + attention_mask = input_kwargs["attention_mask"] + + logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) + masks = torch.zeros_like(attention_mask) + masks[:, :-1] = attention_mask[:, 1:] + + for j in range(len(query_batch)): + if self.is_encoder_decoder: # pragma: no cover + # Decoder sentence starts always in the index 1 after padding in the Enc-Dec Models + start = 1 + end = attention_mask[j, :].sum() - 1 + else: + start = ( + len(query_batch[j]) - 1 + ) # logprobs starts from the second query token + if attention_mask[j, 0] == 0: # offset left padding + start += attention_mask[j, :].nonzero()[0] + end = start + len(response_batch[j]) + if response_masks is not None: + response_masks_batch[j] = torch.cat( + (torch.zeros_like(query_batch[j]), response_masks_batch[j]) + )[1:] + + masks[j, :start] = 0 + masks[j, end:] = 0 + if response_masks is not None: + masks[j, start:end] = ( + masks[j, start:end] * response_masks_batch[j][start:end] + ) + + if return_logits: + all_logits.append(logits.clone()) + else: + del logits + all_values.append(values.clone()) + all_logprobs.append(logprobs) + all_masks.append(masks) + + return ( + torch.cat(all_logprobs), + torch.cat(all_logits)[:, :-1] if return_logits else None, + torch.cat(all_values)[:, :-1], + torch.cat(all_masks)[:, :-1], + ) + + @PPODecorators.empty_device_cache() + def train_minibatch( + self, + old_logprobs: torch.FloatTensor, + values: torch.FloatTensor, + logprobs: torch.FloatTensor, + logits: torch.FloatTensor, + vpreds: torch.FloatTensor, + mask: torch.LongTensor, + advantages: torch.FloatTensor, + returns: torch.FloatTensor, + ): + """ + Train one PPO minibatch + + Args: + logprobs (`torch.FloatTensor`): + Log probabilities of the model, shape [batch_size, response_length] + values (`torch.FloatTensor`): + Values of the value head, shape [batch_size, response_length] + query (`torch.LongTensor`): + Encoded queries, shape [batch_size, query_length] + response (`torch.LongTensor`): + Encoded responses, shape [batch_size, response_length] + model_input (`torch.LongTensor`): + Concatenated queries and responses, shape [batch_size, query_length+response_length] + + Returns: + train_stats (dict[str, `torch.Tensor`]): + Dictionary of training statistics + """ + self.model.train() + loss_p, loss_v, train_stats = self.loss( + old_logprobs, values, logits, vpreds, logprobs, mask, advantages, returns + ) + loss = loss_p + loss_v + self.accelerator.backward(loss) + if self.config.max_grad_norm is not None: + if self.accelerator.sync_gradients: + self.accelerator.clip_grad_norm_( + self.model_params, self.config.max_grad_norm + ) + self.optimizer.step() + if self.config.use_habana: # pragma: no cover + self.htcore.mark_step() + # we call optimizer.zero_grad() every time and let `accelerator` handle accumulation + # see https://huggingface.co/docs/accelerate/usage_guides/gradient_accumulation#the-finished-code + self.optimizer.zero_grad() + return train_stats + + def compute_rewards( + self, + scores: torch.FloatTensor, + logprobs: torch.FloatTensor, + ref_logprobs: torch.FloatTensor, + masks: torch.LongTensor, + ): + """ + Compute per token rewards from scores and KL-penalty. + + Args: + scores (`torch.FloatTensor`): + Scores from the reward model, shape (`batch_size`) + logprobs (`torch.FloatTensor`): + Log probabilities of the model, shape (`batch_size`, `response_length`) + ref_logprobs (`torch.FloatTensor`): + Log probabilities of the reference model, shape (`batch_size`, `response_length`) + """ + rewards, non_score_rewards = [], [] + for score, logprob, ref_logprob, mask in zip( + scores, logprobs, ref_logprobs, masks + ): + # compute KL penalty (from difference in logprobs) + kl = self._kl_penalty(logprob, ref_logprob) + non_score_reward = -self.kl_ctl.value * kl + non_score_rewards.append(non_score_reward) + reward = non_score_reward.clone() + last_non_masked_index = mask.nonzero()[-1] + + # reward is preference model score + KL penalty + reward[last_non_masked_index] += score + rewards.append(reward) + return torch.stack(rewards), torch.stack(non_score_rewards) + + def _kl_penalty( + self, logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor + ) -> torch.FloatTensor: + if self.config.kl_penalty == "kl": + return logprob - ref_logprob + + if self.config.kl_penalty == "abs": + return (logprob - ref_logprob).abs() + + if self.config.kl_penalty == "mse": + return 0.5 * (logprob - ref_logprob).square() + + if self.config.kl_penalty == "full": + # Flip is required due to this issue? :https://github.com/pytorch/pytorch/issues/57459 + return F.kl_div( + ref_logprob, logprob, log_target=True, reduction="none" + ).sum(-1) + + raise NotImplementedError + + def compute_advantages( + self, + values: torch.FloatTensor, + rewards: torch.FloatTensor, + mask: torch.FloatTensor, + ): + lastgaelam = 0 + advantages_reversed = [] + gen_len = rewards.shape[-1] + + values = values * mask + rewards = rewards * mask + + if self.config.whiten_rewards: + rewards = masked_whiten(rewards, mask, shift_mean=False) + + for t in reversed(range(gen_len)): + nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 + delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t] + lastgaelam = delta + self.config.gamma * self.config.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1) + + returns = advantages + values + advantages = masked_whiten(advantages, mask) + advantages = advantages.detach() + return values, advantages, returns + + def loss( + self, + old_logprobs: torch.FloatTensor, + values: torch.FloatTensor, + logits: torch.FloatTensor, + vpreds: torch.FloatTensor, + logprobs: torch.FloatTensor, + mask: torch.LongTensor, + advantages: torch.FloatTensor, + returns: torch.FloatTensor, + ): + """ + Calculate policy and value losses. + + Args: + old_logprobs (`torch.FloatTensor`): + Log probabilities of the model, shape (`batch_size`, `response_length`) + values (`torch.FloatTensor`): + Values of the value head, shape (`batch_size`, `response_length`) + rewards (`torch.FloatTensor`): + Rewards from the reward model, shape (`batch_size`, `response_length`) + logits (`torch.FloatTensor`): + Logits of the model, shape (`batch_size`, `response_length`, `vocab_size`) + v_pred (`torch.FloatTensor`): + Values of the value head, shape (`batch_size`, `response_length`) + logprobs (`torch.FloatTensor`): + Log probabilities of the model, shape (`batch_size`, `response_length`) + """ + vpredclipped = clip_by_value( + vpreds, + values - self.config.cliprange_value, + values + self.config.cliprange_value, + ) + + vf_losses1 = (vpreds - returns) ** 2 + vf_losses2 = (vpredclipped - returns) ** 2 + vf_loss = 0.5 * masked_mean(torch.max(vf_losses1, vf_losses2), mask) + vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).float(), mask) + + ratio = torch.exp(logprobs - old_logprobs) + + pg_losses = -advantages * ratio + pg_losses2 = -advantages * torch.clamp( + ratio, 1.0 - self.config.cliprange, 1.0 + self.config.cliprange + ) + + pg_loss = masked_mean(torch.max(pg_losses, pg_losses2), mask) + pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), mask) + + loss = pg_loss + self.config.vf_coef * vf_loss + + avg_ratio = masked_mean(ratio, mask).item() + if avg_ratio > self.config.ratio_threshold: + warnings.warn( + f"The average ratio of batch ({avg_ratio:.2f}) exceeds threshold {self.config.ratio_threshold:.2f}. " + "Skipping batch." + ) + pg_loss = pg_loss * 0.0 + vf_loss = vf_loss * 0.0 + loss = loss * 0.0 + + entropy = masked_mean(entropy_from_logits(logits), mask) + + approxkl = 0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask) + policykl = masked_mean(old_logprobs - logprobs, mask) + + return_mean, return_var = masked_mean(returns, mask), masked_var(returns, mask) + value_mean, value_var = masked_mean(values, mask), masked_var(values, mask) + + stats = dict( + loss=dict( + policy=pg_loss.detach(), value=vf_loss.detach(), total=loss.detach() + ), + policy=dict( + entropy=entropy.detach(), + approxkl=approxkl.detach(), + policykl=policykl.detach(), + clipfrac=pg_clipfrac.detach(), + advantages=advantages.detach(), + advantages_mean=masked_mean(advantages, mask).detach(), + ratio=ratio.detach(), + ), + returns=dict(mean=return_mean.detach(), var=return_var.detach()), + val=dict( + vpred=masked_mean(vpreds, mask).detach(), + error=masked_mean((vpreds - returns) ** 2, mask).detach(), + clipfrac=vf_clipfrac.detach(), + mean=value_mean.detach(), + var=value_var.detach(), + ), + ) + return pg_loss, self.config.vf_coef * vf_loss, flatten_dict(stats) + + def record_step_stats(self, kl_coef: float, **data): + """ + Record training step statistics. + + + Args: + kl_coef (`float`): + KL coefficient + data (`dict`): + Dictionary of training step data + + Returns: + stats (`dict`): + Dictionary of training step statistics + """ + mask = data.pop("masks") + + kl_list = ((data["logprobs"] - data["ref_logprobs"]) * mask).sum(axis=-1) + mean_kl = kl_list.mean() + mean_entropy = (-data["logprobs"] * mask).sum(axis=-1).mean() + + mean_non_score_reward = masked_mean( + data["non_score_reward"], mask + ) # non_score_reward is size `batch_size`, `response_length` + mean_scores = data["scores"].mean() # scores is size `batch_size` + std_scores = data["scores"].std() + + if mean_kl.item() < -1.0: + # warn users + warnings.warn( + f"KL divergence is starting to become negative: {mean_kl.item():.2f} - this might be a precursor for " + " failed training." + " sometimes this happens because the generation kwargs are not correctly set. Please make sure" + " that the generation kwargs are set correctly, or review your training hyperparameters." + ) + + stats = { + "objective/kl": mean_kl, + "objective/kl_dist": kl_list, + "objective/logprobs": data["logprobs"], + "objective/ref_logprobs": data["ref_logprobs"], + "objective/kl_coef": kl_coef, + "objective/entropy": mean_entropy, + "ppo/mean_non_score_reward": mean_non_score_reward, + "ppo/mean_scores": mean_scores, + "ppo/std_scores": std_scores, + } + + # Log text properties + query_lens = torch.tensor( + [len(query) for query in data["queries"]], dtype=torch.float + ) + response_lens = torch.tensor( + [len(response) for response in data["responses"]], dtype=torch.float + ) + + stats["tokens/queries_len_mean"] = torch.mean(query_lens).cpu().numpy().item() + stats["tokens/queries_len_std"] = torch.std(query_lens).cpu().numpy().item() + stats["tokens/queries_dist"] = query_lens.cpu().numpy() + stats["tokens/responses_len_mean"] = ( + torch.mean(response_lens).cpu().numpy().item() + ) + stats["tokens/responses_len_std"] = ( + torch.std(response_lens).cpu().numpy().item() + ) + stats["tokens/responses_dist"] = response_lens.cpu().numpy() + + for k, v in data["train_stats"].items(): + stats[f"ppo/{k}"] = torch.mean(v, axis=0) + stats["ppo/val/var_explained"] = ( + 1 - stats["ppo/val/error"] / stats["ppo/returns/var"] + ) + return stats + + def log_stats( + self, + stats: dict, + batch: dict, + rewards: List[torch.FloatTensor], + columns_to_log: List[str] = ["query", "response"], + ): + """ + A function that logs all the training stats. Call it at the end of each epoch. + + Args: + stats (dict[str, Any]): + A dictionary of training stats. + batch (dict[str, Any]): + A dictionary of batch data, this contains the queries and responses. + rewards (`List[torch.FloatTensor]`): + A tensor of rewards. + """ + # Log only if we are in the main process + if self.accelerator.is_main_process: + logs = {} + + # Log stats + if not isinstance(rewards, torch.Tensor): + rewards = torch.tensor(rewards).to(self.current_device) + + if "query" not in batch.keys() and "response" not in batch.keys(): + # warn the user that the game logs will not be logged + warnings.warn( + "The game logs will not be logged because the batch does not contain the keys 'query' and " + "'response'. " + ) + elif self.config.log_with == "wandb": # pragma: no cover + if importlib.util.find_spec("wandb") is None: + raise ImportError("import wandb error") + import wandb # pylint: disable=E0611, E0401 + + if any( + [ + column_to_log not in batch.keys() + for column_to_log in columns_to_log + ] + ): + raise ValueError( + f"Columns to log {columns_to_log} are not present in the batch {batch.keys()}." + ) + + batch_list = [batch[column_to_log] for column_to_log in columns_to_log] + + table_rows = [list(r) for r in zip(*batch_list, rewards.cpu().tolist())] + logs.update( + { + "game_log": wandb.Table( + columns=[*columns_to_log, "reward"], rows=table_rows + ) + } + ) + # All reduce rewards if distributed + if self.is_distributed: # pragma: no cover + import torch.distributed as dist + + dist.barrier() + + dist.all_reduce(rewards, op=torch.distributed.ReduceOp.SUM) + rewards /= self.accelerator.num_processes + + logs.update(stats) + + # manually cast in fp32 for bf16 torch tensors + for k, v in logs.items(): + if isinstance(v, torch.Tensor) and v.dtype == torch.bfloat16: + logs[k] = v.float() + + logs["env/reward_mean"] = torch.mean(rewards).cpu().numpy().item() + logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item() + logs["env/reward_dist"] = rewards.cpu().numpy() + + if self.config.log_with == "tensorboard": + # update the current step + self.current_step += 1 + + self.accelerator.log( + logs, + step=self.current_step + if self.config.log_with == "tensorboard" + else None, + ) + + else: + if self.is_distributed: # pragma: no cover + import torch.distributed as dist + + if not isinstance(rewards, torch.Tensor): + rewards = torch.tensor(rewards).to(self.current_device) + + dist.barrier() + dist.all_reduce(rewards, op=torch.distributed.ReduceOp.SUM) + + def create_model_card( + self, path: str, model_name: Optional[str] = "TRL Model" + ) -> None: + """Creates and saves a model card for a TRL model. + + Args: + path (`str`): The path to save the model card to. + model_name (`str`, *optional*): The name of the model, defaults to `TRL Model`. + """ + try: + user = whoami()["name"] + # handle the offline case + except: # noqa + warnings.warn( + "Cannot retrieve user information assuming you are running in offline mode." + ) + return + + if not os.path.exists(path): + os.makedirs(path) + + model_card_content = MODEL_CARD_TEMPLATE.format( + model_name=model_name, model_id=f"{user}/{path}" + ) + with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f: + f.write(model_card_content) + + def _save_pretrained(self, save_directory: str) -> None: + self.accelerator.unwrap_model(self.model).save_pretrained(save_directory) + self.tokenizer.save_pretrained(save_directory) + self.create_model_card(save_directory) + + def _prepare_deepspeed(self, model: PreTrainedModelWrapper): # pragma: no cover + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + config_kwargs = deepspeed_plugin.deepspeed_config + if model is not None: + if hasattr(model, "config"): + hidden_size = ( + max(model.config.hidden_sizes) + if getattr(model.config, "hidden_sizes", None) + else getattr(model.config, "hidden_size", None) + ) + if ( + hidden_size is not None + and config_kwargs["zero_optimization"]["stage"] == 3 + ): + # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace + # cache @ step 0: expected module 1, but got module 0` + # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 + config_kwargs.update( + { + "zero_optimization.reduce_bucket_size": hidden_size + * hidden_size, + "zero_optimization.stage3_param_persistence_threshold": 10 + * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": 0.9 + * hidden_size + * hidden_size, + } + ) + + # If ZeRO-3 is used, we shard both the active and reference model. + # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled + # (stage 0) + if config_kwargs["zero_optimization"]["stage"] != 3: + config_kwargs["zero_optimization"]["stage"] = 0 + model, *_ = deepspeed.initialize(model=model, config=config_kwargs) + model.eval() + return model + + def wrap_fw_for_hpu_graph_mode(self, model: PreTrainedModelWrapper): # pragma: no cover + model = self.accelerator.unwrap_model(model) + if hasattr(model, "hpu_graph_fw"): + model.forward = model.hpu_graph_fw + else: + from habana_frameworks.torch.hpu import wrap_in_hpu_graph # pylint: disable=E0611, E0401 + + model.orig_fw = model.forward + model = wrap_in_hpu_graph(model) + model.hpu_graph_fw = model.forward + + def unwrap_fw_for_hpu_graph_mode(self, model: PreTrainedModelWrapper): # pragma: no cover + model = self.accelerator.unwrap_model(model) + if hasattr(model, "orig_fw"): + model.forward = model.orig_fw + + def wrap_generation_for_hpu_graph_mode(self, model: PreTrainedModelWrapper): # pragma: no cover + from habana_frameworks.torch.hpu import wrap_in_hpu_graph # pylint: disable=E0611, E0401 + + model = self.accelerator.unwrap_model(model) + if getattr(model, "is_peft_model", False): + if hasattr(model.pretrained_model.base_model.model, "hpu_graph_fw"): + model.pretrained_model.base_model.model.forward = ( + model.pretrained_model.base_model.model.hpu_graph_fw + ) + else: + model.pretrained_model.base_model.model.orig_fw = ( + model.pretrained_model.base_model.model.forward + ) + model.pretrained_model.base_model.model = wrap_in_hpu_graph( + model.pretrained_model.base_model.model + ) + model.pretrained_model.base_model.model.hpu_graph_fw = ( + model.pretrained_model.base_model.model.forward + ) + else: + if hasattr(model.pretrained_model, "hpu_graph_fw"): + model.pretrained_model.forward = model.pretrained_model.hpu_graph_fw + else: + model.pretrained_model.orig_fw = model.pretrained_model.forward + model.pretrained_model = wrap_in_hpu_graph(model.pretrained_model) + model.pretrained_model.hpu_graph_fw = model.pretrained_model.forward + + def unwrap_generation_for_hpu_graph_mode(self, model: PreTrainedModelWrapper): # pragma: no cover + model = self.accelerator.unwrap_model(model) + if getattr(model, "is_peft_model", False): + if hasattr(model.pretrained_model.base_model.model, "orig_fw"): + model.pretrained_model.base_model.model.forward = ( + model.pretrained_model.base_model.model.orig_fw + ) + else: + if hasattr(model.pretrained_model, "orig_fw"): + model.pretrained_model.forward = model.pretrained_model.orig_fw diff --git a/tests/CI/test_ppo.py b/tests/CI/test_ppo.py new file mode 100644 index 00000000000..617db562625 --- /dev/null +++ b/tests/CI/test_ppo.py @@ -0,0 +1,344 @@ +import os +import unittest + +from transformers import ( + AutoTokenizer, + pipeline, +) +import copy +import torch +from peft import LoraConfig +import torch.utils.data as data +from torch.utils.data import DataLoader + +from intel_extension_for_transformers.transformers.ppo_core import ( + LengthSampler, + set_seed, +) +from intel_extension_for_transformers.transformers.ppo_config import PPOConfig +from intel_extension_for_transformers.transformers.ppo_trainer import PPOTrainer +from intel_extension_for_transformers.transformers.modeling.trl_models import ( + AutoModelForCausalLMWithValueHead, +) +from huggingface_hub import PyTorchModelHubMixin +from tqdm import tqdm + +MODEL_NAME = "hf-internal-testing/tiny-random-GPTJForCausalLM" +REWARD_NAME = "hf-internal-testing/tiny-random-GPTJForSequenceClassification" +os.environ["ACCELERATE_USE_IPEX"] = "false" + +class DummyDataset(data.Dataset): + def __init__(self): + self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + if getattr(self.tokenizer, "pad_token", None) is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.max_prompt_length = 128 + self.max_length = 256 + question = "you are a AI assistant that created by Intel." + chosen = "intel-extension-for-transformers is based in SH" + + self.encoded_dict = {} + + query = "Question: " + question + "\n\nAnswer: " + tokenized_question = self.tokenizer(query, truncation=True) + + self.encoded_dict["query"] = query + self.encoded_dict["input_ids"] = torch.tensor(tokenized_question["input_ids"]) + + def __len__(self): + return 10 + + def __getitem__(self, index): + """Returns one data pair (source and target).""" + if index < 10: + return self.encoded_dict + + +def collator(data): + return dict((key, [d[key] for d in data]) for key in data[0]) + + +class TestPPO(unittest.TestCase): + @classmethod + def setUpClass(self): + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + self.model = AutoModelForCausalLMWithValueHead.from_pretrained( + MODEL_NAME, + peft_config=lora_config, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + ) + self.model = self.model.to(torch.bfloat16) + self.config = PPOConfig( + steps=10, + model_name=MODEL_NAME, + learning_rate=1.41e-5, + log_with=None, + batch_size=8, + mini_batch_size=1, + gradient_accumulation_steps=8, + optimize_device_cache=True, + early_stopping=True, + target_kl=0.1, + ppo_epochs=4, + seed=100, + init_kl_coef=0.2, + adap_kl_ctrl=True, + use_habana=False, + pad_for_acceleration=False, + pad_max_len=512, + pad_max_input_len=128, + ) + self.dataset = DummyDataset() + self.trainer = PPOTrainer( + self.config, + self.model, + ref_model=None, + tokenizer=self.dataset.tokenizer, + dataset=self.dataset, + data_collator=collator, + optimizer=None, + ) + self.sentiment_pipe = pipeline( + "sentiment-analysis", + model=REWARD_NAME, + tokenizer=self.dataset.tokenizer, + return_token_type_ids=False, + device="cpu", + model_kwargs={ + "low_cpu_mem_usage": True, + "torch_dtype": torch.bfloat16, + }, + ) + + def test_init(self): + self.trainer = PPOTrainer( + self.config, + self.model, + ref_model=None, + tokenizer=self.dataset.tokenizer, + dataset=self.dataset, + data_collator=collator, + optimizer=None, + ) + + self.assertTrue(isinstance(self.trainer, PyTorchModelHubMixin)) + + def test_generation_batched(self): + generation_kwargs = { + # "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": self.dataset.tokenizer.pad_token_id, + "eos_token_id": 100_000, + } + epochs = tqdm( + enumerate(self.trainer.dataloader), + total=len(self.trainer.dataloader), + desc="rl progress", + ) + for epoch, batch in epochs: + question_tensors = batch["input_ids"] + response_tensors, ref_response= self.trainer.generate( + question_tensors, + return_prompt=False, + length_sampler=LengthSampler(100, 120), + generate_ref_response= True, + **generation_kwargs, + ) + batch["response"] = self.dataset.tokenizer.batch_decode( + response_tensors, skip_special_tokens=True + ) + + def test_generation(self): + generation_kwargs = { + # "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": self.dataset.tokenizer.pad_token_id, + "eos_token_id": 100_000, + } + epochs = tqdm( + enumerate(self.trainer.dataloader), + total=len(self.trainer.dataloader), + desc="rl progress", + ) + for epoch, batch in epochs: + question_tensors = batch["input_ids"] + for question_tensor in question_tensors: + response_tensor, ref_response= self.trainer.generate( + question_tensor, + return_prompt=False, + length_sampler=LengthSampler(100, 120), + generate_ref_response= True, + **generation_kwargs, + ) + response = self.dataset.tokenizer.batch_decode( + response_tensor, skip_special_tokens=True + ) + + def test_train(self): + generation_kwargs = { + # "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": self.dataset.tokenizer.pad_token_id, + "eos_token_id": 100_000, + } + sent_kwargs = { + "return_all_scores": True, + "function_to_apply": "none", + "truncation": True, + } + epochs = tqdm( + enumerate(self.trainer.dataloader), + total=len(self.trainer.dataloader), + desc="rl progress", + ) + for epoch, batch in epochs: + question_tensors = batch["input_ids"] + response_tensors = self.trainer.generate( + question_tensors, + return_prompt=False, + length_sampler=LengthSampler(100, 120), + **generation_kwargs, + ) + batch["response"] = self.dataset.tokenizer.batch_decode( + response_tensors, skip_special_tokens=True + ) + # Compute reward score (using the sentiment analysis pipeline) + texts = [q + r for q, r in zip(batch["query"], batch["response"])] + pipe_outputs = self.sentiment_pipe(texts, **sent_kwargs) + rewards = [ + torch.tensor(output[0]["score"] - 0.0) for output in pipe_outputs + ] + # Run PPO step + stats = self.trainer.step(question_tensors, response_tensors, rewards) + self.trainer.log_stats(stats, batch, rewards) + + self.trainer.save_pretrained("/tmp/output") + + def test_train_with_pad_and_custom_config(self): + self.config.pad_for_acceleration = True + self.config.adap_kl_ctrl = False + self.config.use_score_scaling = True + self.config.whiten_rewards = True + self.config.kl_penalty = "full" + self.config.max_grad_norm = 1.0 + self.config.early_stopping = True + self.config.use_score_norm = True + optimizer = torch.optim.SGD(self.model.parameters(), lr=1.0) + lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: 1.0) + self.trainer = PPOTrainer( + self.config, + self.model, + ref_model=self.model, + tokenizer=self.dataset.tokenizer, + dataset=self.dataset, + data_collator=collator, + lr_scheduler=lr_scheduler, + optimizer=optimizer, + ) + generation_kwargs = { + # "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": self.dataset.tokenizer.pad_token_id, + "eos_token_id": 100_000, + } + sent_kwargs = { + "return_all_scores": True, + "function_to_apply": "none", + "truncation": True, + } + epochs = tqdm( + enumerate(self.trainer.dataloader), + total=len(self.trainer.dataloader), + desc="rl progress", + ) + for epoch, batch in epochs: + question_tensors = batch["input_ids"] + response_tensors = self.trainer.generate( + question_tensors, + return_prompt=False, + length_sampler=LengthSampler(100, 120), + **generation_kwargs, + ) + batch["response"] = self.dataset.tokenizer.batch_decode( + response_tensors, skip_special_tokens=True + ) + # Compute reward score (using the sentiment analysis pipeline) + texts = [q + r for q, r in zip(batch["query"], batch["response"])] + pipe_outputs = self.sentiment_pipe(texts, **sent_kwargs) + rewards = [ + torch.tensor(output[0]["score"] - 0.0) for output in pipe_outputs + ] + # Run PPO step + stats = self.trainer.step(question_tensors, response_tensors, rewards) + self.trainer.log_stats(stats, batch, rewards) + + self.trainer.save_pretrained("/tmp/ppo_output") + + def test_train_no_peft(self): + generation_kwargs = { + # "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": self.dataset.tokenizer.pad_token_id, + "eos_token_id": 100_000, + } + sent_kwargs = { + "return_all_scores": True, + "function_to_apply": "none", + "truncation": True, + } + epochs = tqdm( + enumerate(self.trainer.dataloader), + total=len(self.trainer.dataloader), + desc="rl progress", + ) + self.model = AutoModelForCausalLMWithValueHead.from_pretrained( + MODEL_NAME, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + ) + self.trainer = PPOTrainer( + self.config, + self.model, + tokenizer=self.dataset.tokenizer, + dataset=self.dataset, + data_collator=collator, + ) + for epoch, batch in epochs: + question_tensors = batch["input_ids"] + response_tensors = self.trainer.generate( + question_tensors, + return_prompt=False, + **generation_kwargs, + ) + batch["response"] = self.dataset.tokenizer.batch_decode( + response_tensors, skip_special_tokens=True + ) + # Compute reward score (using the sentiment analysis pipeline) + texts = [q + r for q, r in zip(batch["query"], batch["response"])] + pipe_outputs = self.sentiment_pipe(texts, **sent_kwargs) + rewards = [ + torch.tensor(output[0]["score"] - 0.0) for output in pipe_outputs + ] + # Run PPO step + stats = self.trainer.step(question_tensors, response_tensors, rewards) + self.trainer.log_stats(stats, batch, rewards) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/requirements.txt b/tests/requirements.txt index f8ad7474e77..9a0e37bef0b 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -19,3 +19,4 @@ wget git+https://github.com/huggingface/optimum.git@fa8ad73de9063885476ee39dff7c0bffbd45bd2d git+https://github.com/huggingface/optimum-intel.git peft +tyro