From cd1d0625f0b2b7ac263f310c3f005cb1bf60ec20 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 6 Jan 2023 23:49:04 +0000 Subject: [PATCH 01/57] WIP Signed-off-by: Antoni Baum --- configs/sweeps/ppo_sweep.yml | 4 +-- examples/ppo_sentiments.py | 3 +- trlx/sweep.py | 60 +++++++++++++++++++++++------------- 3 files changed, 43 insertions(+), 24 deletions(-) diff --git a/configs/sweeps/ppo_sweep.yml b/configs/sweeps/ppo_sweep.yml index 95469ef1b..cb939e376 100644 --- a/configs/sweeps/ppo_sweep.yml +++ b/configs/sweeps/ppo_sweep.yml @@ -1,12 +1,12 @@ tune_config: mode: "max" - metric: "mean_reward" + metric: "reward/mean" search_alg: "random" scheduler: "fifo" num_samples: 32 # https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs -lr_init: +lr: strategy: "loguniform" values: [0.00001, 0.01] init_kl_coef: diff --git a/examples/ppo_sentiments.py b/examples/ppo_sentiments.py index 8ded11043..cef7babcb 100644 --- a/examples/ppo_sentiments.py +++ b/examples/ppo_sentiments.py @@ -18,10 +18,11 @@ def get_positive_score(scores): return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] -default_config = yaml.safe_load(open("configs/ppo_config.yml")) +# default_config = yaml.safe_load(open("configs/ppo_config.yml")) def main(hparams={}): + default_config = hparams.pop("default_config") config = TRLConfig.update(default_config, hparams) if torch.cuda.is_available(): diff --git a/trlx/sweep.py b/trlx/sweep.py index a18b9357f..f3c8ce378 100644 --- a/trlx/sweep.py +++ b/trlx/sweep.py @@ -1,23 +1,47 @@ # python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml examples/ppo_sentiments.py import argparse +import os import importlib from pathlib import Path import ray import yaml from ray import tune +from ray.air import ScalingConfig, session +from ray.train.torch import TorchTrainer from ray.tune.logger import CSVLoggerCallback from trlx.ray_tune import get_param_space, get_tune_config -from trlx.ray_tune.wandb import create_report, log_trials + +# from trlx.ray_tune.wandb import create_report, log_trials def tune_function( train_function, param_space: dict, tune_config: dict, resources: dict ): + default_config = yaml.safe_load(open("configs/ppo_config.yml")) + param_space["default_config"] = default_config + + def train_function_wrapper(config): + os.environ["WORLD_RANK"] = str(session.get_world_rank()) + os.environ["LOCAL_RANK"] = str(session.get_local_rank()) + os.environ["WORLD_SIZE"] = str(session.get_world_size()) + os.environ["LOCAL_WORLD_SIZE"] = str(session.get_local_world_size()) + + return train_function(config) + + param_space_train = {"train_loop_config": param_space} tuner = tune.Tuner( - tune.with_resources(train_function, resources=resources), - param_space=param_space, + TorchTrainer( + train_function_wrapper, + scaling_config=ScalingConfig( + trainer_resources={"CPU": 0}, + num_workers=2, + use_gpu=bool(resources["gpu"]), + resources_per_worker={"CPU": resources["cpu"], "GPU": resources["gpu"]}, + ), + ), + param_space=param_space_train, tune_config=tune.TuneConfig(**tune_config), run_config=ray.air.RunConfig( local_dir="ray_results", callbacks=[CSVLoggerCallback()] @@ -27,18 +51,18 @@ def tune_function( results = tuner.fit() project_name = tune_config.get("project_name", "sweep") - log_trials( - tuner._local_tuner.get_experiment_checkpoint_dir(), - project_name, - ) + # log_trials( + # tuner._local_tuner.get_experiment_checkpoint_dir(), + # project_name, + # ) - create_report( - project_name, - param_space, - tune_config, - Path(tuner._local_tuner.get_experiment_checkpoint_dir()).stem, - results.get_best_result().config, - ) + # create_report( + # project_name, + # param_space, + # tune_config, + # Path(tuner._local_tuner.get_experiment_checkpoint_dir()).stem, + # results.get_best_result().config, + # ) print("Best hyperparameters found were: ", results.get_best_result().config) @@ -77,12 +101,6 @@ def tune_function( tune_config = get_tune_config(config.pop("tune_config")) param_space = get_param_space(config) - # Initialize Ray. - if args.server_address: - ray.init(address=f"ray://{args.server_address}") - else: - ray.init() - resources = { "cpu": args.num_cpus, "gpu": args.num_gpus, @@ -100,7 +118,7 @@ def tune_function( script_path = args.script.replace(".py", "").replace("/", ".") script = importlib.import_module(script_path) # Register the training function that will be used for training the model. - tune.register_trainable("train_function", script.main) + # tune.register_trainable("train_function", script.main) tune_function(script.main, param_space, tune_config, resources) # Shut down Ray. From ea38369632ef8101465996412b97d5d43dd0cc36 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 13 Jan 2023 12:58:18 -0800 Subject: [PATCH 02/57] WIP Signed-off-by: Antoni Baum --- configs/ppo_config.yml | 2 +- trlx/ray_train/__init__.py | 0 trlx/ray_train/launch.py | 1086 ++++++++++++++++++++++++++++++++++++ trlx/sweep.py | 28 +- 4 files changed, 1113 insertions(+), 3 deletions(-) create mode 100644 trlx/ray_train/__init__.py create mode 100644 trlx/ray_train/launch.py diff --git a/configs/ppo_config.yml b/configs/ppo_config.yml index e1db1f66d..3395f5a42 100644 --- a/configs/ppo_config.yml +++ b/configs/ppo_config.yml @@ -2,7 +2,7 @@ train: seq_length: 1024 epochs: 100 total_steps: 10000 - batch_size: 128 + batch_size: 64 checkpoint_interval: 10000 eval_interval: 100 diff --git a/trlx/ray_train/__init__.py b/trlx/ray_train/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/trlx/ray_train/launch.py b/trlx/ray_train/launch.py new file mode 100644 index 000000000..9825028f1 --- /dev/null +++ b/trlx/ray_train/launch.py @@ -0,0 +1,1086 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import importlib +import logging +import os +import subprocess +import sys +import warnings +from ast import literal_eval +from pathlib import Path +from typing import Dict, List + +import torch + +import psutil +from accelerate.commands.config import default_config_file, load_config_from_file +from accelerate.commands.config.config_args import SageMakerConfig +from accelerate.commands.config.config_utils import DYNAMO_BACKENDS +from accelerate.state import get_int_from_env +from accelerate.utils import ( + ComputeEnvironment, + DistributedType, + DynamoBackend, + PrecisionType, + PrepareForLaunch, + _filter_args, + is_deepspeed_available, + is_rich_available, + is_sagemaker_available, + is_torch_version, + patch_environment, +) +from accelerate.utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS +from accelerate.utils.dataclasses import SageMakerDistributedType +from accelerate.utils.launch import env_var_path_add + + +if is_rich_available(): + from rich import get_console + from rich.logging import RichHandler + + FORMAT = "%(message)s" + logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()]) + + +logger = logging.getLogger(__name__) + +options_to_group = { + "--multi-gpu": "Distributed GPUs", + "--tpu": "TPU", + "--mps": "MPS", + "--use_mps_device": "MPS", + "--use_deepspeed": "DeepSpeed Arguments", + "--use_fsdp": "FSDP Arguments", + "--use_megatron_lm": "Megatron-LM Arguments", +} + + +def clean_option(option): + "Finds all cases of - after the first two characters and changes them to _" + if option.startswith("--"): + return option[:3] + option[3:].replace("-", "_") + + +class _CustomHelpAction(argparse._HelpAction): + """ + This is a custom help action that will hide all arguments that are not used in the command line when the help is + called. This is useful for the case where the user is using a specific platform and only wants to see the arguments + for that platform. + """ + + def __call__(self, parser, namespace, values, option_string=None): + if "accelerate" in sys.argv[0] and "launch" in sys.argv[1:]: + args = sys.argv[2:] + else: + args = sys.argv[1:] + opts = parser._actions + titles = [ + "Hardware Selection Arguments", + "Resource Selection Arguments", + "Training Paradigm Arguments", + "positional arguments", + "optional arguments", + ] + if len(args) > 1: + used_platforms = [arg for arg in args if arg in options_to_group.keys()] + args = list(map(clean_option, args)) + used_titles = [options_to_group[o] for o in used_platforms] + for i, arg in enumerate(opts): + # If the argument's container is outside of the used titles, hide it + if arg.container.title not in titles + used_titles: + setattr(opts[i], "help", argparse.SUPPRESS) + # If the argument is hardware selection, but not being passed, hide it + elif arg.container.title == "Hardware Selection Arguments": + if set(arg.option_strings).isdisjoint(set(args)): + setattr(opts[i], "help", argparse.SUPPRESS) + else: + setattr(opts[i], "help", arg.help + " (currently selected)") + # If the argument is a training paradigm, but not being passed, hide it + elif arg.container.title == "Training Paradigm Arguments": + if set(arg.option_strings).isdisjoint(set(used_platforms)): + setattr(opts[i], "help", argparse.SUPPRESS) + else: + setattr(opts[i], "help", arg.help + " (currently selected)") + for i, group in enumerate(list(parser._action_groups)): + # If all arguments in the group are hidden, hide the group + if all([arg.help == argparse.SUPPRESS for arg in group._group_actions]): + parser._action_groups.remove(group) + + super().__call__(parser, namespace, values, option_string) + + +def launch_command_parser(subparsers=None): + if subparsers is not None: + parser = subparsers.add_parser("launch", add_help=False) + else: + parser = argparse.ArgumentParser("Accelerate launch command", add_help=False) + + parser.register("action", "help", _CustomHelpAction) + parser.add_argument("-h", "--help", action="help", help="Show this help message and exit.") + + parser.add_argument( + "--config_file", default=None, help="The config file to use for the default values in the launching script." + ) + # Hardware selection arguments + hardware_args = parser.add_argument_group( + "Hardware Selection Arguments", "Arguments for selecting the hardware to be used." + ) + hardware_args.add_argument( + "--cpu", default=False, action="store_true", help="Whether or not to force the training on the CPU." + ) + hardware_args.add_argument( + "--mps", + default=False, + action="store_true", + help="Whether or not this should use MPS-enabled GPU device on MacOS machines.", + ) + hardware_args.add_argument( + "--multi_gpu", + default=False, + action="store_true", + help="Whether or not this should launch a distributed GPU training.", + ) + hardware_args.add_argument( + "--tpu", default=False, action="store_true", help="Whether or not this should launch a TPU training." + ) + hardware_args.add_argument( + "--use_mps_device", + default=False, + action="store_true", + help="This argument is deprecated, use `--mps` instead.", + ) + + # Resource selection arguments + resource_args = parser.add_argument_group( + "Resource Selection Arguments", "Arguments for fine-tuning how available hardware should be used." + ) + resource_args.add_argument( + "--dynamo_backend", + type=str, + choices=["no"] + [b.lower() for b in DYNAMO_BACKENDS], + help="Choose a backend to optimize your training with dynamo, see more at " + "https://github.com/pytorch/torchdynamo.", + ) + resource_args.add_argument( + "--mixed_precision", + type=str, + choices=["no", "fp16", "bf16"], + help="Whether or not to use mixed precision training. " + "Choose between FP16 and BF16 (bfloat16) training. " + "BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.", + ) + resource_args.add_argument( + "--fp16", + default=False, + action="store_true", + help="This argument is deprecated, use `--mixed_precision fp16` instead.", + ) + resource_args.add_argument( + "--num_processes", type=int, default=None, help="The total number of processes to be launched in parallel." + ) + resource_args.add_argument( + "--num_machines", type=int, default=None, help="The total number of machines used in this training." + ) + resource_args.add_argument( + "--num_cpu_threads_per_process", + type=int, + default=None, + help="The number of CPU threads per process. Can be tuned for optimal performance.", + ) + + # Training Paradigm arguments + paradigm_args = parser.add_argument_group( + "Training Paradigm Arguments", "Arguments for selecting which training paradigm to be used." + ) + paradigm_args.add_argument( + "--use_deepspeed", + default=False, + action="store_true", + help="Whether to use deepspeed.", + ) + paradigm_args.add_argument( + "--use_fsdp", + default=False, + action="store_true", + help="Whether to use fsdp.", + ) + paradigm_args.add_argument( + "--use_megatron_lm", + default=False, + action="store_true", + help="Whether to use Megatron-LM.", + ) + + # distributed GPU training arguments + distributed_args = parser.add_argument_group("Distributed GPUs", "Arguments related to distributed GPU training.") + distributed_args.add_argument( + "--gpu_ids", + default=None, + help="What GPUs (by id) should be used for training on this machine as a comma-seperated list", + ) + distributed_args.add_argument( + "--same_network", + default=False, + action="store_true", + help="Whether all machines used for multinode training exist on the same local network.", + ) + distributed_args.add_argument( + "--machine_rank", type=int, default=None, help="The rank of the machine on which this script is launched." + ) + distributed_args.add_argument( + "--main_process_ip", type=str, default=None, help="The IP address of the machine of rank 0." + ) + distributed_args.add_argument( + "--main_process_port", + type=int, + default=None, + help="The port to use to communicate with the machine of rank 0.", + ) + # Rendezvous related arguments + distributed_args.add_argument( + "--rdzv_conf", + type=str, + default="", + help="Additional rendezvous configuration (=,=,...).", + ) + distributed_args.add_argument( + "--max_restarts", + type=int, + default=0, + help="Maximum number of worker group restarts before failing.", + ) + distributed_args.add_argument( + "--monitor_interval", + type=float, + default=5, + help="Interval, in seconds, to monitor the state of workers.", + ) + parser.add_argument( + "-m", + "--module", + action="store_true", + help="Change each process to interpret the launch script as a Python module, executing with the same behavior as 'python -m'.", + ) + parser.add_argument( + "--no_python", + action="store_true", + help="Skip prepending the training script with 'python' - just execute it directly. Useful when the script is not a Python script.", + ) + + # tpu arguments + tpu_args = parser.add_argument_group("TPU", "Arguments related to TPU.") + tpu_args.add_argument( + "--main_training_function", + type=str, + default=None, + help="The name of the main function to be executed in your script (only for TPU training).", + ) + tpu_args.add_argument( + "--downcast_bf16", + action="store_true", + help="Whether when using bf16 precision on TPUs if both float and double tensors are cast to bfloat16 or if double tensors remain as float32.", + ) + + # DeepSpeed arguments + deepspeed_args = parser.add_argument_group("DeepSpeed Arguments", "Arguments related to DeepSpeed.") + deepspeed_args.add_argument( + "--deepspeed_config_file", + default=None, + type=str, + help="DeepSpeed config file.", + ) + deepspeed_args.add_argument( + "--zero_stage", + default=None, + type=int, + help="DeepSpeed's ZeRO optimization stage (useful only when `use_deepspeed` flag is passed).", + ) + deepspeed_args.add_argument( + "--offload_optimizer_device", + default=None, + type=str, + help="Decides where (none|cpu|nvme) to offload optimizer states (useful only when `use_deepspeed` flag is passed).", + ) + deepspeed_args.add_argument( + "--offload_param_device", + default=None, + type=str, + help="Decides where (none|cpu|nvme) to offload parameters (useful only when `use_deepspeed` flag is passed).", + ) + deepspeed_args.add_argument( + "--gradient_accumulation_steps", + default=None, + type=int, + help="No of gradient_accumulation_steps used in your training script (useful only when `use_deepspeed` flag is passed).", + ) + deepspeed_args.add_argument( + "--gradient_clipping", + default=None, + type=float, + help="gradient clipping value used in your training script (useful only when `use_deepspeed` flag is passed).", + ) + deepspeed_args.add_argument( + "--zero3_init_flag", + default=None, + type=str, + help="Decides Whether (true|false) to enable `deepspeed.zero.Init` for constructing massive models. " + "Only applicable with DeepSpeed ZeRO Stage-3.", + ) + deepspeed_args.add_argument( + "--zero3_save_16bit_model", + default=None, + type=str, + help="Decides Whether (true|false) to save 16-bit model weights when using ZeRO Stage-3. " + "Only applicable with DeepSpeed ZeRO Stage-3.", + ) + deepspeed_args.add_argument( + "--deepspeed_hostfile", + default=None, + type=str, + help="DeepSpeed hostfile for configuring multi-node compute resources.", + ) + deepspeed_args.add_argument( + "--deepspeed_exclusion_filter", + default=None, + type=str, + help="DeepSpeed exclusion filter string when using mutli-node setup.", + ) + deepspeed_args.add_argument( + "--deepspeed_inclusion_filter", + default=None, + type=str, + help="DeepSpeed inclusion filter string when using mutli-node setup.", + ) + deepspeed_args.add_argument( + "--deepspeed_multinode_launcher", + default=None, + type=str, + help="DeepSpeed multi-node launcher to use.", + ) + + # fsdp arguments + fsdp_args = parser.add_argument_group("FSDP Arguments", "Arguments related to Fully Shared Data Parallelism.") + fsdp_args.add_argument( + "--fsdp_offload_params", + default="false", + type=str, + help="Decides Whether (true|false) to offload parameters and gradients to CPU. (useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_min_num_params", + type=int, + default=1e8, + help="FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_sharding_strategy", + type=int, + default=1, + help="FSDP's Sharding Strategy. (useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_auto_wrap_policy", + type=str, + default=None, + help="FSDP's auto wrap policy. (useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_transformer_layer_cls_to_wrap", + default=None, + type=str, + help="Transformer layer class name (case-sensitive) to wrap ,e.g, `BertLayer`, `GPTJBlock`, `T5Block` .... " + "(useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_backward_prefetch_policy", + default=None, + type=str, + help="FSDP's backward prefetch policy. (useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_state_dict_type", + default=None, + type=str, + help="FSDP's state dict type. (useful only when `use_fsdp` flag is passed).", + ) + + # megatron_lm args + megatron_lm_args = parser.add_argument_group("Megatron-LM Arguments", "Arguments related to Megatron-LM.") + megatron_lm_args.add_argument( + "--megatron_lm_tp_degree", + type=int, + default=1, + help="Megatron-LM's Tensor Parallelism (TP) degree. (useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_pp_degree", + type=int, + default=1, + help="Megatron-LM's Pipeline Parallelism (PP) degree. (useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_num_micro_batches", + type=int, + default=None, + help="Megatron-LM's number of micro batches when PP degree > 1. (useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_sequence_parallelism", + default=None, + type=str, + help="Decides Whether (true|false) to enable Sequence Parallelism when TP degree > 1. " + "(useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_recompute_activations", + default=None, + type=str, + help="Decides Whether (true|false) to enable Selective Activation Recomputation. " + "(useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_use_distributed_optimizer", + default=None, + type=str, + help="Decides Whether (true|false) to use distributed optimizer " + "which shards optimizer state and gradients across Data Pralellel (DP) ranks. " + "(useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_gradient_clipping", + default=1.0, + type=float, + help="Megatron-LM's gradient clipping value based on global L2 Norm (0 to disable). " + "(useful only when `use_megatron_lm` flag is passed).", + ) + + # AWS arguments + aws_args = parser.add_argument_group("AWS Arguments", "Arguments related to AWS.") + aws_args.add_argument( + "--aws_access_key_id", + type=str, + default=None, + help="The AWS_ACCESS_KEY_ID used to launch the Amazon SageMaker training job", + ) + aws_args.add_argument( + "--aws_secret_access_key", + type=str, + default=None, + help="The AWS_SECRET_ACCESS_KEY used to launch the Amazon SageMaker training job.", + ) + parser.add_argument( + "--debug", + action="store_true", + help="Whether to print out the torch.distributed stack trace when something fails.", + ) + parser.add_argument( + "training_script", + type=str, + help=( + "The full path to the script to be launched in parallel, followed by all the arguments for the training " + "script." + ), + ) + + # Other arguments of the training scripts + parser.add_argument("training_script_args", nargs=argparse.REMAINDER, help="Arguments of the training script.") + + if subparsers is not None: + parser.set_defaults(func=launch_command) + return parser + + +def simple_launcher(args): + cmd = [] + if args.no_python and args.module: + raise ValueError("--module and --no_python cannot be used together") + if not args.no_python: + cmd.append(sys.executable) + if args.module: + cmd.append("-m") + cmd.append(args.training_script) + cmd.extend(args.training_script_args) + + current_env = os.environ.copy() + current_env["ACCELERATE_USE_CPU"] = str(args.cpu or args.use_cpu) + if args.use_mps_device: + warnings.warn( + '`use_mps_device` flag is deprecated and will be removed in version 0.15.0 of 🤗 Accelerate. Use "--mps" instead.', + FutureWarning, + ) + args.mps = True + current_env["ACCELERATE_USE_MPS_DEVICE"] = str(args.mps) + if args.mps: + current_env["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + elif args.gpu_ids != "all" and args.gpu_ids is not None: + current_env["CUDA_VISIBLE_DEVICES"] = args.gpu_ids + if args.num_machines > 1: + current_env["MASTER_ADDR"] = args.main_process_ip + current_env["MASTER_PORT"] = str(args.main_process_port) + elif args.num_processes > 1: + current_env["MASTER_ADDR"] = args.main_process_ip if args.main_process_ip is not None else "127.0.0.1" + current_env["MASTER_PORT"] = str(args.main_process_port) if args.main_process_port is not None else "29500" + + try: + mixed_precision = PrecisionType(args.mixed_precision.lower()) + except ValueError: + raise ValueError( + f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}." + ) + + if args.fp16: + warnings.warn( + "`fp16` is deprecated and will be removed in version 0.15.0 of 🤗 Accelerate. Use `mixed_precision fp16` instead.", + FutureWarning, + ) + mixed_precision = "fp16" + + current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision) + + try: + dynamo_backend = DynamoBackend(args.dynamo_backend.upper()) + except ValueError: + raise ValueError(f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DYNAMO_BACKENDS}.") + current_env["ACCELERATE_DYNAMO_BACKEND"] = dynamo_backend.value + + current_env["OMP_NUM_THREADS"] = str(args.num_cpu_threads_per_process) + + process = subprocess.Popen(cmd, env=current_env) + process.wait() + if process.returncode != 0: + raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd) + + +def multi_gpu_launcher(args): + if is_torch_version(">=", "1.9.0"): + import torch.distributed.run as distrib_run + num_processes = getattr(args, "num_processes") + num_machines = getattr(args, "num_machines") + main_process_ip = getattr(args, "main_process_ip") + main_process_port = getattr(args, "main_process_port") + if num_machines > 1: + setattr(args, "nproc_per_node", str(num_processes // num_machines)) + setattr(args, "nnodes", str(num_machines)) + setattr(args, "node_rank", int(args.machine_rank)) + if getattr(args, "same_network", False): + setattr(args, "master_addr", str(main_process_ip)) + setattr(args, "master_port", str(main_process_port)) + else: + setattr(args, "rdzv_endpoint", f"{main_process_ip}:{main_process_port}") + else: + setattr(args, "nproc_per_node", str(num_processes)) + if main_process_port is not None: + setattr(args, "master_port", str(main_process_port)) + + if args.module and args.no_python: + raise ValueError("--module and --no_python cannot be used together") + elif args.module: + setattr(args, "module", True) + elif args.no_python: + setattr(args, "no_python", True) + + current_env = os.environ.copy() + gpu_ids = getattr(args, "gpu_ids", "all") + if gpu_ids != "all" and args.gpu_ids is not None: + current_env["CUDA_VISIBLE_DEVICES"] = gpu_ids + mixed_precision = args.mixed_precision.lower() + try: + mixed_precision = PrecisionType(mixed_precision) + except ValueError: + raise ValueError(f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}.") + + if args.fp16: + warnings.warn( + "`fp16` is deprecated and will be removed in version 0.15.0 of 🤗 Accelerate. Use `mixed_precision fp16` instead.", + FutureWarning, + ) + mixed_precision = "fp16" + + current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision) + + try: + dynamo_backend = DynamoBackend(args.dynamo_backend.upper()) + except ValueError: + raise ValueError(f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DYNAMO_BACKENDS}.") + current_env["ACCELERATE_DYNAMO_BACKEND"] = dynamo_backend.value + + if args.use_fsdp: + current_env["ACCELERATE_USE_FSDP"] = "true" + current_env["FSDP_SHARDING_STRATEGY"] = str(args.fsdp_sharding_strategy) + current_env["FSDP_OFFLOAD_PARAMS"] = str(args.fsdp_offload_params).lower() + current_env["FSDP_MIN_NUM_PARAMS"] = str(args.fsdp_min_num_params) + if args.fsdp_auto_wrap_policy is not None: + current_env["FSDP_AUTO_WRAP_POLICY"] = str(args.fsdp_auto_wrap_policy) + if args.fsdp_transformer_layer_cls_to_wrap is not None: + current_env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = str(args.fsdp_transformer_layer_cls_to_wrap) + if args.fsdp_backward_prefetch_policy is not None: + current_env["FSDP_BACKWARD_PREFETCH"] = str(args.fsdp_backward_prefetch_policy) + if args.fsdp_state_dict_type is not None: + current_env["FSDP_STATE_DICT_TYPE"] = str(args.fsdp_state_dict_type) + + if args.use_megatron_lm: + prefix = "MEGATRON_LM_" + current_env["ACCELERATE_USE_MEGATRON_LM"] = "true" + current_env[prefix + "TP_DEGREE"] = str(args.megatron_lm_tp_degree) + current_env[prefix + "PP_DEGREE"] = str(args.megatron_lm_pp_degree) + current_env[prefix + "GRADIENT_CLIPPING"] = str(args.megatron_lm_gradient_clipping) + if args.megatron_lm_num_micro_batches is not None: + current_env[prefix + "NUM_MICRO_BATCHES"] = str(args.megatron_lm_num_micro_batches) + if args.megatron_lm_sequence_parallelism is not None: + current_env[prefix + "SEQUENCE_PARALLELISM"] = str(args.megatron_lm_sequence_parallelism) + if args.megatron_lm_recompute_activations is not None: + current_env[prefix + "RECOMPUTE_ACTIVATIONS"] = str(args.megatron_lm_recompute_activations) + if args.megatron_lm_use_distributed_optimizer is not None: + current_env[prefix + "USE_DISTRIBUTED_OPTIMIZER"] = str(args.megatron_lm_use_distributed_optimizer) + + current_env["OMP_NUM_THREADS"] = str(args.num_cpu_threads_per_process) + if is_torch_version("<", "1.9.0"): + raise NotImplementedError("Multi-node training requires pytorch>=1.9.0") + + os.environ.update(current_env) + +def deepspeed_launcher(args): + if is_torch_version(">=", "1.9.0"): + import torch.distributed.run as distrib_run + if not is_deepspeed_available(): + raise ImportError("DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source.") + num_processes = getattr(args, "num_processes") + num_machines = getattr(args, "num_machines") + main_process_ip = getattr(args, "main_process_ip") + main_process_port = getattr(args, "main_process_port") + + # make sure launcher is not None + if args.deepspeed_multinode_launcher is None: + # set to default pdsh + setattr(args, "deepspeed_multinode_launcher", DEEPSPEED_MULTINODE_LAUNCHERS[0]) + + if num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]: + cmd = ["deepspeed", "--no_local_rank"] + cmd.extend(["--hostfile", str(args.deepspeed_hostfile), "--launcher", str(args.deepspeed_multinode_launcher)]) + if args.deepspeed_exclusion_filter is not None: + cmd.extend( + [ + "--exclude", + str(args.deepspeed_exclusion_filter), + ] + ) + elif args.deepspeed_inclusion_filter is not None: + cmd.extend( + [ + "--include", + str(args.deepspeed_inclusion_filter), + ] + ) + else: + cmd.extend(["--num_gpus", str(args.num_processes // args.num_machines)]) + + if args.module and args.no_python: + raise ValueError("--module and --no_python cannot be used together") + elif args.module: + cmd.append("--module") + elif args.no_python: + cmd.append("--no_python") + cmd.append(args.training_script) + cmd.extend(args.training_script_args) + elif num_machines > 1 and args.deepspeed_multinode_launcher == DEEPSPEED_MULTINODE_LAUNCHERS[1]: + setattr(args, "nproc_per_node", str(num_processes // num_machines)) + setattr(args, "nnodes", str(num_machines)) + setattr(args, "node_rank", int(args.machine_rank)) + if getattr(args, "same_network", False): + setattr(args, "master_addr", str(main_process_ip)) + setattr(args, "master_port", str(main_process_port)) + else: + setattr(args, "rdzv_endpoint", f"{main_process_ip}:{main_process_port}") + else: + setattr(args, "nproc_per_node", str(num_processes)) + if main_process_port is not None: + setattr(args, "master_port", str(main_process_port)) + + if args.module and args.no_python: + raise ValueError("--module and --no_python cannot be used together") + elif args.module: + setattr(args, "module", True) + elif args.no_python: + setattr(args, "no_python", True) + + current_env = os.environ.copy() + gpu_ids = getattr(args, "gpu_ids", "all") + if gpu_ids != "all" and args.gpu_ids is not None: + current_env["CUDA_VISIBLE_DEVICES"] = gpu_ids + try: + mixed_precision = PrecisionType(args.mixed_precision.lower()) + except ValueError: + raise ValueError( + f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}." + ) + + if args.fp16: + warnings.warn( + '--fp16 flag is deprecated and will be removed in version 0.15.0 of 🤗 Accelerate. Use "--mixed_precision fp16" instead.', + FutureWarning, + ) + mixed_precision = "fp16" + + current_env["PYTHONPATH"] = env_var_path_add("PYTHONPATH", os.path.abspath(".")) + current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision) + current_env["ACCELERATE_USE_DEEPSPEED"] = "true" + current_env["DEEPSPEED_ZERO_STAGE"] = str(args.zero_stage) + current_env["GRADIENT_ACCUMULATION_STEPS"] = str(args.gradient_accumulation_steps) + current_env["GRADIENT_CLIPPING"] = str(args.gradient_clipping).lower() + current_env["DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE"] = str(args.offload_optimizer_device).lower() + current_env["DEEPSPEED_OFFLOAD_PARAM_DEVICE"] = str(args.offload_param_device).lower() + current_env["DEEPSPEED_ZERO3_INIT"] = str(args.zero3_init_flag).lower() + current_env["DEEPSPEED_ZERO3_SAVE_16BIT_MODEL"] = str(args.zero3_save_16bit_model).lower() + if args.deepspeed_config_file is not None: + current_env["DEEPSPEED_CONFIG_FILE"] = str(args.deepspeed_config_file) + + with open(".deepspeed_env", "a") as f: + for key, value in current_env.items(): + if ";" in value or " " in value: + continue + f.write(f"{key}={value}\n") + + os.environ.update(current_env) + + +def tpu_launcher(args): + import torch_xla.distributed.xla_multiprocessing as xmp + + current_env = {} + + if args.no_python: + raise ValueError("--no_python cannot be used with TPU launcher") + + if args.mixed_precision == "bf16": + if args.downcast_bf16: + current_env["XLA_USE_BF16"] = "0" + current_env["XLA_DOWNCAST_BF16"] = "1" + else: + current_env["XLA_USE_BF16"] = "1" + current_env["XLA_DOWNCAST_BF16"] = "0" + + if args.module: + mod_name = args.training_script + else: + # Import training_script as a module + script_path = Path(args.training_script) + sys.path.append(str(script_path.parent.resolve())) + mod_name = script_path.stem + + mod = importlib.import_module(mod_name) + if not hasattr(mod, args.main_training_function): + raise ValueError( + f"Your training script should have a function named {args.main_training_function}, or you should pass a " + "different value to `--main_training_function`." + ) + + # Patch sys.argv + sys.argv = [mod.__file__] + args.training_script_args + + main_function = getattr(mod, args.main_training_function) + with patch_environment(**current_env): + xmp.spawn(PrepareForLaunch(main_function), args=(), nprocs=args.num_processes) + + +def _convert_nargs_to_dict(nargs: List[str]) -> Dict[str, str]: + if len(nargs) < 0: + return {} + # helper function to infer type for argsparser + + def _infer_type(s): + try: + s = float(s) + + if s // 1 == s: + return int(s) + return s + except ValueError: + return s + + parser = argparse.ArgumentParser() + _, unknown = parser.parse_known_args(nargs) + for index, argument in enumerate(unknown): + if argument.startswith(("-", "--")): + action = None + if index + 1 < len(unknown): # checks if next index would be in list + if unknown[index + 1].startswith(("-", "--")): # checks if next element is an key + # raise an error if element is store_true or store_false + raise ValueError( + "SageMaker doesn’t support argparse actions for `store_true` or `store_false`. Please define explicit types" + ) + else: # raise an error if last element is store_true or store_false + raise ValueError( + "SageMaker doesn’t support argparse actions for `store_true` or `store_false`. Please define explicit types" + ) + # adds argument to parser based on action_store true + if action is None: + parser.add_argument(argument, type=_infer_type) + else: + parser.add_argument(argument, action=action) + + return { + key: (literal_eval(value) if value in ("True", "False") else value) + for key, value in parser.parse_args(nargs).__dict__.items() + } + + +def sagemaker_launcher(sagemaker_config: SageMakerConfig, args): + if not is_sagemaker_available(): + raise ImportError( + "Please install sagemaker to be able to launch training on Amazon SageMaker with `pip install accelerate[sagemaker]`" + ) + if args.module or args.no_python: + raise ValueError( + "SageMaker requires a python training script file and cannot be used with --module or --no_python" + ) + + from sagemaker.huggingface import HuggingFace + + # configure environment + print("Configuring Amazon SageMaker environment") + os.environ["AWS_DEFAULT_REGION"] = sagemaker_config.region + + # configure credentials + if sagemaker_config.profile is not None: + os.environ["AWS_PROFILE"] = sagemaker_config.profile + elif args.aws_access_key_id is not None and args.aws_secret_access_key is not None: + os.environ["AWS_ACCESS_KEY_ID"] = args.aws_access_key_id + os.environ["AWS_SECRET_ACCESS_KEY"] = args.aws_secret_access_key + else: + raise EnvironmentError( + "You need to provide an aws_access_key_id and aws_secret_access_key when not using aws_profile" + ) + + # extract needed arguments + source_dir = os.path.dirname(args.training_script) + if not source_dir: # checks if string is empty + source_dir = "." + entry_point = os.path.basename(args.training_script) + if not entry_point.endswith(".py"): + raise ValueError(f'Your training script should be a python script and not "{entry_point}"') + + print("Converting Arguments to Hyperparameters") + hyperparameters = _convert_nargs_to_dict(args.training_script_args) + + try: + mixed_precision = PrecisionType(args.mixed_precision.lower()) + except ValueError: + raise ValueError( + f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}." + ) + + if args.fp16: + warnings.warn('--fp16 flag is deprecated. Use "--mixed_precision fp16" instead.', FutureWarning) + mixed_precision = "fp16" + + try: + dynamo_backend = DynamoBackend(args.dynamo_backend.upper()) + except ValueError: + raise ValueError(f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DYNAMO_BACKENDS}.") + + # Environment variables to be set for use during training job + environment = { + "ACCELERATE_USE_SAGEMAKER": "true", + "ACCELERATE_MIXED_PRECISION": str(mixed_precision), + "ACCELERATE_DYNAMO_BACKEND": dynamo_backend.value, + "ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE": sagemaker_config.distributed_type.value, + } + # configure distribution set up + distribution = None + if sagemaker_config.distributed_type == SageMakerDistributedType.DATA_PARALLEL: + distribution = {"smdistributed": {"dataparallel": {"enabled": True}}} + + # configure sagemaker inputs + sagemaker_inputs = None + if sagemaker_config.sagemaker_inputs_file is not None: + print(f"Loading SageMaker Inputs from {sagemaker_config.sagemaker_inputs_file} file") + sagemaker_inputs = {} + with open(sagemaker_config.sagemaker_inputs_file) as file: + for i, line in enumerate(file): + if i == 0: + continue + l = line.split("\t") + sagemaker_inputs[l[0]] = l[1].strip() + print(f"Loaded SageMaker Inputs: {sagemaker_inputs}") + + # configure sagemaker metrics + sagemaker_metrics = None + if sagemaker_config.sagemaker_metrics_file is not None: + print(f"Loading SageMaker Metrics from {sagemaker_config.sagemaker_metrics_file} file") + sagemaker_metrics = [] + with open(sagemaker_config.sagemaker_metrics_file) as file: + for i, line in enumerate(file): + if i == 0: + continue + l = line.split("\t") + metric_dict = { + "Name": l[0], + "Regex": l[1].strip(), + } + sagemaker_metrics.append(metric_dict) + print(f"Loaded SageMaker Metrics: {sagemaker_metrics}") + + # configure session + print("Creating Estimator") + huggingface_estimator = HuggingFace( + image_uri=sagemaker_config.image_uri, + entry_point=entry_point, + source_dir=source_dir, + role=sagemaker_config.iam_role_name, + transformers_version=sagemaker_config.transformers_version, + pytorch_version=sagemaker_config.pytorch_version, + py_version=sagemaker_config.py_version, + base_job_name=sagemaker_config.base_job_name, + instance_count=sagemaker_config.num_machines, + instance_type=sagemaker_config.ec2_instance_type, + debugger_hook_config=False, + distribution=distribution, + hyperparameters=hyperparameters, + environment=environment, + metric_definitions=sagemaker_metrics, + ) + + huggingface_estimator.fit(inputs=sagemaker_inputs) + print(f"You can find your model data at: {huggingface_estimator.model_data}") + + +def launch_command(args): + # Sanity checks + if sum([args.multi_gpu, args.tpu, args.use_deepspeed, args.use_fsdp]) > 1: + raise ValueError("You can only pick one between `--multi_gpu`, `--use_deepspeed`, `--tpu`, `--use_fsdp`.") + + defaults = None + warned = [] + # Get the default from the config file. + if args.config_file is not None or os.path.isfile(default_config_file) and not args.cpu: + defaults = load_config_from_file(args.config_file) + if ( + not args.multi_gpu + and not args.tpu + and not args.mps + and not args.use_deepspeed + and not args.use_fsdp + and not args.use_megatron_lm + ): + args.use_deepspeed = defaults.distributed_type == DistributedType.DEEPSPEED + args.multi_gpu = defaults.distributed_type == DistributedType.MULTI_GPU + args.tpu = defaults.distributed_type == DistributedType.TPU + args.use_fsdp = defaults.distributed_type == DistributedType.FSDP + args.mps = defaults.distributed_type == DistributedType.MPS + args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM + if not args.mps: + if args.gpu_ids is None: + if defaults.gpu_ids is not None: + args.gpu_ids = defaults.gpu_ids + else: + args.gpu_ids = "all" + if len(args.gpu_ids.split(",")) < 2 and args.multi_gpu and (args.gpu_ids != "all"): + args.multi_gpu = False + if defaults.compute_environment == ComputeEnvironment.LOCAL_MACHINE: + # Update args with the defaults + for name, attr in defaults.__dict__.items(): + if isinstance(attr, dict): + for k in defaults.deepspeed_config: + if getattr(args, k) is None: + setattr(args, k, defaults.deepspeed_config[k]) + for k in defaults.fsdp_config: + arg_to_set = k + if "fsdp" not in arg_to_set: + arg_to_set = "fsdp_" + arg_to_set + setattr(args, arg_to_set, defaults.fsdp_config[k]) + for k in defaults.megatron_lm_config: + setattr(args, k, defaults.megatron_lm_config[k]) + continue + + # Those args are handled separately + if ( + name not in ["compute_environment", "fp16", "mixed_precision", "distributed_type"] + and getattr(args, name, None) is None + ): + setattr(args, name, attr) + if not args.mixed_precision: + if args.fp16: + args.mixed_precision = "fp16" + else: + args.mixed_precision = defaults.mixed_precision + if args.dynamo_backend is None: + warned.append("\t`--dynamo_backend` was set to a value of `'no'`") + args.dynamo_backend = "no" + else: + if args.num_processes is None: + args.num_processes = torch.cuda.device_count() if args.multi_gpu else 1 + warned.append(f"\t`--num_processes` was set to a value of `{args.num_processes}`") + if args.num_machines is None: + warned.append("\t`--num_machines` was set to a value of `1`") + args.num_machines = 1 + if args.mixed_precision is None: + warned.append("\t`--mixed_precision` was set to a value of `'no'`") + args.mixed_precision = "no" + if not hasattr(args, "use_cpu"): + args.use_cpu = args.cpu + if args.dynamo_backend is None: + warned.append("\t`--dynamo_backend` was set to a value of `'no'`") + args.dynamo_backend = "no" + + if args.num_cpu_threads_per_process is None: + args.num_cpu_threads_per_process = 1 + if args.use_cpu and args.num_processes > 1: + local_size = get_int_from_env( + ["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1 + ) + threads_per_process = int(psutil.cpu_count(logical=False) / local_size) + if args.num_cpu_threads_per_process > 1: + args.num_cpu_threads_per_process = threads_per_process + warned.append( + f"\t`--num_cpu_threads_per_process` was set to `{args.num_cpu_threads_per_process}` to improve out-of-box performance when training on CPUs" + ) + + if any(warned): + message = "The following values were not passed to `accelerate launch` and had defaults used instead:\n" + message += "\n".join(warned) + message += ( + "\nTo avoid this warning pass in values for each of the problematic parameters or run `accelerate config`." + ) + logger.warning(message) + + # Use the proper launcher + if args.use_deepspeed and not args.cpu: + deepspeed_launcher(args) + elif args.use_fsdp and not args.cpu: + multi_gpu_launcher(args) + elif args.use_megatron_lm and not args.cpu: + multi_gpu_launcher(args) + elif args.multi_gpu and not args.cpu: + multi_gpu_launcher(args) + elif args.tpu and not args.cpu: + tpu_launcher(args) + elif defaults is not None and defaults.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER: + sagemaker_launcher(defaults, args) + else: + simple_launcher(args) + + +def main(): + parser = launch_command_parser() + args = parser.parse_args() + launch_command(args) + + +if __name__ == "__main__": + main() diff --git a/trlx/sweep.py b/trlx/sweep.py index f3c8ce378..52b3b3353 100644 --- a/trlx/sweep.py +++ b/trlx/sweep.py @@ -10,11 +10,20 @@ from ray.air import ScalingConfig, session from ray.train.torch import TorchTrainer from ray.tune.logger import CSVLoggerCallback +import tempfile from trlx.ray_tune import get_param_space, get_tune_config +from accelerate.commands.config import default_config_file, load_config_from_file +from trlx.ray_train.launch import launch_command, launch_command_parser # from trlx.ray_tune.wandb import create_report, log_trials +from argparse import Namespace +class DefaultNamespace(Namespace): + def __getattr__(self, name: str): + parser = launch_command_parser() + ret = parser.get_default(name) + return ret def tune_function( train_function, param_space: dict, tune_config: dict, resources: dict @@ -22,11 +31,26 @@ def tune_function( default_config = yaml.safe_load(open("configs/ppo_config.yml")) param_space["default_config"] = default_config + config_file_path = default_config_file + with open(config_file_path, "r") as f: + config_data = f.read() + def train_function_wrapper(config): + temp_config_file = tempfile.mkstemp()[1] + with open(temp_config_file, "w") as f: + f.write(config_data) + args = DefaultNamespace() + setattr(args, "config_file", temp_config_file) + launch_command(args) + os.environ["RANK"] = str(session.get_world_rank()) os.environ["WORLD_RANK"] = str(session.get_world_rank()) os.environ["LOCAL_RANK"] = str(session.get_local_rank()) os.environ["WORLD_SIZE"] = str(session.get_world_size()) os.environ["LOCAL_WORLD_SIZE"] = str(session.get_local_world_size()) + os.environ["CROSS_RANK"] = str(session.get_world_rank()) + os.environ["CROSS_SIZE"] = str(session.get_world_size()) + os.environ["LOCAL_SIZE"] = str(session.get_local_world_size()) + print(os.environ) return train_function(config) @@ -36,9 +60,9 @@ def train_function_wrapper(config): train_function_wrapper, scaling_config=ScalingConfig( trainer_resources={"CPU": 0}, - num_workers=2, + num_workers=resources["gpu"], use_gpu=bool(resources["gpu"]), - resources_per_worker={"CPU": resources["cpu"], "GPU": resources["gpu"]}, + resources_per_worker={"CPU": resources["cpu"], "GPU": int(bool(resources["gpu"]))}, ), ), param_space=param_space_train, From ca6d7d11366fde8a3894c34c2fbe958f814c5f3b Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 13 Jan 2023 16:08:10 -0800 Subject: [PATCH 03/57] Cleanup Signed-off-by: Antoni Baum --- trlx/ray_train/accelerate_trainer.py | 167 +++++ trlx/ray_train/launch.py | 986 ++------------------------- trlx/sweep.py | 47 +- 3 files changed, 243 insertions(+), 957 deletions(-) create mode 100644 trlx/ray_train/accelerate_trainer.py diff --git a/trlx/ray_train/accelerate_trainer.py b/trlx/ray_train/accelerate_trainer.py new file mode 100644 index 000000000..4f0b37510 --- /dev/null +++ b/trlx/ray_train/accelerate_trainer.py @@ -0,0 +1,167 @@ +import os +import tempfile +from argparse import Namespace +from functools import wraps +from pathlib import Path +from typing import TYPE_CHECKING, Callable, Dict, Optional, Union, Type, Tuple + +from ray.air import session +from ray.air.checkpoint import Checkpoint +from ray.air.config import DatasetConfig, RunConfig, ScalingConfig +from ray.train.torch.config import TorchConfig +from ray.train.trainer import GenDataset +from ray.util import PublicAPI + +if TYPE_CHECKING: + from ray.data.preprocessor import Preprocessor + + +from ray.train.torch import TorchTrainer +from .launch import launch_command, launch_command_parser +from accelerate.commands.config import default_config_file, load_config_from_file + + +class AccelerateDefaultNamespace(Namespace): + @property + def parser(self): + return launch_command_parser() + + def __getattr__(self, name: str): + return self.parser.get_default(name) + + +class AccelerateConfigWrapper: + """ + Lets Trainables know to treat this as already loaded file content instead of path. + """ + + def __init__( + self, config_raw: str, deepspeed_config_raw: Optional[str] = None + ) -> None: + self.config_raw = config_raw + self.deepspeed_config_raw = deepspeed_config_raw + + def __bool__(self) -> bool: + return bool(self.config_raw) + + +class AccelerateTrainer(TorchTrainer): + def __init__( + self, + train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]], + *, + accelerate_config_path: Union[str, Path, os.PathLike], + train_loop_config: Optional[Dict] = None, + torch_config: Optional[TorchConfig] = None, + scaling_config: Optional[ScalingConfig] = None, + dataset_config: Optional[Dict[str, DatasetConfig]] = None, + run_config: Optional[RunConfig] = None, + datasets: Optional[Dict[str, GenDataset]] = None, + preprocessor: Optional["Preprocessor"] = None, + resume_from_checkpoint: Optional[Checkpoint] = None + ): + self.accelerate_config_path = accelerate_config_path or default_config_file + if isinstance(self.accelerate_config_path, AccelerateConfigWrapper): + self._accelerate_config_raw = self.accelerate_config_path.config_raw + self._deepspeed_config_file_raw = ( + self.accelerate_config_path.deepspeed_config_raw + ) + else: + ( + self._accelerate_config_raw, + self._deepspeed_config_file_raw, + ) = self._load_accelerate_config() + super().__init__( + train_loop_per_worker, + train_loop_config=train_loop_config, + torch_config=torch_config, + scaling_config=scaling_config, + dataset_config=dataset_config, + run_config=run_config, + datasets=datasets, + preprocessor=preprocessor, + resume_from_checkpoint=resume_from_checkpoint, + ) + + def training_loop(self) -> None: + old_train_loop_per_worker = self._train_loop_per_worker + self._train_loop_per_worker = self._wrap_train_loop_per_worker( + self._train_loop_per_worker, + self._accelerate_config_raw, + self._deepspeed_config_file_raw, + ) + try: + ret = super().training_loop() + finally: + self._train_loop_per_worker = old_train_loop_per_worker + return ret + + def as_trainable(self) -> Type["Trainable"]: + # We want to load the config when the Trainer is first instantiated, + # and share the contents with the Trainables (which may be on different) + # nodes + old_accelerate_config_path = self._param_dict["accelerate_config_path"] + self._param_dict["accelerate_config_path"] = AccelerateConfigWrapper( + self._accelerate_config_raw, self._deepspeed_config_file_raw + ) + try: + ret = super().as_trainable() + finally: + self._param_dict["accelerate_config_path"] = old_accelerate_config_path + return ret + + def _load_accelerate_config(self) -> Tuple[str, Optional[str]]: + # We only load config to dict to obtain the deepspeed_config_file + config = load_config_from_file(self.accelerate_config_path) + deepspeed_config_file = getattr(config, "deepspeed_config_file", None) + deepspeed_config_file_raw = None + + if deepspeed_config_file: + with open(deepspeed_config_file, "r") as f: + deepspeed_config_file_raw = f.read() + + # Otherwise, we want to pass raw contents to Trainables for maximum + # compatibility. + with open(self.accelerate_config_path, "r") as f: + return f.read(), deepspeed_config_file_raw + + @classmethod + def _wrap_train_loop_per_worker( + cls, + train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]], + accelerate_config_raw: str, + deepspeed_config_file_raw: str, + ): + @wraps(train_loop_per_worker) + def wrapped_train_loop_per_worker(*args, **kwargs): + with tempfile.TemporaryDirectory() as tempdir: + temp_config_file = os.path.join(tempdir, "default_config.yaml") + with open(temp_config_file, "w") as f: + f.write(accelerate_config_raw) + namespace = AccelerateDefaultNamespace() + namespace.config_file = temp_config_file + namespace.num_cpu_threads_per_process = ( + session.get_trial_resources().bundles[-1]["CPU"] + ) + + if deepspeed_config_file_raw: + deepspeed_config_file = os.path.join( + tempdir, "deepspeed_config.json" + ) + with open(deepspeed_config_file, "w") as f: + f.write(deepspeed_config_file_raw) + namespace.deepspeed_config_file = deepspeed_config_file + + launch_command(namespace) + os.environ["RANK"] = str(session.get_world_rank()) + os.environ["WORLD_RANK"] = str(session.get_world_rank()) + os.environ["CROSS_RANK"] = str(session.get_world_rank()) + os.environ["CROSS_SIZE"] = str(session.get_world_size()) + os.environ["WORLD_SIZE"] = str(session.get_world_size()) + os.environ["LOCAL_RANK"] = str(session.get_local_rank()) + os.environ["LOCAL_WORLD_SIZE"] = str(session.get_local_world_size()) + os.environ["LOCAL_SIZE"] = str(session.get_local_world_size()) + + return train_loop_per_worker(*args, **kwargs) + + return wrapped_train_loop_per_worker diff --git a/trlx/ray_train/launch.py b/trlx/ray_train/launch.py index 9825028f1..113ec860f 100644 --- a/trlx/ray_train/launch.py +++ b/trlx/ray_train/launch.py @@ -14,510 +14,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse -import importlib import logging import os -import subprocess -import sys import warnings -from ast import literal_eval -from pathlib import Path -from typing import Dict, List +from unittest.mock import patch -import torch - -import psutil -from accelerate.commands.config import default_config_file, load_config_from_file -from accelerate.commands.config.config_args import SageMakerConfig from accelerate.commands.config.config_utils import DYNAMO_BACKENDS -from accelerate.state import get_int_from_env from accelerate.utils import ( - ComputeEnvironment, - DistributedType, DynamoBackend, PrecisionType, - PrepareForLaunch, - _filter_args, is_deepspeed_available, - is_rich_available, - is_sagemaker_available, is_torch_version, - patch_environment, ) -from accelerate.utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS -from accelerate.utils.dataclasses import SageMakerDistributedType from accelerate.utils.launch import env_var_path_add - - -if is_rich_available(): - from rich import get_console - from rich.logging import RichHandler - - FORMAT = "%(message)s" - logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()]) - +from accelerate.commands.launch import ( + launch_command_parser, + launch_command as original_launch_command, +) logger = logging.getLogger(__name__) -options_to_group = { - "--multi-gpu": "Distributed GPUs", - "--tpu": "TPU", - "--mps": "MPS", - "--use_mps_device": "MPS", - "--use_deepspeed": "DeepSpeed Arguments", - "--use_fsdp": "FSDP Arguments", - "--use_megatron_lm": "Megatron-LM Arguments", -} - - -def clean_option(option): - "Finds all cases of - after the first two characters and changes them to _" - if option.startswith("--"): - return option[:3] + option[3:].replace("-", "_") - - -class _CustomHelpAction(argparse._HelpAction): - """ - This is a custom help action that will hide all arguments that are not used in the command line when the help is - called. This is useful for the case where the user is using a specific platform and only wants to see the arguments - for that platform. - """ - - def __call__(self, parser, namespace, values, option_string=None): - if "accelerate" in sys.argv[0] and "launch" in sys.argv[1:]: - args = sys.argv[2:] - else: - args = sys.argv[1:] - opts = parser._actions - titles = [ - "Hardware Selection Arguments", - "Resource Selection Arguments", - "Training Paradigm Arguments", - "positional arguments", - "optional arguments", - ] - if len(args) > 1: - used_platforms = [arg for arg in args if arg in options_to_group.keys()] - args = list(map(clean_option, args)) - used_titles = [options_to_group[o] for o in used_platforms] - for i, arg in enumerate(opts): - # If the argument's container is outside of the used titles, hide it - if arg.container.title not in titles + used_titles: - setattr(opts[i], "help", argparse.SUPPRESS) - # If the argument is hardware selection, but not being passed, hide it - elif arg.container.title == "Hardware Selection Arguments": - if set(arg.option_strings).isdisjoint(set(args)): - setattr(opts[i], "help", argparse.SUPPRESS) - else: - setattr(opts[i], "help", arg.help + " (currently selected)") - # If the argument is a training paradigm, but not being passed, hide it - elif arg.container.title == "Training Paradigm Arguments": - if set(arg.option_strings).isdisjoint(set(used_platforms)): - setattr(opts[i], "help", argparse.SUPPRESS) - else: - setattr(opts[i], "help", arg.help + " (currently selected)") - for i, group in enumerate(list(parser._action_groups)): - # If all arguments in the group are hidden, hide the group - if all([arg.help == argparse.SUPPRESS for arg in group._group_actions]): - parser._action_groups.remove(group) - - super().__call__(parser, namespace, values, option_string) - - -def launch_command_parser(subparsers=None): - if subparsers is not None: - parser = subparsers.add_parser("launch", add_help=False) - else: - parser = argparse.ArgumentParser("Accelerate launch command", add_help=False) - - parser.register("action", "help", _CustomHelpAction) - parser.add_argument("-h", "--help", action="help", help="Show this help message and exit.") - - parser.add_argument( - "--config_file", default=None, help="The config file to use for the default values in the launching script." - ) - # Hardware selection arguments - hardware_args = parser.add_argument_group( - "Hardware Selection Arguments", "Arguments for selecting the hardware to be used." - ) - hardware_args.add_argument( - "--cpu", default=False, action="store_true", help="Whether or not to force the training on the CPU." - ) - hardware_args.add_argument( - "--mps", - default=False, - action="store_true", - help="Whether or not this should use MPS-enabled GPU device on MacOS machines.", - ) - hardware_args.add_argument( - "--multi_gpu", - default=False, - action="store_true", - help="Whether or not this should launch a distributed GPU training.", - ) - hardware_args.add_argument( - "--tpu", default=False, action="store_true", help="Whether or not this should launch a TPU training." - ) - hardware_args.add_argument( - "--use_mps_device", - default=False, - action="store_true", - help="This argument is deprecated, use `--mps` instead.", - ) - - # Resource selection arguments - resource_args = parser.add_argument_group( - "Resource Selection Arguments", "Arguments for fine-tuning how available hardware should be used." - ) - resource_args.add_argument( - "--dynamo_backend", - type=str, - choices=["no"] + [b.lower() for b in DYNAMO_BACKENDS], - help="Choose a backend to optimize your training with dynamo, see more at " - "https://github.com/pytorch/torchdynamo.", - ) - resource_args.add_argument( - "--mixed_precision", - type=str, - choices=["no", "fp16", "bf16"], - help="Whether or not to use mixed precision training. " - "Choose between FP16 and BF16 (bfloat16) training. " - "BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.", - ) - resource_args.add_argument( - "--fp16", - default=False, - action="store_true", - help="This argument is deprecated, use `--mixed_precision fp16` instead.", - ) - resource_args.add_argument( - "--num_processes", type=int, default=None, help="The total number of processes to be launched in parallel." - ) - resource_args.add_argument( - "--num_machines", type=int, default=None, help="The total number of machines used in this training." - ) - resource_args.add_argument( - "--num_cpu_threads_per_process", - type=int, - default=None, - help="The number of CPU threads per process. Can be tuned for optimal performance.", - ) - - # Training Paradigm arguments - paradigm_args = parser.add_argument_group( - "Training Paradigm Arguments", "Arguments for selecting which training paradigm to be used." - ) - paradigm_args.add_argument( - "--use_deepspeed", - default=False, - action="store_true", - help="Whether to use deepspeed.", - ) - paradigm_args.add_argument( - "--use_fsdp", - default=False, - action="store_true", - help="Whether to use fsdp.", - ) - paradigm_args.add_argument( - "--use_megatron_lm", - default=False, - action="store_true", - help="Whether to use Megatron-LM.", - ) - - # distributed GPU training arguments - distributed_args = parser.add_argument_group("Distributed GPUs", "Arguments related to distributed GPU training.") - distributed_args.add_argument( - "--gpu_ids", - default=None, - help="What GPUs (by id) should be used for training on this machine as a comma-seperated list", - ) - distributed_args.add_argument( - "--same_network", - default=False, - action="store_true", - help="Whether all machines used for multinode training exist on the same local network.", - ) - distributed_args.add_argument( - "--machine_rank", type=int, default=None, help="The rank of the machine on which this script is launched." - ) - distributed_args.add_argument( - "--main_process_ip", type=str, default=None, help="The IP address of the machine of rank 0." - ) - distributed_args.add_argument( - "--main_process_port", - type=int, - default=None, - help="The port to use to communicate with the machine of rank 0.", - ) - # Rendezvous related arguments - distributed_args.add_argument( - "--rdzv_conf", - type=str, - default="", - help="Additional rendezvous configuration (=,=,...).", - ) - distributed_args.add_argument( - "--max_restarts", - type=int, - default=0, - help="Maximum number of worker group restarts before failing.", - ) - distributed_args.add_argument( - "--monitor_interval", - type=float, - default=5, - help="Interval, in seconds, to monitor the state of workers.", - ) - parser.add_argument( - "-m", - "--module", - action="store_true", - help="Change each process to interpret the launch script as a Python module, executing with the same behavior as 'python -m'.", - ) - parser.add_argument( - "--no_python", - action="store_true", - help="Skip prepending the training script with 'python' - just execute it directly. Useful when the script is not a Python script.", - ) - - # tpu arguments - tpu_args = parser.add_argument_group("TPU", "Arguments related to TPU.") - tpu_args.add_argument( - "--main_training_function", - type=str, - default=None, - help="The name of the main function to be executed in your script (only for TPU training).", - ) - tpu_args.add_argument( - "--downcast_bf16", - action="store_true", - help="Whether when using bf16 precision on TPUs if both float and double tensors are cast to bfloat16 or if double tensors remain as float32.", - ) - - # DeepSpeed arguments - deepspeed_args = parser.add_argument_group("DeepSpeed Arguments", "Arguments related to DeepSpeed.") - deepspeed_args.add_argument( - "--deepspeed_config_file", - default=None, - type=str, - help="DeepSpeed config file.", - ) - deepspeed_args.add_argument( - "--zero_stage", - default=None, - type=int, - help="DeepSpeed's ZeRO optimization stage (useful only when `use_deepspeed` flag is passed).", - ) - deepspeed_args.add_argument( - "--offload_optimizer_device", - default=None, - type=str, - help="Decides where (none|cpu|nvme) to offload optimizer states (useful only when `use_deepspeed` flag is passed).", - ) - deepspeed_args.add_argument( - "--offload_param_device", - default=None, - type=str, - help="Decides where (none|cpu|nvme) to offload parameters (useful only when `use_deepspeed` flag is passed).", - ) - deepspeed_args.add_argument( - "--gradient_accumulation_steps", - default=None, - type=int, - help="No of gradient_accumulation_steps used in your training script (useful only when `use_deepspeed` flag is passed).", - ) - deepspeed_args.add_argument( - "--gradient_clipping", - default=None, - type=float, - help="gradient clipping value used in your training script (useful only when `use_deepspeed` flag is passed).", - ) - deepspeed_args.add_argument( - "--zero3_init_flag", - default=None, - type=str, - help="Decides Whether (true|false) to enable `deepspeed.zero.Init` for constructing massive models. " - "Only applicable with DeepSpeed ZeRO Stage-3.", - ) - deepspeed_args.add_argument( - "--zero3_save_16bit_model", - default=None, - type=str, - help="Decides Whether (true|false) to save 16-bit model weights when using ZeRO Stage-3. " - "Only applicable with DeepSpeed ZeRO Stage-3.", - ) - deepspeed_args.add_argument( - "--deepspeed_hostfile", - default=None, - type=str, - help="DeepSpeed hostfile for configuring multi-node compute resources.", - ) - deepspeed_args.add_argument( - "--deepspeed_exclusion_filter", - default=None, - type=str, - help="DeepSpeed exclusion filter string when using mutli-node setup.", - ) - deepspeed_args.add_argument( - "--deepspeed_inclusion_filter", - default=None, - type=str, - help="DeepSpeed inclusion filter string when using mutli-node setup.", - ) - deepspeed_args.add_argument( - "--deepspeed_multinode_launcher", - default=None, - type=str, - help="DeepSpeed multi-node launcher to use.", - ) - - # fsdp arguments - fsdp_args = parser.add_argument_group("FSDP Arguments", "Arguments related to Fully Shared Data Parallelism.") - fsdp_args.add_argument( - "--fsdp_offload_params", - default="false", - type=str, - help="Decides Whether (true|false) to offload parameters and gradients to CPU. (useful only when `use_fsdp` flag is passed).", - ) - fsdp_args.add_argument( - "--fsdp_min_num_params", - type=int, - default=1e8, - help="FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `use_fsdp` flag is passed).", - ) - fsdp_args.add_argument( - "--fsdp_sharding_strategy", - type=int, - default=1, - help="FSDP's Sharding Strategy. (useful only when `use_fsdp` flag is passed).", - ) - fsdp_args.add_argument( - "--fsdp_auto_wrap_policy", - type=str, - default=None, - help="FSDP's auto wrap policy. (useful only when `use_fsdp` flag is passed).", - ) - fsdp_args.add_argument( - "--fsdp_transformer_layer_cls_to_wrap", - default=None, - type=str, - help="Transformer layer class name (case-sensitive) to wrap ,e.g, `BertLayer`, `GPTJBlock`, `T5Block` .... " - "(useful only when `use_fsdp` flag is passed).", - ) - fsdp_args.add_argument( - "--fsdp_backward_prefetch_policy", - default=None, - type=str, - help="FSDP's backward prefetch policy. (useful only when `use_fsdp` flag is passed).", - ) - fsdp_args.add_argument( - "--fsdp_state_dict_type", - default=None, - type=str, - help="FSDP's state dict type. (useful only when `use_fsdp` flag is passed).", - ) - - # megatron_lm args - megatron_lm_args = parser.add_argument_group("Megatron-LM Arguments", "Arguments related to Megatron-LM.") - megatron_lm_args.add_argument( - "--megatron_lm_tp_degree", - type=int, - default=1, - help="Megatron-LM's Tensor Parallelism (TP) degree. (useful only when `use_megatron_lm` flag is passed).", - ) - megatron_lm_args.add_argument( - "--megatron_lm_pp_degree", - type=int, - default=1, - help="Megatron-LM's Pipeline Parallelism (PP) degree. (useful only when `use_megatron_lm` flag is passed).", - ) - megatron_lm_args.add_argument( - "--megatron_lm_num_micro_batches", - type=int, - default=None, - help="Megatron-LM's number of micro batches when PP degree > 1. (useful only when `use_megatron_lm` flag is passed).", - ) - megatron_lm_args.add_argument( - "--megatron_lm_sequence_parallelism", - default=None, - type=str, - help="Decides Whether (true|false) to enable Sequence Parallelism when TP degree > 1. " - "(useful only when `use_megatron_lm` flag is passed).", - ) - megatron_lm_args.add_argument( - "--megatron_lm_recompute_activations", - default=None, - type=str, - help="Decides Whether (true|false) to enable Selective Activation Recomputation. " - "(useful only when `use_megatron_lm` flag is passed).", - ) - megatron_lm_args.add_argument( - "--megatron_lm_use_distributed_optimizer", - default=None, - type=str, - help="Decides Whether (true|false) to use distributed optimizer " - "which shards optimizer state and gradients across Data Pralellel (DP) ranks. " - "(useful only when `use_megatron_lm` flag is passed).", - ) - megatron_lm_args.add_argument( - "--megatron_lm_gradient_clipping", - default=1.0, - type=float, - help="Megatron-LM's gradient clipping value based on global L2 Norm (0 to disable). " - "(useful only when `use_megatron_lm` flag is passed).", - ) - - # AWS arguments - aws_args = parser.add_argument_group("AWS Arguments", "Arguments related to AWS.") - aws_args.add_argument( - "--aws_access_key_id", - type=str, - default=None, - help="The AWS_ACCESS_KEY_ID used to launch the Amazon SageMaker training job", - ) - aws_args.add_argument( - "--aws_secret_access_key", - type=str, - default=None, - help="The AWS_SECRET_ACCESS_KEY used to launch the Amazon SageMaker training job.", - ) - parser.add_argument( - "--debug", - action="store_true", - help="Whether to print out the torch.distributed stack trace when something fails.", - ) - parser.add_argument( - "training_script", - type=str, - help=( - "The full path to the script to be launched in parallel, followed by all the arguments for the training " - "script." - ), - ) - - # Other arguments of the training scripts - parser.add_argument("training_script_args", nargs=argparse.REMAINDER, help="Arguments of the training script.") - - if subparsers is not None: - parser.set_defaults(func=launch_command) - return parser - def simple_launcher(args): - cmd = [] - if args.no_python and args.module: - raise ValueError("--module and --no_python cannot be used together") - if not args.no_python: - cmd.append(sys.executable) - if args.module: - cmd.append("-m") - cmd.append(args.training_script) - cmd.extend(args.training_script_args) - - current_env = os.environ.copy() + current_env = {} current_env["ACCELERATE_USE_CPU"] = str(args.cpu or args.use_cpu) if args.use_mps_device: warnings.warn( @@ -528,14 +47,6 @@ def simple_launcher(args): current_env["ACCELERATE_USE_MPS_DEVICE"] = str(args.mps) if args.mps: current_env["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" - elif args.gpu_ids != "all" and args.gpu_ids is not None: - current_env["CUDA_VISIBLE_DEVICES"] = args.gpu_ids - if args.num_machines > 1: - current_env["MASTER_ADDR"] = args.main_process_ip - current_env["MASTER_PORT"] = str(args.main_process_port) - elif args.num_processes > 1: - current_env["MASTER_ADDR"] = args.main_process_ip if args.main_process_ip is not None else "127.0.0.1" - current_env["MASTER_PORT"] = str(args.main_process_port) if args.main_process_port is not None else "29500" try: mixed_precision = PrecisionType(args.mixed_precision.lower()) @@ -556,54 +67,25 @@ def simple_launcher(args): try: dynamo_backend = DynamoBackend(args.dynamo_backend.upper()) except ValueError: - raise ValueError(f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DYNAMO_BACKENDS}.") + raise ValueError( + f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DYNAMO_BACKENDS}." + ) current_env["ACCELERATE_DYNAMO_BACKEND"] = dynamo_backend.value current_env["OMP_NUM_THREADS"] = str(args.num_cpu_threads_per_process) - process = subprocess.Popen(cmd, env=current_env) - process.wait() - if process.returncode != 0: - raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd) + os.environ.update(current_env) def multi_gpu_launcher(args): - if is_torch_version(">=", "1.9.0"): - import torch.distributed.run as distrib_run - num_processes = getattr(args, "num_processes") - num_machines = getattr(args, "num_machines") - main_process_ip = getattr(args, "main_process_ip") - main_process_port = getattr(args, "main_process_port") - if num_machines > 1: - setattr(args, "nproc_per_node", str(num_processes // num_machines)) - setattr(args, "nnodes", str(num_machines)) - setattr(args, "node_rank", int(args.machine_rank)) - if getattr(args, "same_network", False): - setattr(args, "master_addr", str(main_process_ip)) - setattr(args, "master_port", str(main_process_port)) - else: - setattr(args, "rdzv_endpoint", f"{main_process_ip}:{main_process_port}") - else: - setattr(args, "nproc_per_node", str(num_processes)) - if main_process_port is not None: - setattr(args, "master_port", str(main_process_port)) - - if args.module and args.no_python: - raise ValueError("--module and --no_python cannot be used together") - elif args.module: - setattr(args, "module", True) - elif args.no_python: - setattr(args, "no_python", True) - - current_env = os.environ.copy() - gpu_ids = getattr(args, "gpu_ids", "all") - if gpu_ids != "all" and args.gpu_ids is not None: - current_env["CUDA_VISIBLE_DEVICES"] = gpu_ids + current_env = {} mixed_precision = args.mixed_precision.lower() try: mixed_precision = PrecisionType(mixed_precision) except ValueError: - raise ValueError(f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}.") + raise ValueError( + f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}." + ) if args.fp16: warnings.warn( @@ -617,7 +99,9 @@ def multi_gpu_launcher(args): try: dynamo_backend = DynamoBackend(args.dynamo_backend.upper()) except ValueError: - raise ValueError(f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DYNAMO_BACKENDS}.") + raise ValueError( + f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DYNAMO_BACKENDS}." + ) current_env["ACCELERATE_DYNAMO_BACKEND"] = dynamo_backend.value if args.use_fsdp: @@ -628,9 +112,13 @@ def multi_gpu_launcher(args): if args.fsdp_auto_wrap_policy is not None: current_env["FSDP_AUTO_WRAP_POLICY"] = str(args.fsdp_auto_wrap_policy) if args.fsdp_transformer_layer_cls_to_wrap is not None: - current_env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = str(args.fsdp_transformer_layer_cls_to_wrap) + current_env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = str( + args.fsdp_transformer_layer_cls_to_wrap + ) if args.fsdp_backward_prefetch_policy is not None: - current_env["FSDP_BACKWARD_PREFETCH"] = str(args.fsdp_backward_prefetch_policy) + current_env["FSDP_BACKWARD_PREFETCH"] = str( + args.fsdp_backward_prefetch_policy + ) if args.fsdp_state_dict_type is not None: current_env["FSDP_STATE_DICT_TYPE"] = str(args.fsdp_state_dict_type) @@ -639,15 +127,25 @@ def multi_gpu_launcher(args): current_env["ACCELERATE_USE_MEGATRON_LM"] = "true" current_env[prefix + "TP_DEGREE"] = str(args.megatron_lm_tp_degree) current_env[prefix + "PP_DEGREE"] = str(args.megatron_lm_pp_degree) - current_env[prefix + "GRADIENT_CLIPPING"] = str(args.megatron_lm_gradient_clipping) + current_env[prefix + "GRADIENT_CLIPPING"] = str( + args.megatron_lm_gradient_clipping + ) if args.megatron_lm_num_micro_batches is not None: - current_env[prefix + "NUM_MICRO_BATCHES"] = str(args.megatron_lm_num_micro_batches) + current_env[prefix + "NUM_MICRO_BATCHES"] = str( + args.megatron_lm_num_micro_batches + ) if args.megatron_lm_sequence_parallelism is not None: - current_env[prefix + "SEQUENCE_PARALLELISM"] = str(args.megatron_lm_sequence_parallelism) + current_env[prefix + "SEQUENCE_PARALLELISM"] = str( + args.megatron_lm_sequence_parallelism + ) if args.megatron_lm_recompute_activations is not None: - current_env[prefix + "RECOMPUTE_ACTIVATIONS"] = str(args.megatron_lm_recompute_activations) + current_env[prefix + "RECOMPUTE_ACTIVATIONS"] = str( + args.megatron_lm_recompute_activations + ) if args.megatron_lm_use_distributed_optimizer is not None: - current_env[prefix + "USE_DISTRIBUTED_OPTIMIZER"] = str(args.megatron_lm_use_distributed_optimizer) + current_env[prefix + "USE_DISTRIBUTED_OPTIMIZER"] = str( + args.megatron_lm_use_distributed_optimizer + ) current_env["OMP_NUM_THREADS"] = str(args.num_cpu_threads_per_process) if is_torch_version("<", "1.9.0"): @@ -655,74 +153,14 @@ def multi_gpu_launcher(args): os.environ.update(current_env) + def deepspeed_launcher(args): - if is_torch_version(">=", "1.9.0"): - import torch.distributed.run as distrib_run if not is_deepspeed_available(): - raise ImportError("DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source.") - num_processes = getattr(args, "num_processes") - num_machines = getattr(args, "num_machines") - main_process_ip = getattr(args, "main_process_ip") - main_process_port = getattr(args, "main_process_port") - - # make sure launcher is not None - if args.deepspeed_multinode_launcher is None: - # set to default pdsh - setattr(args, "deepspeed_multinode_launcher", DEEPSPEED_MULTINODE_LAUNCHERS[0]) - - if num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]: - cmd = ["deepspeed", "--no_local_rank"] - cmd.extend(["--hostfile", str(args.deepspeed_hostfile), "--launcher", str(args.deepspeed_multinode_launcher)]) - if args.deepspeed_exclusion_filter is not None: - cmd.extend( - [ - "--exclude", - str(args.deepspeed_exclusion_filter), - ] - ) - elif args.deepspeed_inclusion_filter is not None: - cmd.extend( - [ - "--include", - str(args.deepspeed_inclusion_filter), - ] - ) - else: - cmd.extend(["--num_gpus", str(args.num_processes // args.num_machines)]) - - if args.module and args.no_python: - raise ValueError("--module and --no_python cannot be used together") - elif args.module: - cmd.append("--module") - elif args.no_python: - cmd.append("--no_python") - cmd.append(args.training_script) - cmd.extend(args.training_script_args) - elif num_machines > 1 and args.deepspeed_multinode_launcher == DEEPSPEED_MULTINODE_LAUNCHERS[1]: - setattr(args, "nproc_per_node", str(num_processes // num_machines)) - setattr(args, "nnodes", str(num_machines)) - setattr(args, "node_rank", int(args.machine_rank)) - if getattr(args, "same_network", False): - setattr(args, "master_addr", str(main_process_ip)) - setattr(args, "master_port", str(main_process_port)) - else: - setattr(args, "rdzv_endpoint", f"{main_process_ip}:{main_process_port}") - else: - setattr(args, "nproc_per_node", str(num_processes)) - if main_process_port is not None: - setattr(args, "master_port", str(main_process_port)) - - if args.module and args.no_python: - raise ValueError("--module and --no_python cannot be used together") - elif args.module: - setattr(args, "module", True) - elif args.no_python: - setattr(args, "no_python", True) + raise ImportError( + "DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source." + ) - current_env = os.environ.copy() - gpu_ids = getattr(args, "gpu_ids", "all") - if gpu_ids != "all" and args.gpu_ids is not None: - current_env["CUDA_VISIBLE_DEVICES"] = gpu_ids + current_env = {} try: mixed_precision = PrecisionType(args.mixed_precision.lower()) except ValueError: @@ -743,10 +181,16 @@ def deepspeed_launcher(args): current_env["DEEPSPEED_ZERO_STAGE"] = str(args.zero_stage) current_env["GRADIENT_ACCUMULATION_STEPS"] = str(args.gradient_accumulation_steps) current_env["GRADIENT_CLIPPING"] = str(args.gradient_clipping).lower() - current_env["DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE"] = str(args.offload_optimizer_device).lower() - current_env["DEEPSPEED_OFFLOAD_PARAM_DEVICE"] = str(args.offload_param_device).lower() + current_env["DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE"] = str( + args.offload_optimizer_device + ).lower() + current_env["DEEPSPEED_OFFLOAD_PARAM_DEVICE"] = str( + args.offload_param_device + ).lower() current_env["DEEPSPEED_ZERO3_INIT"] = str(args.zero3_init_flag).lower() - current_env["DEEPSPEED_ZERO3_SAVE_16BIT_MODEL"] = str(args.zero3_save_16bit_model).lower() + current_env["DEEPSPEED_ZERO3_SAVE_16BIT_MODEL"] = str( + args.zero3_save_16bit_model + ).lower() if args.deepspeed_config_file is not None: current_env["DEEPSPEED_CONFIG_FILE"] = str(args.deepspeed_config_file) @@ -759,321 +203,23 @@ def deepspeed_launcher(args): os.environ.update(current_env) -def tpu_launcher(args): - import torch_xla.distributed.xla_multiprocessing as xmp - - current_env = {} - - if args.no_python: - raise ValueError("--no_python cannot be used with TPU launcher") - - if args.mixed_precision == "bf16": - if args.downcast_bf16: - current_env["XLA_USE_BF16"] = "0" - current_env["XLA_DOWNCAST_BF16"] = "1" - else: - current_env["XLA_USE_BF16"] = "1" - current_env["XLA_DOWNCAST_BF16"] = "0" - - if args.module: - mod_name = args.training_script - else: - # Import training_script as a module - script_path = Path(args.training_script) - sys.path.append(str(script_path.parent.resolve())) - mod_name = script_path.stem - - mod = importlib.import_module(mod_name) - if not hasattr(mod, args.main_training_function): - raise ValueError( - f"Your training script should have a function named {args.main_training_function}, or you should pass a " - "different value to `--main_training_function`." - ) - - # Patch sys.argv - sys.argv = [mod.__file__] + args.training_script_args - - main_function = getattr(mod, args.main_training_function) - with patch_environment(**current_env): - xmp.spawn(PrepareForLaunch(main_function), args=(), nprocs=args.num_processes) - - -def _convert_nargs_to_dict(nargs: List[str]) -> Dict[str, str]: - if len(nargs) < 0: - return {} - # helper function to infer type for argsparser - - def _infer_type(s): - try: - s = float(s) - - if s // 1 == s: - return int(s) - return s - except ValueError: - return s - - parser = argparse.ArgumentParser() - _, unknown = parser.parse_known_args(nargs) - for index, argument in enumerate(unknown): - if argument.startswith(("-", "--")): - action = None - if index + 1 < len(unknown): # checks if next index would be in list - if unknown[index + 1].startswith(("-", "--")): # checks if next element is an key - # raise an error if element is store_true or store_false - raise ValueError( - "SageMaker doesn’t support argparse actions for `store_true` or `store_false`. Please define explicit types" - ) - else: # raise an error if last element is store_true or store_false - raise ValueError( - "SageMaker doesn’t support argparse actions for `store_true` or `store_false`. Please define explicit types" - ) - # adds argument to parser based on action_store true - if action is None: - parser.add_argument(argument, type=_infer_type) - else: - parser.add_argument(argument, action=action) - - return { - key: (literal_eval(value) if value in ("True", "False") else value) - for key, value in parser.parse_args(nargs).__dict__.items() - } - - -def sagemaker_launcher(sagemaker_config: SageMakerConfig, args): - if not is_sagemaker_available(): - raise ImportError( - "Please install sagemaker to be able to launch training on Amazon SageMaker with `pip install accelerate[sagemaker]`" - ) - if args.module or args.no_python: - raise ValueError( - "SageMaker requires a python training script file and cannot be used with --module or --no_python" - ) - - from sagemaker.huggingface import HuggingFace - - # configure environment - print("Configuring Amazon SageMaker environment") - os.environ["AWS_DEFAULT_REGION"] = sagemaker_config.region - - # configure credentials - if sagemaker_config.profile is not None: - os.environ["AWS_PROFILE"] = sagemaker_config.profile - elif args.aws_access_key_id is not None and args.aws_secret_access_key is not None: - os.environ["AWS_ACCESS_KEY_ID"] = args.aws_access_key_id - os.environ["AWS_SECRET_ACCESS_KEY"] = args.aws_secret_access_key - else: - raise EnvironmentError( - "You need to provide an aws_access_key_id and aws_secret_access_key when not using aws_profile" - ) - - # extract needed arguments - source_dir = os.path.dirname(args.training_script) - if not source_dir: # checks if string is empty - source_dir = "." - entry_point = os.path.basename(args.training_script) - if not entry_point.endswith(".py"): - raise ValueError(f'Your training script should be a python script and not "{entry_point}"') - - print("Converting Arguments to Hyperparameters") - hyperparameters = _convert_nargs_to_dict(args.training_script_args) - - try: - mixed_precision = PrecisionType(args.mixed_precision.lower()) - except ValueError: - raise ValueError( - f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}." - ) - - if args.fp16: - warnings.warn('--fp16 flag is deprecated. Use "--mixed_precision fp16" instead.', FutureWarning) - mixed_precision = "fp16" - - try: - dynamo_backend = DynamoBackend(args.dynamo_backend.upper()) - except ValueError: - raise ValueError(f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DYNAMO_BACKENDS}.") - - # Environment variables to be set for use during training job - environment = { - "ACCELERATE_USE_SAGEMAKER": "true", - "ACCELERATE_MIXED_PRECISION": str(mixed_precision), - "ACCELERATE_DYNAMO_BACKEND": dynamo_backend.value, - "ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE": sagemaker_config.distributed_type.value, - } - # configure distribution set up - distribution = None - if sagemaker_config.distributed_type == SageMakerDistributedType.DATA_PARALLEL: - distribution = {"smdistributed": {"dataparallel": {"enabled": True}}} - - # configure sagemaker inputs - sagemaker_inputs = None - if sagemaker_config.sagemaker_inputs_file is not None: - print(f"Loading SageMaker Inputs from {sagemaker_config.sagemaker_inputs_file} file") - sagemaker_inputs = {} - with open(sagemaker_config.sagemaker_inputs_file) as file: - for i, line in enumerate(file): - if i == 0: - continue - l = line.split("\t") - sagemaker_inputs[l[0]] = l[1].strip() - print(f"Loaded SageMaker Inputs: {sagemaker_inputs}") - - # configure sagemaker metrics - sagemaker_metrics = None - if sagemaker_config.sagemaker_metrics_file is not None: - print(f"Loading SageMaker Metrics from {sagemaker_config.sagemaker_metrics_file} file") - sagemaker_metrics = [] - with open(sagemaker_config.sagemaker_metrics_file) as file: - for i, line in enumerate(file): - if i == 0: - continue - l = line.split("\t") - metric_dict = { - "Name": l[0], - "Regex": l[1].strip(), - } - sagemaker_metrics.append(metric_dict) - print(f"Loaded SageMaker Metrics: {sagemaker_metrics}") - - # configure session - print("Creating Estimator") - huggingface_estimator = HuggingFace( - image_uri=sagemaker_config.image_uri, - entry_point=entry_point, - source_dir=source_dir, - role=sagemaker_config.iam_role_name, - transformers_version=sagemaker_config.transformers_version, - pytorch_version=sagemaker_config.pytorch_version, - py_version=sagemaker_config.py_version, - base_job_name=sagemaker_config.base_job_name, - instance_count=sagemaker_config.num_machines, - instance_type=sagemaker_config.ec2_instance_type, - debugger_hook_config=False, - distribution=distribution, - hyperparameters=hyperparameters, - environment=environment, - metric_definitions=sagemaker_metrics, - ) - - huggingface_estimator.fit(inputs=sagemaker_inputs) - print(f"You can find your model data at: {huggingface_estimator.model_data}") +def _raise_notimplementederror(*args, **kwargs): + raise NotImplementedError() def launch_command(args): - # Sanity checks - if sum([args.multi_gpu, args.tpu, args.use_deepspeed, args.use_fsdp]) > 1: - raise ValueError("You can only pick one between `--multi_gpu`, `--use_deepspeed`, `--tpu`, `--use_fsdp`.") - - defaults = None - warned = [] - # Get the default from the config file. - if args.config_file is not None or os.path.isfile(default_config_file) and not args.cpu: - defaults = load_config_from_file(args.config_file) - if ( - not args.multi_gpu - and not args.tpu - and not args.mps - and not args.use_deepspeed - and not args.use_fsdp - and not args.use_megatron_lm - ): - args.use_deepspeed = defaults.distributed_type == DistributedType.DEEPSPEED - args.multi_gpu = defaults.distributed_type == DistributedType.MULTI_GPU - args.tpu = defaults.distributed_type == DistributedType.TPU - args.use_fsdp = defaults.distributed_type == DistributedType.FSDP - args.mps = defaults.distributed_type == DistributedType.MPS - args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM - if not args.mps: - if args.gpu_ids is None: - if defaults.gpu_ids is not None: - args.gpu_ids = defaults.gpu_ids - else: - args.gpu_ids = "all" - if len(args.gpu_ids.split(",")) < 2 and args.multi_gpu and (args.gpu_ids != "all"): - args.multi_gpu = False - if defaults.compute_environment == ComputeEnvironment.LOCAL_MACHINE: - # Update args with the defaults - for name, attr in defaults.__dict__.items(): - if isinstance(attr, dict): - for k in defaults.deepspeed_config: - if getattr(args, k) is None: - setattr(args, k, defaults.deepspeed_config[k]) - for k in defaults.fsdp_config: - arg_to_set = k - if "fsdp" not in arg_to_set: - arg_to_set = "fsdp_" + arg_to_set - setattr(args, arg_to_set, defaults.fsdp_config[k]) - for k in defaults.megatron_lm_config: - setattr(args, k, defaults.megatron_lm_config[k]) - continue - - # Those args are handled separately - if ( - name not in ["compute_environment", "fp16", "mixed_precision", "distributed_type"] - and getattr(args, name, None) is None - ): - setattr(args, name, attr) - if not args.mixed_precision: - if args.fp16: - args.mixed_precision = "fp16" - else: - args.mixed_precision = defaults.mixed_precision - if args.dynamo_backend is None: - warned.append("\t`--dynamo_backend` was set to a value of `'no'`") - args.dynamo_backend = "no" - else: - if args.num_processes is None: - args.num_processes = torch.cuda.device_count() if args.multi_gpu else 1 - warned.append(f"\t`--num_processes` was set to a value of `{args.num_processes}`") - if args.num_machines is None: - warned.append("\t`--num_machines` was set to a value of `1`") - args.num_machines = 1 - if args.mixed_precision is None: - warned.append("\t`--mixed_precision` was set to a value of `'no'`") - args.mixed_precision = "no" - if not hasattr(args, "use_cpu"): - args.use_cpu = args.cpu - if args.dynamo_backend is None: - warned.append("\t`--dynamo_backend` was set to a value of `'no'`") - args.dynamo_backend = "no" - - if args.num_cpu_threads_per_process is None: - args.num_cpu_threads_per_process = 1 - if args.use_cpu and args.num_processes > 1: - local_size = get_int_from_env( - ["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1 - ) - threads_per_process = int(psutil.cpu_count(logical=False) / local_size) - if args.num_cpu_threads_per_process > 1: - args.num_cpu_threads_per_process = threads_per_process - warned.append( - f"\t`--num_cpu_threads_per_process` was set to `{args.num_cpu_threads_per_process}` to improve out-of-box performance when training on CPUs" - ) - - if any(warned): - message = "The following values were not passed to `accelerate launch` and had defaults used instead:\n" - message += "\n".join(warned) - message += ( - "\nTo avoid this warning pass in values for each of the problematic parameters or run `accelerate config`." - ) - logger.warning(message) - - # Use the proper launcher - if args.use_deepspeed and not args.cpu: - deepspeed_launcher(args) - elif args.use_fsdp and not args.cpu: - multi_gpu_launcher(args) - elif args.use_megatron_lm and not args.cpu: - multi_gpu_launcher(args) - elif args.multi_gpu and not args.cpu: - multi_gpu_launcher(args) - elif args.tpu and not args.cpu: - tpu_launcher(args) - elif defaults is not None and defaults.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER: - sagemaker_launcher(defaults, args) - else: - simple_launcher(args) + with patch( + "accelerate.commands.launch.deepspeed_launcher", deepspeed_launcher + ), patch( + "accelerate.commands.launch.multi_gpu_launcher", multi_gpu_launcher + ), patch( + "accelerate.commands.launch.simple_launcher", simple_launcher + ), patch( + "accelerate.commands.launch.tpu_launcher", _raise_notimplementederror + ), patch( + "accelerate.commands.launch.sagemaker_launcher", _raise_notimplementederror + ): + return original_launch_command(args) def main(): diff --git a/trlx/sweep.py b/trlx/sweep.py index 52b3b3353..db19c30e1 100644 --- a/trlx/sweep.py +++ b/trlx/sweep.py @@ -7,23 +7,15 @@ import ray import yaml from ray import tune -from ray.air import ScalingConfig, session -from ray.train.torch import TorchTrainer +from ray.air import ScalingConfig from ray.tune.logger import CSVLoggerCallback -import tempfile from trlx.ray_tune import get_param_space, get_tune_config -from accelerate.commands.config import default_config_file, load_config_from_file -from trlx.ray_train.launch import launch_command, launch_command_parser # from trlx.ray_tune.wandb import create_report, log_trials -from argparse import Namespace -class DefaultNamespace(Namespace): - def __getattr__(self, name: str): - parser = launch_command_parser() - ret = parser.get_default(name) - return ret +from trlx.ray_train.accelerate_trainer import AccelerateTrainer + def tune_function( train_function, param_space: dict, tune_config: dict, resources: dict @@ -31,38 +23,19 @@ def tune_function( default_config = yaml.safe_load(open("configs/ppo_config.yml")) param_space["default_config"] = default_config - config_file_path = default_config_file - with open(config_file_path, "r") as f: - config_data = f.read() - - def train_function_wrapper(config): - temp_config_file = tempfile.mkstemp()[1] - with open(temp_config_file, "w") as f: - f.write(config_data) - args = DefaultNamespace() - setattr(args, "config_file", temp_config_file) - launch_command(args) - os.environ["RANK"] = str(session.get_world_rank()) - os.environ["WORLD_RANK"] = str(session.get_world_rank()) - os.environ["LOCAL_RANK"] = str(session.get_local_rank()) - os.environ["WORLD_SIZE"] = str(session.get_world_size()) - os.environ["LOCAL_WORLD_SIZE"] = str(session.get_local_world_size()) - os.environ["CROSS_RANK"] = str(session.get_world_rank()) - os.environ["CROSS_SIZE"] = str(session.get_world_size()) - os.environ["LOCAL_SIZE"] = str(session.get_local_world_size()) - print(os.environ) - - return train_function(config) - param_space_train = {"train_loop_config": param_space} tuner = tune.Tuner( - TorchTrainer( - train_function_wrapper, + AccelerateTrainer( + train_function, + accelerate_config_path=None, # Use Accelerate default path scaling_config=ScalingConfig( trainer_resources={"CPU": 0}, num_workers=resources["gpu"], use_gpu=bool(resources["gpu"]), - resources_per_worker={"CPU": resources["cpu"], "GPU": int(bool(resources["gpu"]))}, + resources_per_worker={ + "CPU": resources["cpu"], + "GPU": int(bool(resources["gpu"])), + }, ), ), param_space=param_space_train, From 698985a9c8bcae0f50a0c4820cf8bbc6af49a00a Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 13 Jan 2023 16:13:23 -0800 Subject: [PATCH 04/57] Nit Signed-off-by: Antoni Baum --- trlx/ray_train/accelerate_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trlx/ray_train/accelerate_trainer.py b/trlx/ray_train/accelerate_trainer.py index 4f0b37510..c2a1e8e5d 100644 --- a/trlx/ray_train/accelerate_trainer.py +++ b/trlx/ray_train/accelerate_trainer.py @@ -10,7 +10,6 @@ from ray.air.config import DatasetConfig, RunConfig, ScalingConfig from ray.train.torch.config import TorchConfig from ray.train.trainer import GenDataset -from ray.util import PublicAPI if TYPE_CHECKING: from ray.data.preprocessor import Preprocessor @@ -132,6 +131,7 @@ def _wrap_train_loop_per_worker( accelerate_config_raw: str, deepspeed_config_file_raw: str, ): + """Wrap around train_loop_per_worker to set necessary Accelerate env vars.""" @wraps(train_loop_per_worker) def wrapped_train_loop_per_worker(*args, **kwargs): with tempfile.TemporaryDirectory() as tempdir: From 649c01e16d6a752d7683c6015584030a81deb1db Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 17 Jan 2023 20:54:33 +0000 Subject: [PATCH 05/57] Cleanup Signed-off-by: Antoni Baum --- examples/ppo_sentiments.py | 6 ++-- trlx/ray_train/accelerate_trainer.py | 4 +++ trlx/sweep.py | 45 ++++++++++++++++++---------- 3 files changed, 35 insertions(+), 20 deletions(-) diff --git a/examples/ppo_sentiments.py b/examples/ppo_sentiments.py index dfad92839..7291e690b 100644 --- a/examples/ppo_sentiments.py +++ b/examples/ppo_sentiments.py @@ -18,9 +18,6 @@ def get_positive_score(scores): return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] -# default_config = yaml.safe_load(open("configs/ppo_config.yml")) - - def main(hparams={}): default_config = hparams.pop("default_config") config = TRLConfig.update(default_config, hparams) @@ -56,4 +53,5 @@ def reward_fn(samples: List[str], **kwargs) -> List[float]: if __name__ == "__main__": - main() + default_config = yaml.safe_load(open("configs/ppo_config.yml")) + main({"default_config": default_config}) diff --git a/trlx/ray_train/accelerate_trainer.py b/trlx/ray_train/accelerate_trainer.py index c2a1e8e5d..dacf4896f 100644 --- a/trlx/ray_train/accelerate_trainer.py +++ b/trlx/ray_train/accelerate_trainer.py @@ -132,6 +132,7 @@ def _wrap_train_loop_per_worker( deepspeed_config_file_raw: str, ): """Wrap around train_loop_per_worker to set necessary Accelerate env vars.""" + @wraps(train_loop_per_worker) def wrapped_train_loop_per_worker(*args, **kwargs): with tempfile.TemporaryDirectory() as tempdir: @@ -140,9 +141,12 @@ def wrapped_train_loop_per_worker(*args, **kwargs): f.write(accelerate_config_raw) namespace = AccelerateDefaultNamespace() namespace.config_file = temp_config_file + namespace.num_processes = 1 + namespace.num_machines = session.get_world_size() namespace.num_cpu_threads_per_process = ( session.get_trial_resources().bundles[-1]["CPU"] ) + namespace.gpu_ids = None if deepspeed_config_file_raw: deepspeed_config_file = os.path.join( diff --git a/trlx/sweep.py b/trlx/sweep.py index db19c30e1..2fe649ddd 100644 --- a/trlx/sweep.py +++ b/trlx/sweep.py @@ -18,24 +18,25 @@ def tune_function( - train_function, param_space: dict, tune_config: dict, resources: dict + train_function, + param_space: dict, + tune_config: dict, + default_config: dict, + resources: dict, ): - default_config = yaml.safe_load(open("configs/ppo_config.yml")) + num_workers = resources.pop("num_workers") param_space["default_config"] = default_config - param_space_train = {"train_loop_config": param_space} + tuner = tune.Tuner( AccelerateTrainer( train_function, - accelerate_config_path=None, # Use Accelerate default path + accelerate_config_path=None, # Mandatory arg. None means use Accelerate default path scaling_config=ScalingConfig( trainer_resources={"CPU": 0}, - num_workers=resources["gpu"], - use_gpu=bool(resources["gpu"]), - resources_per_worker={ - "CPU": resources["cpu"], - "GPU": int(bool(resources["gpu"])), - }, + num_workers=num_workers, + use_gpu=bool(resources["GPU"]), + resources_per_worker=resources, ), ), param_space=param_space_train, @@ -44,8 +45,8 @@ def tune_function( local_dir="ray_results", callbacks=[CSVLoggerCallback()] ), ) - results = tuner.fit() + project_name = tune_config.get("project_name", "sweep") # log_trials( @@ -74,10 +75,19 @@ def tune_function( help="The config file defining the param_space.", ) parser.add_argument( - "--num-cpus", type=int, default=4, help="Number of CPUs to use per exp." + "--default-config", + type=str, + required=True, + help="The default config file for the script.", + ) + parser.add_argument( + "--num-workers", type=int, default=1, help="Number of workers to use per trial." + ) + parser.add_argument( + "--num-cpus", type=int, default=4, help="Number of CPUs to use per worker." ) parser.add_argument( - "--num-gpus", type=int, default=1, help="Number of GPUs to use per exp." + "--num-gpus", type=int, default=1, help="Number of GPUs to use per worker." ) parser.add_argument( "-y", "--assume-yes", action="store_true", help="Don't ask for confirmation" @@ -97,10 +107,13 @@ def tune_function( config = yaml.safe_load(f) tune_config = get_tune_config(config.pop("tune_config")) param_space = get_param_space(config) + with open(args.default_config) as f: + default_config = yaml.safe_load(f) resources = { - "cpu": args.num_cpus, - "gpu": args.num_gpus, + "num_workers": args.num_workers, + "CPU": args.num_cpus, + "GPU": args.num_gpus, } print(f'WARNING: Importing main from "{args.script}" and everything along with it') @@ -116,7 +129,7 @@ def tune_function( script = importlib.import_module(script_path) # Register the training function that will be used for training the model. # tune.register_trainable("train_function", script.main) - tune_function(script.main, param_space, tune_config, resources) + tune_function(script.main, param_space, tune_config, default_config, resources) # Shut down Ray. ray.shutdown() From 0793c9595c3286f72a7d2e5c6e7462212f122b13 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 17 Jan 2023 13:39:23 -0800 Subject: [PATCH 06/57] Fixes Signed-off-by: Antoni Baum --- trlx/data/configs.py | 4 ++++ trlx/sweep.py | 4 +++- trlx/trainer/accelerate_base_trainer.py | 5 +++-- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/trlx/data/configs.py b/trlx/data/configs.py index 238aed4e5..46f7153f9 100644 --- a/trlx/data/configs.py +++ b/trlx/data/configs.py @@ -176,6 +176,9 @@ class TrainConfig: :param seed: Random seed :type seed: int + + :param git_tag: Git tag for logging (as returned by ``trlx.utils.get_git_tags()``) + :type git_tag: Optional[Tuple[str, str]] """ total_steps: int @@ -200,6 +203,7 @@ class TrainConfig: trackers: Tuple[str] = ("wandb",) seed: int = 1000 + git_tag: Optional[Tuple[str, str]] = None @classmethod def from_dict(cls, config: Dict[str, Any]): diff --git a/trlx/sweep.py b/trlx/sweep.py index 2fe649ddd..287f132bc 100644 --- a/trlx/sweep.py +++ b/trlx/sweep.py @@ -11,6 +11,7 @@ from ray.tune.logger import CSVLoggerCallback from trlx.ray_tune import get_param_space, get_tune_config +from trlx.utils import get_git_tag # from trlx.ray_tune.wandb import create_report, log_trials @@ -25,7 +26,8 @@ def tune_function( resources: dict, ): num_workers = resources.pop("num_workers") - param_space["default_config"] = default_config + param_space["default_config"] = default_config.copy() + param_space["default_config"]["train"]["git_tag"] = get_git_tag() param_space_train = {"train_loop_config": param_space} tuner = tune.Tuner( diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 9c04e3b17..a8cecfebf 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -69,7 +69,8 @@ def __init__(self, config, **kwargs): num_gpus = "1gpu" else: num_gpus = f"{self.accelerator.num_processes}gpus" - branch = get_git_tag()[0] + git_tags = config.train.git_tag or get_git_tag() + branch = git_tags[0] run_name = "/".join([script_name, model_name, num_gpus]) + f":{branch}" @@ -83,7 +84,7 @@ def __init__(self, config, **kwargs): "name": run_name, "entity": self.config.train.entity_name, "group": self.config.train.group_name, - "tags": ["/".join(get_git_tag())], + "tags": ["/".join(git_tags)], "mode": "disabled" if os.environ.get("debug", False) else "online", } self.accelerator.init_trackers( From 9bdcbfe1e139b750a983af53030d2a1983f3868c Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 17 Jan 2023 14:33:01 -0800 Subject: [PATCH 07/57] Make sure master_port, master_addr are set Signed-off-by: Antoni Baum --- trlx/ray_train/accelerate_trainer.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/trlx/ray_train/accelerate_trainer.py b/trlx/ray_train/accelerate_trainer.py index dacf4896f..0171d1bcb 100644 --- a/trlx/ray_train/accelerate_trainer.py +++ b/trlx/ray_train/accelerate_trainer.py @@ -156,7 +156,14 @@ def wrapped_train_loop_per_worker(*args, **kwargs): f.write(deepspeed_config_file_raw) namespace.deepspeed_config_file = deepspeed_config_file + # Set by TorchBackend + master_addr = os.environ["MASTER_ADDR"] + master_port = os.environ["MASTER_PORT"] + launch_command(namespace) + + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = master_port os.environ["RANK"] = str(session.get_world_rank()) os.environ["WORLD_RANK"] = str(session.get_world_rank()) os.environ["CROSS_RANK"] = str(session.get_world_rank()) From 496c891efe74cf19ffe77c15fa579c2593153174 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 17 Jan 2023 23:49:14 +0000 Subject: [PATCH 08/57] Make private Signed-off-by: Antoni Baum --- trlx/ray_train/accelerate_trainer.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/trlx/ray_train/accelerate_trainer.py b/trlx/ray_train/accelerate_trainer.py index 0171d1bcb..9026a3189 100644 --- a/trlx/ray_train/accelerate_trainer.py +++ b/trlx/ray_train/accelerate_trainer.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from ray.data.preprocessor import Preprocessor + from ray.tune.trainable import Trainable from ray.train.torch import TorchTrainer @@ -20,7 +21,7 @@ from accelerate.commands.config import default_config_file, load_config_from_file -class AccelerateDefaultNamespace(Namespace): +class _AccelerateDefaultNamespace(Namespace): @property def parser(self): return launch_command_parser() @@ -29,7 +30,7 @@ def __getattr__(self, name: str): return self.parser.get_default(name) -class AccelerateConfigWrapper: +class _AccelerateConfigWrapper: """ Lets Trainables know to treat this as already loaded file content instead of path. """ @@ -60,7 +61,7 @@ def __init__( resume_from_checkpoint: Optional[Checkpoint] = None ): self.accelerate_config_path = accelerate_config_path or default_config_file - if isinstance(self.accelerate_config_path, AccelerateConfigWrapper): + if isinstance(self.accelerate_config_path, _AccelerateConfigWrapper): self._accelerate_config_raw = self.accelerate_config_path.config_raw self._deepspeed_config_file_raw = ( self.accelerate_config_path.deepspeed_config_raw @@ -100,7 +101,7 @@ def as_trainable(self) -> Type["Trainable"]: # and share the contents with the Trainables (which may be on different) # nodes old_accelerate_config_path = self._param_dict["accelerate_config_path"] - self._param_dict["accelerate_config_path"] = AccelerateConfigWrapper( + self._param_dict["accelerate_config_path"] = _AccelerateConfigWrapper( self._accelerate_config_raw, self._deepspeed_config_file_raw ) try: @@ -139,7 +140,8 @@ def wrapped_train_loop_per_worker(*args, **kwargs): temp_config_file = os.path.join(tempdir, "default_config.yaml") with open(temp_config_file, "w") as f: f.write(accelerate_config_raw) - namespace = AccelerateDefaultNamespace() + + namespace = _AccelerateDefaultNamespace() namespace.config_file = temp_config_file namespace.num_processes = 1 namespace.num_machines = session.get_world_size() From 55022f24446743da8c6b7892962fa61ccbd81043 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 18 Jan 2023 00:05:37 +0000 Subject: [PATCH 09/57] Tweak Signed-off-by: Antoni Baum --- trlx/ray_train/accelerate_trainer.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/trlx/ray_train/accelerate_trainer.py b/trlx/ray_train/accelerate_trainer.py index 9026a3189..00bc281ee 100644 --- a/trlx/ray_train/accelerate_trainer.py +++ b/trlx/ray_train/accelerate_trainer.py @@ -141,14 +141,21 @@ def wrapped_train_loop_per_worker(*args, **kwargs): with open(temp_config_file, "w") as f: f.write(accelerate_config_raw) + # Set by TorchBackend + master_addr = os.environ["MASTER_ADDR"] + master_port = os.environ["MASTER_PORT"] + namespace = _AccelerateDefaultNamespace() namespace.config_file = temp_config_file namespace.num_processes = 1 namespace.num_machines = session.get_world_size() + namespace.machine_rank = session.get_world_rank() namespace.num_cpu_threads_per_process = ( session.get_trial_resources().bundles[-1]["CPU"] ) namespace.gpu_ids = None + namespace.main_process_ip = master_addr + namespace.main_process_port = master_port if deepspeed_config_file_raw: deepspeed_config_file = os.path.join( @@ -158,10 +165,6 @@ def wrapped_train_loop_per_worker(*args, **kwargs): f.write(deepspeed_config_file_raw) namespace.deepspeed_config_file = deepspeed_config_file - # Set by TorchBackend - master_addr = os.environ["MASTER_ADDR"] - master_port = os.environ["MASTER_PORT"] - launch_command(namespace) os.environ["MASTER_ADDR"] = master_addr From df0530e1fee4e9e8b190254e68a12dfe55d6ac57 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 18 Jan 2023 03:51:22 +0000 Subject: [PATCH 10/57] Restore wanddb Signed-off-by: Antoni Baum --- trlx/sweep.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/trlx/sweep.py b/trlx/sweep.py index 287f132bc..a46c136e2 100644 --- a/trlx/sweep.py +++ b/trlx/sweep.py @@ -1,6 +1,5 @@ # python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml examples/ppo_sentiments.py import argparse -import os import importlib from pathlib import Path @@ -11,11 +10,9 @@ from ray.tune.logger import CSVLoggerCallback from trlx.ray_tune import get_param_space, get_tune_config -from trlx.utils import get_git_tag - -# from trlx.ray_tune.wandb import create_report, log_trials - +from trlx.ray_tune.wandb import create_report, log_trials from trlx.ray_train.accelerate_trainer import AccelerateTrainer +from trlx.utils import get_git_tag def tune_function( @@ -51,18 +48,18 @@ def tune_function( project_name = tune_config.get("project_name", "sweep") - # log_trials( - # tuner._local_tuner.get_experiment_checkpoint_dir(), - # project_name, - # ) - - # create_report( - # project_name, - # param_space, - # tune_config, - # Path(tuner._local_tuner.get_experiment_checkpoint_dir()).stem, - # results.get_best_result().config, - # ) + log_trials( + tuner._local_tuner.get_experiment_checkpoint_dir(), + project_name, + ) + + create_report( + project_name, + param_space, + tune_config, + Path(tuner._local_tuner.get_experiment_checkpoint_dir()).stem, + results.get_best_result().config, + ) print("Best hyperparameters found were: ", results.get_best_result().config) From 7be54486caa6b525fe5a2755d663574423ca36d8 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 18 Jan 2023 03:52:53 +0000 Subject: [PATCH 11/57] Add ray.init() back, remove unnecesary code Signed-off-by: Antoni Baum --- trlx/sweep.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/trlx/sweep.py b/trlx/sweep.py index a46c136e2..f4bd0df20 100644 --- a/trlx/sweep.py +++ b/trlx/sweep.py @@ -44,8 +44,8 @@ def tune_function( local_dir="ray_results", callbacks=[CSVLoggerCallback()] ), ) - results = tuner.fit() + results = tuner.fit() project_name = tune_config.get("project_name", "sweep") log_trials( @@ -109,6 +109,12 @@ def tune_function( with open(args.default_config) as f: default_config = yaml.safe_load(f) + # Initialize Ray. + if args.server_address: + ray.init(address=f"ray://{args.server_address}") + else: + ray.init() + resources = { "num_workers": args.num_workers, "CPU": args.num_cpus, @@ -126,8 +132,6 @@ def tune_function( # convert a nested path to a module path script_path = args.script.replace(".py", "").replace("/", ".") script = importlib.import_module(script_path) - # Register the training function that will be used for training the model. - # tune.register_trainable("train_function", script.main) tune_function(script.main, param_space, tune_config, default_config, resources) # Shut down Ray. From b885912b5e2795490de2e0ea379892dc13ca68b4 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 27 Jan 2023 17:52:54 +0000 Subject: [PATCH 12/57] Set ACCELERATE_TORCH_DEVICE Signed-off-by: Antoni Baum --- trlx/ray_train/accelerate_trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/trlx/ray_train/accelerate_trainer.py b/trlx/ray_train/accelerate_trainer.py index 00bc281ee..3120a02d2 100644 --- a/trlx/ray_train/accelerate_trainer.py +++ b/trlx/ray_train/accelerate_trainer.py @@ -8,7 +8,7 @@ from ray.air import session from ray.air.checkpoint import Checkpoint from ray.air.config import DatasetConfig, RunConfig, ScalingConfig -from ray.train.torch.config import TorchConfig +from ray.train.torch import TorchConfig, get_device from ray.train.trainer import GenDataset if TYPE_CHECKING: @@ -177,6 +177,7 @@ def wrapped_train_loop_per_worker(*args, **kwargs): os.environ["LOCAL_RANK"] = str(session.get_local_rank()) os.environ["LOCAL_WORLD_SIZE"] = str(session.get_local_world_size()) os.environ["LOCAL_SIZE"] = str(session.get_local_world_size()) + os.environ["ACCELERATE_TORCH_DEVICE"] = str(get_device()) return train_loop_per_worker(*args, **kwargs) From 864d75712aeb7902d22fc009426af22178995e4a Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Thu, 16 Feb 2023 21:11:20 +0200 Subject: [PATCH 13/57] refactor(ray_tune): collapse files into `sweep.py` & fix w&b reports --- trlx/ray_tune/__init__.py | 161 ------------------ trlx/ray_tune/train_funcs.py | 32 ---- trlx/ray_tune/wandb.py | 199 ----------------------- trlx/sweep.py | 307 +++++++++++++++++++++++++++++++---- 4 files changed, 275 insertions(+), 424 deletions(-) delete mode 100644 trlx/ray_tune/__init__.py delete mode 100644 trlx/ray_tune/train_funcs.py delete mode 100644 trlx/ray_tune/wandb.py diff --git a/trlx/ray_tune/__init__.py b/trlx/ray_tune/__init__.py deleted file mode 100644 index 8d0229437..000000000 --- a/trlx/ray_tune/__init__.py +++ /dev/null @@ -1,161 +0,0 @@ -from ray import tune - - -def get_param_space(config: dict): # noqa: C901 - """Get the param space from the config file.""" - - def get_strategy(value): - """Get search space strategy from config. - A search space defines valid values for your hyperparameters and - can specify how these values are sampled. - - Refer to the documentation for more info: - https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs - - The user will have to define the search space in the config file by providing - the name of the `strategy` and the `values` to sample from. - - The valid strategies are: - - `uniform` (List) - Samples uniformly between the given bounds. - - `quniform` (List) - Samples uniformly between the given bounds, quantized. - - `loguniform` (List) - Samples uniformly between the given bounds on a log scale. - - `qloguniform` (List) - Samples uniformly between the given bounds on a log scale, quantized. - - `randn` (List) - Samples from a normal distribution. - - `qrandn` (List) - Samples from a normal distribution, quantized. - - `randint` (List) - Samples uniformly between the given bounds, quantized to integers. - - `qrandint` (List) - Samples uniformly between the given bounds, quantized to integers. - - `lograndint` (List) - Samples uniformly between the given bounds on a log scale, quantized to integers. - - `qlograndint` (List) - Samples uniformly between the given bounds on a log scale, quantized to integers. - - `choice` (List) - Samples from a discrete set of values. - - `qrandn` (List) - Samples from a normal distribution, quantized. - - `grid_search` (List) - Samples from the given list of values. - - """ - - strategy = value["strategy"] - if strategy == "uniform": - assert isinstance(value["values"], list) - assert len(value["values"]) == 2 - return tune.uniform(*value["values"]) - elif strategy == "quniform": - assert isinstance(value["values"], list) - assert len(value["values"]) == 3 - return tune.quniform(*value["values"]) - elif strategy == "loguniform": - assert isinstance(value["values"], list) - assert 2 <= len(value["values"]) <= 3 - return tune.loguniform(*value["values"]) - elif strategy == "qloguniform": - assert isinstance(value["values"], list) - assert len(value["values"]) == 4 - return tune.qloguniform(*value["values"]) - elif strategy == "randn": - assert isinstance(value["values"], list) - assert len(value["values"]) == 2 - return tune.randn(*value["values"]) - elif strategy == "qrandn": - assert isinstance(value["values"], list) - assert len(value["values"]) == 3 - return tune.qrandn(*value["values"]) - elif strategy == "randint": - assert isinstance(value["values"], list) - assert len(value["values"]) == 2 - return tune.randint(*value["values"]) - elif strategy == "qrandint": - assert isinstance(value["values"], list) - assert len(value["values"]) == 3 - return tune.qrandint(*value["values"]) - elif strategy == "lograndint": - assert isinstance(value["values"], list) - assert len(value["values"]) == 3 - return tune.lograndint(*value["values"]) - elif strategy == "qlograndint": - assert isinstance(value["values"], list) - assert len(value["values"]) == 4 - return tune.qlograndint(*value["values"]) - elif strategy == "choice": - assert isinstance(value["values"], list) - return tune.choice(value["values"]) - elif strategy == "grid": - assert isinstance(value["values"], list) - return tune.grid_search(value["values"]) - - for k, v in config.items(): - if k != "tune_config": - config[k] = get_strategy(v) - - return config - - -def get_search_alg(tune_config: dict): - """Initialize the search algorithm and return it. - - Bayesian Optimization is currently supported. - """ - search_alg = tune_config["search_alg"] - - if search_alg == "bayesopt": - try: - from ray.tune.search.bayesopt import BayesOptSearch - except ImportError: - raise ImportError("Please pip install bayesian-optimization to use BayesOptSearch.") - - assert "metric" in tune_config.keys() and "mode" in tune_config.keys() - "Please specify metric and mode for BayesOptSearch." - - return BayesOptSearch(metric=tune_config["metric"], mode=tune_config["mode"]) - elif search_alg == "bohb": - try: - from ray.tune.search.bohb import TuneBOHB - except ImportError: - raise ImportError("Please pip install hpbandster and ConfigSpace to use TuneBOHB.") - - assert "metric" in tune_config.keys() and "mode" in tune_config.keys() - "Please specify metric and mode for TuneBOHB." - - return TuneBOHB() - elif search_alg == "random": - return None - else: - NotImplementedError("Search algorithm not supported.") - - -def get_scheduler(tune_config: dict): - """Initialize the scheduler and return it. - - The schedulers can early terminate bad trials, pause trials, - clone trials, and alter hyperparameters of a running trial. - - Refer to the documentation for more info: - https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#tune-schedulers - - Currently available schedulers are: - - `hyperband` - Implements the HyperBand early stopping algorithm. - - """ - scheduler = tune_config["scheduler"] - - if scheduler == "hyperband": - return tune.schedulers.HyperBandScheduler() - elif scheduler == "hyperbandforbohb": - return tune.schedulers.HyperBandForBOHB() - elif scheduler == "fifo": - return None - else: - NotImplementedError("Scheduler not supported.") - - -def get_tune_config(tune_config: dict): - """Get the tune config to initialized `tune.TuneConfig` - to be passed `tune.Tuner`. - """ - if "search_alg" in tune_config.keys() and tune_config["search_alg"] is not None: - tune_config["search_alg"] = get_search_alg(tune_config) - - if "scheduler" in tune_config.keys() and tune_config["scheduler"] is not None: - tune_config["scheduler"] = get_scheduler(tune_config) - - # Remove config keys with None values. - tune_config = {k: v for k, v in tune_config.items() if v is not None} - - return tune_config diff --git a/trlx/ray_tune/train_funcs.py b/trlx/ray_tune/train_funcs.py deleted file mode 100644 index e2f3994e8..000000000 --- a/trlx/ray_tune/train_funcs.py +++ /dev/null @@ -1,32 +0,0 @@ -# Find the optimal hyperparameters to generates positive movie -# reviews by tuning a pretrained on IMDB model with a sentiment reward function. - -from datasets import load_dataset - -import trlx -from trlx.data.configs import TRLConfig - - -def ppo_sentiments_train(config: dict): - from transformers import pipeline - - config = TRLConfig.from_dict(config) - - sentiment_fn = pipeline("sentiment-analysis", "lvwerra/distilbert-imdb", device=-1) - - def reward_fn(samples, **kwargs): - outputs = sentiment_fn(samples, return_all_scores=True) - sentiments = [output[1]["score"] for output in outputs] - return sentiments - - # Take few words off of movies reviews as prompts - imdb = load_dataset("imdb", split="train+test") - prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] - - trlx.train( - "lvwerra/gpt2-imdb", - reward_fn=reward_fn, - prompts=prompts, - eval_prompts=["I don't know much about Hungarian underground"] * 64, - config=config, - ) diff --git a/trlx/ray_tune/wandb.py b/trlx/ray_tune/wandb.py deleted file mode 100644 index ac0f01bd4..000000000 --- a/trlx/ray_tune/wandb.py +++ /dev/null @@ -1,199 +0,0 @@ -"""Utility function to log the results of a Ray Tune experiment to W&B.""" - -import json -import math -import os -from pathlib import Path - -import wandb - -from trlx.utils import significant - -import wandb.apis.reports as wb # isort: skip - - -ray_info = [ - "done", - "time_this_iter_s", - "timesteps_total", - "episodes_total", - "iterations_since_restore", - "timesteps_since_restore", - "time_since_restore", - "warmup_time", - "should_checkpoint", - "training_iteration", - "timestamp", - "pid", -] - - -def parse_result(result): - out = {} - for k, v in result.items(): - if isinstance(v, (int, float)) and not k.startswith("config.") and k not in ray_info: - out[k] = v - - return out - - -def log_trials(trial_path: str, project_name: str): - trial_path = Path(trial_path) - files = os.listdir(trial_path) - - trial_paths = [] - for filename in files: - tmp_path = os.path.join(trial_path, filename) - if os.path.isdir(tmp_path): - trial_paths.append(tmp_path) - - for trial in trial_paths: - files = os.listdir(trial) - - # Open params.json and load the configs for that trial. - with open(os.path.join(trial, "params.json"), "r") as f: - params = json.load(f) - - name = ",".join(f"{k}={significant(v)}" for k, v in params.items()) - # Initialize wandb - run = wandb.init( - name=name, - project=project_name, - config=params, - group=trial_path.stem, - job_type="hyperopt", - ) - - # Open result.json and log the metrics to W&B. - with open(os.path.join(trial, "result.json"), "r") as f: - for line in f: - result = json.loads(line) - result.pop("config", None) - wandb.log(parse_result(result)) - - # Close the W&B run. - run.finish() - - -def create_report(project_name, param_space, tune_config, trial_path, best_config=None): - def get_parallel_coordinate(param_space, metric): - column_names = list(param_space.keys()) - columns = [wb.PCColumn(column) for column in column_names] - - return wb.ParallelCoordinatesPlot( - columns=columns + [wb.PCColumn(metric)], - layout={"x": 0, "y": 0, "w": 12 * 2, "h": 5 * 2}, - ) - - def get_param_importance(metric): - return wb.ParameterImportancePlot( - # Get it from the metric name. - with_respect_to=metric, - layout={"x": 0, "y": 5, "w": 6 * 2, "h": 4 * 2}, - ) - - def get_scatter_plot(metric): - return wb.ScatterPlot( - # Get it from the metric name. - title=f"{metric} v. Index", - x="Index", - y=metric, - running_ymin=True, - font_size="small", - layout={"x": 6, "y": 5, "w": 6 * 2, "h": 4 * 2}, - ) - - def get_metrics_with_history(project_name, group_name, entity=None): - entity_project = f"{entity}/{project_name}" if entity else project_name - api = wandb.Api() - runs = api.runs(entity_project) - - runs = sorted( - runs, - key=lambda run: run.summary.get(tune_config["metric"], -math.inf), - reverse=True, - ) - - for run in runs: - if run.group == str(group_name): - history = run.history() - metrics = history.columns - break - - metrics = [metric for metric in metrics if not metric.startswith("_")] - return metrics - - report = wb.Report( - project=project_name, - title=f"Hyperparameter Optimization Report: {trial_path}", - description="This is a report that shows the results of a hyperparameter optimization experiment.", - ) - - report.blocks = [ - wb.P( - "The following plots show the results of the hyperparameter optimization experiment. " - "Use this as a starting point for your analysis. Go in the edit mode to customize the report. " - "Share it with your team to collaborate on the analysis." - ), - wb.H1(text="Analysis"), - wb.P( - "Parallel coordinates chart (top) summarize the relationship between large numbers of hyperparameters " - "and model metrics at a glance. \nThe scatter plot (right) compares the different trials and gives you a " - "insight on how the trials progressed. \nThe parameter importance plot(left) lists the hyperparameters " - "that were the best predictors of, and highly correlated to desirable values of your metrics." - ), - wb.PanelGrid( - panels=[ - get_parallel_coordinate(param_space, tune_config["metric"]), - get_param_importance(tune_config["metric"]), - get_scatter_plot(tune_config["metric"]), - ], - runsets=[wb.Runset(project=project_name).set_filters_with_python_expr(f'group == "{trial_path}"')], - ), - ] - - metrics = get_metrics_with_history( - project_name, - trial_path, - ) - - line_plot_panels = [] - for metric in metrics: - line_plot_panels.append( - wb.LinePlot( - title=f"{metric}", - x="Step", - y=[f"{metric}"], - title_x="Step", - smoothing_show_original=True, - max_runs_to_show=10, - plot_type="line", - font_size="auto", - legend_position="north", - ) - ) - - report.blocks = report.blocks + [ - wb.H1(text="Metrics"), - wb.P( - "The following line plots show the metrics for each trial. Use this to investigate the " - "performance of the model for each trial at the metrics level." - ), - wb.PanelGrid( - panels=line_plot_panels, - runsets=[wb.Runset(project=project_name).set_filters_with_python_expr(f'group == "{trial_path}"')], - ), - ] - - if best_config: - report.blocks = report.blocks + [ - wb.H1(text="Best Config"), - wb.P( - "The code block shown below is the best config found by the hyperparameter " - "optimization experiment according to Ray Tune." - ), - wb.CodeBlock(code=[json.dumps(best_config, indent=4)], language="json"), - ] - - report.save() - print(report.url) diff --git a/trlx/sweep.py b/trlx/sweep.py index 939d223a8..82b3fc4af 100644 --- a/trlx/sweep.py +++ b/trlx/sweep.py @@ -1,26 +1,188 @@ # python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml examples/ppo_sentiments.py import argparse import importlib -from pathlib import Path +import json +from datetime import datetime import ray +import wandb.apis.reports as wb import yaml from ray import tune from ray.air import ScalingConfig from ray.tune.logger import CSVLoggerCallback -from trlx.ray_tune import get_param_space, get_tune_config -from trlx.ray_tune.wandb import create_report, log_trials +import wandb from trlx.ray_train.accelerate_trainer import AccelerateTrainer from trlx.utils import get_git_tag +def get_param_space(config: dict): # noqa: C901 + """Get the param space from the config file.""" + + def get_strategy(value): + """Get search space strategy from config. + A search space defines valid values for your hyperparameters and + can specify how these values are sampled. + + Refer to the documentation for more info: + https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs + + The user will have to define the search space in the config file by providing + the name of the `strategy` and the `values` to sample from. + + The valid strategies are: + - `uniform` (List) - Samples uniformly between the given bounds. + - `quniform` (List) - Samples uniformly between the given bounds, quantized. + - `loguniform` (List) - Samples uniformly between the given bounds on a log scale. + - `qloguniform` (List) - Samples uniformly between the given bounds on a log scale, quantized. + - `randn` (List) - Samples from a normal distribution. + - `qrandn` (List) - Samples from a normal distribution, quantized. + - `randint` (List) - Samples uniformly between the given bounds, quantized to integers. + - `qrandint` (List) - Samples uniformly between the given bounds, quantized to integers. + - `lograndint` (List) - Samples uniformly between the given bounds on a log scale, quantized to integers. + - `qlograndint` (List) - Samples uniformly between the given bounds on a log scale, quantized to integers. + - `choice` (List) - Samples from a discrete set of values. + - `qrandn` (List) - Samples from a normal distribution, quantized. + - `grid_search` (List) - Samples from the given list of values. + + """ + + strategy = value["strategy"] + if strategy == "uniform": + assert isinstance(value["values"], list) + assert len(value["values"]) == 2 + return tune.uniform(*value["values"]) + elif strategy == "quniform": + assert isinstance(value["values"], list) + assert len(value["values"]) == 3 + return tune.quniform(*value["values"]) + elif strategy == "loguniform": + assert isinstance(value["values"], list) + assert 2 <= len(value["values"]) <= 3 + return tune.loguniform(*value["values"]) + elif strategy == "qloguniform": + assert isinstance(value["values"], list) + assert len(value["values"]) == 4 + return tune.qloguniform(*value["values"]) + elif strategy == "randn": + assert isinstance(value["values"], list) + assert len(value["values"]) == 2 + return tune.randn(*value["values"]) + elif strategy == "qrandn": + assert isinstance(value["values"], list) + assert len(value["values"]) == 3 + return tune.qrandn(*value["values"]) + elif strategy == "randint": + assert isinstance(value["values"], list) + assert len(value["values"]) == 2 + return tune.randint(*value["values"]) + elif strategy == "qrandint": + assert isinstance(value["values"], list) + assert len(value["values"]) == 3 + return tune.qrandint(*value["values"]) + elif strategy == "lograndint": + assert isinstance(value["values"], list) + assert len(value["values"]) == 3 + return tune.lograndint(*value["values"]) + elif strategy == "qlograndint": + assert isinstance(value["values"], list) + assert len(value["values"]) == 4 + return tune.qlograndint(*value["values"]) + elif strategy == "choice": + assert isinstance(value["values"], list) + return tune.choice(value["values"]) + elif strategy == "grid": + assert isinstance(value["values"], list) + return tune.grid_search(value["values"]) + + for k, v in config.items(): + if k != "tune_config": + config[k] = get_strategy(v) + + return config + + +def get_search_alg(tune_config: dict): + """Initialize the search algorithm and return it. + + Bayesian Optimization is currently supported. + """ + search_alg = tune_config["search_alg"] + + if search_alg == "bayesopt": + try: + from ray.tune.search.bayesopt import BayesOptSearch + except ImportError: + raise ImportError("Please pip install bayesian-optimization to use BayesOptSearch.") + + assert "metric" in tune_config.keys() and "mode" in tune_config.keys() + "Please specify metric and mode for BayesOptSearch." + + return BayesOptSearch(metric=tune_config["metric"], mode=tune_config["mode"]) + elif search_alg == "bohb": + try: + from ray.tune.search.bohb import TuneBOHB + except ImportError: + raise ImportError("Please pip install hpbandster and ConfigSpace to use TuneBOHB.") + + assert "metric" in tune_config.keys() and "mode" in tune_config.keys() + "Please specify metric and mode for TuneBOHB." + + return TuneBOHB() + elif search_alg == "random": + return None + else: + NotImplementedError("Search algorithm not supported.") + + +def get_scheduler(tune_config: dict): + """Initialize the scheduler and return it. + + The schedulers can early terminate bad trials, pause trials, + clone trials, and alter hyperparameters of a running trial. + + Refer to the documentation for more info: + https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#tune-schedulers + + Currently available schedulers are: + - `hyperband` - Implements the HyperBand early stopping algorithm. + + """ + scheduler = tune_config["scheduler"] + + if scheduler == "hyperband": + return tune.schedulers.HyperBandScheduler() + elif scheduler == "hyperbandforbohb": + return tune.schedulers.HyperBandForBOHB() + elif scheduler == "fifo": + return None + else: + NotImplementedError("Scheduler not supported.") + + +def get_tune_config(tune_config: dict): + """Get the tune config to initialized `tune.TuneConfig` + to be passed `tune.Tuner`. + """ + if "search_alg" in tune_config.keys() and tune_config["search_alg"] is not None: + tune_config["search_alg"] = get_search_alg(tune_config) + + if "scheduler" in tune_config.keys() and tune_config["scheduler"] is not None: + tune_config["scheduler"] = get_scheduler(tune_config) + + # Remove config keys with None values. + tune_config = {k: v for k, v in tune_config.items() if v is not None} + + return tune_config + + def tune_function( train_function, param_space: dict, + resources: dict, tune_config: dict, default_config: dict, - resources: dict, + accelerate_config_path: str = None, ): num_workers = resources.pop("num_workers") param_space["default_config"] = default_config.copy() @@ -30,11 +192,12 @@ def tune_function( tuner = tune.Tuner( AccelerateTrainer( train_function, - accelerate_config_path=None, # Mandatory arg. None means use Accelerate default path + # Mandatory arg. None means use Accelerate default path + accelerate_config_path=accelerate_config_path, scaling_config=ScalingConfig( trainer_resources={"CPU": 0}, num_workers=num_workers, - use_gpu=bool(resources["GPU"]), + use_gpu=True, resources_per_worker=resources, ), ), @@ -44,27 +207,102 @@ def tune_function( ) results = tuner.fit() - project_name = tune_config.get("project_name", "sweep") + project_name = default_config["train"]["project_name"] + group_name = default_config["train"]["group_name"] + entity_name = default_config["train"].get("entity_name", None) - log_trials( - tuner._local_tuner.get_experiment_checkpoint_dir(), - project_name, - ) + column_names = param_space.pop("default_config").keys() + target_metric = tune_config["metric"] + + create_report(target_metric, column_names, entity_name, project_name, group_name, results.get_best_result().config) - create_report( - project_name, - param_space, - tune_config, - Path(tuner._local_tuner.get_experiment_checkpoint_dir()).stem, - results.get_best_result().config, + +def create_report(target_metric, column_names, entity_name, project_name, group_name, best_config): + report = wb.Report( + project=project_name, + title=f"Hyperparameter Optimization Report: {group_name}", + description="", ) - print("Best hyperparameters found were: ", results.get_best_result().config) + report.blocks = [ + wb.PanelGrid( + panels=[ + wb.ParallelCoordinatesPlot( + columns=[wb.PCColumn(f"c::{column}") for column in column_names] + [wb.PCColumn(target_metric)], + layout={"x": 0, "y": 0, "w": 12 * 2, "h": 5 * 2}, + ), + wb.ParameterImportancePlot( + with_respect_to=target_metric, + layout={"x": 0, "y": 5, "w": 6 * 2, "h": 4 * 2}, + ), + wb.ScatterPlot( + # Get it from the metric name. + title=f"{target_metric} v. Index", + x="Index", + y=target_metric, + running_ymin=True, + font_size="small", + layout={"x": 6, "y": 5, "w": 6 * 2, "h": 4 * 2}, + ), + ], + runsets=[ + wb.Runset(project=project_name).set_filters_with_python_expr(f'group == "{group_name}"'), + ], + ), + ] + + entity_project = f"{entity_name}/{project_name}" if entity_name else project_name + api = wandb.Api() + runs = api.runs(entity_project) + print(f"{entity_project=}") + print(f"{len(runs)=}") + + for run in runs: + print(f"{run.group=}") + if run.group == group_name: + history = run.history() + metrics = history.columns + break + + metrics = [metric for metric in metrics if not metric.startswith("_")] + + line_plot_panels = [] + for metric in metrics: + line_plot_panels.append( + wb.LinePlot( + title=f"{metric}", + x="Step", + y=[f"{metric}"], + title_x="Step", + smoothing_show_original=True, + max_runs_to_show=100, + plot_type="line", + font_size="auto", + legend_position="north", + ) + ) + + report.blocks = report.blocks + [ + wb.H1(text="Metrics"), + wb.PanelGrid( + panels=line_plot_panels, + runsets=[wb.Runset(project=project_name).set_filters_with_python_expr(f'group == "{group_name}"')], + ), + ] + + if best_config: + report.blocks = report.blocks + [ + wb.H1(text="Best Config"), + wb.CodeBlock(code=[json.dumps(best_config, indent=4)], language="json"), + ] + + report.save() + print(report.url) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("script", type=str, help="Path to the script") + parser.add_argument("script", type=str, help="Path to the example script") parser.add_argument( "--config", type=str, @@ -72,26 +310,31 @@ def tune_function( help="The config file defining the param_space.", ) parser.add_argument( - "--default-config", + "--default_config", type=str, required=True, help="The default config file for the script.", ) - parser.add_argument("--num-workers", type=int, default=1, help="Number of workers to use per trial.") - parser.add_argument("--num-cpus", type=int, default=4, help="Number of CPUs to use per worker.") - parser.add_argument("--num-gpus", type=int, default=1, help="Number of GPUs to use per worker.") - parser.add_argument("-y", "--assume-yes", action="store_true", help="Don't ask for confirmation") parser.add_argument( - "--server-address", + "--accelerate_config", + type=str, + required=False, + help="The default config file for the script.", + ) + parser.add_argument("--project_name", type=str, help="W&B project name to log runs into") + parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs (workers) to use per trial.") + parser.add_argument("--num_cpus", type=int, default=4, help="Number of CPUs to use per GPU (worker).") + parser.add_argument("-y", "--assume_yes", action="store_true", help="Don't ask for confirmation") + parser.add_argument( + "--server_address", type=str, default=None, required=False, help="The address of server to connect to if using Ray Client.", ) - args, _ = parser.parse_known_args() + args = parser.parse_args() - # Read config and parse it with open(args.config) as f: config = yaml.safe_load(f) tune_config = get_tune_config(config.pop("tune_config")) @@ -99,16 +342,17 @@ def tune_function( with open(args.default_config) as f: default_config = yaml.safe_load(f) - # Initialize Ray. + default_config["train"]["project_name"] = args.project_name + default_config["train"]["group_name"] = datetime.now().replace(microsecond=0).isoformat() + if args.server_address: ray.init(address=f"ray://{args.server_address}") else: ray.init() resources = { - "num_workers": args.num_workers, + "num_workers": args.num_gpus, "CPU": args.num_cpus, - "GPU": args.num_gpus, } print(f'WARNING: Importing main from "{args.script}" and everything along with it') @@ -122,7 +366,6 @@ def tune_function( # convert a nested path to a module path script_path = args.script.replace(".py", "").replace("/", ".") script = importlib.import_module(script_path) - tune_function(script.main, param_space, tune_config, default_config, resources) + tune_function(script.main, param_space, resources, tune_config, default_config, args.accelerate_config) - # Shut down Ray. ray.shutdown() From 2cda942a0fa454bb381472a088896edc9d75c474 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Thu, 16 Feb 2023 21:13:49 +0200 Subject: [PATCH 14/57] feat(configs): add & revert back flat updating of the config --- trlx/data/configs.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/trlx/data/configs.py b/trlx/data/configs.py index 8724a1a86..e3d8702de 100644 --- a/trlx/data/configs.py +++ b/trlx/data/configs.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Any, Dict, Optional, Set +from typing import Any, Dict, Optional, Set, Tuple import yaml @@ -12,6 +12,8 @@ def merge(base: Dict, update: Dict, updated: Set) -> Dict: if k in update and isinstance(v, dict): base[k] = merge(v, update[k], updated) updated.add(k) + elif isinstance(v, dict): + base[k] = merge(v, update, updated) elif k in update: base[k] = update[k] updated.add(k) @@ -275,10 +277,24 @@ def from_dict(cls, config: Dict): @classmethod def update(cls, baseconfig: Dict, config: Dict): + update = {} + # unflatten a string variable name into a nested dictionary + # key1.key2.key3: value -> {key1: {key2: {key3: value}}} + for name, value in config.items(): + if isinstance(value, dict): + update[name] = value + else: + *layers, var = name.split(".") + if layers: + d = update.setdefault(layers[0], {}) + for layer in layers[1:]: + d = d.setdefault(layer, {}) + d[var] = value + updates = set() - merged = merge(baseconfig, config, updates) + merged = merge(baseconfig, update, updates) - for param in config: + for param in update: if param not in updates: raise ValueError(f"parameter {param} is not present in the config (typo or a wrong config)") From bbe2ecb517ed2bf087a559149e920eecfd945961 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Thu, 16 Feb 2023 21:17:11 +0200 Subject: [PATCH 15/57] fix(base_trainer): reenable w&b logging through ray-tune --- trlx/trainer/accelerate_base_trainer.py | 37 +++++++++++++------------ 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index d225e69a9..7a6962c4f 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -15,7 +15,6 @@ from rich.table import Table from transformers import AutoTokenizer -import trlx.utils.logging as logging from trlx.data.configs import TRLConfig from trlx.trainer import BaseRLTrainer, register_trainer from trlx.utils import ( @@ -24,6 +23,7 @@ get_git_tag, get_optimizer_class, get_scheduler_class, + logging, significant, ) from trlx.utils.modeling import ( @@ -82,7 +82,11 @@ def __init__(self, config, **kwargs): # noqa: C901 run_name = "/".join([script_name, model_name, num_gpus]) + f":{branch}" - if self.accelerator.is_main_process and not ray.is_initialized(): + if ray.is_initialized(): + logging.set_verbosity(logging.WARNING) + logging.disable_progress_bar() + + if self.accelerator.is_main_process: config_dict = self.config.to_dict() dist_config = get_distributed_config(self.accelerator) config_dict["distributed"] = dist_config @@ -423,22 +427,22 @@ def evaluate(self): # noqa: C901 if self.accelerator.is_main_process: rows = sum(list(map(list, zip(*table))), []) - # Add metrics/rewards to the table's title - table_title = f"Evaluation #{self.nth_evaluation}" - for k, x in stats.items(): - if k.startswith("reward") or k.startswith("metrics"): - table_title += f" {k}: {significant(x)}" + if not ray.is_initialized(): + # Add metrics/rewards to the table's title + table_title = f"Evaluation #{self.nth_evaluation}" + for k, x in stats.items(): + if k.startswith("reward") or k.startswith("metrics"): + table_title += f" {k}: {significant(x)}" - rich_table = Table(*columns, title=table_title, show_lines=True) - for ix in range(max(min(3, len(rows)), len(gen_sweep_values))): - rich_table.add_row(*[str(significant(x)) for x in rows[ix]]) - Console().print(rich_table) + rich_table = Table(*columns, title=table_title, show_lines=True) + for ix in range(max(min(3, len(rows)), len(gen_sweep_values))): + rich_table.add_row(*[str(significant(x)) for x in rows[ix]]) + Console().print(rich_table) - if not ray.is_initialized(): - if self.config.train.tracker == "wandb": - import wandb + if self.config.train.tracker == "wandb": + import wandb - stats["samples"] = wandb.Table(columns, rows) + stats["samples"] = wandb.Table(columns, rows) self.nth_evaluation += 1 return stats @@ -546,8 +550,7 @@ def learn(self): # noqa: C901 checkpoint = Checkpoint.from_directory("state") session.report(filter_non_scalars(stats), checkpoint=checkpoint) - if not ray.is_initialized(): - self.accelerator.log(stats, step=self.iter_count) + self.accelerator.log(stats, step=self.iter_count) desc = " | ".join(f"{k}: {v:.2f}" for k, v in stats.items() if k.startswith("loss")) tbar.set_description(f"[{desc}]") From 0092f60280047cde26cb14a0cbebc1d77f8c5e15 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Thu, 16 Feb 2023 23:55:48 +0200 Subject: [PATCH 16/57] revert(base_trainer): remove trlx's verbosity limit when under ray --- trlx/trainer/accelerate_base_trainer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 7a6962c4f..3d1eb8162 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -82,10 +82,6 @@ def __init__(self, config, **kwargs): # noqa: C901 run_name = "/".join([script_name, model_name, num_gpus]) + f":{branch}" - if ray.is_initialized(): - logging.set_verbosity(logging.WARNING) - logging.disable_progress_bar() - if self.accelerator.is_main_process: config_dict = self.config.to_dict() dist_config = get_distributed_config(self.accelerator) From 88e8d86b0f670f4c235896c403bda87b99cfabbe Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Thu, 16 Feb 2023 23:56:52 +0200 Subject: [PATCH 17/57] refactor(sweep): flatten code & remove debug prints --- trlx/sweep.py | 92 ++++++++++++++++++++------------------------------- 1 file changed, 36 insertions(+), 56 deletions(-) diff --git a/trlx/sweep.py b/trlx/sweep.py index 82b3fc4af..e7d6e6527 100644 --- a/trlx/sweep.py +++ b/trlx/sweep.py @@ -176,52 +176,11 @@ def get_tune_config(tune_config: dict): return tune_config -def tune_function( - train_function, - param_space: dict, - resources: dict, - tune_config: dict, - default_config: dict, - accelerate_config_path: str = None, -): - num_workers = resources.pop("num_workers") - param_space["default_config"] = default_config.copy() - param_space["default_config"]["train"]["git_tag"] = get_git_tag() - param_space_train = {"train_loop_config": param_space} - - tuner = tune.Tuner( - AccelerateTrainer( - train_function, - # Mandatory arg. None means use Accelerate default path - accelerate_config_path=accelerate_config_path, - scaling_config=ScalingConfig( - trainer_resources={"CPU": 0}, - num_workers=num_workers, - use_gpu=True, - resources_per_worker=resources, - ), - ), - param_space=param_space_train, - tune_config=tune.TuneConfig(**tune_config), - run_config=ray.air.RunConfig(local_dir="ray_results", callbacks=[CSVLoggerCallback()]), - ) - - results = tuner.fit() - project_name = default_config["train"]["project_name"] - group_name = default_config["train"]["group_name"] - entity_name = default_config["train"].get("entity_name", None) - - column_names = param_space.pop("default_config").keys() - target_metric = tune_config["metric"] - - create_report(target_metric, column_names, entity_name, project_name, group_name, results.get_best_result().config) - - def create_report(target_metric, column_names, entity_name, project_name, group_name, best_config): report = wb.Report( project=project_name, - title=f"Hyperparameter Optimization Report: {group_name}", - description="", + title=f"Hyperparameter Optimization Report: {project_name}", + description=group_name, ) report.blocks = [ @@ -254,11 +213,8 @@ def create_report(target_metric, column_names, entity_name, project_name, group_ entity_project = f"{entity_name}/{project_name}" if entity_name else project_name api = wandb.Api() runs = api.runs(entity_project) - print(f"{entity_project=}") - print(f"{len(runs)=}") for run in runs: - print(f"{run.group=}") if run.group == group_name: history = run.history() metrics = history.columns @@ -321,7 +277,6 @@ def create_report(target_metric, column_names, entity_name, project_name, group_ required=False, help="The default config file for the script.", ) - parser.add_argument("--project_name", type=str, help="W&B project name to log runs into") parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs (workers) to use per trial.") parser.add_argument("--num_cpus", type=int, default=4, help="Number of CPUs to use per GPU (worker).") parser.add_argument("-y", "--assume_yes", action="store_true", help="Don't ask for confirmation") @@ -342,19 +297,11 @@ def create_report(target_metric, column_names, entity_name, project_name, group_ with open(args.default_config) as f: default_config = yaml.safe_load(f) - default_config["train"]["project_name"] = args.project_name - default_config["train"]["group_name"] = datetime.now().replace(microsecond=0).isoformat() - if args.server_address: ray.init(address=f"ray://{args.server_address}") else: ray.init() - resources = { - "num_workers": args.num_gpus, - "CPU": args.num_cpus, - } - print(f'WARNING: Importing main from "{args.script}" and everything along with it') if not args.assume_yes: @@ -366,6 +313,39 @@ def create_report(target_metric, column_names, entity_name, project_name, group_ # convert a nested path to a module path script_path = args.script.replace(".py", "").replace("/", ".") script = importlib.import_module(script_path) - tune_function(script.main, param_space, resources, tune_config, default_config, args.accelerate_config) + project_name = "sweep_" + script_path.split(".")[-1] + + default_config["train"]["project_name"] = project_name + default_config["train"]["group_name"] = datetime.now().replace(microsecond=0).isoformat() + param_space["default_config"] = default_config.copy() + param_space["default_config"]["train"]["git_tag"] = get_git_tag() + param_space_train = {"train_loop_config": param_space} + + tuner = tune.Tuner( + AccelerateTrainer( + script.main, + # Mandatory arg. None means use Accelerate default path + accelerate_config_path=args.accelerate_config, + scaling_config=ScalingConfig( + trainer_resources={"CPU": 0}, + num_workers=args.num_gpus, + use_gpu=True, + resources_per_worker={"CPU": args.num_cpus}, + ), + ), + param_space=param_space_train, + tune_config=tune.TuneConfig(**tune_config), + run_config=ray.air.RunConfig(local_dir="ray_results", callbacks=[CSVLoggerCallback()]), + ) + + results = tuner.fit() + group_name = default_config["train"]["group_name"] + entity_name = default_config["train"].get("entity_name", None) + + column_names = param_space.pop("default_config") + column_names = param_space.keys() + target_metric = tune_config["metric"] + + create_report(target_metric, column_names, entity_name, project_name, group_name, results.get_best_result().config) ray.shutdown() From caf13bee79e88ffcfbc700c6f23d7a2bd6e9250d Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Thu, 16 Feb 2023 23:57:52 +0200 Subject: [PATCH 18/57] chore(configs/ppo_sweep): update variable names to nested structure --- configs/sweeps/ppo_sweep.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/configs/sweeps/ppo_sweep.yml b/configs/sweeps/ppo_sweep.yml index cb939e376..f2fc7ae12 100644 --- a/configs/sweeps/ppo_sweep.yml +++ b/configs/sweeps/ppo_sweep.yml @@ -3,15 +3,15 @@ tune_config: metric: "reward/mean" search_alg: "random" scheduler: "fifo" - num_samples: 32 + num_samples: 16 # https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs -lr: +optimizer.kwargs.lr: strategy: "loguniform" - values: [0.00001, 0.01] -init_kl_coef: + values: [0.000001, 0.0001] +method.init_kl_coef: strategy: "uniform" values: [0, 0.2] -vf_coef: +method.vf_coef: strategy: "uniform" - values: [0.5, 2] + values: [0.25, 1.5] From b5b59ad9e41ea2a57e9ca37cba0f060a379d9400 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Fri, 17 Feb 2023 00:49:58 +0200 Subject: [PATCH 19/57] style(ray_trainer): satisfy black --- trlx/ray_train/accelerate_trainer.py | 6 +++--- trlx/ray_train/launch.py | 15 +++------------ 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/trlx/ray_train/accelerate_trainer.py b/trlx/ray_train/accelerate_trainer.py index d9e2e9392..347d5f1f1 100644 --- a/trlx/ray_train/accelerate_trainer.py +++ b/trlx/ray_train/accelerate_trainer.py @@ -3,7 +3,7 @@ from argparse import Namespace from functools import wraps from pathlib import Path -from typing import TYPE_CHECKING, Callable, Dict, Optional, Union, Type, Tuple +from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type, Union from ray.air import session from ray.air.checkpoint import Checkpoint @@ -15,10 +15,10 @@ from ray.data.preprocessor import Preprocessor from ray.tune.trainable import Trainable - +from accelerate.commands.config import default_config_file, load_config_from_file from ray.train.torch import TorchTrainer + from .launch import launch_command, launch_command_parser -from accelerate.commands.config import default_config_file, load_config_from_file class _AccelerateDefaultNamespace(Namespace): diff --git a/trlx/ray_train/launch.py b/trlx/ray_train/launch.py index 4b6c48c96..b782f551f 100644 --- a/trlx/ray_train/launch.py +++ b/trlx/ray_train/launch.py @@ -20,6 +20,8 @@ from unittest.mock import patch from accelerate.commands.config.config_utils import DYNAMO_BACKENDS +from accelerate.commands.launch import launch_command as original_launch_command +from accelerate.commands.launch import launch_command_parser from accelerate.utils import ( DynamoBackend, PrecisionType, @@ -27,10 +29,6 @@ is_torch_version, ) from accelerate.utils.launch import env_var_path_add -from accelerate.commands.launch import ( - launch_command_parser, - launch_command as original_launch_command, -) logger = logging.getLogger(__name__) @@ -75,7 +73,7 @@ def simple_launcher(args): os.environ.update(current_env) -def multi_gpu_launcher(args): +def multi_gpu_launcher(args): # noqa: C901 current_env = {} mixed_precision = args.mixed_precision.lower() try: @@ -146,13 +144,6 @@ def deepspeed_launcher(args): f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}." ) - if args.fp16: - warnings.warn( - '--fp16 flag is deprecated and will be removed in version 0.15.0 of 🤗 Accelerate. Use "--mixed_precision fp16" instead.', - FutureWarning, - ) - mixed_precision = "fp16" - current_env["PYTHONPATH"] = env_var_path_add("PYTHONPATH", os.path.abspath(".")) current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision) current_env["ACCELERATE_USE_DEEPSPEED"] = "true" From deb667a00cb66a1dd20f9a61ea6d7482637ac6e3 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Fri, 17 Feb 2023 00:53:23 +0200 Subject: [PATCH 20/57] style(sweep): satisfy isort --- trlx/sweep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trlx/sweep.py b/trlx/sweep.py index e7d6e6527..e963f423c 100644 --- a/trlx/sweep.py +++ b/trlx/sweep.py @@ -5,13 +5,13 @@ from datetime import datetime import ray +import wandb import wandb.apis.reports as wb import yaml from ray import tune from ray.air import ScalingConfig from ray.tune.logger import CSVLoggerCallback -import wandb from trlx.ray_train.accelerate_trainer import AccelerateTrainer from trlx.utils import get_git_tag From 9082cce43a3726f65aeb371518a70574d1b702a1 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Mon, 27 Feb 2023 19:05:45 +0200 Subject: [PATCH 21/57] chore(configs/sweeps): update variable names to the nested structure --- configs/sweeps/ilql_sweep.yml | 13 +++++++------ configs/sweeps/ppo_sweep.yml | 7 +++++-- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/configs/sweeps/ilql_sweep.yml b/configs/sweeps/ilql_sweep.yml index ead70b0fd..fa23296a6 100644 --- a/configs/sweeps/ilql_sweep.yml +++ b/configs/sweeps/ilql_sweep.yml @@ -3,17 +3,18 @@ tune_config: metric: "metrics/sentiments" search_alg: "random" scheduler: "fifo" - num_samples: 32 + num_samples: 8 -lr: +# https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs +optimizer.kwargs.lr: strategy: "loguniform" - values: [0.00001, 0.01] -tau: + values: [0.000001, 0.001] +method.tau: strategy: "uniform" values: [0.6, 0.9] -steps_for_target_q_sync: +method.steps_for_target_q_sync: strategy: "choice" values: [1, 5, 10] -alpha: +method.alpha: strategy: "loguniform" values: [0.001, 1.0] diff --git a/configs/sweeps/ppo_sweep.yml b/configs/sweeps/ppo_sweep.yml index f2fc7ae12..84fa627d9 100644 --- a/configs/sweeps/ppo_sweep.yml +++ b/configs/sweeps/ppo_sweep.yml @@ -3,15 +3,18 @@ tune_config: metric: "reward/mean" search_alg: "random" scheduler: "fifo" - num_samples: 16 + num_samples: 8 # https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs optimizer.kwargs.lr: strategy: "loguniform" - values: [0.000001, 0.0001] + values: [0.000001, 0.001] method.init_kl_coef: strategy: "uniform" values: [0, 0.2] method.vf_coef: strategy: "uniform" values: [0.25, 1.5] +model.num_layers_unfrozen: + strategy: "choice" + values: [-1, 2, 6] From 02c0976f225253fc3881977b2f6d95c4c208d78f Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Mon, 27 Feb 2023 19:07:20 +0200 Subject: [PATCH 22/57] chore(ilql_sentiments): restructure config loading for sweeps --- examples/ilql_sentiments.py | 13 ++++++------- examples/ppo_sentiments.py | 2 +- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/ilql_sentiments.py b/examples/ilql_sentiments.py index 03caa66aa..9b10e256e 100644 --- a/examples/ilql_sentiments.py +++ b/examples/ilql_sentiments.py @@ -15,13 +15,8 @@ def get_positive_score(scores): return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] -config_path = pathlib.Path(__file__).parent.joinpath("../configs/ilql_config.yml") -with config_path.open() as f: - default_config = yaml.safe_load(f) - - def main(hparams={}): - config = TRLConfig.update(default_config, hparams) + config = TRLConfig.update(hparams.pop("default_config"), hparams) sentiment_fn = pipeline( "sentiment-analysis", @@ -48,4 +43,8 @@ def metric_fn(samples: List[str], **kwargs) -> Dict[str, List[float]]: if __name__ == "__main__": - main() + config_path = pathlib.Path(__file__).parent.joinpath("../configs/ilql_config.yml") + with config_path.open() as f: + default_config = yaml.safe_load(f) + + main({"default_config": default_config}) diff --git a/examples/ppo_sentiments.py b/examples/ppo_sentiments.py index feff61a68..c17283598 100644 --- a/examples/ppo_sentiments.py +++ b/examples/ppo_sentiments.py @@ -48,7 +48,7 @@ def reward_fn(samples: List[str], **kwargs) -> List[float]: trlx.train( reward_fn=reward_fn, prompts=prompts, - eval_prompts=["I don't know much about Hungarian underground"] * 64, + eval_prompts=["I don't know much about Hungarian underground"] * 256, config=config, ) From 9c39c1c86dc880efa6b7cc6c5478857fb15fce33 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Mon, 27 Feb 2023 19:09:25 +0200 Subject: [PATCH 23/57] fix(sweep): rework `best_config` block --- trlx/sweep.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/trlx/sweep.py b/trlx/sweep.py index e963f423c..5732a020b 100644 --- a/trlx/sweep.py +++ b/trlx/sweep.py @@ -5,13 +5,13 @@ from datetime import datetime import ray -import wandb import wandb.apis.reports as wb import yaml from ray import tune from ray.air import ScalingConfig from ray.tune.logger import CSVLoggerCallback +import wandb from trlx.ray_train.accelerate_trainer import AccelerateTrainer from trlx.utils import get_git_tag @@ -247,9 +247,19 @@ def create_report(target_metric, column_names, entity_name, project_name, group_ ] if best_config: + best_config = best_config["train_loop_config"] + config = best_config.pop("default_config") + for name, value in best_config.items(): + *layers, var = name.split(".") + if layers: + d = config[layers[0]] + for layer in layers[1:]: + d = d[layer] + d[var] = value + report.blocks = report.blocks + [ wb.H1(text="Best Config"), - wb.CodeBlock(code=[json.dumps(best_config, indent=4)], language="json"), + wb.CodeBlock(code=[json.dumps(config, indent=4)], language="json"), ] report.save() From f5007408205e3d69d67e7c8ef17cdc385d96a353 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Mon, 27 Feb 2023 19:11:43 +0200 Subject: [PATCH 24/57] chore(configs): disable `scheduler` in default configs --- configs/ilql_config.yml | 4 ++-- configs/ppo_config.yml | 4 ++-- configs/sft_config.yml | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/configs/ilql_config.yml b/configs/ilql_config.yml index 40c162c70..4a1f4706a 100644 --- a/configs/ilql_config.yml +++ b/configs/ilql_config.yml @@ -30,8 +30,8 @@ optimizer: scheduler: name: "cosine_annealing" kwargs: - T_max: 1000 # train.total_steps - eta_min: 5.0e-5 + T_max: 100000000000 + eta_min: 0 method: name: "ilqlconfig" diff --git a/configs/ppo_config.yml b/configs/ppo_config.yml index 32a43ceb6..aea895ffc 100644 --- a/configs/ppo_config.yml +++ b/configs/ppo_config.yml @@ -29,8 +29,8 @@ optimizer: scheduler: name: "cosine_annealing" kwargs: - T_max: 10000 # train.total_steps - eta_min: 1.0e-4 + T_max: 100000000000 + eta_min: 0 method: name: "ppoconfig" diff --git a/configs/sft_config.yml b/configs/sft_config.yml index 4b1efe358..710c4c1b9 100644 --- a/configs/sft_config.yml +++ b/configs/sft_config.yml @@ -29,8 +29,8 @@ optimizer: scheduler: name: "cosine_annealing" kwargs: - T_max: 10000 # train.total_steps - eta_min: 1.0e-4 + T_max: 100000000000 + eta_min: 0 method: name: "sftconfig" From 2cf94cf9a9ef643231b3603e5bdf34662890bb30 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Mon, 27 Feb 2023 19:13:04 +0200 Subject: [PATCH 25/57] fix(accelerate_trainer): device mismatch from `get_device` by setting `TORCH_DEVICE` to `ray.air.session.get_local_rank()` --- trlx/ray_train/accelerate_trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/trlx/ray_train/accelerate_trainer.py b/trlx/ray_train/accelerate_trainer.py index 347d5f1f1..0ce0cfb62 100644 --- a/trlx/ray_train/accelerate_trainer.py +++ b/trlx/ray_train/accelerate_trainer.py @@ -8,7 +8,7 @@ from ray.air import session from ray.air.checkpoint import Checkpoint from ray.air.config import DatasetConfig, RunConfig, ScalingConfig -from ray.train.torch import TorchConfig, get_device +from ray.train.torch import TorchConfig from ray.train.trainer import GenDataset if TYPE_CHECKING: @@ -56,7 +56,7 @@ def __init__( run_config: Optional[RunConfig] = None, datasets: Optional[Dict[str, GenDataset]] = None, preprocessor: Optional["Preprocessor"] = None, - resume_from_checkpoint: Optional[Checkpoint] = None + resume_from_checkpoint: Optional[Checkpoint] = None, ): self.accelerate_config_path = accelerate_config_path or default_config_file if isinstance(self.accelerate_config_path, _AccelerateConfigWrapper): @@ -169,7 +169,7 @@ def wrapped_train_loop_per_worker(*args, **kwargs): os.environ["LOCAL_RANK"] = str(session.get_local_rank()) os.environ["LOCAL_WORLD_SIZE"] = str(session.get_local_world_size()) os.environ["LOCAL_SIZE"] = str(session.get_local_world_size()) - os.environ["ACCELERATE_TORCH_DEVICE"] = str(get_device()) + os.environ["ACCELERATE_TORCH_DEVICE"] = f"cuda:{session.get_local_rank()}" return train_loop_per_worker(*args, **kwargs) From ebe630870b1cd64193251db3e22dc76dffad19b9 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Mon, 27 Feb 2023 19:19:16 +0200 Subject: [PATCH 26/57] chore(base_trainer): lower verbosity when under sweep --- trlx/trainer/accelerate_base_trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index d633fefd4..ab0a7281a 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -66,6 +66,10 @@ def __init__(self, config, **kwargs): # noqa: C901 if config.model.model_arch_type != "seq2seq": self.tokenizer.pad_token = self.tokenizer.eos_token + if ray.is_initialized(): + logging.set_verbosity(logging.WARNING) + logging.disable_progress_bar() + script_name = os.path.basename(sys.argv[0]).rsplit(".", 1)[0] if not isinstance(config.model.model_path, str): model_name = str(config.model.model_path).split()[0] From 0a1303e421c915488578522816ef29c6cebf6a33 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Mon, 27 Feb 2023 19:20:09 +0200 Subject: [PATCH 27/57] chore(ppo_trainer): reenable w&b logging in `make_experience` --- trlx/trainer/accelerate_ppo_trainer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 8b4cce883..1fd0f30a9 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -4,7 +4,6 @@ from time import time from typing import Callable, List, Optional -import ray import torch import torch.nn.functional as F from torch.utils.data import DataLoader @@ -484,9 +483,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq stats["kl_ctl_value"] = self.kl_ctl.value stats["time/exp"] = exp_time - - if not ray.is_initialized(): - self.accelerator.log(stats, step=iter_count) + self.accelerator.log(stats, step=iter_count) # Push samples and rewards to trainer's rollout storage self.push_to_store(ppo_rl_elements) From 1ac7b99992987851059c39414909723e0d37be64 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Mon, 27 Feb 2023 19:24:03 +0200 Subject: [PATCH 28/57] feat(scripts): add an example of setting up ray cluster on slurm --- scripts/sweep-cw.sh | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 scripts/sweep-cw.sh diff --git a/scripts/sweep-cw.sh b/scripts/sweep-cw.sh new file mode 100644 index 000000000..e0c8881e5 --- /dev/null +++ b/scripts/sweep-cw.sh @@ -0,0 +1,40 @@ +#!/bin/bash +#SBATCH --job-name=trlx-sweep +#SBATCH --account=trlx +#SBATCH --partition=a100-cu117 +#SBATCH --nodes=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --mem=0 +#SBATCH --output=%j +#SBATCH --exclusive + +export NCCL_DEBUG=WARN +export NCCL_PROTO=simple +export FI_EFA_FORK_SAFE=1 +export FI_LOG_LEVEL=1 +export FI_EFA_USE_DEVICE_RDMA=1 +export FI_EFA_ENABLE_SHM_TRANSFER=0 +export FI_PROVIDER=efa +export FI_EFA_TX_MIN_CREDITS=64 +# export CUDA_LAUNCH_BLOCKING=1 + +export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` +export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) + +cd $TRLX +source $TRLX/.env/bin/activate + +ray start --head --port=6379 & + +export HOSTNAMES=($HOSTNAMES) +for node in ${HOSTNAMES[@]:1}; do + echo "Starting ray worker @ $node" + srun --nodes=1 --ntasks=1 -w "$node" ray start --address $MASTER_ADDR:6379 --block & +done + +sleep 10 +ray status + +NUM_GPUS=16 +python -m trlx.sweep -y --config configs/sweeps/ppo_sweep.yml --default_config configs/ppo_config.yml --accelerate_config configs/accelerate/zero2-bf16.yaml --num_gpus $NUM_GPUS examples/ppo_sentiments.py +# python -m trlx.sweep -y --config configs/sweeps/ilql_sweep.yml --default_config configs/ilql_config.yml --accelerate_config configs/accelerate/zero2-bf16.yaml --num_gpus $NUM_GPUS examples/ilql_sentiments.py From b3664b61407cfc211b7baf2dd2c98c14a55f6ad0 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Mon, 27 Feb 2023 19:36:18 +0200 Subject: [PATCH 29/57] style(sweep): satisfy isort --- trlx/sweep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trlx/sweep.py b/trlx/sweep.py index 5732a020b..d5ea40048 100644 --- a/trlx/sweep.py +++ b/trlx/sweep.py @@ -5,13 +5,13 @@ from datetime import datetime import ray +import wandb import wandb.apis.reports as wb import yaml from ray import tune from ray.air import ScalingConfig from ray.tune.logger import CSVLoggerCallback -import wandb from trlx.ray_train.accelerate_trainer import AccelerateTrainer from trlx.utils import get_git_tag From 614af8c47f581e5a0779d6588b7999c3469c4e5a Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Wed, 1 Mar 2023 00:39:56 +0200 Subject: [PATCH 30/57] revert(base_trainer): remove logging verbosity changes under ray --- trlx/trainer/accelerate_base_trainer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index ab0a7281a..d633fefd4 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -66,10 +66,6 @@ def __init__(self, config, **kwargs): # noqa: C901 if config.model.model_arch_type != "seq2seq": self.tokenizer.pad_token = self.tokenizer.eos_token - if ray.is_initialized(): - logging.set_verbosity(logging.WARNING) - logging.disable_progress_bar() - script_name = os.path.basename(sys.argv[0]).rsplit(".", 1)[0] if not isinstance(config.model.model_path, str): model_name = str(config.model.model_path).split()[0] From 77a71ca82b7c3c43ee8a3566557fe3ff813bd9c3 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Thu, 2 Mar 2023 12:33:06 +0200 Subject: [PATCH 31/57] Revert "Merge branch 'main' into ray-train-integration" This reverts commit 697217a5cd87dcf7b9ab9030d155a78a974911fb, reversing changes made to 614af8c47f581e5a0779d6588b7999c3469c4e5a. --- README.md | 21 +- configs/nemo_ilql_config.yml | 52 + configs/ppo_gptj.yml | 56 + docs/source/trainer.rst | 15 + examples/architext.py | 14 +- examples/ilql_sentiments.py | 13 +- examples/nemo_ilql_inference.py | 20 +- examples/nemo_ilql_sentiments.py | 25 +- examples/notebooks/trlx_simulacra.ipynb | 3 +- examples/ppo_sentiments.py | 13 +- examples/randomwalks/ilql_randomwalks.py | 13 +- examples/randomwalks/ppo_randomwalks.py | 14 +- examples/sft_sentiments.py | 11 +- examples/simulacra.py | 2 - .../configs/ppo_config_cnn_daily.yml | 61 ++ .../t5_summarize_daily_cnn.py | 79 +- .../configs/ppo_config_summ_gptj.yml | 53 + .../trlx_gptj_text_summarization.py | 72 +- setup.cfg | 4 +- tests/test_models.py | 340 ------- tests/test_ppo.py | 84 ++ tests/test_utils.py | 28 +- trlx/data/configs.py | 25 - trlx/data/default_configs.py | 119 --- trlx/models/modeling_base.py | 223 ---- trlx/trainer/__init__.py | 4 + trlx/trainer/accelerate_base_trainer.py | 24 +- trlx/trainer/accelerate_ilql_trainer.py | 30 +- trlx/trainer/accelerate_ppo_trainer.py | 53 +- trlx/trainer/accelerate_sft_trainer.py | 10 + trlx/{models => trainer/nemo}/README.md | 2 +- trlx/{models => trainer/nemo}/__init__.py | 0 .../nemo/gpt.py} | 2 +- trlx/trainer/nemo_ilql_trainer.py | 4 +- trlx/trainer/nn/__init__.py | 0 .../nn/ilql_models.py} | 136 ++- .../nn/ppo_models.py} | 962 +++++++++--------- trlx/trlx.py | 13 +- trlx/utils/__init__.py | 2 +- trlx/utils/modeling.py | 19 +- 40 files changed, 1058 insertions(+), 1563 deletions(-) create mode 100644 configs/nemo_ilql_config.yml create mode 100644 configs/ppo_gptj.yml create mode 100755 examples/summarize_daily_cnn/configs/ppo_config_cnn_daily.yml create mode 100755 examples/summarize_rlhf/configs/ppo_config_summ_gptj.yml delete mode 100644 tests/test_models.py create mode 100644 tests/test_ppo.py delete mode 100644 trlx/data/default_configs.py delete mode 100644 trlx/models/modeling_base.py rename trlx/{models => trainer/nemo}/README.md (98%) rename trlx/{models => trainer/nemo}/__init__.py (100%) rename trlx/{models/modeling_nemo_ilql.py => trainer/nemo/gpt.py} (99%) create mode 100644 trlx/trainer/nn/__init__.py rename trlx/{models/modeling_ilql.py => trainer/nn/ilql_models.py} (73%) rename trlx/{models/modeling_ppo.py => trainer/nn/ppo_models.py} (59%) diff --git a/README.md b/README.md index e140c21fb..da9ba405d 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ trainer = trlx.train('gpt2', reward_fn=lambda samples, **kwargs: [sample.count(' #### Using a reward-labeled dataset ```python -trainer = trlx.train('EleutherAI/gpt-j-6B', samples=['dolphins', 'geese'], rewards=[1.0, 100.0]) +trainer = trlx.train('EleutherAI/gpt-j-6B', dataset=[('dolphins', 'geese'), (1.0, 100.0)]) ``` #### Trainers provide a wrapper over their underlying model @@ -57,25 +57,14 @@ trainer = trlx.train('EleutherAI/gpt-j-6B', samples=['dolphins', 'geese'], rewar trainer.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'), do_sample=True) ``` -#### Configure Hyperparameters - -```python -from trlx.data.default_configs import default_ppo_config, TrainConfig - -config = default_ppo_config() -config.model.model_path = 'EleutherAI/gpt-neox-20b' -config.train.seq_length = 32 -config.train.batch_size = 16 - -trainer = trlx.train(config=config, reward_fn=lambda samples, **kwargs: [float(int(sample)) for sample in samples]) -``` - #### Save the resulting model to a Hugging Face pretrained language model. (Ready to upload to the Hub!) ```python trainer.save_pretrained('/path/to/output/folder/') ``` +🩹 Warning: Only the `AcceleratePPOTrainer` can write HuggingFace transformers to disk with `save_pretrained` at the moment, as ILQL trainers require inference behavior currently unsupported by available `transformers` architectures. + #### Use 🤗 Accelerate to launch distributed training ```bash @@ -85,13 +74,13 @@ accelerate launch examples/simulacra.py #### Use NeMo-Megatron to launch distributed training -Follow the setup instructions in the [NeMo README](./trlx/models/). +Follow the setup instructions in the [NeMo README](./trlx/trainer/nemo). ```bash python examples/nemo_ilql_sentiments.py ``` -For more usage see the [NeMo README](./trlx/models) +For more usage see the [NeMo README](./trlx/trainer/nemo) #### Use Ray Tune to launch hyperparameter sweep diff --git a/configs/nemo_ilql_config.yml b/configs/nemo_ilql_config.yml new file mode 100644 index 000000000..1d4cc71e2 --- /dev/null +++ b/configs/nemo_ilql_config.yml @@ -0,0 +1,52 @@ +train: + seq_length: 1024 + batch_size: 512 + epochs: 100 + total_steps: 200 + checkpoint_interval: 200 + eval_interval: 20 + + pipeline: "PromptPipeline" + trainer: "NeMoILQLTrainer" + trainer_kwargs: + pretrained_model: "/mnt/nvme/home/uwu/nemo-megatron-gpt-20B/" + megatron_cfg: "megatron_20b.yaml" + seed: 1000 + +model: + model_path: "gpt2" + num_layers_unfrozen: -1 + +tokenizer: + tokenizer_path: "gpt2" + truncation_side: "right" + +optimizer: + name: "adamw" + kwargs: + lr: 5.0e-5 + betas: [0.9, 0.95] + eps: 1.0e-8 + weight_decay: 1.0e-6 + +scheduler: + name: "cosine_annealing" + kwargs: + T_max: 2000 # train.total_steps + eta_min: 1.0e-6 + +method: + name: "ilqlconfig" + tau: 0.7 + gamma: 0.99 + cql_scale: 0.1 + awac_scale: 1 + alpha: 0.001 + beta: 0 + steps_for_target_q_sync: 5 + two_qs: True + gen_kwargs: + max_new_tokens: 56 + top_k: 20 + beta: 2 + temperature: 0.9 diff --git a/configs/ppo_gptj.yml b/configs/ppo_gptj.yml new file mode 100644 index 000000000..0595f7ded --- /dev/null +++ b/configs/ppo_gptj.yml @@ -0,0 +1,56 @@ +train: + seq_length: 48 + epochs: 10 + total_steps: 80000 + batch_size: 8 + + checkpoint_interval: 1000000 + eval_interval: 16 + + pipeline: "PromptPipeline" + trainer: "AcceleratePPOTrainer" + +model: + model_path: "EleutherAI/gpt-j-6B" + num_layers_unfrozen: 2 + +tokenizer: + tokenizer_path: "gpt2" + +optimizer: + name: "adamw" + kwargs: + lr: 1.412e-4 + betas: [0.9, 0.95] + eps: 1.0e-8 + weight_decay: 1.0e-6 + +scheduler: + name: "cosine_annealing" + kwargs: + T_max: 80000 # train.total_steps + eta_min: 1.412e-4 + +method: + name: "ppoconfig" + num_rollouts: 8 + chunk_size: 8 + ppo_epochs: 4 + init_kl_coef: 0.2 + target: 6 + horizon: 10000 + gamma: 1 + lam: 0.95 + cliprange: 0.2 + cliprange_value: 0.2 + vf_coef: 0.2 + scale_reward: False + ref_mean: null + ref_std: null + cliprange_reward: 10 + gen_kwargs: + max_new_tokens: 48 + top_k: 0.0 + top_p: 0.7 + do_sample: True + temperature: 0.5 diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst index 6259c8b21..0972cc5ff 100644 --- a/docs/source/trainer.rst +++ b/docs/source/trainer.rst @@ -19,7 +19,22 @@ Note that new trainers must be registered with ``trlx.trainer.register_trainer`` .. autoclass:: trlx.trainer.accelerate_ppo_trainer.AcceleratePPOTrainer :members: +.. autoclass:: trlx.trainer.nn.ppo_models.CausalLMWithValueHead + :members: + +.. autoclass:: trlx.trainer.nn.ppo_models.GPTModelBranch + :members: + +.. autoclass:: trlx.trainer.nn.ppo_models.OPTModelBranch + :members: + +.. autoclass:: trlx.trainer.nn.ppo_models.CausalLMHydraWithValueHead + :members: + **ILQL** .. autoclass:: trlx.trainer.accelerate_ilql_trainer.AccelerateILQLTrainer :members: + +.. autoclass:: trlx.trainer.nn.ilql_models.CausalLMWithValueHeads + :members: diff --git a/examples/architext.py b/examples/architext.py index 6e31f3497..d854c4858 100644 --- a/examples/architext.py +++ b/examples/architext.py @@ -1,7 +1,11 @@ # Toy example of optimizing textual interior designs to output the least number of rooms # Also see https://architext.design/ +import pathlib + +import yaml + import trlx -from trlx.data.default_configs import default_ppo_config +from trlx.data.configs import TRLConfig def reward_fn(samples, **kwargs): @@ -26,9 +30,13 @@ def reward_fn(samples, **kwargs): "[prompt] the kitchen is not adjacent to the bathroom [layout]", ] +config_path = pathlib.Path(__file__).parent.joinpath("../configs/ppo_config.yml") +with config_path.open() as f: + default_config = yaml.safe_load(f) + -def main(): - config = default_ppo_config() +def main(hparams={}): + config = TRLConfig.update(default_config, hparams) trlx.train(model_path="architext/gptj-162M", reward_fn=reward_fn, prompts=prompts, config=config) diff --git a/examples/ilql_sentiments.py b/examples/ilql_sentiments.py index d425b24d1..9b10e256e 100644 --- a/examples/ilql_sentiments.py +++ b/examples/ilql_sentiments.py @@ -1,11 +1,13 @@ import os +import pathlib from typing import Dict, List +import yaml from datasets import load_dataset from transformers import pipeline import trlx -from trlx.data.default_configs import TRLConfig, default_ilql_config +from trlx.data.configs import TRLConfig def get_positive_score(scores): @@ -14,8 +16,7 @@ def get_positive_score(scores): def main(hparams={}): - # Merge sweep config with default config if given - config = TRLConfig.update(default_ilql_config().to_dict(), hparams) + config = TRLConfig.update(hparams.pop("default_config"), hparams) sentiment_fn = pipeline( "sentiment-analysis", @@ -42,4 +43,8 @@ def metric_fn(samples: List[str], **kwargs) -> Dict[str, List[float]]: if __name__ == "__main__": - main() + config_path = pathlib.Path(__file__).parent.joinpath("../configs/ilql_config.yml") + with config_path.open() as f: + default_config = yaml.safe_load(f) + + main({"default_config": default_config}) diff --git a/examples/nemo_ilql_inference.py b/examples/nemo_ilql_inference.py index f172f6fbb..425a8cdb2 100644 --- a/examples/nemo_ilql_inference.py +++ b/examples/nemo_ilql_inference.py @@ -2,6 +2,7 @@ import sys from glob import glob +import yaml from nemo.collections.nlp.modules.common.megatron.megatron_init import ( fake_initialize_model_parallel, ) @@ -9,24 +10,10 @@ from nemo.utils.model_utils import inject_model_parallel_rank from omegaconf.omegaconf import OmegaConf -from trlx.data.configs import TrainConfig -from trlx.data.default_configs import default_ilql_config +from trlx.data.configs import TRLConfig from trlx.trainer.nemo_ilql_trainer import ILQLGPT, megatron_trainer -default_config = default_ilql_config() - -trl_config = default_config.evolve( - train=TrainConfig( - **dict( - default_config.train.__dict__, - trainer="NeMoILQLTrainer", - trainer_kwargs=dict( - pretrained_model="/mnt/nvme/home/uwu/nemo-megatron-gpt-20B/", - megatron_cfg="megatron_20b.yaml", - ), - ), - ) -) +default_config = yaml.safe_load(open(os.path.dirname(__file__) + "/../configs/nemo_ilql_config.yml")) def find_checkpoints(checkpoint_dir): @@ -36,6 +23,7 @@ def find_checkpoints(checkpoint_dir): def main(megatron_cfg_path, checkpoint_path): + trl_config = TRLConfig.update(default_config, {}) ilql_config = trl_config.method megatron_cfg = OmegaConf.load(megatron_cfg_path) diff --git a/examples/nemo_ilql_sentiments.py b/examples/nemo_ilql_sentiments.py index 34044622a..82abe3b7b 100644 --- a/examples/nemo_ilql_sentiments.py +++ b/examples/nemo_ilql_sentiments.py @@ -1,10 +1,12 @@ +import os from typing import Dict, List +import yaml from datasets import load_dataset from transformers import pipeline import trlx -from trlx.data.default_configs import default_ilql_config +from trlx.data.configs import TRLConfig def get_positive_score(scores): @@ -12,25 +14,11 @@ def get_positive_score(scores): return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] -default_config = default_ilql_config() +default_config = yaml.safe_load(open(os.path.dirname(__file__) + "/../configs/nemo_ilql_config.yml")) def main(hparams={}): - # Merge sweep config with default config if given - - config = default_config.evolve( - train=dict( - seq_length=1024, - batch_size=512, - total_steps=200, - trainer="NeMoILQLTrainer", - trainer_kwargs=dict( - pretrained_model="/mnt/nvme/home/uwu/nemo-megatron-gpt-20B/", - megatron_cfg="megatron_20b.yaml", - ), - ) - ) - config = config.evolve(**hparams) + config = TRLConfig.update(default_config, hparams) sentiment_fn = pipeline( "sentiment-analysis", @@ -48,8 +36,7 @@ def metric_fn(samples: List[str], **kwargs) -> Dict[str, List[float]]: imdb = load_dataset("imdb", split="train+test") trlx.train( - samples=imdb["text"], - rewards=imdb["label"], + dataset=(imdb["text"], imdb["label"]), eval_prompts=["I don't know much about Hungarian underground"] * 128, metric_fn=metric_fn, config=config, diff --git a/examples/notebooks/trlx_simulacra.ipynb b/examples/notebooks/trlx_simulacra.ipynb index 407d624f1..28e36aaa1 100644 --- a/examples/notebooks/trlx_simulacra.ipynb +++ b/examples/notebooks/trlx_simulacra.ipynb @@ -1158,8 +1158,7 @@ "source": [ "trlx.train(\n", " \"gpt2\",\n", - " samples=prompts,\n", - " rewards=ratings,\n", + " dataset=(prompts, ratings),\n", " eval_prompts=[\"Hatsune Miku, Red Dress\"] * 64,\n", ")" ] diff --git a/examples/ppo_sentiments.py b/examples/ppo_sentiments.py index 23a76fb58..c17283598 100644 --- a/examples/ppo_sentiments.py +++ b/examples/ppo_sentiments.py @@ -2,14 +2,16 @@ # with a sentiment reward function import os +import pathlib from typing import List import torch +import yaml from datasets import load_dataset from transformers import pipeline import trlx -from trlx.data.default_configs import TRLConfig, default_ppo_config +from trlx.data.configs import TRLConfig def get_positive_score(scores): @@ -18,8 +20,8 @@ def get_positive_score(scores): def main(hparams={}): - # Merge sweep config with default config if given - config = TRLConfig.update(default_ppo_config().to_dict(), hparams) + default_config = hparams.pop("default_config") + config = TRLConfig.update(default_config, hparams) if torch.cuda.is_available(): device = int(os.environ.get("LOCAL_RANK", 0)) @@ -52,4 +54,7 @@ def reward_fn(samples: List[str], **kwargs) -> List[float]: if __name__ == "__main__": - main() + config_path = pathlib.Path(__file__).parent.joinpath("../configs/ppo_config.yml") + with config_path.open() as f: + default_config = yaml.safe_load(f) + main({"default_config": default_config}) diff --git a/examples/randomwalks/ilql_randomwalks.py b/examples/randomwalks/ilql_randomwalks.py index 043787e10..ebc31660a 100644 --- a/examples/randomwalks/ilql_randomwalks.py +++ b/examples/randomwalks/ilql_randomwalks.py @@ -1,12 +1,19 @@ +import pathlib + +import yaml from transformers import GPT2Config import trlx from examples.randomwalks import generate_random_walks -from trlx.data.default_configs import default_ilql_config +from trlx.data.configs import TRLConfig + +config_path = pathlib.Path(__file__).parent.joinpath("configs/ilql_randomwalks.yml") +with config_path.open() as f: + default_config = yaml.safe_load(f) -def main(): - config = default_ilql_config() +def main(hparams={}): + config = TRLConfig.update(default_config, hparams) metric_fn, eval_prompts, walks, _ = generate_random_walks(seed=config.train.seed) rewards = metric_fn(walks)["optimality"] diff --git a/examples/randomwalks/ppo_randomwalks.py b/examples/randomwalks/ppo_randomwalks.py index 50981cbe2..113897fe6 100644 --- a/examples/randomwalks/ppo_randomwalks.py +++ b/examples/randomwalks/ppo_randomwalks.py @@ -1,10 +1,18 @@ +import pathlib + +import yaml + import trlx from examples.randomwalks import generate_random_walks -from trlx.data.default_configs import default_ppo_config +from trlx.data.configs import TRLConfig + +config_path = pathlib.Path(__file__).parent.joinpath("configs/ppo_randomwalks.yml") +with config_path.open() as f: + default_config = yaml.safe_load(f) -def main(): - config = default_ppo_config().evolve(model=dict(model_path="gpt2")) +def main(hparams={}): + config = TRLConfig.update(default_config, hparams) metric_fn, prompts, *_ = generate_random_walks(seed=config.train.seed) diff --git a/examples/sft_sentiments.py b/examples/sft_sentiments.py index 0270159ea..c289d3a38 100644 --- a/examples/sft_sentiments.py +++ b/examples/sft_sentiments.py @@ -1,11 +1,17 @@ import os +import pathlib from typing import Dict, List +import yaml from datasets import load_dataset from transformers import pipeline import trlx -from trlx.data.default_configs import TRLConfig, default_sft_config +from trlx.data.configs import TRLConfig + +config_path = pathlib.Path(__file__).parent.joinpath("../configs/sft_config.yml") +with config_path.open() as f: + default_config = yaml.safe_load(f) def get_positive_score(scores): @@ -14,8 +20,7 @@ def get_positive_score(scores): def main(hparams={}): - # Merge sweep config with default config if given - config = TRLConfig.update(default_sft_config().to_dict(), hparams) + config = TRLConfig.update(default_config, hparams) imdb = load_dataset("imdb", split="train+test") # Finetune on only positive reviews diff --git a/examples/simulacra.py b/examples/simulacra.py index f4d6f82d8..cc28520d6 100644 --- a/examples/simulacra.py +++ b/examples/simulacra.py @@ -6,7 +6,6 @@ from urllib.request import urlretrieve import trlx -from trlx.data.default_configs import default_ilql_config url = "https://raw.githubusercontent.com/JD-P/simulacra-aesthetic-captions/main/sac_public_2022_06_29.sqlite" dbpath = "sac_public_2022_06_29.sqlite" @@ -27,7 +26,6 @@ prompts, ratings = tuple(map(list, zip(*c.fetchall()))) trlx.train( - config=default_ilql_config(), samples=prompts, rewards=ratings, eval_prompts=["Hatsune Miku, Red Dress"] * 64, diff --git a/examples/summarize_daily_cnn/configs/ppo_config_cnn_daily.yml b/examples/summarize_daily_cnn/configs/ppo_config_cnn_daily.yml new file mode 100755 index 000000000..2134beadd --- /dev/null +++ b/examples/summarize_daily_cnn/configs/ppo_config_cnn_daily.yml @@ -0,0 +1,61 @@ +train: + seq_length: 612 + epochs: 100 + total_steps: 100000 + batch_size: 12 + + checkpoint_interval: 10000 + eval_interval: 500 + save_best: False + + pipeline: "PromptPipeline" + trainer: "AcceleratePPOTrainer" + +model: + model_path: "google/flan-t5-large" + model_arch_type: "seq2seq" + num_layers_unfrozen: 2 + +tokenizer: + tokenizer_path: "google/flan-t5-large" + truncation_side: "right" + +optimizer: + name: "adamw" + kwargs: + lr: 1.0e-5 + betas: [0.9, 0.999] + eps: 1.0e-8 + weight_decay: 1.0e-6 + +scheduler: + name: "cosine_annealing" + kwargs: + T_max: 10000 + eta_min: 1.0e-6 + +method: + name: "ppoconfig" + num_rollouts: 512 + chunk_size: 12 + ppo_epochs: 4 + init_kl_coef: 0.05 + target: 6 + horizon: 10000 + gamma: 0.99 + lam: 0.95 + cliprange: 0.2 + cliprange_value: 0.2 + vf_coef: 1.0 + scale_reward: False + ref_mean: null + ref_std: null + cliprange_reward: 10 + gen_kwargs: + max_new_tokens: 100 + gen_experience_kwargs: + max_new_tokens: 100 + do_sample: True + temperature: 1.0 + top_k: 50 + top_p: 0.95 diff --git a/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py b/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py index 4c3a56758..67863bf7d 100755 --- a/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py +++ b/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py @@ -1,3 +1,4 @@ +import pathlib from typing import List from datasets import load_dataset @@ -5,15 +6,7 @@ from transformers import AutoTokenizer import trlx -from trlx.data.configs import ( - ModelConfig, - OptimizerConfig, - SchedulerConfig, - TokenizerConfig, - TrainConfig, - TRLConfig, -) -from trlx.models.modeling_ppo import PPOConfig +from trlx.data.configs import TRLConfig try: import evaluate @@ -22,72 +15,8 @@ "To run this example, please install the `evaluate` and `nltk` packages" "by running `pip install evaluate`" ) -config = TRLConfig( - train=TrainConfig( - seq_length=612, - epochs=100, - total_steps=100000, - batch_size=12, - checkpoint_interval=10000, - eval_interval=500, - pipeline="PromptPipeline", - trainer="AcceleratePPOTrainer", - ), - model=ModelConfig( - model_path="google/flan-t5-large", - model_arch_type="seq2seq", - num_layers_unfrozen=2, - ), - tokenizer=TokenizerConfig( - tokenizer_path="google/flan-t5-large", - truncation_side="right", - ), - optimizer=OptimizerConfig( - name="adamw", - kwargs={ - "lr": 1.0e-5, - "betas": [0.9, 0.999], - "eps": 1.0e-8, - "weight_decay": 1.0e-6, - }, - ), - scheduler=SchedulerConfig( - name="cosine_annealing", - kwargs={ - "T_max": 10000, - "eta_min": 1.0e-6, - }, - ), - method=PPOConfig( - name="PPOConfig", - num_rollouts=512, - chunk_size=12, - ppo_epochs=4, - init_kl_coef=0.05, - target=6, - horizon=10000, - gamma=0.99, - lam=0.95, - cliprange=0.2, - cliprange_value=0.2, - vf_coef=1.0, - scale_reward=None, - ref_mean=None, - ref_std=None, - cliprange_reward=10, - gen_kwargs={ - "max_new_tokens": 100, - }, - gen_experience_kwargs={ - "max_new_tokens": 100, - "do_sample": True, - "temperature": 1.0, - "top_k": 50, - "top_p": 0.95, - }, - ), -) - +config_path = pathlib.Path(__file__).parent / "configs/ppo_config_cnn_daily.yml" +config = TRLConfig.load_yaml(config_path) meteor = evaluate.load("meteor") # use meteor as the reward function diff --git a/examples/summarize_rlhf/configs/ppo_config_summ_gptj.yml b/examples/summarize_rlhf/configs/ppo_config_summ_gptj.yml new file mode 100755 index 000000000..8055a49b5 --- /dev/null +++ b/examples/summarize_rlhf/configs/ppo_config_summ_gptj.yml @@ -0,0 +1,53 @@ +train: + seq_length: 550 + epochs: 50 + total_steps: 100000 + batch_size: 4 + + checkpoint_interval: 10000 + eval_interval: 200 + + pipeline: "PromptPipeline" + trainer: "AcceleratePPOTrainer" + +model: + model_path: "CarperAI/openai_summarize_tldr_sft" + num_layers_unfrozen: 8 + +tokenizer: + tokenizer_path: "gpt2" + truncation_side: "right" + +optimizer: + name: "adamw" + kwargs: + lr: 5.0e-6 + betas: [0.9, 0.999] + eps: 1.0e-8 + weight_decay: 0.01 + +scheduler: + name: "cosine_annealing" + kwargs: + T_max: 100000 + eta_min: 5.0e-6 + +method: + name: "ppoconfig" + num_rollouts: 128 + chunk_size: 16 + ppo_epochs: 4 + init_kl_coef: 0.1 + target: 6 + horizon: 10000 + gamma: 1 + lam: 0.95 + cliprange: 0.2 + cliprange_value: 0.2 + vf_coef: 0.2 + scale_reward: False + ref_mean: null + ref_std: null + cliprange_reward: 10 + gen_kwargs: + max_new_tokens: 50 diff --git a/examples/summarize_rlhf/trlx_gptj_text_summarization.py b/examples/summarize_rlhf/trlx_gptj_text_summarization.py index 9d0d8dd46..3d9e3c5f3 100755 --- a/examples/summarize_rlhf/trlx_gptj_text_summarization.py +++ b/examples/summarize_rlhf/trlx_gptj_text_summarization.py @@ -1,4 +1,5 @@ import os +import pathlib from typing import List import torch @@ -8,15 +9,7 @@ from transformers import AutoTokenizer import trlx -from trlx.data.configs import ( - ModelConfig, - OptimizerConfig, - SchedulerConfig, - TokenizerConfig, - TrainConfig, - TRLConfig, -) -from trlx.trainer.nn.ppo_models import PPOConfig +from trlx.data.configs import TRLConfig REWARD_CHECKPOINT_PATH = "reward_model/rm_checkpoint/pytorch_model.bin" if not os.path.exists(REWARD_CHECKPOINT_PATH): @@ -27,64 +20,6 @@ ) SFT_MODEL_PATH = "CarperAI/openai_summarize_tldr_sft" -config = TRLConfig( - train=TrainConfig( - seq_length=550, - epochs=50, - total_steps=100000, - batch_size=4, - checkpoint_interval=10000, - eval_interval=200, - pipeline="PromptPipeline", - trainer="AcceleratePPOTrainer", - ), - model=ModelConfig( - model_path="CarperAI/openai_summarize_tldr_sft", - num_layers_unfrozen=8, - ), - tokenizer=TokenizerConfig( - tokenizer_path="gpt2", - truncation_side="right", - ), - optimizer=OptimizerConfig( - name="adamw", - kwargs={ - "lr": 5.0e-6, - "betas": [0.9, 0.999], - "eps": 1.0e-8, - "weight_decay": 0.01, - }, - ), - scheduler=SchedulerConfig( - name="cosine_annealing", - kwargs={ - "T_max": 100000, - "eta_min": 5.0e-6, - }, - ), - method=PPOConfig( - name="PPOConfig", - num_rollouts=128, - chunk_size=16, - ppo_epochs=4, - init_kl_coef=0.1, - target=6, - horizon=10000, - gamma=1, - lam=0.95, - cliprange=0.2, - cliprange_value=0.2, - vf_coef=0.2, - scale_reward=None, - ref_mean=None, - ref_std=None, - cliprange_reward=10, - gen_kwargs={ - "max_new_tokens": 50, - }, - ), -) - if __name__ == "__main__": # Load the pre-trained reward model @@ -152,6 +87,9 @@ def reward_fn(samples: List[str], **kwargs): norms_scores = scores - original_scores return norms_scores + config_path = pathlib.Path(__file__).parent.joinpath("configs/ppo_config_summ_gptj.yml") + config = TRLConfig.load_yaml(config_path) + tokenizer = AutoTokenizer.from_pretrained(config.tokenizer.tokenizer_path) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" diff --git a/setup.cfg b/setup.cfg index 8893f97c0..9c0ef02f2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,7 @@ [metadata] name = trlx author = Alex Havrilla -version = 0.5.0 +version = 0.3.0 url = https://github.com/CarperAI/trlx description = A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF) long_description = file: README.md @@ -12,8 +12,6 @@ license = MIT packages = find: install_requires = accelerate>=0.12.0 - attrs>=22.1.0 - cattrs>=22.2.0 datasets deepspeed>=0.7.3 einops>=0.4.1 diff --git a/tests/test_models.py b/tests/test_models.py deleted file mode 100644 index 7924393a7..000000000 --- a/tests/test_models.py +++ /dev/null @@ -1,340 +0,0 @@ -import copy -import gc -import tempfile -import unittest - -import torch -import transformers - -from trlx.models.modeling_ilql import AutoModelForCausalLMWithILQLHeads -from trlx.models.modeling_ppo import ( - AutoModelForCausalLMWithHydraValueHead, - AutoModelForCausalLMWithValueHead, - AutoModelForSeq2SeqLMWithHydraValueHead, - AutoModelForSeq2SeqLMWithValueHead, -) - -AUTO_CAUSAL_LM_PATHS = ["gpt2", "EleutherAI/pythia-160m", "facebook/opt-125m"] -AUTO_SEQ2SEQ_LM_PATHS = ["t5-small", "google/flan-t5-small"] - - -# Value Head Modeling Tests - - -class TestAutoModelForCausalLMWithValueHead(unittest.TestCase): - _auto_model_class = AutoModelForCausalLMWithValueHead - _supported_args = {} - - def setUp(self): - self.text = "Once upon a time there was a happy goose named Louis. He liked to eat bananas." - - def tearDown(self): - gc.collect() # Try to free up memory - - def _create_inputs(self, model_path): - tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) - tokenizer.pad_token = tokenizer.eos_token - tokenizer.padding_side = "left" - return tokenizer(self.text, truncation=True, padding="max_length", max_length=4, return_tensors="pt") - - def test_forward(self): - for model_path in AUTO_CAUSAL_LM_PATHS: - model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) - inputs = self._create_inputs(model_path) - - # Ensure that the `forward` method doesn't throw an error on generic inputs - try: - model(**inputs) - except Exception as e: - self.assertFalse(True, msg=e) - - def test_generate(self): - for model_path in AUTO_CAUSAL_LM_PATHS: - model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) - inputs = self._create_inputs(model_path) - - # Ensure that the `generate` method doesn't throw an error on generic inputs - try: - model.generate(**inputs, return_dict=True, output_hidden_states=True) - except Exception as e: - self.assertFalse(True, msg=e) - - def test_save_load(self): - for model_path in AUTO_CAUSAL_LM_PATHS: - model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) - modified_model = copy.deepcopy(model) - - # Manually modify value head parameters - modified_model.v_head[-1].bias = torch.nn.Parameter(torch.tensor([6000053.33])) - - with tempfile.TemporaryDirectory() as tmpdirname: - modified_model.save_pretrained(tmpdirname) - loaded_model = self._auto_model_class.from_pretrained(tmpdirname) - - # Check that the loaded model state dict is the same as the saved model state dict - loaded_state_dict = loaded_model.state_dict() - self.assertEqual(modified_model.state_dict().keys(), loaded_state_dict.keys()) - for name, saved_state in modified_model.state_dict().items(): - self.assertTrue(torch.all(torch.isclose(saved_state, loaded_state_dict[name]))) - - # Assert loaded states are not the same as the original unmodified pretrained model - self.assertFalse(torch.all(torch.isclose(modified_model.v_head[-1].bias, model.v_head[-1].bias))) - - def test_from_config(self): - for model_path in AUTO_CAUSAL_LM_PATHS: - config = transformers.AutoConfig.from_pretrained(model_path) - # Modify the config to ensure the model is initialized from the custom config - config.vocab_size = 2 - model = self._auto_model_class.from_config(config, **self._supported_args) - self.assertEqual(model.base_model.get_output_embeddings().out_features, config.vocab_size) - - -class TestAutoModelForCausalLMWithHydraValueHead(TestAutoModelForCausalLMWithValueHead): - _auto_model_class = AutoModelForCausalLMWithHydraValueHead - _supported_args = {"num_layers_unfrozen": 2} # TODO: Test various values - - def test_forward(self): - for model_path in AUTO_CAUSAL_LM_PATHS: - model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) - inputs = self._create_inputs(model_path) - - with torch.no_grad(): - # Compare logits and hidden states from frozen and unfrozen heads - unfrozen_outputs = model(**inputs, return_dict=True, output_hidden_states=True) - unfrozen_last_hidden_state = unfrozen_outputs.hidden_states[-1] - unfrozen_logits = unfrozen_outputs.logits - - frozen_outputs = model.forward_hydra(**inputs, return_dict=True, output_hidden_states=True) - frozen_last_hidden_state = frozen_outputs.hidden_states[-1] - frozen_logits = frozen_outputs.logits - - hs_diff = torch.sum(unfrozen_last_hidden_state - frozen_last_hidden_state).item() - logits_diff = torch.sum(unfrozen_logits - frozen_logits).item() - - self.assertEqual(hs_diff, 0) - self.assertEqual(logits_diff, 0) - - def test_lm_heads(self): - for model_path in AUTO_CAUSAL_LM_PATHS: - model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) - inputs = self._create_inputs(model_path) - - # Compare frozen and unfrozen logits - with torch.no_grad(): - unfrozen_outputs = model(**inputs, return_dict=True, output_hidden_states=True) - unfrozen_logits = unfrozen_outputs.logits - frozen_logits = model.frozen_head.lm_head(unfrozen_outputs.hidden_states[-1].to(torch.float32)) - diff = torch.sum(unfrozen_logits - frozen_logits).item() - self.assertEqual(diff, 0) - - def test_frozen_head(self): - for model_path in AUTO_CAUSAL_LM_PATHS: - model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) - # Ensure that all parameters of the hyrda `model.frozen_head` are actually frozen - for parameter in model.frozen_head.parameters(): - self.assertTrue(parameter.requires_grad is False) - - -class TestAutoModelForSeq2SeqLMWithValueHead(unittest.TestCase): - _auto_model_class = AutoModelForSeq2SeqLMWithValueHead - _supported_args = {} - - def setUp(self): - self.encoder_text = "Translate this text to French: Hello, my dog is cute" - self.decoder_text = "Bonjour, mon chien est mignon" - - def tearDown(self): - gc.collect() # Try to free up memory - - def _create_inputs(self, model_path): - tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) - tokenizer.sep_token = "" - tokenizer.pad_token = tokenizer.eos_token - tokenizer.padding_side = "left" - - encoder_inputs = tokenizer( - self.encoder_text, truncation=True, padding="max_length", max_length=10, return_tensors="pt" - ) - decoder_inputs = tokenizer(self.decoder_text, return_tensors="pt") - return { - **encoder_inputs, - "decoder_input_ids": decoder_inputs.input_ids, - "decoder_attention_mask": decoder_inputs.attention_mask, - } - - def test_forward(self): - for model_path in AUTO_SEQ2SEQ_LM_PATHS: - model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) - inputs = self._create_inputs(model_path) - - # Ensure that the `forward` method doesn't throw an error on generic inputs - try: - model(**inputs) - except Exception as e: - self.assertFalse(True, msg=e) - - def test_generate(self): - for model_path in AUTO_SEQ2SEQ_LM_PATHS: - model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) - inputs = self._create_inputs(model_path) - - # Ensure that the `generate` method doesn't throw an error on generic inputs - try: - model.generate(inputs["input_ids"]) - except Exception as e: - self.assertFalse(True, msg=e) - - def test_save_load(self): - for model_path in AUTO_SEQ2SEQ_LM_PATHS: - model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) - modified_model = copy.deepcopy(model) - - # Manually modify value head parameters - modified_model.v_head[-1].bias = torch.nn.Parameter(torch.tensor([6000053.33])) - - with tempfile.TemporaryDirectory() as tmpdirname: - modified_model.save_pretrained(tmpdirname) - loaded_model = self._auto_model_class.from_pretrained(tmpdirname) - - # Check that the loaded model state dict is the same as the saved model state dict - loaded_state_dict = loaded_model.state_dict() - self.assertEqual(modified_model.state_dict().keys(), loaded_state_dict.keys()) - for name, saved_state in modified_model.state_dict().items(): - self.assertTrue(torch.all(torch.isclose(saved_state, loaded_state_dict[name]))) - - # Assert loaded states are not the same as the original unmodified pretrained model - self.assertFalse(torch.all(torch.isclose(modified_model.v_head[-1].bias, model.v_head[-1].bias))) - - def test_from_config(self): - for model_path in AUTO_SEQ2SEQ_LM_PATHS: - config = transformers.AutoConfig.from_pretrained(model_path) - # Modify the config to ensure the model is initialized from the custom config - config.vocab_size = 2 - model = self._auto_model_class.from_config(config, **self._supported_args) - self.assertEqual(model.base_model.get_output_embeddings().out_features, config.vocab_size) - - -class TestAutoModelForSeq2SeqLMWithHydraValueHead(TestAutoModelForSeq2SeqLMWithValueHead): - _auto_model_class = AutoModelForSeq2SeqLMWithHydraValueHead - _supported_args = {"num_layers_unfrozen": 2} # TODO: Test various values - - @unittest.skip("TODO: Final hidden states are not the same for frozen and unfrozen T5 heads") - def test_forward(self): - for model_path in AUTO_SEQ2SEQ_LM_PATHS: - model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) - inputs = self._create_inputs(model_path) - - with torch.no_grad(): - # Compare logits and hidden states from frozen and unfrozen heads - unfrozen_outputs = model(**inputs, return_dict=True, output_hidden_states=True) - unfrozen_last_hidden_state = unfrozen_outputs.decoder_hidden_states[-1] - unfrozen_logits = unfrozen_outputs.logits - - frozen_outputs = model.forward_hydra(**inputs, return_dict=True, output_hidden_states=True) - frozen_last_hidden_state = frozen_outputs.decoder_hidden_states[-1] - frozen_logits = frozen_outputs.logits - - hs_diff = torch.sum(unfrozen_last_hidden_state - frozen_last_hidden_state).item() - logits_diff = torch.sum(unfrozen_logits - frozen_logits).item() - - self.assertEqual(hs_diff, 0) - self.assertEqual(logits_diff, 0) - - @unittest.skip("TODO: Final hidden states are not the same for frozen and unfrozen T5 heads") - def test_lm_heads(self): - for model_path in AUTO_SEQ2SEQ_LM_PATHS: - model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) - inputs = self._create_inputs(model_path) - - # Compare frozen and unfrozen logits - with torch.no_grad(): - unfrozen_outputs = model(**inputs, return_dict=True, output_hidden_states=True) - unfrozen_logits = unfrozen_outputs.logits - last_hidden_state = unfrozen_outputs.decoder_hidden_states[-1] - frozen_logits = model.frozen_head.lm_head(last_hidden_state) - diff = torch.sum(unfrozen_logits - frozen_logits).item() - self.assertEqual(diff, 0) - - def test_frozen_head(self): - for model_path in AUTO_SEQ2SEQ_LM_PATHS: - model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) - # Ensure that all parameters of the hyrda `model.frozen_head` are actually frozen - for parameter in model.frozen_head.parameters(): - self.assertTrue(parameter.requires_grad is False) - - -# ILQL Heads Modeling Tests - - -class TestAutoModelForCausalLMWithILQLHeads(unittest.TestCase): - _auto_model_class = AutoModelForCausalLMWithILQLHeads - _supported_args = {"two_qs": True, "alpha": 0.8} # TODO: Test various values - - def setUp(self): - self.text = "Once upon a time there was a happy goose named Louis. He liked to eat bananas." - - def tearDown(self): - gc.collect() # Try to free up memory - - def _create_inputs(self, model_path): - tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) - tokenizer.pad_token = tokenizer.eos_token - tokenizer.padding_side = "left" - return tokenizer(self.text, truncation=True, padding="max_length", max_length=4, return_tensors="pt") - - def test_forward(self): - for model_path in AUTO_CAUSAL_LM_PATHS: - model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) - inputs = self._create_inputs(model_path) - - # Ensure that the `forward` method doesn't throw an error on generic inputs - try: - model(**inputs) - except Exception as e: - self.assertFalse(True, msg=e) - - def test_generate(self): - for model_path in AUTO_CAUSAL_LM_PATHS: - model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) - inputs = self._create_inputs(model_path) - - # Ensure that the `generate` method doesn't throw an error on generic inputs - try: - model.generate(**inputs) - except Exception as e: - self.assertFalse(True, msg=e) - - def test_save_load(self): - for model_path in AUTO_CAUSAL_LM_PATHS: - model = self._auto_model_class.from_pretrained(model_path, **self._supported_args) - modified_model = copy.deepcopy(model) - - # Manually modify value head parameters - modified_model.ilql_heads.q_heads[0][0].bias = torch.nn.Parameter( - torch.ones_like(modified_model.ilql_heads.q_heads[0][0].bias) * 600053.34 - ) - - with tempfile.TemporaryDirectory() as tmpdirname: - modified_model.save_pretrained(tmpdirname) - loaded_model = self._auto_model_class.from_pretrained(tmpdirname) - - # Check that the loaded model state dict is the same as the saved model state dict - loaded_state_dict = loaded_model.state_dict() - self.assertEqual(modified_model.state_dict().keys(), loaded_state_dict.keys()) - for name, saved_state in modified_model.state_dict().items(): - self.assertTrue(torch.all(torch.isclose(saved_state, loaded_state_dict[name]))) - - # Assert loaded states are not the same as the original unmodified pretrained model - self.assertFalse( - torch.all( - torch.isclose(modified_model.ilql_heads.q_heads[0][0].bias, model.ilql_heads.q_heads[0][0].bias) - ) - ) - - def test_from_config(self): - for model_path in AUTO_CAUSAL_LM_PATHS: - config = transformers.AutoConfig.from_pretrained(model_path) - # Modify the config to ensure the model is initialized from the custom config - config.vocab_size = 2 - model = self._auto_model_class.from_config(config, **self._supported_args) - self.assertEqual(model.base_model.get_output_embeddings().out_features, config.vocab_size) diff --git a/tests/test_ppo.py b/tests/test_ppo.py new file mode 100644 index 000000000..d5be5fdbe --- /dev/null +++ b/tests/test_ppo.py @@ -0,0 +1,84 @@ +import unittest + +import torch +from transformers import AutoTokenizer + +from trlx.data.configs import TRLConfig +from trlx.trainer.nn.ppo_models import CausalLMHydraWithValueHead +from trlx.utils.modeling import RunningMoments + + +# Note tests must start with "test_" +class TestHydraHead(unittest.TestCase): + @classmethod + def setUpClass(cls): + print("Testing Hydra model...") + config = TRLConfig.load_yaml("configs/test_config.yml") + cls.hydra_model = CausalLMHydraWithValueHead(config.model.model_path, config.model.num_layers_unfrozen) + + tokenizer = AutoTokenizer.from_pretrained(config.tokenizer.tokenizer_path) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + cls.dummy_inputs = tokenizer( + "Once upon a time there was a happy goose named Louis. He liked to eat bananas.", + truncation=True, + padding="max_length", + max_length=4, + return_tensors="pt", + ) + + def test_lm_heads(self): + with torch.no_grad(): + unfrozen_outputs = TestHydraHead.hydra_model( + **TestHydraHead.dummy_inputs, return_dict=True, output_hidden_states=True + ) + unfrozen_logits = unfrozen_outputs.logits + last_hidden_states = unfrozen_outputs.hidden_states[-1].to(torch.float32) + frozen_logits = TestHydraHead.hydra_model.frozen_head.lm_head(last_hidden_states) + diff = torch.sum(unfrozen_logits - frozen_logits).item() + self.assertEqual(diff, 0) + + def test_frozen_head(self): + # Ensure that all parameters of the `hydra_model.frozen_head` are actually frozen + for parameter in TestHydraHead.hydra_model.frozen_head.parameters(): + self.assertTrue(parameter.requires_grad is False) + + def test_forward(self): + with torch.no_grad(): + unfrozen_outputs = TestHydraHead.hydra_model( + **TestHydraHead.dummy_inputs, return_dict=True, output_hidden_states=True + ) + unfrozen_last_hidden_states = unfrozen_outputs.hidden_states[-1] + unfrozen_logits = unfrozen_outputs.logits + + frozen_outputs = TestHydraHead.hydra_model.forward_hydra( + **TestHydraHead.dummy_inputs, return_dict=True, output_hidden_states=True + ) + frozen_last_hidden_states = frozen_outputs.hidden_states[-1] + frozen_logits = frozen_outputs.logits + + hs_diff = torch.sum(unfrozen_last_hidden_states - frozen_last_hidden_states).item() + logits_diff = torch.sum(unfrozen_logits - frozen_logits).item() + self.assertEqual(hs_diff, 0) + self.assertEqual(logits_diff, 0) + + +class TestStatistics(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.m = RunningMoments() + cls.a1 = torch.arange(100, dtype=float) + cls.a2 = torch.ones(100, dtype=float) + cls.a3 = torch.exp(torch.arange(10, dtype=float)) + cls.a4 = torch.tensor([-10, -1, 0, 1, 10], dtype=float) + + def test_running_moments(self): + assert torch.isclose(self.m.update(self.a1)[1], self.a1.std(unbiased=True), atol=1e-6) + assert torch.isclose(self.m.update(self.a2)[1], self.a2.std(unbiased=True), atol=1e-6) + assert torch.isclose(self.m.update(self.a3)[1], self.a3.std(unbiased=True), atol=1e-6) + assert torch.isclose(self.m.update(self.a4)[1], self.a4.std(unbiased=True), atol=1e-6) + + a = torch.hstack((self.a1, self.a2, self.a3, self.a4)) + assert torch.isclose(self.m.mean, a.mean(), atol=1e-6) + assert torch.isclose(self.m.std, a.std(unbiased=True), atol=1e-6) diff --git a/tests/test_utils.py b/tests/test_utils.py index f3c09c23b..7a0af9959 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,3 @@ -import unittest - import accelerate import pytest import torch @@ -70,9 +68,9 @@ def test_hf_attr_getters(model_name: str): arch = transformers.AutoModelForCausalLM.from_config(config) arch_getters = [ - modeling_utils.hf_get_decoder, - modeling_utils.hf_get_decoder_final_norm, - modeling_utils.hf_get_decoder_blocks, + modeling_utils.hf_get_causal_base_model, + modeling_utils.hf_get_causal_final_norm, + modeling_utils.hf_get_causal_hidden_layers, modeling_utils.hf_get_lm_head, ] for get in arch_getters: @@ -127,23 +125,3 @@ def test_parse_delta_kwargs(model_name): ) for kwarg_mod in delta_kwargs["modified_modules"]: assert kwarg_mod.endswith("a") or kwarg_mod.endswith("b"), "Parsed modified module should contain ['a', 'b']" - - -class TestStatistics(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.m = modeling_utils.RunningMoments() - cls.a1 = torch.arange(100, dtype=float) - cls.a2 = torch.ones(100, dtype=float) - cls.a3 = torch.exp(torch.arange(10, dtype=float)) - cls.a4 = torch.tensor([-10, -1, 0, 1, 10], dtype=float) - - def test_running_moments(self): - assert torch.isclose(self.m.update(self.a1)[1], self.a1.std(unbiased=True), atol=1e-6) - assert torch.isclose(self.m.update(self.a2)[1], self.a2.std(unbiased=True), atol=1e-6) - assert torch.isclose(self.m.update(self.a3)[1], self.a3.std(unbiased=True), atol=1e-6) - assert torch.isclose(self.m.update(self.a4)[1], self.a4.std(unbiased=True), atol=1e-6) - - a = torch.hstack((self.a1, self.a2, self.a3, self.a4)) - assert torch.isclose(self.m.mean, a.mean(), atol=1e-6) - assert torch.isclose(self.m.std, a.std(unbiased=True), atol=1e-6) diff --git a/trlx/data/configs.py b/trlx/data/configs.py index 678259a4e..432ed922a 100644 --- a/trlx/data/configs.py +++ b/trlx/data/configs.py @@ -1,4 +1,3 @@ -from copy import deepcopy from dataclasses import dataclass, field from typing import Any, Dict, Optional, Set, Tuple @@ -22,20 +21,6 @@ def merge(base: Dict, update: Dict, updated: Set) -> Dict: return base -def _merge_dicts(base: Dict, update: Dict) -> Dict: - "Merge two dictionaries recursively, returning a new dictionary." - - base = deepcopy(base) - - for k, v in update.items(): - if isinstance(v, dict): - base[k] = _merge_dicts(base.get(k, {}), v) - else: - base[k] = v - - return base - - @dataclass class ModelConfig: """ @@ -272,16 +257,6 @@ def to_dict(self): return data - def evolve(self, **kwargs) -> "TRLConfig": - """ - Evolve TRLConfig with new parameters. Can update nested parameters. - >>> config = trlx.data.default_configs.default_ilql_config() - >>> config = config.evolve(method=dict(gamma=0.99, gen_kwargs=dict(max_new_tokens=100)) - >>> config.method.gamma - 0.99 - """ - return TRLConfig.from_dict(_merge_dicts(self.to_dict(), kwargs)) - @classmethod def from_dict(cls, config: Dict): """ diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py deleted file mode 100644 index 1f9297db2..000000000 --- a/trlx/data/default_configs.py +++ /dev/null @@ -1,119 +0,0 @@ -from trlx.models.modeling_ilql import ILQLConfig -from trlx.models.modeling_ppo import PPOConfig -from trlx.trainer.accelerate_sft_trainer import SFTConfig - -from .configs import ( - ModelConfig, - OptimizerConfig, - SchedulerConfig, - TokenizerConfig, - TrainConfig, - TRLConfig, -) - - -def default_ppo_config(): - return TRLConfig( - train=TrainConfig( - seq_length=1024, - epochs=100, - total_steps=10000, - batch_size=32, - checkpoint_interval=10000, - eval_interval=100, - pipeline="PromptPipeline", - trainer="AcceleratePPOTrainer", - ), - model=ModelConfig(model_path="lvwerra/gpt2-imdb", num_layers_unfrozen=2), - tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), - optimizer=OptimizerConfig( - name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) - ), - scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1.0e-4)), - method=PPOConfig( - name="PPOConfig", - num_rollouts=128, - chunk_size=128, - ppo_epochs=4, - init_kl_coef=0.05, - target=6, - horizon=10000, - gamma=1, - lam=0.95, - cliprange=0.2, - cliprange_value=0.2, - vf_coef=1, - scale_reward="ignored", - ref_mean=None, - ref_std=None, - cliprange_reward=10, - gen_kwargs=dict( - max_new_tokens=40, - top_k=0, - top_p=1.0, - do_sample=True, - ), - ), - ) - - -def default_ilql_config(): - return TRLConfig( - train=TrainConfig( - seq_length=64, - batch_size=32, - epochs=100, - total_steps=1000, - checkpoint_interval=1000, - eval_interval=100, - pipeline="PromptPipeline", - trainer="AccelerateILQLTrainer", - ), - model=ModelConfig(model_path="gpt2", num_layers_unfrozen=-1), - tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), - optimizer=OptimizerConfig( - name="adamw", kwargs=dict(lr=5.0e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) - ), - scheduler=SchedulerConfig( - name="cosine_annealing", kwargs=dict(T_max=1000, eta_min=5.0e-5) # train.total_steps - ), - method=ILQLConfig( - name="ilqlconfig", - tau=0.7, - gamma=0.99, - cql_scale=0.1, - awac_scale=1, - alpha=0.001, - beta=0, - steps_for_target_q_sync=5, - two_qs=True, - gen_kwargs=dict(max_new_tokens=56, top_k=20, beta=4, temperature=1.0), - ), - ) - - -def default_sft_config(): - return TRLConfig( - train=TrainConfig( - seq_length=1024, - epochs=100, - total_steps=1000, - batch_size=8, - checkpoint_interval=10000, - eval_interval=100, - pipeline="PromptPipeline", - trainer="AccelerateSFTTrainer", - ), - model=ModelConfig(model_path="gpt2", num_layers_unfrozen=-1), - tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), - optimizer=OptimizerConfig( - name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) - ), - scheduler=SchedulerConfig( - name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1.0e-4) # train.total_steps - ), - method=SFTConfig( - name="sftconfig", - gen_kwargs=dict(max_new_tokens=40, top_k=0, top_p=1.0, do_sample=True), - ), - ) diff --git a/trlx/models/modeling_base.py b/trlx/models/modeling_base.py deleted file mode 100644 index 6e6d10d1e..000000000 --- a/trlx/models/modeling_base.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright 2022 CarperAI & The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# NOTE: This file contains a modified version of the `PreTrainedModelWrapper` class from -# HuggingFace's `trl` library. The original source code can be found here: -# https://github.com/lvwerra/trl/blob/78c13226bf8ea1ccd9b1c091f03a938098521f6c/trl/models/modeling_base.py - -import inspect -import json -import os -from typing import Any, Dict, List, Optional, Union - -import torch -import torch.nn as nn -import transformers -from huggingface_hub import hf_hub_download - - -class PreTrainedModelWrapper(nn.Module, transformers.utils.PushToHubMixin): - """A wrapper around `transformers.PreTrainedModel` - - Reference: @younesbelkada's `PreTrainedModelWrapper` - https://github.com/lvwerra/trl/blob/4f5c16fafde42d9aca971952bcdcc1f5a0a68cf0/trl/models/modeling_base.py#L2 - - Attributes: - _auto_model_parent_class (transformers.AutoModel): The `transformers.AutoModel` - type to base the wrapping behavior off of, e.g. `transformers.AutoModelForCausalLM`. - _supported_modules (List[str]): A list of attribute names for modules of - the underlying architecture model. This is used, for example, to save - and load any additional modules by manipulating the state dict. - _supported_args (List[str]): A list of arguments specific to the underlying - architecture to separate from arguments that are supported by the - parent `AutoModel` class. Any arguments that are not supported by the - underlying model will be passed to the parent `AutoModel` class. - """ - - _auto_model_parent_class: transformers.AutoModel = None - _supported_modules: List[str] = None - # TODO (jon-tow): Supported args should come from a `PretrainedConfig` of the - # specific underlying type similar to how config instances can be used to instantiate - # `transformers.PreTrainedModel`s. - _supported_args: List[str] = None - - def __init__(self, base_model: Optional[transformers.PreTrainedModel] = None, **kwargs): - super().__init__() - self.base_model = base_model - # cache `forward` args for general use (avoids incompatible args across architectures) - self.forward_kwargs = inspect.getfullargspec(self.base_model.forward).args - - @classmethod - def _split_kwargs(cls, kwargs: Dict[str, Any]): - """Separates the kwargs from the supported arguments within `supported_args` - and those that are not - """ - supported_kwargs = {} - unsupported_kwargs = {} - for key, value in kwargs.items(): - if key in cls._supported_args: - supported_kwargs[key] = value - else: - unsupported_kwargs[key] = value - return supported_kwargs, unsupported_kwargs - - @classmethod - def from_config(cls, config: transformers.PretrainedConfig, **kwargs): - """Instantiate the pretrained pytorch model from a configuration. - - Args: - config (transformers.PretrainedConfig): The configuration to use to - instantiate the base model. - - NOTE: Loading a model from its configuration file does **not** load the - model weights. It only affects the model's configuration. Use - `~transformers.AutoModel.from_pretrained` to load the model weights. - """ - if kwargs is not None: - wrapped_model_kwargs, from_config_kwargs = cls._split_kwargs(kwargs) - else: - from_config_kwargs = {} - wrapped_model_kwargs = {} - base_model = cls._auto_model_parent_class.from_config(config, **from_config_kwargs) - model = cls(base_model, **wrapped_model_kwargs) - return model - - @classmethod - def from_pretrained( # noqa: max-complexity - cls, - pretrained_model_name_or_path: Union[str, transformers.PreTrainedModel], - *model_args, - **kwargs, - ): - """Instantiate a pretrained pytorch model from a pretrained model configuration. - This method is a wrapper around `transformers.PreTrainedModel.from_pretrained`. - Please refer to the documentation of `transformers.PreTrainedModel.from_pretrained` - for more information. - - Args: - pretrained_model_name_or_path (str or `transformers.PreTrainedModel`): - The identifier of the pretrained model to load or the pretrained model itself. - *model_args (sequence of positional arguments, *optional*): - All remaining positional arguments will be passed to the `_auto_model_parent_class`. - **kwargs (dict, *optional*): - Dictionary of keyword arguments to pass to both the underlying `_auto_model_parent_class` - call (e.g. `transformers.AutoModelForCausalLM.from_pretrained`) and the specific - instance of the wrapped model. - - NOTE: You must pass in arguments specific to the wrapped model as keyword arguments. - """ - if kwargs is not None: - wrapped_model_kwargs, from_pretrained_kwargs = cls._split_kwargs(kwargs) - else: - from_pretrained_kwargs = {} - wrapped_model_kwargs = {} - - if isinstance(pretrained_model_name_or_path, str): - # Load the base model using the `transformers` AutoClass (e.g. AutoModelForCausalLM) - base_model = cls._auto_model_parent_class.from_pretrained( - pretrained_model_name_or_path, *model_args, **from_pretrained_kwargs - ) - elif isinstance(pretrained_model_name_or_path, transformers.PreTrainedModel): - base_model = pretrained_model_name_or_path - else: - raise ValueError( - f"Invalid type for `base_model_name_or_path`: {type(pretrained_model_name_or_path)}" - "Expected `str` or `transformers.PreTrainedModel`." - ) - - model = cls(base_model, **wrapped_model_kwargs) - - if isinstance(pretrained_model_name_or_path, str): - filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") - sharded_index_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json") - is_sharded = False - - if not os.path.exists(filename): - try: - filename = hf_hub_download(pretrained_model_name_or_path, "pytorch_model.bin") - # Sharded - except Exception: - if os.path.exists(sharded_index_filename): - index_file_name = sharded_index_filename - else: - index_file_name = hf_hub_download( - pretrained_model_name_or_path, - "pytorch_model.bin.index.json", - ) - with open(index_file_name, "r") as f: - index = json.load(f) - # Collect files containing weights from supported modules - files_to_download = set() - for k, v in index["weight_map"].items(): - if any([module in k for module in cls._supported_modules]): - files_to_download.add(v) - is_sharded = True - - if is_sharded: - # Merge each shard into a state dict - # TODO: Optimize this to avoid wasting RAM - state_dict = {} - for shard_file in files_to_download: - filename = os.path.join(pretrained_model_name_or_path, shard_file) - # Download if shard file doesn't exist locally - if not os.path.exists(filename): - filename = hf_hub_download(pretrained_model_name_or_path, shard_file) - state_dict.update(torch.load(filename, map_location="cpu")) - else: - state_dict = torch.load(filename, map_location="cpu") - else: - state_dict = pretrained_model_name_or_path.state_dict() - - model.post_init(state_dict=state_dict) - return model - - def save_pretrained(self, *args, **kwargs): - """Save the pretrained model to a directory. This method is a wrapper - around `transformers.PreTrainedModel.save_pretrained`. Please refer to - the documentation of `transformers.PreTrainedModel.save_pretrained` for - more information. - - Args: - *args (`list`, *optional*): - Positional arguments passed along to the underlying model's - `save_pretrained` method. - **kwargs (`dict`, *optional*): - Keyword arguments passed along to the underlying model's - `save_pretrained` method. - """ - state_dict = kwargs.pop("state_dict", None) - if state_dict is None: - state_dict = self.state_dict() - kwargs["state_dict"] = state_dict - - return self.base_model.save_pretrained(*args, **kwargs) - - def state_dict(self, *args, **kwargs): - """Return the state_dict of the pretrained model.""" - raise NotImplementedError - - def post_init(self, *args, **kwargs): - """Post initialization method. This method is called after the model is - instantiated and loaded from a checkpoint. It can be used to perform - additional operations such as loading the state_dict. - """ - raise NotImplementedError - - def get_compatible_forward_kwargs(self, **kwargs) -> Dict[str, Any]: - """Filter out arguments not supported by the specific instance of - `base_model.transformer.forward` - """ - # FIXME: This is a hack to get around the fact that the `transformers` - # architectures we use don't have a consistent API for `forward` parameters. - return {k: v for k, v in kwargs.items() if k in self.forward_kwargs} diff --git a/trlx/trainer/__init__.py b/trlx/trainer/__init__.py index 251febc62..e1c469e21 100644 --- a/trlx/trainer/__init__.py +++ b/trlx/trainer/__init__.py @@ -1,7 +1,11 @@ +import os import sys from abc import abstractmethod from typing import Any, Callable, Dict, Iterable +import torch + +from trlx.data import RLElement from trlx.data.configs import TRLConfig from trlx.pipeline import BaseRolloutStore diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 65bd7ab1a..d633fefd4 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -249,23 +249,15 @@ def save(self, directory: Optional[str] = None): """Creates a checkpoint of the optimizer, scheduler and model""" self.accelerator.save_state(directory or self.config.train.checkpoint_dir) - def save_pretrained(self, directory: Optional[str] = None, **kwargs): - """Save the underlying Hugging Face model, tokenizer, and configuration files to a directory for - later use. - - Args: - directory (str, *optional*): The directory to save the trainer files to. - NOTE: If not specified, the model will be saved to a directory named `hf_model` in the - checkpoint directory as specified by the Trainer's config. - **kwargs: Additional keyword arguments passed to the underlying Hugging Face model's - `save_pretrained` method. + @abstractmethod + def save_pretrained(self, directory: Optional[str] = None): + """Save the model and its configuration file to a directory, so that it can be re-loaded with the + `transformers.PreTrainedModel.from_pretrained` method. + + NOTE: If a `directory` is not provided, the model will be saved to a sub-directory + of the Trainer config checkpoint dir named "hf_model" (e.g. `/ckpts/hf_model`). """ - if directory is None: - directory = f"{self.config.train.checkpoint_dir}/hf_model" - self.accelerator.wait_for_everyone() - self.accelerator.unwrap_model(self.model).save_pretrained(directory, **kwargs) - if self.accelerator.is_main_process: - self.tokenizer.save_pretrained(directory) + pass def load(self, directory=None): """Load checkpoint of optimizer, scheduler and a model""" diff --git a/trlx/trainer/accelerate_ilql_trainer.py b/trlx/trainer/accelerate_ilql_trainer.py index c2dfe3361..231b2c059 100644 --- a/trlx/trainer/accelerate_ilql_trainer.py +++ b/trlx/trainer/accelerate_ilql_trainer.py @@ -1,19 +1,18 @@ import os -from typing import cast +from typing import Optional, cast import numpy as np import torch -import transformers from rich.console import Console from rich.table import Table import trlx.utils.logging as logging from trlx.data.configs import TRLConfig from trlx.data.ilql_types import ILQLBatch -from trlx.models.modeling_ilql import AutoModelForCausalLMWithILQLHeads, ILQLConfig from trlx.pipeline.offline_pipeline import ILQLRolloutStorage, tokenize_dialogue from trlx.trainer import register_trainer from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer +from trlx.trainer.nn.ilql_models import CausalLMWithValueHeads, ILQLConfig from trlx.utils import to_device logger = logging.get_logger(__name__) @@ -38,15 +37,10 @@ def __init__(self, config: TRLConfig, **kwargs): ) def get_arch(self, config): - from_fn = AutoModelForCausalLMWithILQLHeads.from_pretrained - # backward-compat: Try to create a randomly initialized architecture from a config - if issubclass(type(config.model.model_path), transformers.PretrainedConfig): - from_fn = AutoModelForCausalLMWithILQLHeads.from_config - - return from_fn( + return CausalLMWithValueHeads( config.model.model_path, - two_qs=config.method.two_qs, - alpha=config.method.alpha, + ilql_config=config.method, + num_layers_unfrozen=config.model.num_layers_unfrozen, ) def post_backward_callback(self): @@ -80,6 +74,20 @@ def prepare_learning(self): self.total_steps = self.config.train.epochs * len(train_dataloader) self.total_steps = min(self.total_steps, self.config.train.total_steps) + def save_pretrained(self, directory: Optional[str] = None): + """NOTE: If a `directory` is not provided, the model will be saved to a sub-directory + of the Trainer config checkpoint dir named "hf_model" (e.g. `/ckpts/hf_model`). + """ + # TODO: Support saving with `transformers.PreTrainedModel.save_pretrained`. + # This is currently not supported becasue `nn.ilql_models.CausalLMWithValueHeads` + # requires a custom `generate` method using its (value/q) heads to steer + # sampling - something that is not possible with the default + # `transformers.PreTrainedModel.generate`. + raise NotImplementedError( + "`AccelerateILQLTrainer` does not currently support automatic saving " + "with `transformers.PreTrainedModel.save_pretrained`." + ) + def make_experience(self, samples, rewards, max_length=2048): """ Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the trainer diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 3be3ab5af..1fd0f30a9 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -2,11 +2,10 @@ import os import uuid from time import time -from typing import Callable, List +from typing import Callable, List, Optional import torch import torch.nn.functional as F -import transformers from torch.utils.data import DataLoader from transformers import AutoTokenizer @@ -14,16 +13,16 @@ from trlx.data.accelerate_base_datatypes import PromptBatch from trlx.data.configs import TRLConfig from trlx.data.ppo_types import PPORLBatch, PPORLElement -from trlx.models.modeling_ppo import ( - AdaptiveKLController, - AutoModelForCausalLMWithHydraValueHead, - AutoModelForSeq2SeqLMWithHydraValueHead, - FixedKLController, -) from trlx.pipeline.offline_pipeline import PromptPipeline from trlx.pipeline.ppo_pipeline import PPORolloutStorage from trlx.trainer import register_trainer from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer +from trlx.trainer.nn.ppo_models import ( + AdaptiveKLController, + CausalLMHydraWithValueHead, + FixedKLController, + Seq2SeqLMHydraWithValueHead, +) from trlx.utils import Clock from trlx.utils.modeling import RunningMoments, logprobs_of_labels @@ -71,7 +70,6 @@ def __init__(self, config: TRLConfig, **kwargs): if not hasattr(self.model, "frozen_head"): self.ref_model = self.get_arch(self.config) self.ref_model.to(self.accelerator.device) - self.ref_model.eval() # Setup the KL controller # This helps prevent large divergences in the controller (policy) @@ -119,19 +117,9 @@ def __init__(self, config: TRLConfig, **kwargs): def get_arch(self, config: TRLConfig): """Get the model""" - model_class = AutoModelForCausalLMWithHydraValueHead if config.model.model_arch_type == "seq2seq": - model_class = AutoModelForSeq2SeqLMWithHydraValueHead - - from_fn = model_class.from_pretrained - # backward-compat: Try to create a randomly initialized architecture from a config - if issubclass(type(config.model.model_path), transformers.PretrainedConfig): - from_fn = model_class.from_config - - return from_fn( - config.model.model_path, - num_layers_unfrozen=config.model.num_layers_unfrozen, - ) + return Seq2SeqLMHydraWithValueHead(config.model.model_path, config.model.num_layers_unfrozen) + return CausalLMHydraWithValueHead(config.model.model_path, config.model.num_layers_unfrozen) def loss(self, batch: PPORLBatch): """Forward pass & loss @@ -237,6 +225,15 @@ def prepare_learning(self): self.total_steps = self.config.train.epochs * self.n_updates_per_batch * len(self.train_dataloader) self.total_steps = min(self.total_steps, self.config.train.total_steps) + def save_pretrained(self, directory: Optional[str] = None): + """NOTE: If a `directory` is not provided, the model will be saved to a sub-directory + of the Trainer config checkpoint dir named "hf_model" (e.g. `/ckpts/hf_model`). + """ + if directory is None: + directory = f"{self.config.train.checkpoint_dir}/hf_model" + self.accelerator.unwrap_model(self.model).base_model.save_pretrained(directory) + self.tokenizer.save_pretrained(directory) + def add_prompt_pipeline(self, pipeline: PromptPipeline): """Add a prompt pipeline dataloader to a trainer instance for the `make_experience` stage""" prompt_dataloader = pipeline.create_loader(self.config.method.chunk_size, shuffle=True) @@ -378,14 +375,12 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq input_ids=prompt_tensors, attention_mask=attention_mask, decoder_input_ids=sample_outputs, - return_dict=True, - ).logits + ) else: ref_logits = self.ref_model( input_ids=prompt_tensors, attention_mask=attention_mask, decoder_input_ids=sample_outputs, - return_dict=True, ).logits else: all_tokens = torch.cat((prompt_tensors.to(device), sample_outputs), dim=1) @@ -400,14 +395,14 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ref_logits = self.model.forward_hydra( all_tokens, attention_mask=attention_mask, - return_dict=True, - ).logits + return_dict=False, + ) else: - ref_logits = self.ref_model( + ref_logits, _, *_ = self.ref_model( all_tokens, attention_mask=attention_mask, - return_dict=True, - ).logits + return_dict=False, + ) ref_logits = ref_logits.to(device) if self.config.model.model_arch_type == "seq2seq": diff --git a/trlx/trainer/accelerate_sft_trainer.py b/trlx/trainer/accelerate_sft_trainer.py index e061896e6..447913d9a 100644 --- a/trlx/trainer/accelerate_sft_trainer.py +++ b/trlx/trainer/accelerate_sft_trainer.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Optional from transformers import AutoModelForCausalLM @@ -55,3 +56,12 @@ def prepare_learning(self): self.n_updates_per_batch = 1 self.total_steps = self.config.train.epochs * len(train_dataloader) self.total_steps = min(self.total_steps, self.config.train.total_steps) + + def save_pretrained(self, directory: Optional[str] = None): + """NOTE: If a `directory` is not provided, the model will be saved to a sub-directory + of the Trainer config checkpoint dir named "hf_model" (e.g. `/ckpts/hf_model`). + """ + if directory is None: + directory = f"{self.config.train.checkpoint_dir}/hf_model" + self.accelerator.unwrap_model(self.model).base_model.save_pretrained(directory) + self.tokenizer.save_pretrained(directory) diff --git a/trlx/models/README.md b/trlx/trainer/nemo/README.md similarity index 98% rename from trlx/models/README.md rename to trlx/trainer/nemo/README.md index fdcb44597..c913c843a 100644 --- a/trlx/models/README.md +++ b/trlx/trainer/nemo/README.md @@ -27,7 +27,7 @@ exp_manager: ``` ## NeMo Megatron setup -Clone https://github.com/NVIDIA/NeMo/tree/r1.15.0 (currently only up to `r1.15.0` is supoprted) and apex from https://github.com/NVIDIA/apex/. +Clone https://github.com/NVIDIA/NeMo/ and apex from https://github.com/NVIDIA/apex/. 1) install conda (or mamba/micromamba) diff --git a/trlx/models/__init__.py b/trlx/trainer/nemo/__init__.py similarity index 100% rename from trlx/models/__init__.py rename to trlx/trainer/nemo/__init__.py diff --git a/trlx/models/modeling_nemo_ilql.py b/trlx/trainer/nemo/gpt.py similarity index 99% rename from trlx/models/modeling_nemo_ilql.py rename to trlx/trainer/nemo/gpt.py index 31ac49a8a..89eb2554b 100644 --- a/trlx/models/modeling_nemo_ilql.py +++ b/trlx/trainer/nemo/gpt.py @@ -40,7 +40,7 @@ from nemo.collections.nlp.parts.utils_funcs import get_last_rank from trlx.data.ilql_types import ILQLBatch, unflatten_dataclass -from trlx.models.modeling_ilql import ILQLConfig, batched_index_select +from trlx.trainer.nn.ilql_models import ILQLConfig, batched_index_select from trlx.utils import to_device, tree_map diff --git a/trlx/trainer/nemo_ilql_trainer.py b/trlx/trainer/nemo_ilql_trainer.py index a23a94f43..c58cc3249 100644 --- a/trlx/trainer/nemo_ilql_trainer.py +++ b/trlx/trainer/nemo_ilql_trainer.py @@ -22,14 +22,14 @@ from trlx.data.configs import TRLConfig from trlx.data.ilql_types import ILQLBatch, ILQLElement, flatten_dataclass -from trlx.models.modeling_ilql import ILQLConfig -from trlx.models.modeling_nemo_ilql import ILQLGPT from trlx.pipeline.offline_pipeline import ( ILQLRolloutStorage, ilql_collate_fn, tokenize_dialogue, ) from trlx.trainer import register_trainer +from trlx.trainer.nemo.gpt import ILQLGPT +from trlx.trainer.nn.ilql_models import ILQLConfig from . import BaseRLTrainer diff --git a/trlx/trainer/nn/__init__.py b/trlx/trainer/nn/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/trlx/models/modeling_ilql.py b/trlx/trainer/nn/ilql_models.py similarity index 73% rename from trlx/models/modeling_ilql.py rename to trlx/trainer/nn/ilql_models.py index 5eab697e4..aba7b4eff 100644 --- a/trlx/models/modeling_ilql.py +++ b/trlx/trainer/nn/ilql_models.py @@ -1,9 +1,10 @@ -import gc +import inspect import os from copy import deepcopy from dataclasses import dataclass from functools import reduce from itertools import chain +from typing import Any, Dict, Union import deepspeed # type: ignore import numpy as np @@ -15,10 +16,11 @@ from trlx.data.ilql_types import ILQLBatch from trlx.data.method_configs import MethodConfig, register_method -from trlx.models.modeling_base import PreTrainedModelWrapper from trlx.utils.modeling import ( flatten_dict, + freeze_bottom_causal_layers, get_tensor_stats, + hf_get_causal_base_model, hf_get_hidden_size, hf_get_lm_head, make_head, @@ -57,6 +59,9 @@ class ILQLConfig(MethodConfig): two_qs: bool gen_kwargs: dict + def heads(self, hidden_size: int, vocab_size: int, dtype: type): + return ILQLHeads(self, hidden_size, vocab_size, dtype) + def loss(self, outputs, labels: ILQLBatch): logits, (qs, target_qs, vs) = outputs terminal_mask = labels.dones[:, :-1] @@ -129,23 +134,16 @@ def cql_loss(q): class ILQLHeads(nn.Module): - def __init__( - self, - hidden_size: int, - vocab_size: int, - two_qs: bool, - alpha: float, - dtype: type, - ): + def __init__(self, config: ILQLConfig, hidden_size: int, vocab_size: int, dtype: type): super().__init__() self.hidden_size = hidden_size self.vocab_size = vocab_size - self.two_qs = two_qs - self.alpha = alpha self.v_head = make_head(self.hidden_size, 1, dtype) + self.config = config + + n_qs = 2 if self.config.two_qs else 1 - n_qs = 2 if self.two_qs else 1 self.q_heads = nn.ModuleList(make_head(self.hidden_size, self.vocab_size, dtype) for _ in range(n_qs)) self.target_q_heads = nn.ModuleList(deepcopy(q_head) for q_head in self.q_heads) @@ -185,38 +183,53 @@ def sync_target_q_heads(self): with deepspeed.zero.GatheredParameters(list(params), modifier_rank=0): if deepspeed.comm.get_rank() == 0: - self._sync_target_q_heads(self.alpha) + self._sync_target_q_heads(self.config.alpha) else: - self._sync_target_q_heads(self.alpha) - - -class AutoModelForCausalLMWithILQLHeads(PreTrainedModelWrapper): - """An `AutoModel` class wrapper for `transformers` causal models wtih a language - modeling head and ILQL heads. + self._sync_target_q_heads(self.config.alpha) - References: - [1] Snell et al., "Offline RL for Natural Language Generation with Implicit Language Q Learning", - https://arxiv.org/abs/2206.11871, 2022 - """ - _auto_model_parent_class = transformers.AutoModelForCausalLM - _supported_modules = ["ilql_heads"] - _supported_args = ["two_qs", "alpha"] +class CausalLMWithValueHeads(nn.Module): + """This is a wrapper around huggingface AutoModelForCausalLM with two additional scalar heads""" def __init__( self, - base_model: transformers.PreTrainedModel, - *, - two_qs: bool = True, - alpha: float = 0.99, + config: Union[transformers.PretrainedConfig, str], + ilql_config: ILQLConfig, + num_layers_unfrozen=-1, ): - super().__init__(base_model) - hidden_size = hf_get_hidden_size(self.base_model.config) - vocab_size = self.base_model.config.vocab_size - dtype = next(hf_get_lm_head(self.base_model).parameters()).dtype - self.two_qs = two_qs - self.alpha = alpha - self.ilql_heads = ILQLHeads(hidden_size, vocab_size, self.two_qs, self.alpha, dtype=dtype) + super().__init__() + + # enable zero3 init within from_pretrained + if os.environ.get("DEEPSPEED_ZERO_STAGE", "0") == "3": + config_path = os.environ.get("DEEPSPEED_CONFIG_FILE", "") + if config_path: + _hfconfig = transformers.deepspeed.HfDeepSpeedConfig(config_path) # noqa: F841 + + if isinstance(config, str): + self.config = transformers.AutoConfig.from_pretrained(config) + self.base_model = transformers.AutoModelForCausalLM.from_pretrained(config) + else: + self.config = config + self.base_model = transformers.AutoModelForCausalLM.from_config(config) + + self.base_model.transformer = hf_get_causal_base_model(self.base_model) + self.base_model.lm_head = hf_get_lm_head(self.base_model) + freeze_bottom_causal_layers(self.base_model, num_layers_unfrozen) + + # Cache `transformer.forward` args for general use (avoids incompatible args across architectures) + self.base_model_transformer_args = inspect.getfullargspec(self.base_model.transformer.forward).args + + dtype = next(self.base_model.lm_head.parameters()).dtype + self.hidden_size = hf_get_hidden_size(self.config) + self.ilql_heads = ilql_config.heads(self.hidden_size, self.config.vocab_size, dtype) + self.ilql_config = ilql_config + + def _get_compatible_forward_kwargs(self, **kwargs) -> Dict[str, Any]: + """Filter out arguments not supported by the specific instance of `base_model.transformer.forward`""" + return {k: v for k, v in kwargs.items() if k in self.base_model_transformer_args} + + def sync_target_q_heads(self): + self.ilql_heads.sync_target_q_heads() def forward( self, @@ -227,18 +240,19 @@ def forward( actions_ixs=None, states_ixs=None, ): - forward_kwargs = self.get_compatible_forward_kwargs( + forward_kwargs = self._get_compatible_forward_kwargs( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, ) - forward_kwargs["output_hidden_states"] = True + out = self.base_model.transformer(**forward_kwargs) + hs = out.last_hidden_state - outputs = self.base_model(**forward_kwargs) - qs, target_qs, vs = self.ilql_heads(outputs.hidden_states[-1], states_ixs=states_ixs, actions_ixs=actions_ixs) + logits = self.base_model.lm_head(hs) + qs, target_qs, vs = self.ilql_heads(hs, states_ixs=states_ixs, actions_ixs=actions_ixs) - return outputs.logits, qs, target_qs, vs, outputs.past_key_values + return logits, qs, target_qs, vs, out.past_key_values def generate( self, @@ -260,9 +274,6 @@ def generate( changing token probabilities as to how advantageous they would be according to value functions estimations. """ - pad_token_id = pad_token_id if pad_token_id is not None else self.base_model.config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.base_model.config.eos_token_id - if attention_mask is None: attention_mask = input_ids.not_equal(pad_token_id) @@ -283,7 +294,7 @@ def generate( ) logits, _, target_qs, vs, past_key_values = out - if self.two_qs: + if self.ilql_config.two_qs: qs = torch.minimum(target_qs[0][:, -1, :], target_qs[1][:, -1, :]) else: qs = target_qs[:, -1, :] @@ -313,29 +324,10 @@ def generate( return samples - def sync_target_q_heads(self): - self.ilql_heads.sync_target_q_heads() + @property + def dummy_inputs(self): + return {"input_ids": torch.ones(1, 1, device=self.base_model.device, dtype=torch.long)} - def state_dict(self, *args, **kwargs): - """ - Returns the state dictionary of the model. We add the state dictionary of the ilql heads - to the state dictionary of the wrapped model by prepending the key with `ilql_heads.`. - """ - base_model_state_dict = self.base_model.state_dict(*args, **kwargs) - ilql_heads_state_dict = self.ilql_heads.state_dict(*args, **kwargs) - for k, v in ilql_heads_state_dict.items(): - base_model_state_dict[f"ilql_heads.{k}"] = v - return base_model_state_dict - - def post_init(self, state_dict): - """ - We add the state dictionary of the ilql heads to the state dictionary of the wrapped model - by preprending the key with `ilql_heads.`. This function removes the `ilql_heads.` prefix from the - keys of the value head state dictionary. - """ - for k in list(state_dict.keys()): - if "ilql_heads." in k: - state_dict[k.replace("ilql_heads.", "")] = state_dict.pop(k) - self.ilql_heads.load_state_dict(state_dict, strict=False) - del state_dict - gc.collect() + @property + def device(self): + return self.base_model.device diff --git a/trlx/models/modeling_ppo.py b/trlx/trainer/nn/ppo_models.py similarity index 59% rename from trlx/models/modeling_ppo.py rename to trlx/trainer/nn/ppo_models.py index 787286123..6cbb64d25 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/trainer/nn/ppo_models.py @@ -1,8 +1,7 @@ -import gc import inspect from copy import deepcopy from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import numpy as np import torch @@ -14,13 +13,12 @@ from transformers.models.opt import modeling_opt from trlx.data.method_configs import MethodConfig, register_method -from trlx.models.modeling_base import PreTrainedModelWrapper from trlx.utils.modeling import ( flatten_dict, get_tensor_stats, - hf_get_decoder, - hf_get_decoder_blocks, - hf_get_decoder_final_norm, + hf_get_causal_base_model, + hf_get_causal_final_norm, + hf_get_causal_hidden_layers, hf_get_hidden_size, hf_get_lm_head, hf_get_num_hidden_layers, @@ -124,7 +122,7 @@ class PPOConfig(MethodConfig): cliprange: float cliprange_value: float vf_coef: float - scale_reward: Optional[str] + scale_reward: str ref_mean: Optional[float] ref_std: Optional[float] cliprange_reward: float @@ -150,11 +148,11 @@ def get_advantages_and_returns( Ret1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ... + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ... - Args: - values: Tensor of shape (batch_size, response_size) - rewards: Tensor of shape (batch_size, response_size) - response_length: Length of the response sequence - use_whitening: Whether to use whitening (ie. normalize advantages) or not + Input: + - values: Tensor of shape (batch_size, response_size) + - rewards: Tensor of shape (batch_size, response_size) + - response_length: Length of the response sequence + - use_whitening: Whether to use whitening (ie. normalize advantages) or not """ lastgaelam = 0 advantages_reversed = [] @@ -233,13 +231,13 @@ def loss( return loss, flatten_dict(stats) -# CausalLM architectures +# PPO Layers @dataclass -class CausalLMOutputWithValue(ModelOutput): +class CausalLMOutputWithCrossAttentions(ModelOutput): loss: Optional[torch.FloatTensor] = None - logits: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None @@ -247,193 +245,413 @@ class CausalLMOutputWithValue(ModelOutput): value: Optional[torch.FloatTensor] = None -class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper): - """An `AutoModel` class wrapper for `transformers` causal models that have a - language modeling head and a value head +class CausalLMWithValueHead(nn.Module): + """The CausalLMWithValueModel class implements a causal language model with + a secondary, scalar head. """ - _auto_model_parent_class = transformers.AutoModelForCausalLM - _supported_modules = ["v_head"] - _supported_args = [] + def __init__(self, config: Union[transformers.PretrainedConfig, str]): + super().__init__() + if isinstance(config, str): + self.config = transformers.AutoConfig.from_pretrained(config) + self.base_model = transformers.AutoModelForCausalLM.from_pretrained(config) + else: + self.config = config + self.base_model = transformers.AutoModelForCausalLM.from_config(config) + + self.base_model.transformer = hf_get_causal_base_model(self.base_model) + self.base_model.lm_head = hf_get_lm_head(self.base_model) + dtype = next(self.base_model.lm_head.parameters()).dtype + self.v_head = make_head(hf_get_hidden_size(self.config), 1, dtype) + + # Cache `transformer.forward` args for general use (avoids incompatible args across architectures) + self.base_model_transformer_args = inspect.getfullargspec(self.base_model.transformer.forward).args + + def _get_compatible_forward_kwargs(self, **kwargs) -> Dict[str, Any]: + """Filter out arguments not supported by the specific instance of `base_model.transformer.forward`""" + return {k: v for k, v in kwargs.items() if k in self.base_model_transformer_args} + + def generate(self, input_ids, **kwargs): + return self.base_model.generate(input_ids, **kwargs) + + def forward( + self, + input_ids=None, + attention_mask=None, + past_key_values=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + return_dict=False, + ): + forward_kwargs = self._get_compatible_forward_kwargs( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + ) + transformer_outputs = self.base_model.transformer(**forward_kwargs) + last_hidden_state = transformer_outputs.last_hidden_state + lm_logits = self.base_model.lm_head(last_hidden_state) + value = self.v_head(last_hidden_state).squeeze(-1) + + if not return_dict: + outputs = (lm_logits,) + transformer_outputs[1:] + (value,) + return outputs + + return CausalLMOutputWithCrossAttentions( + loss=None, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + value=value, + ) + + +class CausalLMHydraWithValueHead(nn.Module): + """The CausalLMHydraWithValueHead class implements a causal language model + with a secondary, scalar head. + """ def __init__( self, - base_model: transformers.PreTrainedModel, + config: Union[transformers.PretrainedConfig, str], + num_layers_unfrozen: int = -1, ): - super().__init__(base_model) - self.v_head = make_head(hf_get_hidden_size(self.base_model.config), 1) + super().__init__() + + if isinstance(config, str): + self.config = transformers.AutoConfig.from_pretrained(config) + self.base_model = transformers.AutoModelForCausalLM.from_pretrained(config) + else: + self.config = config + self.base_model = transformers.AutoModelForCausalLM.from_config(config) + + self.base_model.transformer = hf_get_causal_base_model(self.base_model) + self.base_model.lm_head = hf_get_lm_head(self.base_model) + dtype = next(self.base_model.lm_head.parameters()).dtype + self.v_head = make_head(hf_get_hidden_size(self.config), 1, dtype) + + self.num_layers_unfrozen = num_layers_unfrozen + if self.num_layers_unfrozen > 0: + transformer_blocks = list(hf_get_causal_hidden_layers(self.base_model)) + branch_class = hf_get_causal_lm_branch_class(self.config) + self.frozen_head = branch_class( + self.config, + transformer_blocks[-self.num_layers_unfrozen :], + final_norm=hf_get_causal_final_norm(self.base_model), + lm_head=self.base_model.lm_head, + ) + # Cache `transformer.forward` args for general use (avoids incompatible args across architectures) + self.base_model_transformer_args = inspect.getfullargspec(self.base_model.transformer.forward).args + + def _get_compatible_forward_kwargs(self, **kwargs) -> Dict[str, Any]: + """Filter out arguments not supported by the specific instance of `base_model.transformer.forward`""" + return {k: v for k, v in kwargs.items() if k in self.base_model_transformer_args} + + def generate(self, input_ids, **x): + return self.base_model.generate(input_ids, **x) + + def forward_hydra(self, input_ids, **forward_kwargs): + forward_kwargs = self._get_compatible_forward_kwargs(**forward_kwargs) + if forward_kwargs.get("return_dict") is not None: + return_dict = forward_kwargs["return_dict"] + else: + return_dict = True + forward_kwargs["return_dict"] = True + forward_kwargs["output_hidden_states"] = True + output = self.forward(input_ids, **forward_kwargs) + all_hidden_states = output.hidden_states + # Get output of last frozen hidden layer + # Select hidden state before first layer of branch. + input_hidden_state = all_hidden_states[-(self.num_layers_unfrozen + 1)] + # Get size of last hidden state + output_shape = all_hidden_states[-1].size() + outputs = self.frozen_head(input_hidden_state, output_shape, **forward_kwargs) + if not return_dict: + return outputs.logits + return outputs def forward( self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - position_ids: Optional[List[torch.FloatTensor]] = None, - head_mask: Optional[torch.Tensor] = None, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = True, return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithValue]: - forward_kwargs = self.get_compatible_forward_kwargs( + ): + forward_kwargs = self._get_compatible_forward_kwargs( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, - past_key_values=past_key_values, head_mask=head_mask, inputs_embeds=inputs_embeds, - use_cache=use_cache, + past_key_values=past_key_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, + token_type_ids=token_type_ids, ) - forward_kwargs["output_hidden_states"] = True - forward_kwargs["return_dict"] = True - - outputs = self.base_model(**forward_kwargs) - value = self.v_head(outputs.hidden_states[-1]).squeeze(-1) + transformer_outputs = self.base_model.transformer(**forward_kwargs) + last_hidden_state = transformer_outputs.last_hidden_state + lm_logits = self.base_model.lm_head(last_hidden_state) + value = self.v_head(last_hidden_state).squeeze(-1) if not return_dict: - outputs = (outputs.logits,) + outputs[1:] + (value,) + outputs = (lm_logits,) + transformer_outputs[1:] + (value,) return outputs - return CausalLMOutputWithValue(**outputs, value=value) - - def generate(self, *args, **kwargs) -> Union[ModelOutput, torch.LongTensor]: - return self.base_model.generate(*args, **kwargs) - - def state_dict(self, *args, **kwargs): - """ - Returns the state dictionary of the model. We add the state dictionary of the value head - to the state dictionary of the wrapped model by prepending the key with `v_head.`. - """ - base_model_state_dict = self.base_model.state_dict(*args, **kwargs) - v_head_state_dict = self.v_head.state_dict(*args, **kwargs) - for k, v in v_head_state_dict.items(): - base_model_state_dict[f"v_head.{k}"] = v - return base_model_state_dict + return CausalLMOutputWithCrossAttentions( + loss=None, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=None, + value=value, + ) - def post_init(self, state_dict): - """ - Adds the state dictionary of the value head to the state dictionary of the wrapped model - by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the - keys of the value head state dictionary. - """ - for k in list(state_dict.keys()): - if "v_head." in k: - state_dict[k.replace("v_head.", "")] = state_dict.pop(k) - self.v_head.load_state_dict(state_dict, strict=False) - del state_dict - gc.collect() # noqa: E702 +@dataclass +class Seq2SeqLMOutput(ModelOutput): + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + value: Optional[torch.FloatTensor] = None -class AutoModelForCausalLMWithHydraValueHead(AutoModelForCausalLMWithValueHead): - _supported_modules = ["v_head", "frozen_head"] - _supported_args = ["num_layers_unfrozen"] +class Seq2SeqLMHydraWithValueHead(nn.Module): def __init__( self, - base_model: transformers.PreTrainedModel, - *, + config: Union[transformers.PretrainedConfig, str], num_layers_unfrozen: int = -1, ): - super().__init__(base_model) + super().__init__() + if isinstance(config, str): + self.config = transformers.AutoConfig.from_pretrained(config) + else: + self.config = config + self.base_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.config.name_or_path) + self.v_head = make_head(hf_get_hidden_size(self.config), 1) + self.num_layers_unfrozen = num_layers_unfrozen if self.num_layers_unfrozen > 0: - config = self.base_model.config - branch_class = hf_get_branch_class(config) - self.frozen_head = branch_class( - self.base_model, - num_layers_unfrozen=self.num_layers_unfrozen, - ).eval() + self.frozen_head = T5Branch(self.config, self.base_model, self.num_layers_unfrozen) + # Cache `transformer.forward` args for general use (avoids incompatible args across architectures) + self.base_model_args = inspect.getfullargspec(self.base_model.forward).args - def forward_hydra( + def _get_compatible_forward_kwargs(self, **kwargs) -> Dict[str, Any]: + """Filter out arguments not supported by the specific instance of `base_model.transformer.forward`""" + return {k: v for k, v in kwargs.items() if k in self.base_model_args} + + def generate(self, input_ids, **x): + return self.base_model.generate(input_ids, **x) + + def forward_hydra(self, input_ids, attention_mask, decoder_input_ids, **forward_kwargs): + forward_kwargs = self._get_compatible_forward_kwargs(**forward_kwargs) + forward_kwargs["return_dict"] = True + output = self.forward(input_ids, attention_mask, decoder_input_ids, **forward_kwargs) + all_hidden_states = output.decoder_hidden_states + # Get output of last frozen hidden layer + # Select hidden state before first layer of branch. + input_hidden_state = all_hidden_states[-(self.num_layers_unfrozen + 1)] + encoder_hidden_states = output.encoder_last_hidden_state + # Get size of last hidden state + outputs = self.frozen_head( + decoder_input_ids, + input_hidden_state, + encoder_hidden_states, + attention_mask, + False, + False, + ) + return outputs.logits + + def forward( self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - position_ids: Optional[List[torch.FloatTensor]] = None, - head_mask: Optional[torch.Tensor] = None, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.FloatTensor] = None, + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = True, + output_hidden_states: Optional[bool] = True, return_dict: Optional[bool] = None, - ) -> Union[torch.FloatTensor, CausalLMOutputWithValue]: - forward_kwargs = self.get_compatible_forward_kwargs( + ): + forward_kwargs = self._get_compatible_forward_kwargs( input_ids=input_ids, attention_mask=attention_mask, - position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, past_key_values=past_key_values, - head_mask=head_mask, inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) - return_dict = forward_kwargs.get("return_dict", True) - forward_kwargs["return_dict"] = True - forward_kwargs["output_hidden_states"] = True + t5_outputs = self.base_model(**forward_kwargs) + lm_logits = t5_outputs.logits + last_hidden_state = t5_outputs.decoder_hidden_states[-1] + value = self.v_head(last_hidden_state).squeeze(-1) + + return Seq2SeqLMOutput( + loss=None, + logits=lm_logits, + decoder_hidden_states=t5_outputs.decoder_hidden_states, + decoder_attentions=t5_outputs.decoder_attentions, + cross_attentions=t5_outputs.cross_attentions, + encoder_last_hidden_state=t5_outputs.encoder_last_hidden_state, + encoder_hidden_states=t5_outputs.encoder_hidden_states, + encoder_attentions=t5_outputs.encoder_attentions, + past_key_values=t5_outputs.past_key_values, + value=value, + ) - outputs = self.forward(**forward_kwargs) - # Select the hidden state before the first branching layer - input_hidden_state = outputs.hidden_states[-(self.num_layers_unfrozen + 1)] - output_shape = outputs.hidden_states[-1].size() - forward_kwargs.pop("input_ids", None) # Ignore `input_ids` for branch head - forward_kwargs.pop("inputs_embeds", None) # Ignore `inputs_embeds` for branch head - hydra_outputs = self.frozen_head(input_hidden_state, output_shape, **forward_kwargs) +class T5Branch(transformers.PreTrainedModel): + # Decoder branch only + def __init__( + self, + config: transformers.PretrainedConfig, + base_model: transformers.PreTrainedModel, + num_layers_unfrozen: int, + ): + super().__init__(config) + + # Defined by the main trunk + self.hidden_size = hf_get_hidden_size(config) + self.decoder = deepcopy(base_model.decoder) + self.decoder.block = nn.ModuleList(self.decoder.block[-num_layers_unfrozen:]) + self.lm_head = deepcopy(base_model.lm_head) + # Model parallel + self.model_parallel = False + self.device_map = None + self.last_device = None + self.gradient_checkpointing = False - if not return_dict: - return hydra_outputs.logits - return hydra_outputs + for parameter in self.parameters(): + parameter.requires_grad = False + def forward( + self, + input_ids, + hidden_states, + encoder_hidden_states, + encoder_attention_mask, + use_cache: bool = False, + output_attentions: bool = False, + ): + input_shape = input_ids.size() + batch_size, seq_length = input_shape + + attention_mask = torch.ones(batch_size, seq_length, device=hidden_states.device) + + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + position_bias = None + encoder_decoder_position_bias = None + + for i, layer_module in enumerate(self.decoder.block): + layer_outputs = layer_module( + hidden_states, # size: (batch_size, seq_length, hidden_size) + attention_mask=extended_attention_mask, # size: (batch_size, 1, seq_length, seq_length) + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + use_cache=use_cache, + output_attentions=output_attentions, + ) -class ModelBranch(transformers.PreTrainedModel): - """Implements the frozen upper trunk of the pretrained reference model used - when computing the PPO KL-divergence penalty. + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + # append next layer key value states + + hidden_states = self.decoder.final_layer_norm(hidden_states) + hidden_states = self.decoder.dropout(hidden_states) + lm_logits = self.lm_head(hidden_states) + + return Seq2SeqLMOutput(logits=lm_logits) + + +class GPTModelBranch(transformers.PreTrainedModel): + """ + GPTModelBranch implements the frozen upper trunk of the reference model + used when computing the PPO KL-divergence penalty. Expects a list of + frozen transformer blocks and an lm_head from the base model. """ def __init__( self, - base_model: transformers.PreTrainedModel, - *, - num_layers_unfrozen: int, + config: transformers.PretrainedConfig, + transformer_blocks: nn.ModuleList, + final_norm: nn.Module, + lm_head: nn.Module, ): - """ - Args: - base_model (transformers.PreTrainedModel): The pretrained model to extract upper trunk from - num_layers_unfrozen (int): The number of trainable layers - """ - super().__init__(base_model.config) + super().__init__(config) - # The branch is defined by the last `num_layers_unfrozen` layers of the pretrained model - decoder_blocks = deepcopy(hf_get_decoder_blocks(base_model)) - self.decoder_blocks = nn.ModuleList(list(decoder_blocks)[-num_layers_unfrozen:]) - self.final_norm = deepcopy(hf_get_decoder_final_norm(base_model)) - self.lm_head = deepcopy(hf_get_lm_head(base_model)) + # Defined by the main trunk + self.hidden_size = hf_get_hidden_size(config) + self.transformer_blocks = deepcopy(nn.ModuleList(transformer_blocks)) + self.final_norm = deepcopy(final_norm) + self.lm_head = deepcopy(lm_head) - self.hidden_size = hf_get_hidden_size(self.config) + # Model parallel self.model_parallel = False self.device_map = None - self.last_device = None self.gradient_checkpointing = False - # Freeze the entire branch + # Turning off grad saves memory + for parameter in self.parameters(): parameter.requires_grad_(False) - -class GPTModelBranch(ModelBranch): def forward( # noqa: max-complexity self, hidden_states: torch.Tensor, # Takes as input hidden_states instead of input_ids output_shape: torch.Tensor, # output_size given by main trunk past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, attention_mask: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, @@ -441,10 +659,8 @@ def forward( # noqa: max-complexity output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = False, - ) -> Union[Tuple, CausalLMOutputWithValue]: - """Reference: - https://github.com/huggingface/transformers/blob/2411f0e465e761790879e605a4256f3d4afb7f82/src/transformers/models/gpt2/modeling_gpt2.py#L743 # noqa: E501 - """ + position_ids: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: batch_size = hidden_states.size()[0] output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -457,16 +673,30 @@ def forward( # noqa: max-complexity device = hidden_states.device if past_key_values is None: - past_key_values = tuple([None] * len(self.decoder_blocks)) + past_key_values = tuple([None] * len(self.transformer_blocks)) + # GPT2Attention mask. if attention_mask is not None: if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.add_cross_attention and encoder_hidden_states is not None: ( encoder_batch_size, @@ -480,17 +710,24 @@ def forward( # noqa: max-complexity else: encoder_attention_mask = None + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N head_mask = self.get_head_mask(head_mask, hf_get_num_hidden_layers(self.config)) presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None - for i, (block, layer_past) in enumerate(zip(self.decoder_blocks, past_key_values)): + for i, (block, layer_past) in enumerate(zip(self.transformer_blocks, past_key_values)): + # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) if layer_past is not None: layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states if attention_mask is not None: attention_mask = attention_mask.to(hidden_states.device) if isinstance(head_mask, torch.Tensor): @@ -530,6 +767,7 @@ def forward( # noqa: max-complexity if self.config.add_cross_attention: all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + # Model Parallel: If it's the last layer for that device, put things on the next device if self.model_parallel: for k, v in self.device_map.items(): if i == v[-1] and "cuda:" + str(k) != self.last_device: @@ -538,9 +776,19 @@ def forward( # noqa: max-complexity hidden_states = self.final_norm(hidden_states) hidden_states = hidden_states.view(output_shape) + # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + # last_hidden_state = hidden_states + # past_key_values = presents + # hidden_states = all_hidden_states + # attentions = all_self_attentions + # cross_attentions = all_cross_attentions + + # START OF CAUSAL HEAD # + # hidden_states = hidden_states.to(torch.float32) Present for gptj + if self.model_parallel: torch.cuda.set_device(self.transformer.first_device) hidden_states = hidden_states.to(self.lm_head.weight.device) @@ -551,23 +799,54 @@ def forward( # noqa: max-complexity outputs = (lm_logits,) + (None,) + (None,) return outputs - return CausalLMOutputWithValue( + return CausalLMOutputWithCrossAttentions( + loss=None, logits=lm_logits, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, + value=None, ) -class OPTModelBranch(ModelBranch): +class OPTModelBranch(transformers.PreTrainedModel): + """ + OPTModelBranch implements the frozen upper trunk of the reference model + used when computing the PPO KL-divergence penalty. Expects a list of + frozen transformer blocks and an lm_head from the base model. + """ + + def __init__( + self, + config: transformers.PretrainedConfig, + transformer_blocks: nn.ModuleList, + final_norm: nn.Module, + lm_head: nn.Module, + ): + super().__init__(config) + + # Defined by the main trunk + self.hidden_size = hf_get_hidden_size(config) + self.transformer_blocks = deepcopy(nn.ModuleList(transformer_blocks)) + self.final_norm = deepcopy(final_norm) + self.lm_head = deepcopy(lm_head) + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + + # Turning off grad saves memory + for parameter in self.parameters(): + parameter.requires_grad_(False) + def forward( # noqa: max-complexity self, - hidden_states: torch.Tensor, - output_shape: torch.Tensor, + hidden_states: torch.Tensor, # Takes as input hidden_states instead of input_ids + output_shape: torch.Tensor, # output_size given by main trunk past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, attention_mask: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, @@ -575,10 +854,9 @@ def forward( # noqa: max-complexity output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = False, - ) -> Union[Tuple, CausalLMOutputWithValue]: - """Reference: - https://github.com/huggingface/transformers/blob/bdb84e2bada3658f99c6a81c963ec562f8485151/src/transformers/models/opt/modeling_opt.py#L840 # noqa: E501 - """ + position_ids: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + """Override OPTForCausalLM""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -586,12 +864,17 @@ def forward( # noqa: max-complexity use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + ####################################################################### + # Modififed OPTDecoder.forward + ####################################################################### + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(hidden_states.shape[:2], dtype=torch.bool, device=hidden_states.device) input_shape = hidden_states.size()[:-1] + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = modeling_opt._make_causal_mask( @@ -601,6 +884,7 @@ def forward( # noqa: max-complexity ).to(hidden_states.device) if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] expanded_attn_mask = modeling_opt._expand_mask( attention_mask, hidden_states.dtype, tgt_len=input_shape[-1] ).to(hidden_states.device) @@ -609,19 +893,21 @@ def forward( # noqa: max-complexity ) attention_mask = combined_attention_mask + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None + # check if head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask], ["head_mask"]): if attn_mask is not None: - if attn_mask.size()[0] != (len(self.decoder_blocks)): + if attn_mask.size()[0] != (len(self.transformer_blocks)): raise ValueError( - f"The `{mask_name}` should be specified for {len(self.decoder_blocks)} layers, but it is for" + f"The `{mask_name}` should be specified for {len(self.transformer_blocks)} layers, but it is for" f" {head_mask.size()[0]}." ) - for idx, decoder_layer in enumerate(self.decoder_blocks): + for idx, decoder_layer in enumerate(self.transformer_blocks): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -652,11 +938,16 @@ def forward( # noqa: max-complexity # if self.project_out is not None: # hidden_states = self.project_out(hidden_states) + # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + ####################################################################### + # End of modified OPTDecoder.forward + ####################################################################### + lm_logits = self.lm_head(hidden_states).contiguous() if not return_dict: @@ -672,22 +963,54 @@ def forward( # noqa: max-complexity if v is not None ) - return CausalLMOutputWithValue( + return CausalLMOutputWithCrossAttentions( + loss=None, logits=lm_logits, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, + cross_attentions=None, + value=None, ) -class BloomModelBranch(ModelBranch): - def forward( # noqa: max-complexity +class BloomModelBranch(transformers.PreTrainedModel): + """ + BloomModelBranch implements the frozen upper trunk of the reference model + used when computing the PPO KL-divergence penalty. Expects a list of + frozen transformer blocks and an lm_head from the base model. + """ + + def __init__( + self, + config: transformers.PretrainedConfig, + transformer_blocks: nn.ModuleList, + final_norm: nn.Module, + lm_head: nn.Module, + ): + super().__init__(config) + + # Defined by the main trunk + self.hidden_size = hf_get_hidden_size(config) + self.transformer_blocks = deepcopy(nn.ModuleList(transformer_blocks)) + self.final_norm = deepcopy(final_norm) + self.lm_head = deepcopy(lm_head) + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + + # Turning off grad saves memory + for parameter in self.parameters(): + parameter.requires_grad_(False) + + def forward( # noqa: C901 self, hidden_states: torch.Tensor, # Takes as input hidden_states instead of input_ids - output_shape: torch.Tensor, + output_shape: torch.Tensor, # output_size given by main trunk past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, attention_mask: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, @@ -695,10 +1018,8 @@ def forward( # noqa: max-complexity output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = False, - ) -> Union[Tuple, CausalLMOutputWithValue]: - """Reference: - https://github.com/huggingface/transformers/blob/2411f0e465e761790879e605a4256f3d4afb7f82/src/transformers/models/bloom/modeling_bloom.py#L623 # noqa: E501 - """ + position_ids: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -706,17 +1027,26 @@ def forward( # noqa: max-complexity use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + ####################################################################### + # Modififed BloomModel.forward + ####################################################################### + batch_size, seq_length = hidden_states.shape[:2] if past_key_values is None: - past_key_values = tuple([None] * len(self.decoder_blocks)) + past_key_values = tuple([None] * len(self.transformer_blocks)) + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, hf_get_num_hidden_layers(self.config)) presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None + # Compute alibi tensor: check modeling_bloom.build_alibi_tensor documentation seq_length_with_past = seq_length past_key_values_length = 0 if past_key_values[0] is not None: @@ -729,6 +1059,8 @@ def forward( # noqa: max-complexity alibi = modeling_bloom.build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype) + # create causal mask + # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] combined_attention_mask = None device = attention_mask.device input_shape = (batch_size, seq_length) @@ -741,13 +1073,14 @@ def forward( # noqa: max-complexity past_key_values_length=past_key_values_length, ) + # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] expanded_attn_mask = modeling_bloom._expand_mask(attention_mask, tgt_length=src_length) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask ) causal_mask = combined_attention_mask - for i, (block, layer_past) in enumerate(zip(self.decoder_blocks, past_key_values)): + for i, (block, layer_past) in enumerate(zip(self.transformer_blocks, past_key_values)): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -768,11 +1101,16 @@ def forward( # noqa: max-complexity if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + # Add last hidden state hidden_states = self.final_norm(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + ####################################################################### + # End of modified BloomModel.forward + ####################################################################### + lm_logits = self.lm_head(hidden_states) if not return_dict: @@ -788,315 +1126,21 @@ def forward( # noqa: max-complexity if v is not None ) - return CausalLMOutputWithValue( + return CausalLMOutputWithCrossAttentions( + loss=None, logits=lm_logits, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions, + cross_attentions=None, + value=None, ) -# Seq2Seq architectures - - -@dataclass -class Seq2SeqLMOutputWithValue(ModelOutput): - loss: Optional[torch.FloatTensor] = None - logits: Optional[torch.FloatTensor] = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None - decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor]] = None - encoder_last_hidden_state: Optional[torch.FloatTensor] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None - value: Optional[torch.FloatTensor] = None - - -class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper): - """An `AutoModel` class wrapper for `transformers` sequence-to-sequence - models that have a language modeling head and a value head - """ - - _auto_model_parent_class = transformers.AutoModelForSeq2SeqLM - _supported_modules = ["v_head"] - _supported_args = [] - - def __init__( - self, - base_model: transformers.PreTrainedModel, - ): - super().__init__(base_model) - self.v_head = make_head(hf_get_hidden_size(self.base_model.config), 1) - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.FloatTensor] = None, - encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - decoder_head_mask: Optional[torch.FloatTensor] = None, - cross_attn_head_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = True, - output_hidden_states: Optional[bool] = True, - return_dict: Optional[bool] = None, - ) -> Seq2SeqLMOutputWithValue: - forward_kwargs = self.get_compatible_forward_kwargs( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - encoder_outputs=encoder_outputs, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - forward_kwargs["output_hidden_states"] = True - forward_kwargs["return_dict"] = True - - outputs = self.base_model(**forward_kwargs) - last_hidden_state = outputs.decoder_hidden_states[-1] - value = self.v_head(last_hidden_state).squeeze(-1) - - return Seq2SeqLMOutputWithValue(**outputs, value=value) - - def generate(self, *args, **kwargs) -> Union[ModelOutput, torch.LongTensor]: - return self.base_model.generate(*args, **kwargs) - - def state_dict(self, *args, **kwargs): - """ - Returns the state dictionary of the model. We add the state dictionary of the value head - to the state dictionary of the wrapped model by prepending the key with `v_head.`. - """ - base_model_state_dict = self.base_model.state_dict(*args, **kwargs) - v_head_state_dict = self.v_head.state_dict(*args, **kwargs) - for k, v in v_head_state_dict.items(): - base_model_state_dict[f"v_head.{k}"] = v - return base_model_state_dict - - def post_init(self, state_dict): - """ - We add the state dictionary of the value head to the state dictionary of the wrapped model - by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the - keys of the value head state dictionary. - """ - for k in list(state_dict.keys()): - if "v_head." in k: - state_dict[k.replace("v_head.", "")] = state_dict.pop(k) - self.v_head.load_state_dict(state_dict, strict=False) - del state_dict - gc.collect() # noqa: E702 - - -class AutoModelForSeq2SeqLMWithHydraValueHead(AutoModelForSeq2SeqLMWithValueHead): - _supported_modules = ["v_head", "frozen_head"] - _supported_args = ["num_layers_unfrozen"] - - def __init__( - self, - base_model: transformers.PreTrainedModel, - *, - num_layers_unfrozen: int = -1, - ): - super().__init__(base_model) - self.num_layers_unfrozen = num_layers_unfrozen - if self.num_layers_unfrozen > 0: - branch_class = T5Branch # TODO: Add support for other model branches - self.frozen_head = branch_class( - self.base_model, - num_layers_unfrozen=self.num_layers_unfrozen, - ).eval() - - def forward_hydra( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.FloatTensor] = None, - encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - decoder_head_mask: Optional[torch.FloatTensor] = None, - cross_attn_head_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Seq2SeqLMOutputWithValue: - forward_kwargs = self.get_compatible_forward_kwargs( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - encoder_outputs=encoder_outputs, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - return_dict = forward_kwargs.get("return_dict", True) - forward_kwargs["output_hidden_states"] = True - forward_kwargs["return_dict"] = True - - outputs = self.forward(**forward_kwargs) - # Select the hidden state before the first branching layer - input_hidden_state = outputs.decoder_hidden_states[-(self.num_layers_unfrozen + 1)] - hydra_outputs = self.frozen_head( - hidden_states=input_hidden_state, - attention_mask=decoder_attention_mask, - encoder_hidden_states=outputs.encoder_last_hidden_state, - encoder_attention_mask=attention_mask, - use_cache=False, - output_attentions=False, - output_hidden_states=True, - return_dict=return_dict, - ) - - if not return_dict: - return hydra_outputs.logits - return hydra_outputs - - -class T5Branch(ModelBranch): - """Decoder only T5 branch""" - - def __init__( - self, - base_model: transformers.PreTrainedModel, - *, - num_layers_unfrozen: int, - ): - super().__init__(base_model, num_layers_unfrozen=num_layers_unfrozen) - self.dropout = hf_get_decoder(base_model).dropout - self.is_decoder = True - - def forward( # noqa: max-complexity - self, - hidden_states: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.LongTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, Seq2SeqLMOutputWithValue]: - """Reference: - https://github.com/huggingface/transformers/blob/bc21aaca789f1a366c05e8b5e111632944886393/src/transformers/models/t5/modeling_t5.py#L899 # noqa: E501 - """ - batch_size, seq_length = hidden_states.shape[:2] - input_shape = (batch_size, seq_length) - - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if attention_mask is None: - attention_mask = torch.ones(batch_size, seq_length, device=hidden_states.device) - if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones( - batch_size, encoder_seq_length, device=hidden_states.device, dtype=torch.long - ) - - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) - - if self.is_decoder and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=hidden_states.device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_extended_attention_mask = None - - all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - position_bias = None - encoder_decoder_position_bias = None - - for _, layer_module in enumerate(self.decoder_blocks): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = layer_module( - hidden_states, - attention_mask=extended_attention_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - if use_cache is False: - layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - - hidden_states, present_key_value_state = layer_outputs[:2] - - position_bias = layer_outputs[2] - if self.is_decoder and encoder_hidden_states is not None: - encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[3],) - - hidden_states = self.final_norm(hidden_states) - hidden_states = self.dropout(hidden_states) - - # Add last layer - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - sequence_output = hidden_states - - if self.config.tie_word_embeddings: - # Rescale output before projecting on vocab - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 # noqa: E501 - sequence_output = sequence_output * (self.config.d_model**-0.5) - - lm_logits = self.lm_head(sequence_output) - - if not return_dict: - return (lm_logits,) - - return Seq2SeqLMOutputWithValue( - logits=lm_logits, - decoder_hidden_states=all_hidden_states, - decoder_attentions=all_attentions, - ) - - -# Branch class utils - - -def hf_get_branch_class( +def hf_get_causal_lm_branch_class( config: transformers.PretrainedConfig, ) -> "ModelBranch": - """Returns the model branch class for the given config.""" + """Returns the CausalLM branch class for the given config.""" gpt_branch_supported_archs = [ "GPTJForCausalLM", "GPT2LMHeadModel", diff --git a/trlx/trlx.py b/trlx/trlx.py index 6338ebc53..10e3621d1 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -3,11 +3,6 @@ from typing import Callable, Dict, Iterable, List, Optional, Tuple from trlx.data.configs import TRLConfig -from trlx.data.default_configs import ( - default_ilql_config, - default_ppo_config, - default_sft_config, -) from trlx.utils import set_seed from trlx.utils.loading import get_pipeline, get_trainer @@ -22,6 +17,7 @@ def train( # noqa: C901 eval_prompts: Optional[List[str]] = None, metric_fn: Optional[Callable[[List[str], List[str], List[str]], Dict[str, List[float]]]] = None, config: Optional[TRLConfig] = None, + logit_mask: Optional[List[List[bool]]] = None, stop_sequences: Optional[List[str]] = [], ): """ @@ -49,6 +45,7 @@ def train( # noqa: C901 Function to compute statistics on batches of generated samples. Its arguments are the same as in `reward_fn` (`samples`, `prompts`, `outputs`) but the return is dictionary with keys as metric's name and values and lists of numeric values per each sample in batch + logit_mask (Optional[List]): Bigram masking matrix stop_sequences (Optional[List[str]]): String sequences to trim generations (both for generating of experience and evaluation) up to its encounter in them. Generations will not contain them and also will also be right-stripped @@ -58,11 +55,11 @@ def train( # noqa: C901 "Passing the `config` argument implicitly is depreciated, load it from `configs` directory instead" ) if reward_fn: - config = default_ppo_config() + config = TRLConfig.load_yaml("configs/ppo_config.yml") elif rewards: - config = default_ilql_config() + config = TRLConfig.load_yaml("configs/ilql_config.yml") else: - config = default_sft_config() + config = TRLConfig.load_yaml("configs/sft_config.yml") set_seed(config.train.seed) diff --git a/trlx/utils/__init__.py b/trlx/utils/__init__.py index 803557724..4f33d1502 100644 --- a/trlx/utils/__init__.py +++ b/trlx/utils/__init__.py @@ -29,7 +29,7 @@ def significant(x: Number, ndigits=2) -> Number: if isinstance(x, torch.Tensor): x = x.item() - if not isinstance(x, Number) or math.isnan(x) or x == 0: + if not isinstance(x, Number) or x == 0: return x return round(x, ndigits - int(math.floor(math.log10(abs(x))))) diff --git a/trlx/utils/modeling.py b/trlx/utils/modeling.py index d810dacc4..8aff33df0 100644 --- a/trlx/utils/modeling.py +++ b/trlx/utils/modeling.py @@ -33,7 +33,7 @@ def make_head(n_embd: int, out: int, dtype: type = torch.float32) -> nn.Sequenti def freeze_bottom_causal_layers(model: nn.Module, num_layers_unfrozen: int = 0): """Freezes the bottom transformer block layers of the specified model.""" - hidden_layers = hf_get_decoder_blocks(model) + hidden_layers = hf_get_causal_hidden_layers(model) if num_layers_unfrozen == 0: hidden_layers_to_freeze = list(hidden_layers) elif num_layers_unfrozen > 0: @@ -102,7 +102,7 @@ def findattr(obj, attrs: Tuple[str]) -> Union[object, None]: raise ValueError(f"Could not find an attribute from `{attrs}` in `{obj}`") -def hf_get_decoder(model: nn.Module) -> nn.Module: +def hf_get_causal_base_model(model: transformers.AutoModelForCausalLM) -> nn.Module: """Returns the causal decoder backbone of the specified HuggingFace transformers model. NOTE: Different model configurations have different causal decoder attribute @@ -111,12 +111,12 @@ def hf_get_decoder(model: nn.Module) -> nn.Module: - model.decoder: (OPTConfig, BloomConfig) - gpt_neox: (GPTNeoXConfig) """ - decoder_attrs = ("transformer", "model.decoder", "gpt_neox", "decoder") + decoder_attrs = ("transformer", "model.decoder", "gpt_neox") return findattr(model, decoder_attrs) -def hf_get_decoder_final_norm(model: nn.Module) -> float: - """Returns the final (layer) norm of the specified decoder. +def hf_get_causal_final_norm(model: nn.Module) -> float: + """Returns the final (layer) norm of the specified model. NOTE: Different model configurations have different final norm attribute names. - transformer.ln_f: (GPT2LMHeadModel, GPTJForCausalLM) - model.decoder.final_layer_norm: (OPTForCausalLM) @@ -125,19 +125,17 @@ def hf_get_decoder_final_norm(model: nn.Module) -> float: norm_attrs = ( "transformer.ln_f", "model.decoder.final_layer_norm", - "decoder.final_layer_norm", "gpt_neox.final_layer_norm", ) return findattr(model, norm_attrs) -def hf_get_decoder_blocks(model: nn.Module) -> Tuple[nn.Module]: - """Returns the decoder hidden layers of the specified model. +def hf_get_causal_hidden_layers(model: nn.Module) -> Tuple[nn.Module]: + """Returns the hidden layers of the specified model. NOTE: Different model configurations have different hidden layer attribute names. - transformer.h: (BloomForCausalLM, GPT2LMHeadModel, GPTJForCausalLM) - model.decoder.layers: (OPTForCausalLM) - gpt_neox.layers: (GPTNeoXForCausalLM) - - decoder.block: (T5ForConditionalGeneration) """ hidden_layers_attrs = ( "h", @@ -146,12 +144,11 @@ def hf_get_decoder_blocks(model: nn.Module) -> Tuple[nn.Module]: "transformer.h", "model.decoder.layers", "gpt_neox.layers", - "decoder.block", ) return findattr(model, hidden_layers_attrs) -def hf_get_lm_head(model: nn.Module) -> nn.Module: +def hf_get_lm_head(model: transformers.AutoModelForCausalLM) -> nn.Module: """Returns the language modeling (lm) head of the specified HuggingFace transformers model. NOTE: Different model configurations have different `lm_head` attribute names. From 840c9b2486718ff0d72397351412a9351e7e3d7d Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 10 Mar 2023 06:33:27 -0800 Subject: [PATCH 32/57] Fix device mismatch, update to accelerate>=0.17.0 (#360) --- configs/accelerate/ddp.yaml | 2 +- configs/accelerate/zero2-bf16.yaml | 2 +- configs/accelerate/zero2-fp16.yaml | 2 +- configs/accelerate/zero3.yaml | 2 +- examples/hh/ppo_hh.py | 4 +- examples/hh/to_triton.py | 4 +- examples/ppo_sentiments.py | 10 +- .../configs/default_accelerate_config.yaml | 2 +- .../trlx_gptj_text_summarization.py | 4 +- .../summarize_rlhf/trlx_inference_gptj.py | 4 +- trlx/ray_train/accelerate_trainer.py | 5 +- trlx/ray_train/launch.py | 185 +++++------------- 12 files changed, 70 insertions(+), 156 deletions(-) diff --git a/configs/accelerate/ddp.yaml b/configs/accelerate/ddp.yaml index 6147c7564..7ff51a747 100644 --- a/configs/accelerate/ddp.yaml +++ b/configs/accelerate/ddp.yaml @@ -2,7 +2,7 @@ compute_environment: LOCAL_MACHINE deepspeed_config: {} distributed_type: MULTI_GPU downcast_bf16: no -dynamo_backend: 'NO' +dynamo_config: {} fsdp_config: {} gpu_ids: all machine_rank: 0 diff --git a/configs/accelerate/zero2-bf16.yaml b/configs/accelerate/zero2-bf16.yaml index abfd1b0da..10902e702 100644 --- a/configs/accelerate/zero2-bf16.yaml +++ b/configs/accelerate/zero2-bf16.yaml @@ -9,7 +9,7 @@ deepspeed_config: zero_stage: 2 distributed_type: DEEPSPEED downcast_bf16: no -dynamo_backend: 'NO' +dynamo_config: {} fsdp_config: {} machine_rank: 0 main_training_function: main diff --git a/configs/accelerate/zero2-fp16.yaml b/configs/accelerate/zero2-fp16.yaml index cf38de3aa..8fbdcb45a 100644 --- a/configs/accelerate/zero2-fp16.yaml +++ b/configs/accelerate/zero2-fp16.yaml @@ -9,7 +9,7 @@ deepspeed_config: zero_stage: 2 distributed_type: DEEPSPEED downcast_bf16: no -dynamo_backend: 'NO' +dynamo_config: {} fsdp_config: {} machine_rank: 0 main_training_function: main diff --git a/configs/accelerate/zero3.yaml b/configs/accelerate/zero3.yaml index 47a267bec..9525aad12 100644 --- a/configs/accelerate/zero3.yaml +++ b/configs/accelerate/zero3.yaml @@ -10,7 +10,7 @@ deepspeed_config: zero_stage: 3 distributed_type: DEEPSPEED downcast_bf16: no -dynamo_backend: 'NO' +dynamo_config: {} fsdp_config: {} machine_rank: 0 main_training_function: main diff --git a/examples/hh/ppo_hh.py b/examples/hh/ppo_hh.py index c68f294ec..eeefecc07 100644 --- a/examples/hh/ppo_hh.py +++ b/examples/hh/ppo_hh.py @@ -77,7 +77,9 @@ def forward(self, input_ids): reward_model.load_state_dict(torch.load(checkpoint)) reward_model.eval() reward_model.requires_grad_(False) - device = torch.cuda.device_count() - 1 + device = os.environ.get("ACCELERATE_TORCH_DEVICE", None) + if device is None: + device = torch.cuda.device_count() - 1 reward_model = reward_model.half().to(device) def reward_fn(samples, prompts, outputs): diff --git a/examples/hh/to_triton.py b/examples/hh/to_triton.py index ad26009d3..c2a541517 100644 --- a/examples/hh/to_triton.py +++ b/examples/hh/to_triton.py @@ -24,7 +24,9 @@ args = parser.parse_args() model_name = args.checkpoint.split("/")[-1] -device = torch.device(args.device) +device = os.environ.get("ACCELERATE_TORCH_DEVICE", None) +if device is None: + device = torch.device(args.device) class RewardModel(nn.Module): diff --git a/examples/ppo_sentiments.py b/examples/ppo_sentiments.py index c17283598..e730ea41b 100644 --- a/examples/ppo_sentiments.py +++ b/examples/ppo_sentiments.py @@ -23,10 +23,12 @@ def main(hparams={}): default_config = hparams.pop("default_config") config = TRLConfig.update(default_config, hparams) - if torch.cuda.is_available(): - device = int(os.environ.get("LOCAL_RANK", 0)) - else: - device = -1 + device = os.environ.get("ACCELERATE_TORCH_DEVICE", None) + if device is None: + if torch.cuda.is_available(): + device = int(os.environ.get("LOCAL_RANK", 0)) + else: + device = -1 sentiment_fn = pipeline( "sentiment-analysis", diff --git a/examples/summarize_rlhf/configs/default_accelerate_config.yaml b/examples/summarize_rlhf/configs/default_accelerate_config.yaml index 956d6048e..e7345951d 100644 --- a/examples/summarize_rlhf/configs/default_accelerate_config.yaml +++ b/examples/summarize_rlhf/configs/default_accelerate_config.yaml @@ -6,7 +6,7 @@ deepspeed_config: zero3_init_flag: false distributed_type: DEEPSPEED downcast_bf16: 'no' -dynamo_backend: 'NO' +dynamo_config: {} fsdp_config: {} gpu_ids: null machine_rank: 0 diff --git a/examples/summarize_rlhf/trlx_gptj_text_summarization.py b/examples/summarize_rlhf/trlx_gptj_text_summarization.py index 3d9e3c5f3..6380a0a68 100755 --- a/examples/summarize_rlhf/trlx_gptj_text_summarization.py +++ b/examples/summarize_rlhf/trlx_gptj_text_summarization.py @@ -29,7 +29,9 @@ rw_model.load_state_dict(torch.load(REWARD_CHECKPOINT_PATH)) rw_model.half() rw_model.eval() - rw_device = torch.device("cuda:{}".format(1)) # set reward model device + rw_device = os.environ.get("ACCELERATE_TORCH_DEVICE", None) + if rw_device is None: + rw_device = torch.device("cuda:{}".format(1)) # set reward model device rw_model.to(rw_device) def get_scores(samples: List[str]): diff --git a/examples/summarize_rlhf/trlx_inference_gptj.py b/examples/summarize_rlhf/trlx_inference_gptj.py index f5a54365a..3032289e3 100644 --- a/examples/summarize_rlhf/trlx_inference_gptj.py +++ b/examples/summarize_rlhf/trlx_inference_gptj.py @@ -32,7 +32,9 @@ def load_model(path): rw_model.load_state_dict(torch.load(REWARD_CHECKPOINT_PATH)) rw_model.half() rw_model.eval() -rw_device = torch.device("cuda:{}".format(1)) +rw_device = os.environ.get("ACCELERATE_TORCH_DEVICE", None) +if rw_device is None: + rw_device = torch.device("cuda:{}".format(1)) rw_model.to(rw_device) diff --git a/trlx/ray_train/accelerate_trainer.py b/trlx/ray_train/accelerate_trainer.py index 0ce0cfb62..e1003a68f 100644 --- a/trlx/ray_train/accelerate_trainer.py +++ b/trlx/ray_train/accelerate_trainer.py @@ -16,7 +16,7 @@ from ray.tune.trainable import Trainable from accelerate.commands.config import default_config_file, load_config_from_file -from ray.train.torch import TorchTrainer +from ray.train.torch import TorchTrainer, get_device from .launch import launch_command, launch_command_parser @@ -137,6 +137,8 @@ def wrapped_train_loop_per_worker(*args, **kwargs): with open(temp_config_file, "w") as f: f.write(accelerate_config_raw) + os.environ["ACCELERATE_TORCH_DEVICE"] = str(get_device()) + # Set by TorchBackend master_addr = os.environ["MASTER_ADDR"] master_port = os.environ["MASTER_PORT"] @@ -169,7 +171,6 @@ def wrapped_train_loop_per_worker(*args, **kwargs): os.environ["LOCAL_RANK"] = str(session.get_local_rank()) os.environ["LOCAL_WORLD_SIZE"] = str(session.get_local_world_size()) os.environ["LOCAL_SIZE"] = str(session.get_local_world_size()) - os.environ["ACCELERATE_TORCH_DEVICE"] = f"cuda:{session.get_local_rank()}" return train_loop_per_worker(*args, **kwargs) diff --git a/trlx/ray_train/launch.py b/trlx/ray_train/launch.py index b782f551f..137673e70 100644 --- a/trlx/ray_train/launch.py +++ b/trlx/ray_train/launch.py @@ -16,119 +16,38 @@ import logging import os -import warnings -from unittest.mock import patch - -from accelerate.commands.config.config_utils import DYNAMO_BACKENDS -from accelerate.commands.launch import launch_command as original_launch_command -from accelerate.commands.launch import launch_command_parser -from accelerate.utils import ( - DynamoBackend, - PrecisionType, - is_deepspeed_available, - is_torch_version, + +try: + from packaging.version import Version +except ImportError: + from distutils.version import LooseVersion as Version + +import accelerate + +if Version(accelerate.__version__) < Version("0.17.0.dev0"): + raise RuntimeError(f"AccelerateTrainer requires accelerate>=0.17.0, got {accelerate.__version__}") + +from accelerate.commands.launch import ( + ComputeEnvironment, + _validate_launch_command, + launch_command_parser, + prepare_deepspeed_cmd_env, + prepare_multi_gpu_env, + prepare_simple_launcher_cmd_env, ) -from accelerate.utils.launch import env_var_path_add +from accelerate.utils import is_deepspeed_available logger = logging.getLogger(__name__) def simple_launcher(args): - current_env = {} - current_env["ACCELERATE_USE_CPU"] = str(args.cpu or args.use_cpu) - if args.use_mps_device: - warnings.warn( - '`use_mps_device` flag is deprecated and will be removed in version 0.15.0 of 🤗 Accelerate. Use "--mps" instead.', - FutureWarning, - ) - args.mps = True - current_env["ACCELERATE_USE_MPS_DEVICE"] = str(args.mps) - if args.mps: - current_env["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" - - try: - mixed_precision = PrecisionType(args.mixed_precision.lower()) - except ValueError: - raise ValueError( - f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}." - ) - - if args.fp16: - warnings.warn( - "`fp16` is deprecated and will be removed in version 0.15.0 of 🤗 Accelerate. Use `mixed_precision fp16` instead.", - FutureWarning, - ) - mixed_precision = "fp16" - - current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision) - - try: - dynamo_backend = DynamoBackend(args.dynamo_backend.upper()) - except ValueError: - raise ValueError(f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DYNAMO_BACKENDS}.") - current_env["ACCELERATE_DYNAMO_BACKEND"] = dynamo_backend.value - - current_env["OMP_NUM_THREADS"] = str(args.num_cpu_threads_per_process) + _, current_env = prepare_simple_launcher_cmd_env(args) os.environ.update(current_env) -def multi_gpu_launcher(args): # noqa: C901 - current_env = {} - mixed_precision = args.mixed_precision.lower() - try: - mixed_precision = PrecisionType(mixed_precision) - except ValueError: - raise ValueError(f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}.") - - if args.fp16: - warnings.warn( - "`fp16` is deprecated and will be removed in version 0.15.0 of 🤗 Accelerate. Use `mixed_precision fp16` instead.", - FutureWarning, - ) - mixed_precision = "fp16" - - current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision) - - try: - dynamo_backend = DynamoBackend(args.dynamo_backend.upper()) - except ValueError: - raise ValueError(f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DYNAMO_BACKENDS}.") - current_env["ACCELERATE_DYNAMO_BACKEND"] = dynamo_backend.value - - if args.use_fsdp: - current_env["ACCELERATE_USE_FSDP"] = "true" - current_env["FSDP_SHARDING_STRATEGY"] = str(args.fsdp_sharding_strategy) - current_env["FSDP_OFFLOAD_PARAMS"] = str(args.fsdp_offload_params).lower() - current_env["FSDP_MIN_NUM_PARAMS"] = str(args.fsdp_min_num_params) - if args.fsdp_auto_wrap_policy is not None: - current_env["FSDP_AUTO_WRAP_POLICY"] = str(args.fsdp_auto_wrap_policy) - if args.fsdp_transformer_layer_cls_to_wrap is not None: - current_env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = str(args.fsdp_transformer_layer_cls_to_wrap) - if args.fsdp_backward_prefetch_policy is not None: - current_env["FSDP_BACKWARD_PREFETCH"] = str(args.fsdp_backward_prefetch_policy) - if args.fsdp_state_dict_type is not None: - current_env["FSDP_STATE_DICT_TYPE"] = str(args.fsdp_state_dict_type) - - if args.use_megatron_lm: - prefix = "MEGATRON_LM_" - current_env["ACCELERATE_USE_MEGATRON_LM"] = "true" - current_env[prefix + "TP_DEGREE"] = str(args.megatron_lm_tp_degree) - current_env[prefix + "PP_DEGREE"] = str(args.megatron_lm_pp_degree) - current_env[prefix + "GRADIENT_CLIPPING"] = str(args.megatron_lm_gradient_clipping) - if args.megatron_lm_num_micro_batches is not None: - current_env[prefix + "NUM_MICRO_BATCHES"] = str(args.megatron_lm_num_micro_batches) - if args.megatron_lm_sequence_parallelism is not None: - current_env[prefix + "SEQUENCE_PARALLELISM"] = str(args.megatron_lm_sequence_parallelism) - if args.megatron_lm_recompute_activations is not None: - current_env[prefix + "RECOMPUTE_ACTIVATIONS"] = str(args.megatron_lm_recompute_activations) - if args.megatron_lm_use_distributed_optimizer is not None: - current_env[prefix + "USE_DISTRIBUTED_OPTIMIZER"] = str(args.megatron_lm_use_distributed_optimizer) - - current_env["OMP_NUM_THREADS"] = str(args.num_cpu_threads_per_process) - if is_torch_version("<", "1.9.0"): - raise NotImplementedError("Multi-node training requires pytorch>=1.9.0") - +def multi_gpu_launcher(args): + current_env = prepare_multi_gpu_env(args) os.environ.update(current_env) @@ -136,49 +55,33 @@ def deepspeed_launcher(args): if not is_deepspeed_available(): raise ImportError("DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source.") - current_env = {} - try: - mixed_precision = PrecisionType(args.mixed_precision.lower()) - except ValueError: - raise ValueError( - f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}." - ) - - current_env["PYTHONPATH"] = env_var_path_add("PYTHONPATH", os.path.abspath(".")) - current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision) - current_env["ACCELERATE_USE_DEEPSPEED"] = "true" - current_env["DEEPSPEED_ZERO_STAGE"] = str(args.zero_stage) - current_env["GRADIENT_ACCUMULATION_STEPS"] = str(args.gradient_accumulation_steps) - current_env["GRADIENT_CLIPPING"] = str(args.gradient_clipping).lower() - current_env["DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE"] = str(args.offload_optimizer_device).lower() - current_env["DEEPSPEED_OFFLOAD_PARAM_DEVICE"] = str(args.offload_param_device).lower() - current_env["DEEPSPEED_ZERO3_INIT"] = str(args.zero3_init_flag).lower() - current_env["DEEPSPEED_ZERO3_SAVE_16BIT_MODEL"] = str(args.zero3_save_16bit_model).lower() - if args.deepspeed_config_file is not None: - current_env["DEEPSPEED_CONFIG_FILE"] = str(args.deepspeed_config_file) - - with open(".deepspeed_env", "a") as f: - for key, value in current_env.items(): - if ";" in value or " " in value: - continue - f.write(f"{key}={value}\n") + _, current_env = prepare_deepspeed_cmd_env(args) os.environ.update(current_env) -def _raise_notimplementederror(*args, **kwargs): - raise NotImplementedError() - - def launch_command(args): - with patch("accelerate.commands.launch.deepspeed_launcher", deepspeed_launcher), patch( - "accelerate.commands.launch.multi_gpu_launcher", multi_gpu_launcher - ), patch("accelerate.commands.launch.simple_launcher", simple_launcher), patch( - "accelerate.commands.launch.tpu_launcher", _raise_notimplementederror - ), patch( - "accelerate.commands.launch.sagemaker_launcher", _raise_notimplementederror - ): - return original_launch_command(args) + args, defaults, mp_from_config_flag = _validate_launch_command(args) + + # Use the proper launcher + if args.use_deepspeed and not args.cpu: + args.deepspeed_fields_from_accelerate_config = list(defaults.deepspeed_config.keys()) if defaults else [] + if mp_from_config_flag: + args.deepspeed_fields_from_accelerate_config.append("mixed_precision") + args.deepspeed_fields_from_accelerate_config = ",".join(args.deepspeed_fields_from_accelerate_config) + deepspeed_launcher(args) + elif args.use_fsdp and not args.cpu: + multi_gpu_launcher(args) + elif args.use_megatron_lm and not args.cpu: + multi_gpu_launcher(args) + elif args.multi_gpu and not args.cpu: + multi_gpu_launcher(args) + elif args.tpu and not args.cpu: + raise NotImplementedError() + elif defaults is not None and defaults.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER: + raise NotImplementedError() + else: + simple_launcher(args) def main(): From 812569bd39fc4ad3ca67de4d0d8913f2920eddf4 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Mon, 13 Mar 2023 13:42:12 +0200 Subject: [PATCH 33/57] merge(trainers): unmerge old merges, merge anew recent changes --- trlx/data/configs.py | 33 +- trlx/trainer/__init__.py | 4 - trlx/trainer/accelerate_ppo_trainer.py | 54 +- trlx/trainer/accelerate_sft_trainer.py | 10 - trlx/trainer/nemo/README.md | 344 ------- trlx/trainer/nemo/__init__.py | 0 trlx/trainer/nemo/gpt.py | 786 ---------------- trlx/trainer/nemo_ilql_trainer.py | 4 +- trlx/trainer/nn/__init__.py | 0 trlx/trainer/nn/ilql_models.py | 333 ------- trlx/trainer/nn/ppo_models.py | 1171 ------------------------ trlx/utils/__init__.py | 2 +- trlx/utils/modeling.py | 19 +- 13 files changed, 70 insertions(+), 2690 deletions(-) delete mode 100644 trlx/trainer/nemo/README.md delete mode 100644 trlx/trainer/nemo/__init__.py delete mode 100644 trlx/trainer/nemo/gpt.py delete mode 100644 trlx/trainer/nn/__init__.py delete mode 100644 trlx/trainer/nn/ilql_models.py delete mode 100644 trlx/trainer/nn/ppo_models.py diff --git a/trlx/data/configs.py b/trlx/data/configs.py index f091344a3..4b76129a8 100644 --- a/trlx/data/configs.py +++ b/trlx/data/configs.py @@ -1,5 +1,6 @@ +from copy import deepcopy from dataclasses import dataclass, field -from typing import Any, Dict, Optional, Set, Tuple +from typing import Any, Dict, Optional, Set import yaml @@ -12,8 +13,6 @@ def merge(base: Dict, update: Dict, updated: Set) -> Dict: if k in update and isinstance(v, dict): base[k] = merge(v, update[k], updated) updated.add(k) - elif isinstance(v, dict): - base[k] = merge(v, update, updated) elif k in update: base[k] = update[k] updated.add(k) @@ -21,6 +20,20 @@ def merge(base: Dict, update: Dict, updated: Set) -> Dict: return base +def _merge_dicts(base: Dict, update: Dict) -> Dict: + "Merge two dictionaries recursively, returning a new dictionary." + + base = deepcopy(base) + + for k, v in update.items(): + if isinstance(v, dict): + base[k] = _merge_dicts(base.get(k, {}), v) + else: + base[k] = v + + return base + + @dataclass class ModelConfig: """ @@ -183,9 +196,6 @@ class TrainConfig: :param seed: Random seed :type seed: int - - :param git_tag: Git tag for logging (as returned by ``trlx.utils.get_git_tags()``) - :type git_tag: Optional[Tuple[str, str]] """ total_steps: int @@ -212,7 +222,6 @@ class TrainConfig: logging_dir: Optional[str] = None seed: int = 1000 - git_tag: Optional[Tuple[str, str]] = None @classmethod def from_dict(cls, config: Dict[str, Any]): @@ -259,6 +268,16 @@ def to_dict(self): return data + def evolve(self, **kwargs) -> "TRLConfig": + """ + Evolve TRLConfig with new parameters. Can update nested parameters. + >>> config = trlx.data.default_configs.default_ilql_config() + >>> config = config.evolve(method=dict(gamma=0.99, gen_kwargs=dict(max_new_tokens=100)) + >>> config.method.gamma + 0.99 + """ + return TRLConfig.from_dict(_merge_dicts(self.to_dict(), kwargs)) + @classmethod def from_dict(cls, config: Dict): """ diff --git a/trlx/trainer/__init__.py b/trlx/trainer/__init__.py index 97f6831dc..8e0d239df 100644 --- a/trlx/trainer/__init__.py +++ b/trlx/trainer/__init__.py @@ -1,11 +1,7 @@ -import os import sys from abc import abstractmethod from typing import Any, Callable, Dict, Iterable, Optional -import torch - -from trlx.data import RLElement from trlx.data.configs import TRLConfig from trlx.pipeline import BaseRolloutStore diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 7b4d931f0..7d1c34f45 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -2,10 +2,12 @@ import os import uuid from time import time -from typing import Callable, List, Optional +from typing import Callable, List +import ray import torch import torch.nn.functional as F +import transformers from torch.utils.data import DataLoader from transformers import AutoTokenizer @@ -13,16 +15,16 @@ from trlx.data.accelerate_base_datatypes import PromptBatch from trlx.data.configs import TRLConfig from trlx.data.ppo_types import PPORLBatch, PPORLElement +from trlx.models.modeling_ppo import ( + AdaptiveKLController, + AutoModelForCausalLMWithHydraValueHead, + AutoModelForSeq2SeqLMWithHydraValueHead, + FixedKLController, +) from trlx.pipeline.offline_pipeline import PromptPipeline from trlx.pipeline.ppo_pipeline import PPORolloutStorage from trlx.trainer import register_trainer from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer -from trlx.trainer.nn.ppo_models import ( - AdaptiveKLController, - CausalLMHydraWithValueHead, - FixedKLController, - Seq2SeqLMHydraWithValueHead, -) from trlx.utils import Clock from trlx.utils.modeling import RunningMoments, logprobs_of_labels @@ -70,6 +72,7 @@ def __init__(self, config: TRLConfig, **kwargs): if not hasattr(self.model, "frozen_head"): self.ref_model = self.get_arch(self.config) self.ref_model.to(self.accelerator.device) + self.ref_model.eval() # Setup the KL controller # This helps prevent large divergences in the controller (policy) @@ -117,9 +120,19 @@ def __init__(self, config: TRLConfig, **kwargs): def get_arch(self, config: TRLConfig): """Get the model""" + model_class = AutoModelForCausalLMWithHydraValueHead if config.model.model_arch_type == "seq2seq": - return Seq2SeqLMHydraWithValueHead(config.model.model_path, config.model.num_layers_unfrozen) - return CausalLMHydraWithValueHead(config.model.model_path, config.model.num_layers_unfrozen) + model_class = AutoModelForSeq2SeqLMWithHydraValueHead + + from_fn = model_class.from_pretrained + # backward-compat: Try to create a randomly initialized architecture from a config + if issubclass(type(config.model.model_path), transformers.PretrainedConfig): + from_fn = model_class.from_config + + return from_fn( + config.model.model_path, + num_layers_unfrozen=config.model.num_layers_unfrozen, + ) def loss(self, batch: PPORLBatch): """Forward pass & loss @@ -230,15 +243,6 @@ def prepare_learning(self): self.total_steps = self.config.train.epochs * self.n_updates_per_batch * len(self.train_dataloader) self.total_steps = min(self.total_steps, self.config.train.total_steps) - def save_pretrained(self, directory: Optional[str] = None): - """NOTE: If a `directory` is not provided, the model will be saved to a sub-directory - of the Trainer config checkpoint dir named "hf_model" (e.g. `/ckpts/hf_model`). - """ - if directory is None: - directory = f"{self.config.train.checkpoint_dir}/hf_model" - self.accelerator.unwrap_model(self.model).base_model.save_pretrained(directory) - self.tokenizer.save_pretrained(directory) - def add_prompt_pipeline(self, pipeline: PromptPipeline): """Add a prompt pipeline dataloader to a trainer instance for the `make_experience` stage""" prompt_dataloader = pipeline.create_loader(self.config.method.chunk_size, shuffle=True) @@ -412,14 +416,14 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ref_logits = self.model.forward_hydra( all_tokens, attention_mask=attention_mask, - return_dict=False, - ) + return_dict=True, + ).logits else: - ref_logits, _, *_ = self.ref_model( + ref_logits = self.ref_model( all_tokens, attention_mask=attention_mask, - return_dict=False, - ) + return_dict=True, + ).logits ref_logits = ref_logits.to(device) if self.config.model.model_arch_type == "seq2seq": @@ -500,7 +504,9 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq stats["kl_ctl_value"] = self.kl_ctl.value stats["time/exp"] = exp_time - self.accelerator.log(stats, step=iter_count) + + if not ray.is_initialized(): + self.accelerator.log(stats, step=iter_count) # Push samples and rewards to trainer's rollout storage self.push_to_store(ppo_rl_elements) diff --git a/trlx/trainer/accelerate_sft_trainer.py b/trlx/trainer/accelerate_sft_trainer.py index 447913d9a..e061896e6 100644 --- a/trlx/trainer/accelerate_sft_trainer.py +++ b/trlx/trainer/accelerate_sft_trainer.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional from transformers import AutoModelForCausalLM @@ -56,12 +55,3 @@ def prepare_learning(self): self.n_updates_per_batch = 1 self.total_steps = self.config.train.epochs * len(train_dataloader) self.total_steps = min(self.total_steps, self.config.train.total_steps) - - def save_pretrained(self, directory: Optional[str] = None): - """NOTE: If a `directory` is not provided, the model will be saved to a sub-directory - of the Trainer config checkpoint dir named "hf_model" (e.g. `/ckpts/hf_model`). - """ - if directory is None: - directory = f"{self.config.train.checkpoint_dir}/hf_model" - self.accelerator.unwrap_model(self.model).base_model.save_pretrained(directory) - self.tokenizer.save_pretrained(directory) diff --git a/trlx/trainer/nemo/README.md b/trlx/trainer/nemo/README.md deleted file mode 100644 index c913c843a..000000000 --- a/trlx/trainer/nemo/README.md +++ /dev/null @@ -1,344 +0,0 @@ -## Using pretrained NeMo models -To use a NeMo models in `.nemo` format, like [NeMo Megatron-GPT-20B](https://huggingface.co/nvidia/nemo-megatron-gpt-20B), download and un-tar it: -``` -tar xvf nemo_gpt20B_bf16_tp4.nemo -``` -This will extract the model weights and the model config. - -Then set `train.trainer_kwargs.pretrained_model` to the path to the directory containing the parameters. The model hyperparameters in the `train.trainer_kwargs.megatron_cfg` should match the ones in the model config. - -## Inference ILQL trained NeMo models -To load a checkpoint, run -``` -python examples/nemo_ilql_inference.py configs/nemo_configs/megatron_20b.yaml "/path/to/ilql_sentiments_logs/checkpoints" -``` -To save checkpoints, ensure the following is set in the NeMo config: -``` -exp_manager: - explicit_log_dir: ilql_sentiments_logs - create_checkpoint_callback: True -``` - -## Resume Training -To resume training, ensure the following is set in the NeMo config: -``` -exp_manager: - resume_if_exists: True -``` - -## NeMo Megatron setup -Clone https://github.com/NVIDIA/NeMo/ and apex from https://github.com/NVIDIA/apex/. - -1) install conda (or mamba/micromamba) - -2) srun into a compute node with a gpu (if running on HPC cluster) -``` -srun --pty bash -i -``` - -3) copy the conda env export below and change the name and prefix -``` -conda env create -f env.yaml -``` - -4) install nemo -``` -git clone https://github.com/NVIDIA/NeMo/ -cd NeMo && pip install '.[all]' -``` - -6) install apex (or clone the github) -``` -git clone https://github.com/NVIDIA/apex/ -cd apex -pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" --global-option="--distributed_adam" --global-option="--deprecated_fused_adam" ./ -``` - - -# conda env export -``` -name: nemo-113 -prefix: /mnt/nvme/jobs/nemo/nemo-source -channels: - - anaconda - - conda-forge - - defaults -dependencies: - - _libgcc_mutex=0.1=conda_forge - - _openmp_mutex=4.5=2_gnu - - bzip2=1.0.8=h7f98852_4 - - c-ares=1.18.1=h7f8727e_0 - - ca-certificates=2022.9.24=ha878542_0 - - curl=7.84.0=h5eee18b_0 - - expat=2.4.4=h295c915_0 - - gettext=0.21.1=h27087fc_0 - - git=2.34.1=pl5262hc120c5b_0 - - krb5=1.19.2=hac12032_0 - - lame=3.100=h166bdaf_1003 - - ld_impl_linux-64=2.39=hcc3a1bd_1 - - libcurl=7.84.0=h91b91d3_0 - - libedit=3.1.20210910=h7f8727e_0 - - libev=4.33=h7f8727e_1 - - libffi=3.2.1=he1b5a44_1007 - - libflac=1.4.2=h27087fc_0 - - libgcc-ng=12.2.0=h65d4601_19 - - libgomp=12.2.0=h65d4601_19 - - libnghttp2=1.46.0=hce63b2e_0 - - libnsl=2.0.0=h7f98852_0 - - libogg=1.3.4=h7f98852_1 - - libopus=1.3.1=h7f98852_1 - - libsndfile=1.1.0=h27087fc_0 - - libsqlite=3.39.4=h753d276_0 - - libssh2=1.10.0=h8f2d780_0 - - libstdcxx-ng=12.2.0=h46fd767_19 - - libuuid=2.32.1=h7f98852_1000 - - libvorbis=1.3.7=h9c3ff4c_0 - - libzlib=1.2.12=h166bdaf_2 - - mpg123=1.30.2=h27087fc_1 - - ncurses=6.3=h27087fc_1 - - openssl=1.1.1q=h7f8727e_0 - - pcre2=10.37=he7ceb23_1 - - perl=5.26.2=h14c3975_0 - - pip=22.3.1=pyhd8ed1ab_0 - - python=3.8.2=he5300dc_7_cpython - - readline=8.1.2=h0f457ee_0 - - sqlite=3.39.4=h4ff8645_0 - - tk=8.6.12=h1ccaba5_0 - - wheel=0.38.4=pyhd8ed1ab_0 - - xz=5.2.6=h166bdaf_0 - - zlib=1.2.12=h7f8727e_2 - - pip: - - absl-py==1.3.0 - - aiohttp==3.8.3 - - aiosignal==1.3.1 - - alabaster==0.7.12 - - aniso8601==9.0.1 - - antlr4-python3-runtime==4.9.3 - - appdirs==1.4.4 - - asttokens==2.1.0 - - async-timeout==4.0.2 - - attrdict==2.0.1 - - attrs==22.1.0 - - audioread==3.0.0 - - babel==2.11.0 - - backcall==0.2.0 - - beautifulsoup4==4.11.1 - - black==19.10b0 - - boto3==1.26.13 - - botocore==1.29.13 - - braceexpand==0.1.7 - - cachetools==5.2.0 - - certifi==2022.9.24 - - cffi==1.15.1 - - charset-normalizer==2.1.1 - - click==8.0.2 - - colorama==0.4.6 - - commonmark==0.9.1 - - contourpy==1.0.6 - - cycler==0.11.0 - - cython==0.29.32 - - debugpy==1.6.3 - - decorator==5.1.1 - - distance==0.1.3 - - docker-pycreds==0.4.0 - - docopt==0.6.2 - - docutils==0.19 - - editdistance==0.6.1 - - einops==0.6.0 - - entrypoints==0.4 - - exceptiongroup==1.0.4 - - executing==1.2.0 - - faiss-cpu==1.7.3 - - fasttext==0.9.2 - - filelock==3.8.0 - - flask==2.2.2 - - flask-restful==0.3.9 - - fonttools==4.38.0 - - frozenlist==1.3.3 - - fsspec==2022.11.0 - - ftfy==6.1.1 - - g2p-en==2.1.0 - - gdown==4.5.3 - - gitdb==4.0.9 - - gitpython==3.1.29 - - google-auth==2.14.1 - - google-auth-oauthlib==0.4.6 - - grpcio==1.50.0 - - h5py==3.7.0 - - huggingface-hub==0.11.0 - - hydra-core==1.2.0 - - idna==3.4 - - ijson==3.1.4 - - imagesize==1.4.1 - - importlib-metadata==5.0.0 - - importlib-resources==5.10.0 - - inflect==6.0.2 - - iniconfig==1.1.1 - - ipadic==1.0.0 - - ipykernel==6.17.1 - - ipython==8.6.0 - - ipywidgets==8.0.2 - - isort==4.3.21 - - itsdangerous==2.1.2 - - jedi==0.18.1 - - jieba==0.42.1 - - jinja2==3.1.2 - - jiwer==2.5.1 - - jmespath==1.0.1 - - joblib==1.2.0 - - jupyter-client==7.4.7 - - jupyter-core==5.0.0 - - jupyterlab-widgets==3.0.3 - - kaldi-python-io==1.2.2 - - kaldiio==2.17.2 - - kiwisolver==1.4.4 - - latexcodec==2.0.1 - - levenshtein==0.20.2 - - librosa==0.9.2 - - llvmlite==0.39.1 - - loguru==0.6.0 - - lxml==4.9.1 - - markdown==3.4.1 - - markupsafe==2.1.1 - - marshmallow==3.19.0 - - matplotlib==3.6.2 - - matplotlib-inline==0.1.6 - - mecab-python3==1.0.5 - - mpmath==1.2.1 - - multidict==6.0.2 - - nest-asyncio==1.5.6 - - nltk==3.7 - - numba==0.56.4 - - numpy==1.23.4 - - nvidia-cublas-cu11==11.10.3.66 - - nvidia-cuda-nvrtc-cu11==11.7.99 - - nvidia-cuda-runtime-cu11==11.7.99 - - nvidia-cudnn-cu11==8.5.0.96 - - oauthlib==3.2.2 - - omegaconf==2.2.3 - - onnx==1.12.0 - - opencc==1.1.4 - - packaging==21.3 - - pandas==1.5.1 - - pangu==4.0.6.1 - - parameterized==0.8.1 - - parso==0.8.3 - - pathspec==0.10.2 - - pathtools==0.1.2 - - pesq==0.0.4 - - pexpect==4.8.0 - - pickleshare==0.7.5 - - pillow==9.3.0 - - pip-api==0.0.30 - - pipreqs==0.4.11 - - plac==1.3.5 - - platformdirs==2.5.4 - - pluggy==1.0.0 - - pooch==1.6.0 - - portalocker==2.6.0 - - progress==1.6 - - promise==2.3 - - prompt-toolkit==3.0.32 - - protobuf==3.20.1 - - psutil==5.9.4 - - ptyprocess==0.7.0 - - pure-eval==0.2.2 - - pyannote-core==4.5 - - pyannote-database==4.1.3 - - pyannote-metrics==3.2.1 - - pyasn1==0.4.8 - - pyasn1-modules==0.2.8 - - pybind11==2.10.1 - - pybtex==0.24.0 - - pybtex-docutils==1.0.2 - - pycparser==2.21 - - pydantic==1.10.2 - - pydeprecate==0.3.2 - - pydub==0.25.1 - - pygments==2.13.0 - - pynini==2.1.5 - - pyparsing==3.0.9 - - pypinyin==0.47.1 - - pysocks==1.7.1 - - pystoi==0.3.3 - - pytest==7.2.0 - - pytest-runner==6.0.0 - - python-dateutil==2.8.2 - - pytorch-lightning==1.7.7 - - pytz==2022.6 - - pyyaml==5.4.1 - - pyzmq==24.0.1 - - rapidfuzz==2.13.2 - - regex==2022.10.31 - - requests==2.28.1 - - requests-oauthlib==1.3.1 - - resampy==0.4.2 - - rich==12.6.0 - - rsa==4.9 - - ruamel-yaml==0.17.21 - - ruamel-yaml-clib==0.2.7 - - s3transfer==0.6.0 - - sacremoses==0.0.53 - - scikit-learn==1.1.3 - - scipy==1.9.3 - - sentence-transformers==2.2.2 - - sentencepiece==0.1.97 - - sentry-sdk==1.11.0 - - setproctitle==1.3.2 - - setuptools==59.5.0 - - shellingham==1.5.0 - - shortuuid==1.0.11 - - simplejson==3.18.0 - - six==1.16.0 - - smmap==5.0.0 - - snowballstemmer==2.2.0 - - sortedcontainers==2.4.0 - - soundfile==0.11.0 - - soupsieve==2.3.2.post1 - - sox==1.4.1 - - sphinx==5.3.0 - - sphinxcontrib-applehelp==1.0.2 - - sphinxcontrib-bibtex==2.5.0 - - sphinxcontrib-devhelp==1.0.2 - - sphinxcontrib-htmlhelp==2.0.0 - - sphinxcontrib-jsmath==1.0.1 - - sphinxcontrib-qthelp==1.0.3 - - sphinxcontrib-serializinghtml==1.1.5 - - stack-data==0.6.1 - - sympy==1.11.1 - - tabulate==0.9.0 - - tensorboard==2.11.0 - - tensorboard-data-server==0.6.1 - - tensorboard-plugin-wit==1.8.1 - - termcolor==2.1.0 - - text-unidecode==1.3 - - textdistance==4.5.0 - - texterrors==0.4.4 - - threadpoolctl==3.1.0 - - tokenizers==0.12.1 - - toml==0.10.2 - - tomli==2.0.1 - - torch==1.13.0 - - torchaudio==0.13.0 - - torchmetrics==0.10.3 - - torchvision==0.14.0 - - tornado==6.2 - - tqdm==4.64.1 - - traitlets==5.5.0 - - transformers==4.21.2 - - typed-ast==1.5.4 - - typer==0.7.0 - - typing-extensions==4.4.0 - - urllib3==1.26.12 - - wandb==0.13.5 - - wcwidth==0.2.5 - - webdataset==0.1.62 - - werkzeug==2.2.2 - - wget==3.2 - - widgetsnbextension==4.0.3 - - wrapt==1.14.1 - - yarg==0.1.9 - - yarl==1.8.1 - - youtokentome==1.0.6 - - zipp==3.10.0 -``` diff --git a/trlx/trainer/nemo/__init__.py b/trlx/trainer/nemo/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/trlx/trainer/nemo/gpt.py b/trlx/trainer/nemo/gpt.py deleted file mode 100644 index 89eb2554b..000000000 --- a/trlx/trainer/nemo/gpt.py +++ /dev/null @@ -1,786 +0,0 @@ -# Extensible version of the GPT model -import sys -from copy import deepcopy -from functools import partial, reduce -from math import sqrt -from pathlib import Path -from typing import List, Mapping, Optional, Tuple, Union - -import torch -import torch.distributed -import torch.nn as nn -import torch.nn.functional as F -from apex.transformer import parallel_state, tensor_parallel -from apex.transformer.tensor_parallel.mappings import ( - gather_from_sequence_parallel_region, -) -from einops import rearrange -from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import ( - MegatronPretrainingBatchSampler, -) -from nemo.collections.nlp.models.language_modeling.megatron.gpt_model import ( - post_language_model_processing, -) -from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import ( - MegatronGPTModel, -) -from nemo.collections.nlp.modules.common.megatron.module import ( - Float16Module, - MegatronModule, -) -from nemo.collections.nlp.modules.common.megatron.utils import ( - average_losses_across_data_parallel_group, - get_ltor_masks_and_position_ids, -) -from nemo.collections.nlp.modules.common.transformer.text_generation import ( - LengthParam, - OutputType, - SamplingParam, -) -from nemo.collections.nlp.parts.utils_funcs import get_last_rank - -from trlx.data.ilql_types import ILQLBatch, unflatten_dataclass -from trlx.trainer.nn.ilql_models import ILQLConfig, batched_index_select -from trlx.utils import to_device, tree_map - - -class ParallelLinear(nn.Module): - """Linear layer parallelized over the longer dimension.""" - - def __init__( - self, - in_size: int, - out_size: int, - init_method=partial(nn.init.kaiming_uniform_, a=sqrt(5), nonlinearity="relu"), - use_cpu_initialization=False, - bias=True, - sequence_parallel=False, - gradient_accumulation_fusion=False, - gather_output=True, - input_is_parallel=False, - ): - super().__init__() - - no_async_tensor_model_parallel_allreduce = ( - parallel_state.get_tensor_model_parallel_world_size() == 1 or sequence_parallel - ) - - if in_size < out_size: - self.layer = tensor_parallel.ColumnParallelLinear( - in_size, - out_size, - gather_output=gather_output, - init_method=init_method, - skip_bias_add=False, - use_cpu_initialization=use_cpu_initialization, - bias=bias, - sequence_parallel_enabled=sequence_parallel, - no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce, - gradient_accumulation_fusion=gradient_accumulation_fusion, - ) - else: - self.layer = tensor_parallel.RowParallelLinear( - in_size, - out_size, - input_is_parallel=input_is_parallel, - init_method=init_method, - skip_bias_add=False, - use_cpu_initialization=use_cpu_initialization, - bias=bias, - sequence_parallel_enabled=sequence_parallel, - gradient_accumulation_fusion=gradient_accumulation_fusion, - ) - - def forward(self, x): - output, bias = self.layer(x) - if bias is not None: - return output + bias - return output - - -def make_parallel_head(n_embd: int, out: int, sequence_parallel=False) -> nn.Sequential: - """Returns a generic sequential model parallel MLP head.""" - parallel_intermediate = out < (n_embd * 2) - return nn.Sequential( - ParallelLinear( - n_embd, - n_embd * 2, - sequence_parallel=sequence_parallel, - gather_output=not parallel_intermediate, - ), - nn.ReLU(), - ParallelLinear( - n_embd * 2, - out, - sequence_parallel=sequence_parallel, - input_is_parallel=parallel_intermediate, - ), - ) - - -class ParallelILQLHeads(nn.Module): - def __init__( - self, - config: ILQLConfig, - hidden_size: int, - vocab_size: int, - sequence_parallel=False, - ): - super().__init__() - self.hidden_size = hidden_size - self.vocab_size = vocab_size - self.v_head = make_parallel_head(hidden_size, 1, sequence_parallel=sequence_parallel) - self.config = config - - n_qs = 2 if self.config.two_qs else 1 - - self.q_heads = nn.ModuleList(make_parallel_head(self.hidden_size, self.vocab_size) for _ in range(n_qs)) - - self.target_q_heads = nn.ModuleList(deepcopy(q_head) for q_head in self.q_heads) - self.target_q_heads.requires_grad_(False) - - def forward(self, hidden_states): - qs = tuple(q_head(hidden_states) for q_head in self.q_heads) - target_qs = tuple(q_head(hidden_states) for q_head in self.target_q_heads) - vs = self.v_head(hidden_states) - - qs, target_qs, vs = tree_map(lambda t: rearrange(t, "T N ... -> N T ..."), (qs, target_qs, vs)) - - return qs, target_qs, vs - - def _sync_target_q_heads(self, alpha: float): - for target_q_head, q_head in zip(self.target_q_heads, self.q_heads): - for target_param, copy_param in zip(target_q_head.parameters(), q_head.parameters()): - target_param.data.copy_((alpha * copy_param.data) + (1.0 - alpha) * target_param.data) - - def sync_target_q_heads(self): - self._sync_target_q_heads(self.config.alpha) - - -class LMHeads(MegatronModule): - def __init__(self, language_model, other_heads): - super().__init__() - # must be this attribute name - self.pre_process = language_model.pre_process - self.post_process = language_model.post_process - self.language_model = language_model - - self.other_heads = other_heads - - if hasattr(language_model, "word_embeddings"): - self.word_embeddings = language_model.word_embeddings - - # The tensor from the previous pipeline rank arrives via this method - def set_input_tensor(self, input_tensor): - return self.language_model.set_input_tensor(input_tensor) - - def word_embeddings_weight(self): - return self.language_model.word_embeddings_weight() - - def load_state_dict(self, lm_state_dict, strict=True): - """Load GPTModel state dict.""" - self.language_model.language_model.load_state_dict(lm_state_dict, strict=strict) - - def forward( - self, - *args, - get_key_value=False, - forward_method_parallel_output=None, - **kwargs, - ): - lm_output = self.language_model(*args, get_key_value=get_key_value, **kwargs) - logits = post_language_model_processing( - lm_output, - labels=None, - logit_weights=self.language_model.word_embeddings_weight(), - get_key_value=get_key_value, - parallel_output=False, # self.language_model.parallel_output, - forward_method_parallel_output=forward_method_parallel_output, - fp16_lm_cross_entropy=self.language_model.fp16_lm_cross_entropy, - return_logits=True, - sequence_parallel=self.language_model.sequence_parallel, - gradient_accumulation_fusion=self.language_model.gradient_accumulation_fusion, - ) - - if get_key_value: - logits, presents = logits - lm_output, lm_output_presents = lm_output - - heads_output = self.other_heads(lm_output) - return logits, heads_output - - -def unwrap_float16_module(module): - if isinstance(module, Float16Module): - return module.module - return module - - -def reshard_for_pipeline_parallelism(num_layers, state_dict): - """Filter out the layers that are not in the current pipeline stage - and shift the layer ids to match the local stage layer ids.""" - pp_rank = parallel_state.get_pipeline_model_parallel_rank() - pp_size = parallel_state.get_pipeline_model_parallel_world_size() - - stage_layers = num_layers // pp_size - pp_offset = pp_rank * stage_layers - - encoder_layers_key = "model.language_model.encoder.layers." - - def filter_in_pp_rank(key): - if key.startswith(encoder_layers_key): - layer_idx = int(key.split(".")[4]) - return pp_offset <= layer_idx < (pp_offset + stage_layers) - elif key.startswith("model.language_model.encoder.final_layernorm") and not pp_rank == (pp_size - 1): - return False - else: - return True - - def shift_layer_idx(key): - """If the key is for a transformer layer, shift down the layer index to select the - correct layer for this pipeline stage.""" - if key.startswith(encoder_layers_key): - layer_idx = int(key.split(".")[4]) - return f"{encoder_layers_key}{str(layer_idx - pp_offset)}.{'.'.join(key.split('.')[5:])}" - else: - return key - - state_dict = {shift_layer_idx(k): v for k, v in state_dict.items() if filter_in_pp_rank(k)} - - return state_dict - - -class ILQLGPT(MegatronGPTModel): - ilql_config: ILQLConfig - - def __init__(self, ilql_config, metric_fn=None, **kwargs): - self.ilql_config = ilql_config - self.metric_fn = metric_fn - super().__init__(**kwargs) - if len(list(self.parameters())) == 0: - raise ValueError("No parameters in model") - - self._ori_activations_checkpoint_granularity = self.cfg.get("activations_checkpoint_granularity", None) - self._ori_activations_checkpoint_method = self.cfg.get("activations_checkpoint_method", None) - self._ori_activations_checkpoint_num_layers = self.cfg.get("activations_checkpoint_num_layers", None) - - @classmethod - def list_available_models(cls) -> Optional[Mapping[str, str]]: - return None - - def build_train_valid_test_datasets(self): - pass - - def build_data_loader(self, dataset, collate_fn, consumed_samples=0): - dp_rank = parallel_state.get_data_parallel_rank() - dp_size = parallel_state.get_data_parallel_world_size() - print( - f"Building data loader for {type(dataset)=} {len(dataset)=} {dp_rank=} {dp_size=}", - file=sys.stderr, - ) - batch_sampler = MegatronPretrainingBatchSampler( - total_samples=len(dataset), - consumed_samples=consumed_samples, - micro_batch_size=self.cfg.micro_batch_size, - global_batch_size=self.cfg.global_batch_size, - data_parallel_rank=dp_rank, - data_parallel_size=dp_size, - drop_last=True, - ) - return torch.utils.data.DataLoader( - dataset, - batch_sampler=batch_sampler, - # For some reason this causes a crash when using >0 workers - # with grad accumulation > 1 - num_workers=0, - pin_memory=True, - collate_fn=collate_fn, - ) - - def set_train_dataset(self, train_dataset, collate_fn): - self._train_dataset = train_dataset - self._train_collate_fn = collate_fn - - def set_valid_dataset(self, valid_dataset, collate_fn): - self._valid_dataset = valid_dataset - self._valid_collate_fn = collate_fn - - # Called by superclass to build data loaders - def setup_training_data(self, _): - if hasattr(self, "_train_dataset"): - self._train_dl = self.build_data_loader(self._train_dataset, self._train_collate_fn) - - def setup_validation_data(self, _): - if hasattr(self, "_valid_dataset"): - self._validation_dl = self.build_data_loader(self._valid_dataset, self._valid_collate_fn) - - def load_from_pretrained(self, checkpoint_dir): - mp_rank = parallel_state.get_tensor_model_parallel_rank() - rank_subfolder = f"mp_rank_{mp_rank:02d}" - rank_params = Path(checkpoint_dir) / rank_subfolder / "model_weights.ckpt" - print(f"Loading from {rank_params}") - state_dict = torch.load(rank_params) - - state_dict = reshard_for_pipeline_parallelism(self.cfg.num_layers, state_dict) - - def trim_key(key, prefix): - assert key.startswith(prefix), f"key {key} in state_dict does not start with {prefix}" - return key[len(prefix) :] - - lm_state_dict = {trim_key(k, "model.language_model."): v for k, v in state_dict.items()} - - encoder_state_dict = {trim_key(k, "encoder."): v for k, v in lm_state_dict.items() if k.startswith("encoder.")} - - lm_state_dict = {**lm_state_dict, "encoder": encoder_state_dict} - - unwrap_float16_module(self.model).load_state_dict(lm_state_dict, strict=True) - print(f"Loaded from pretrained {rank_params}") - - def model_provider_func(self, pre_process: bool, post_process: bool): - """ - Model construction for Apex Pipeline Parallelism. - Each rank will construct the model but inside the model, - only the relevant layers for that rank should be constructed. - On the first rank, pre_process will be True - On the last rank, post_process will be True - """ - gpt = super().model_provider_func(pre_process, post_process=post_process) - # This disables post-processing the lm output to the vocab - gpt.post_process = False - # This enables the final layernorm in the GPT model if there is one - gpt.language_model.post_process = post_process - # If running on the last pipeline stage, add the ILQL heads - if post_process: - parallel_ilql_heads = ParallelILQLHeads( - self.ilql_config, - self.cfg.hidden_size, - self.padded_vocab_size, - self.cfg.sequence_parallel, - ) - - return LMHeads( - gpt, - parallel_ilql_heads, - ) - else: - return gpt - - # Adapted from NeMo - # https://github.com/NVIDIA/NeMo/blob/r1.13.0/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L259 - def training_step(self, batch: ILQLBatch, batch_idx: int): # noqa: C901 - """ - Our dataloaders produce a micro-batch and then we fetch - a number of microbatches depending on the global batch size and model parallel size - from the dataloader to produce a list of microbatches. - Batch should be a list of microbatches and those microbatches should on CPU. - Microbatches are then moved to GPU during the pipeline. - The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. - """ - # we zero grads here because we also call backward in the apex fwd/bwd functions - self._optimizer.zero_grad() - - if parallel_state.is_pipeline_first_stage(ignore_virtual=True) or parallel_state.is_pipeline_last_stage( - ignore_virtual=True - ): - # we prepare the micro batches for the apex fwd/bwd function - batch_for_pipeline = batch - else: - # The intermediate pipeline stages do not need any inputs from data loader - # GPT3 uses decoder with AttnMask:causal, thus doesn't need attention_mask - batch_for_pipeline = None - - # Pipeline stages will transfer this shape tensor to and from the - # previous and next stages - # The model must output a tensor of this shape if not the last pipeline - # stage. The model is given input of this shape if not the first pipeline - # stage via .set_input_tensor - tensor_shape = [ - self.cfg.encoder_seq_length, - self.cfg.micro_batch_size, - self.cfg.hidden_size, - ] - - # handle asynchronous grad reduction - if self.with_distributed_adam: - if self.megatron_amp_o2: - # copy grads to main grad - def custom_sync_context_handler(): - return self._optimizer.no_sync(greedy_grad_copy=True) - - else: - # keep grad tensors around - def custom_sync_context_handler(): - return self._optimizer.no_sync(greedy_grad_copy=False) - - else: - if self.megatron_amp_o2 and not self.cfg.get("sequence_parallel", False): - custom_sync_context_handler = self._optimizer.no_sync - else: - # TODO: enable async grad all reduce for O1/autocast mixed precision training - custom_sync_context_handler = None - - # run forward and backwards passes for an entire global batch - # we do this inside training_step to support pipeline parallelism - # This gets the correct fwd/bwd pipeline step depending on the pipeline - # parallelism configuration - fwd_bwd_function = self._get_fwd_bwd_function() - - last_stage_output = fwd_bwd_function( - forward_step_func=self.get_forward_output_and_loss_func(), - batch=batch_for_pipeline, - model=self.model, - forward_only=False, - tensor_shape=tensor_shape, - dtype=self.autocast_dtype, - grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, - custom_sync_context_handler=custom_sync_context_handler, - sequence_parallel_enabled=self.cfg.get("sequence_parallel", False), - sync_batch_comm=self.cfg.get("sync_batch_comm", False), - num_micro_batches_with_partial_activation_checkpoints=self.cfg.get( - "num_micro_batches_with_partial_activation_checkpoints", None - ), - ) - - # only the last stages of the pipeline return losses - if last_stage_output: - # average loss across micro batches - outputs = {k: [output[k] for output in last_stage_output] for k in last_stage_output[0].keys()} - outputs = {k: torch.concat([torch.as_tensor(vi).unsqueeze(0) for vi in v]) for k, v in outputs.items()} - - mean_outputs = {k: v.mean() for k, v in outputs.items()} - loss_mean = mean_outputs["avg_loss"] - else: - mean_outputs = {} - loss_mean = torch.tensor(0.0).cuda() - - # when using sequence parallelism, the sequence parallel layernorm grads must be all-reduced - if self.cfg.get("tensor_model_parallel_size", 1) > 1 and self.cfg.get("sequence_parallel", False): - self.allreduce_sequence_parallel_gradients() - if self.with_distributed_adam: - # launch grad reductions - # Note: grads in first pipeline stage have already been - # reduced - if not parallel_state.is_pipeline_first_stage(): - self.reduce_overlap_gradients() - elif self.megatron_amp_o2: - # when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously) - if self.cfg.get("pipeline_model_parallel_size", 1) > 1 or self.cfg.get("sequence_parallel", False): - # main grads are stored in the MainParamsOptimizer wrapper - self._optimizer.allreduce_main_grads() - else: - # async grad allreduce is not currently implemented for O1/autocasting mixed precision training - # so we all-reduce gradients after the pipeline - self.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf) - - if self.cfg.get("pipeline_model_parallel_size", 1) > 1: - # when using pipeline parallelism the first and last stage must keep embeddings in sync - self.allreduce_first_last_embeddings() - - # we can only log on one rank if it is rank zero so we broadcast from last rank - # we can avoid this broadcast by updating the PTL log function to accept specific ranks - torch.distributed.broadcast(loss_mean, get_last_rank()) - - if self.cfg.precision == 16: - loss_scale = self.trainer.precision_plugin.scaler._scale - if loss_scale is not None: - self.log("loss_scale", loss_scale) - - self.log( - "reduced_train_loss", - loss_mean, - prog_bar=True, - rank_zero_only=True, - ) - - for k, v in mean_outputs.items(): - if k != "avg_loss": - self.log(k, v) - - self.log( - "global_step", - float(self.trainer.global_step), - prog_bar=True, - rank_zero_only=True, - ) - - if self.trainer.global_step % self.ilql_config.steps_for_target_q_sync == 0 and self.trainer.global_step > 0: - if parallel_state.is_pipeline_last_stage(): - unwrap_float16_module(self.model).other_heads.sync_target_q_heads() - - return loss_mean - - def activation_checkpointing_(self, enable: bool): - def toggle_checkpointing(module): - if hasattr(module, "activations_checkpoint_granularity"): - if enable: - module.activations_checkpoint_granularity = self._ori_activations_checkpoint_granularity - else: - module.activations_checkpoint_granularity = None - - if hasattr(module, "activations_checkpoint_method"): - if enable: - module.activations_checkpoint_method = self._ori_activations_checkpoint_method - else: - module.activations_checkpoint_method = None - - if hasattr(module, "activations_checkpoint_num_layers"): - if enable: - module.activations_checkpoint_num_layers = self._ori_activations_checkpoint_num_layers - else: - module.activations_checkpoint_num_layers = None - - self.model.apply(toggle_checkpointing) - - if enable: - self.cfg.activations_checkpoint_granularity = self._ori_activations_checkpoint_granularity - self.cfg.activations_checkpoint_method = self._ori_activations_checkpoint_method - self.cfg.activations_checkpoint_num_layers = self._ori_activations_checkpoint_num_layers - else: - self.cfg.activations_checkpoint_granularity = None - self.cfg.activations_checkpoint_method = None - self.cfg.activations_checkpoint_num_layers = None - - # TODO: replace this with less magical code - def sequence_parallel_(self, enabled: bool): - self.cfg.sequence_parallel = enabled - - def toggle_sp(m): - if hasattr(m, "sequence_parallel"): - m.sequence_parallel = enabled - - # for the Row/ColumnParallelLinear layers - if hasattr(m, "sequence_parallel_enabled"): - if hasattr(m, "input_is_parallel"): - m.sequence_parallel_enabled = enabled and m.input_is_parallel - elif hasattr(m, "gather_output"): - m.sequence_parallel_enabled = enabled and not m.gather_output - else: - m.sequence_parallel_enabled = enabled - - self.model.apply(toggle_sp) - - def validation_step(self, batch: Tuple[List[int], List[int]], batch_idx: int): - if self.metric_fn is None: - raise ValueError("Must set metric_fn to use validation") - - sp_was_enabled = self.cfg.get("sequence_parallel", False) - if sp_was_enabled: - self.sequence_parallel_(False) - - activations_checkpointing_was_enabled = self.cfg.get("activations_checkpoint_granularity", None) is not None - - if activations_checkpointing_was_enabled: - self.activation_checkpointing_(False) - - input_ids, lengths = batch - input_ids, lengths = torch.as_tensor(input_ids), torch.as_tensor(lengths) - - input_ids, lengths = to_device((input_ids, lengths), torch.cuda.current_device(), non_blocking=True) - - max_new_tokens = self.ilql_config.gen_kwargs.get("max_new_tokens", 64) - - gen = self.generate((input_ids, lengths), dict(max_length=max_new_tokens, min_length=0)) - - metrics = self.metric_fn(gen["sentences"]) - - metric_keys, metric_values = zip(*metrics.items()) - - columns = ["sentences", *metric_keys] - rows = list(zip(gen["sentences"], *metric_values)) - - avg_metrics = {f"avg_{k}": torch.as_tensor(v).mean() for k, v in metrics.items()} - - if activations_checkpointing_was_enabled: - self.activation_checkpointing_(True) - - if sp_was_enabled: - self.sequence_parallel_(True) - - # NeMo generate resets the microbatch calculator - from apex.transformer.pipeline_parallel.utils import ( - _reconfigure_microbatch_calculator, - ) - from nemo.utils import AppState - - _reconfigure_microbatch_calculator( - rank=AppState().global_rank, - rampup_batch_size=None, - global_batch_size=self.cfg.global_batch_size, - micro_batch_size=self.cfg.micro_batch_size, - data_parallel_size=AppState().data_parallel_size, - ) - - return avg_metrics, (rows, columns) - - def validation_epoch_end(self, outputs: List[Tuple[dict, Tuple[List[str], List[str]]]]): - metrics, tables = zip(*outputs) - _, columns = tables[0] - rows = [r for trows, _ in tables for r in trows] - - self.logger.log_text(key="samples", columns=columns, data=rows) - - outputs_soa = {k: torch.as_tensor([d[k] for d in metrics]) for k in metrics[0].keys()} - # this assumes all validation microbatches are the same size - avg_outputs = {k: v.mean() for k, v in outputs_soa.items()} - for k, v in avg_outputs.items(): - self.log( - f"val_metrics/{k}", - v, - prog_bar=True, - rank_zero_only=True, - sync_dist=True, - ) - - # Need to override this otherwise distributed fused adam won't work - # with frozen layers - def parameters(self): - return (p for p in self.model.parameters() if p.requires_grad) - - def get_forward_output_and_loss_func(self, validation_step=False): - def fwd_output_and_loss_func(batch: List[torch.Tensor], model, checkpoint_activations_all_layers=None): - # On first and last pipeline stages, the input data is passed in - if batch is not None: - batch = unflatten_dataclass(ILQLBatch)(batch) - batch = to_device(batch, torch.cuda.current_device(), non_blocking=True) - - inputs = batch.input_ids - pad_by = self.cfg.encoder_seq_length - inputs.shape[1] - inputs = torch.nn.functional.pad(inputs, (0, pad_by), value=self.tokenizer.eos_id) - - ( - attention_mask, - loss_mask, - position_ids, - ) = get_ltor_masks_and_position_ids( - data=inputs, - eod_token=self.tokenizer.eos_id, - reset_position_ids=False, - reset_attention_mask=False, - eod_mask_loss=False, - ) - - model_output = model( - input_ids=inputs, - position_ids=position_ids.long(), - attention_mask=attention_mask, - ) - else: - # In-between stages are given data via the pipeline engine - # Still need to specify thes arguments to avoid errors - model_output = model(input_ids=None, position_ids=None, attention_mask=None) - - def gather_ntc(t: torch.Tensor): - """Gather sequence parallel tensor [batch, seq, hidden]""" - t = rearrange(t, "N T ... -> T N ...") - t = gather_from_sequence_parallel_region(t, to_model_parallel=False) - t = rearrange(t, "T N ... -> N T ...") - return t - - def loss_func(model_output): - # # TODO: implement this in a sequence parallel way - logits, (qs, target_qs, vs) = model_output - - if self.cfg.sequence_parallel: - qs, target_qs, vs = tree_map(gather_ntc, (qs, target_qs, vs)) - - qs = tree_map( - lambda t: batched_index_select(t, batch.actions_ixs, 1), - qs, - ) - - target_qs = tree_map( - lambda t: batched_index_select(t, batch.actions_ixs, 1), - target_qs, - ) - - vs = batched_index_select(vs, batch.states_ixs, 1) - - model_output = (logits, (qs, target_qs, vs)) - loss_for_mb, stats = self.ilql_config.loss(model_output, batch) - - reduced_loss = average_losses_across_data_parallel_group([loss_for_mb]) - - # TODO: figure out why this sync is needed (crashes otherwise) - torch.cuda.synchronize() - - return loss_for_mb, {"avg_loss": reduced_loss, **stats} - - return model_output, loss_func - - return fwd_output_and_loss_func - - def get_forward_output_only_func( - self, - set_inference_key_value_memory=False, - inference_max_sequence_len=None, - checkpoint_activations_all_layers=None, - ): - def fwd_output_only_func( - batch: torch.Tensor, - model, - ): - if batch is not None: - batch = to_device(batch, torch.cuda.current_device(), non_blocking=True) - - extra_arg = {} - - if len(batch) == 3: - tokens, attention_mask, position_ids = batch - else: - ( - tokens, - attention_mask, - position_ids, - set_inference_key_value_memory, - inference_max_sequence_len, - ) = batch - - extra_arg["set_inference_key_value_memory"] = set_inference_key_value_memory[0].item() - extra_arg["inference_max_sequence_len"] = inference_max_sequence_len[0].item() - - model_output = model( - input_ids=tokens, - position_ids=position_ids.long(), - attention_mask=attention_mask, - **extra_arg, - ) - else: - model_output = model(input_ids=None, position_ids=None, attention_mask=None) - - def ilql_postprocess(model_output): - model_output = tree_map(lambda t: t.float(), model_output) - - logits, (_, target_qs, vs) = model_output - - target_q = reduce(torch.minimum, target_qs) - advantage = target_q - vs - pi_beta = F.log_softmax(logits, -1) - beta = self.ilql_config.gen_kwargs.get("beta", 1.0) - - logits = pi_beta + beta * advantage - - return logits, {"logits": logits} - - return model_output, ilql_postprocess - - return fwd_output_only_func - - def generate( - self, - inputs: Union[List[str], torch.Tensor, List[dict]], - length_params: LengthParam, - sampling_params: SamplingParam = None, - ) -> OutputType: - if sampling_params is None: - sampling_params = { - "use_greedy": False, - "temperature": self.ilql_config.gen_kwargs.get("temperature", 1.0), - "top_k": self.ilql_config.gen_kwargs.get("top_k", 0), - "top_p": 0.9, - "repetition_penalty": 1.2, - "add_BOS": False, - "all_probs": False, - "compute_logprob": False, - } - - return super().generate(inputs, length_params, sampling_params) diff --git a/trlx/trainer/nemo_ilql_trainer.py b/trlx/trainer/nemo_ilql_trainer.py index c58cc3249..a23a94f43 100644 --- a/trlx/trainer/nemo_ilql_trainer.py +++ b/trlx/trainer/nemo_ilql_trainer.py @@ -22,14 +22,14 @@ from trlx.data.configs import TRLConfig from trlx.data.ilql_types import ILQLBatch, ILQLElement, flatten_dataclass +from trlx.models.modeling_ilql import ILQLConfig +from trlx.models.modeling_nemo_ilql import ILQLGPT from trlx.pipeline.offline_pipeline import ( ILQLRolloutStorage, ilql_collate_fn, tokenize_dialogue, ) from trlx.trainer import register_trainer -from trlx.trainer.nemo.gpt import ILQLGPT -from trlx.trainer.nn.ilql_models import ILQLConfig from . import BaseRLTrainer diff --git a/trlx/trainer/nn/__init__.py b/trlx/trainer/nn/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/trlx/trainer/nn/ilql_models.py b/trlx/trainer/nn/ilql_models.py deleted file mode 100644 index aba7b4eff..000000000 --- a/trlx/trainer/nn/ilql_models.py +++ /dev/null @@ -1,333 +0,0 @@ -import inspect -import os -from copy import deepcopy -from dataclasses import dataclass -from functools import reduce -from itertools import chain -from typing import Any, Dict, Union - -import deepspeed # type: ignore -import numpy as np -import torch -import torch.nn.functional as F -import transformers -from torch import nn -from torchtyping import TensorType - -from trlx.data.ilql_types import ILQLBatch -from trlx.data.method_configs import MethodConfig, register_method -from trlx.utils.modeling import ( - flatten_dict, - freeze_bottom_causal_layers, - get_tensor_stats, - hf_get_causal_base_model, - hf_get_hidden_size, - hf_get_lm_head, - make_head, -) - - -def topk_mask(xs: torch.FloatTensor, k: int): - if k > xs.shape[-1]: - return xs - mintop = torch.topk(xs, k)[0][:, -1].unsqueeze(-1) - return torch.where(xs < mintop, -np.inf * torch.ones_like(xs, dtype=xs.dtype), xs) - - -def batched_index_select( - x: TensorType["batch", "seq_len", "hidden"], - idxs: TensorType["batch", "index_len"], - dim: int, -) -> TensorType["batch", "index_len", "hidden"]: - """ - Gather vectors at idxs along dim from x - """ - idxs = idxs.unsqueeze(-1).expand(idxs.shape[0], idxs.shape[1], x.shape[-1]) - return x.gather(dim=dim, index=idxs) - - -@dataclass -@register_method -class ILQLConfig(MethodConfig): - tau: float - gamma: float - cql_scale: float - awac_scale: float - alpha: float - beta: float - steps_for_target_q_sync: float - two_qs: bool - gen_kwargs: dict - - def heads(self, hidden_size: int, vocab_size: int, dtype: type): - return ILQLHeads(self, hidden_size, vocab_size, dtype) - - def loss(self, outputs, labels: ILQLBatch): - logits, (qs, target_qs, vs) = outputs - terminal_mask = labels.dones[:, :-1] - n_nonterminal = max(1, terminal_mask.sum()) - - actions = labels.input_ids[:, 1:].gather(dim=1, index=labels.actions_ixs).unsqueeze(-1) - nactions = actions.shape[1] - bsize, _, dsize = logits.shape - - Q = [q.gather(-1, actions).squeeze(-1) for q in qs] - targetQs = [q.gather(-1, actions).squeeze(-1).detach() for q in target_qs] - targetQ = reduce(torch.minimum, targetQs) - - # values of current states - V = vs[:, :-1].squeeze() - # values of next states - Vnext = vs[:, 1:].squeeze() * labels.dones[:, 1:] - # target to fit Q - Q_ = labels.rewards + self.gamma * Vnext.detach() - - loss_qs = [((Qi - Q_) * terminal_mask).pow(2).sum() / n_nonterminal for Qi in Q] - loss_q = sum(loss_qs) - - targetQ = targetQ.detach() - - loss_v = ( - ( - (targetQ >= V).int() * self.tau * (targetQ - V).pow(2) - + (targetQ < V).int() * (1 - self.tau) * (targetQ - V).pow(2) - ) - * terminal_mask - ).sum() / n_nonterminal - - def cql_loss(q): - loss = F.cross_entropy(q.reshape(-1, dsize), actions.reshape(-1), reduction="none") - loss = loss.reshape(bsize, nactions) * terminal_mask - loss = loss.sum() / n_nonterminal - return loss - - loss_cql = sum(cql_loss(q) for q in qs) - - # select logits from continuations - action_logits = batched_index_select(logits, labels.actions_ixs, dim=1) - cross_entropy = F.cross_entropy( - action_logits.reshape(-1, dsize), - actions.reshape(-1), - reduction="none", - ).reshape(bsize, nactions) - - with torch.no_grad(): - awac_weight = torch.exp(self.beta * (targetQ - V)) - - loss_awac = torch.sum(cross_entropy * awac_weight * terminal_mask) / n_nonterminal - loss = loss_q + loss_v + self.cql_scale * loss_cql + self.awac_scale * loss_awac - - stats = dict( - losses=dict( - loss=loss.item(), - loss_q=loss_q.item(), - loss_v=loss_v.item(), - loss_cql=loss_cql.item(), - loss_awac=loss_awac.item(), - ), - values=get_tensor_stats(V, terminal_mask, n_nonterminal), - qvalues={str(ix): get_tensor_stats(Q[ix], terminal_mask, n_nonterminal) for ix in range(len(Q))}, - awac_weight=get_tensor_stats(awac_weight, terminal_mask, n_nonterminal), - ) - - return loss, flatten_dict(stats) - - -class ILQLHeads(nn.Module): - def __init__(self, config: ILQLConfig, hidden_size: int, vocab_size: int, dtype: type): - super().__init__() - - self.hidden_size = hidden_size - self.vocab_size = vocab_size - self.v_head = make_head(self.hidden_size, 1, dtype) - self.config = config - - n_qs = 2 if self.config.two_qs else 1 - - self.q_heads = nn.ModuleList(make_head(self.hidden_size, self.vocab_size, dtype) for _ in range(n_qs)) - self.target_q_heads = nn.ModuleList(deepcopy(q_head) for q_head in self.q_heads) - - for target_q_head in self.target_q_heads: - target_q_head.requires_grad_(False) - - def forward( - self, - hs: torch.Tensor, - states_ixs: torch.Tensor = None, - actions_ixs: torch.Tensor = None, - **kwargs, - ): - if states_ixs is not None: - states_hs = batched_index_select(hs, states_ixs, 1) - actions_hs = batched_index_select(hs, actions_ixs, 1) - else: - states_hs = actions_hs = hs - - qs = tuple(q_head(actions_hs) for q_head in self.q_heads) - target_qs = tuple(q_head(actions_hs) for q_head in self.target_q_heads) - vs = self.v_head(states_hs) - - return qs, target_qs, vs - - def _sync_target_q_heads(self, alpha): - for target_q_head, q_head in zip(self.target_q_heads, self.q_heads): - for target_param, copy_param in zip(target_q_head.parameters(), q_head.parameters()): - target_param.data.copy_((alpha * copy_param.data) + (1.0 - alpha) * target_param.data) - - def sync_target_q_heads(self): - if os.environ.get("DEEPSPEED_ZERO_STAGE", "0") == "3": - params = chain( - chain(q_head.parameters() for q_head in self.q_heads), - chain(q_head.parameters() for q_head in self.target_q_heads), - ) - - with deepspeed.zero.GatheredParameters(list(params), modifier_rank=0): - if deepspeed.comm.get_rank() == 0: - self._sync_target_q_heads(self.config.alpha) - else: - self._sync_target_q_heads(self.config.alpha) - - -class CausalLMWithValueHeads(nn.Module): - """This is a wrapper around huggingface AutoModelForCausalLM with two additional scalar heads""" - - def __init__( - self, - config: Union[transformers.PretrainedConfig, str], - ilql_config: ILQLConfig, - num_layers_unfrozen=-1, - ): - super().__init__() - - # enable zero3 init within from_pretrained - if os.environ.get("DEEPSPEED_ZERO_STAGE", "0") == "3": - config_path = os.environ.get("DEEPSPEED_CONFIG_FILE", "") - if config_path: - _hfconfig = transformers.deepspeed.HfDeepSpeedConfig(config_path) # noqa: F841 - - if isinstance(config, str): - self.config = transformers.AutoConfig.from_pretrained(config) - self.base_model = transformers.AutoModelForCausalLM.from_pretrained(config) - else: - self.config = config - self.base_model = transformers.AutoModelForCausalLM.from_config(config) - - self.base_model.transformer = hf_get_causal_base_model(self.base_model) - self.base_model.lm_head = hf_get_lm_head(self.base_model) - freeze_bottom_causal_layers(self.base_model, num_layers_unfrozen) - - # Cache `transformer.forward` args for general use (avoids incompatible args across architectures) - self.base_model_transformer_args = inspect.getfullargspec(self.base_model.transformer.forward).args - - dtype = next(self.base_model.lm_head.parameters()).dtype - self.hidden_size = hf_get_hidden_size(self.config) - self.ilql_heads = ilql_config.heads(self.hidden_size, self.config.vocab_size, dtype) - self.ilql_config = ilql_config - - def _get_compatible_forward_kwargs(self, **kwargs) -> Dict[str, Any]: - """Filter out arguments not supported by the specific instance of `base_model.transformer.forward`""" - return {k: v for k, v in kwargs.items() if k in self.base_model_transformer_args} - - def sync_target_q_heads(self): - self.ilql_heads.sync_target_q_heads() - - def forward( - self, - input_ids, - attention_mask=None, - position_ids=None, - past_key_values=None, - actions_ixs=None, - states_ixs=None, - ): - forward_kwargs = self._get_compatible_forward_kwargs( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - ) - out = self.base_model.transformer(**forward_kwargs) - hs = out.last_hidden_state - - logits = self.base_model.lm_head(hs) - qs, target_qs, vs = self.ilql_heads(hs, states_ixs=states_ixs, actions_ixs=actions_ixs) - - return logits, qs, target_qs, vs, out.past_key_values - - def generate( - self, - input_ids, - attention_mask=None, - position_ids=None, - past_key_values=None, - beta=1, - max_new_tokens=32, - max_length=1024, - temperature=1, - top_k=20, - logit_mask=None, - pad_token_id=None, - eos_token_id=None, - ): - """ - Generates samples akin to hf's `.generate` but with custom logp prepossessing: - changing token probabilities as to how advantageous they would be - according to value functions estimations. - """ - if attention_mask is None: - attention_mask = input_ids.not_equal(pad_token_id) - - if position_ids is None: - position_ids = attention_mask.cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask.eq(0), 0) - - samples = input_ids.clone() - max_new_tokens = min(max_new_tokens, max_length - input_ids.shape[1]) - - finished = torch.zeros(input_ids.shape[0], 1, dtype=torch.long, device=input_ids.device) - for _ in range(max_new_tokens): - out = self.forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - ) - - logits, _, target_qs, vs, past_key_values = out - if self.ilql_config.two_qs: - qs = torch.minimum(target_qs[0][:, -1, :], target_qs[1][:, -1, :]) - else: - qs = target_qs[:, -1, :] - - logits = logits[:, -1, :] - vs = vs[:, -1, :] - - if logit_mask is not None: - mask = logit_mask[input_ids[:, -1].squeeze().to(logit_mask.device)] - logits[torch.where(mask)] = -np.inf - - adv = qs - vs - pi_beta = F.log_softmax(logits, -1) - pi_top_k = topk_mask(pi_beta + beta * adv, top_k) - pi = F.softmax(pi_top_k / temperature, -1) - - input_ids = torch.multinomial(pi, num_samples=1) - input_ids = (1 - finished) * input_ids + finished * eos_token_id - finished = (input_ids == eos_token_id).long() - - samples = torch.hstack((samples, input_ids)) - attention_mask = torch.hstack((attention_mask, (input_ids != eos_token_id).long())) - position_ids = (position_ids[:, -1] + 1).view(-1, 1) - - if torch.all(finished): - break - - return samples - - @property - def dummy_inputs(self): - return {"input_ids": torch.ones(1, 1, device=self.base_model.device, dtype=torch.long)} - - @property - def device(self): - return self.base_model.device diff --git a/trlx/trainer/nn/ppo_models.py b/trlx/trainer/nn/ppo_models.py deleted file mode 100644 index 6cbb64d25..000000000 --- a/trlx/trainer/nn/ppo_models.py +++ /dev/null @@ -1,1171 +0,0 @@ -import inspect -from copy import deepcopy -from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn as nn -import transformers -from torchtyping import TensorType -from transformers.modeling_outputs import ModelOutput -from transformers.models.bloom import modeling_bloom -from transformers.models.opt import modeling_opt - -from trlx.data.method_configs import MethodConfig, register_method -from trlx.utils.modeling import ( - flatten_dict, - get_tensor_stats, - hf_get_causal_base_model, - hf_get_causal_final_norm, - hf_get_causal_hidden_layers, - hf_get_hidden_size, - hf_get_lm_head, - hf_get_num_hidden_layers, - make_head, - whiten, -) - -# KL Controllers - - -class AdaptiveKLController: - """Adaptive KL Controller as described in Ziegler et al. "Fine-Tuning Language Models from Human Preferences" - Reference: Section 2.2 https://arxiv.org/pdf/1909.08593.pdf#page=2 - Source: https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_policy.py - """ - - def __init__(self, init_kl_coef: float, target: float, horizon: int): - self.value = init_kl_coef - self.target = target - self.horizon = horizon - - def update(self, current: float, n_steps: int): - """Returns adaptively updated KL coefficient, βₜ₊₁. - Arguments: - current: The current KL value between the newest policy and the initial policy. - """ - proportional_error = np.clip(current / self.target - 1, -0.2, 0.2) # ϵₜ - mult = 1 + proportional_error * n_steps / self.horizon - self.value *= mult # βₜ₊₁ - - -class FixedKLController: - """Fixed KL controller.""" - - def __init__(self, kl_coef): - self.value = kl_coef - - def update(self, current: float, n_steps: int): - """Returns updated KL coefficient, βₜ₊₁. - Arguments: - current: The current KL value between the newest policy and the initial policy. - """ - pass - - -# PPO Configs - - -@dataclass -@register_method -class PPOConfig(MethodConfig): - """ - Config for PPO method - - :param ppo_epochs: Number of updates per batch - :type ppo_epochs: int - - :param num_rollouts: Number of experiences to observe before learning - :type num_rollouts: int - - :param init_kl_coef: Initial value for KL coefficient - :type init_kl_coef: float - - :param target: Target value for KL coefficient - :type target: float - - :param horizon: Number of steps for KL coefficient to reach target - :type horizon: int - - :param gamma: Discount factor - :type gamma: float - - :param lam: GAE lambda - :type lam: float - - :param cliprange: Clipping range for PPO policy loss (1 - cliprange, 1 + cliprange) - :type cliprange: float - - :param cliprange_value: Clipping range for predicted values - (observed values - cliprange_value, observed values + cliprange_value) - :type cliprange_value: float - - :param vf_coef: Value loss scale w.r.t policy loss - :type vf_coef: float - - :param gen_kwargs: Additioanl kwargs for the generation - :type gen_kwargs: Dict[str, Any] - - :param gen_experience_kwargs: if this is not None, then the experience is generated using this - :type gen_experience_kwargs: Dict[str, Any] - """ - - ppo_epochs: int - num_rollouts: int - chunk_size: int - init_kl_coef: float - target: float - horizon: int - gamma: float - lam: float - cliprange: float - cliprange_value: float - vf_coef: float - scale_reward: str - ref_mean: Optional[float] - ref_std: Optional[float] - cliprange_reward: float - gen_kwargs: dict - gen_experience_kwargs: Optional[dict] = None - - def get_advantages_and_returns( - self, - values: TensorType["batch_size", "response_size"], - rewards: TensorType["batch_size", "response_size"], - response_length: int, - use_whitening: Optional[bool] = True, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Function that computes advantages and returns from rewards and values. - Calculated as in the original PPO paper: https://arxiv.org/abs/1707.06347 - Note that rewards may include a KL divergence loss term. - - Advantages looks like this: - Adv1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ... - - V1 + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ... - - Returns looks like this: - Ret1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ... - + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ... - - Input: - - values: Tensor of shape (batch_size, response_size) - - rewards: Tensor of shape (batch_size, response_size) - - response_length: Length of the response sequence - - use_whitening: Whether to use whitening (ie. normalize advantages) or not - """ - lastgaelam = 0 - advantages_reversed = [] - for t in reversed(range(response_length)): - nextvalues = values[:, t + 1] if t < response_length - 1 else 0.0 - delta = rewards[:, t] + self.gamma * nextvalues - values[:, t] - lastgaelam = delta + self.gamma * self.lam * lastgaelam - advantages_reversed.append(lastgaelam) - advantages = torch.stack(advantages_reversed[::-1], dim=1) - returns = advantages + values - if use_whitening: - advantages = whiten(advantages) - return advantages.detach(), returns - - def loss( - self, - logprobs: TensorType["batch_size", "response_size"], - values: TensorType["batch_size", "response_size"], - old_logprobs: TensorType["batch_size", "response_size"], - old_values: TensorType["batch_size", "response_size"], - advantages: TensorType["batch_size", "response_size"], - returns: TensorType["batch_size", "response_size"], - mask: TensorType["batch_size", "response_size"], - ): - """PPO objective function. - References: - - https://stable-baselines.readthedocs.io/en/master/modules/ppo2.html - """ - values_clipped = torch.clamp( - values, - old_values - self.cliprange_value, - old_values + self.cliprange_value, - ) - n = mask.sum() - - vf_loss1 = (values - returns) ** 2 - vf_loss2 = (values_clipped - returns) ** 2 - vf_loss = 0.5 * torch.sum(torch.max(vf_loss1, vf_loss2) * mask) / n - vf_clipfrac = torch.sum((vf_loss2 > vf_loss1).float() * mask) / n - - log_ratio = (logprobs - old_logprobs) * mask - ratio = torch.exp(log_ratio) - # Unbiased KL-div estimates (`k3`). Ref: http://joschu.net/blog/kl-approx.html - with torch.no_grad(): - approx_kl = torch.mean((ratio - 1) - log_ratio) - - pg_loss1 = -advantages * ratio - pg_loss2 = -advantages * torch.clamp( - ratio, - 1.0 - self.cliprange, - 1.0 + self.cliprange, - ) - pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / n - pg_clipfrac = torch.sum((pg_loss2 > pg_loss1).float() * mask) / n - - loss = pg_loss + self.vf_coef * vf_loss - - stats = dict( - losses=dict( - total_loss=loss.item(), - policy_loss=pg_loss.item(), - value_loss=vf_loss.item(), - ), - values=dict( - get_tensor_stats(values, mask, n), - values_error=torch.sum(((values - returns) * mask) ** 2) / n, - clipfrac=vf_clipfrac, - ), - old_values=get_tensor_stats(old_values, mask, n), - returns=get_tensor_stats(returns, mask, n), - policy=dict(approx_kl=approx_kl.item(), clipfrac=pg_clipfrac.item()), - ratio=(ratio * mask).sum() / n, - padding_percentage=n / mask.numel(), - ) - - return loss, flatten_dict(stats) - - -# PPO Layers - - -@dataclass -class CausalLMOutputWithCrossAttentions(ModelOutput): - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor]] = None - value: Optional[torch.FloatTensor] = None - - -class CausalLMWithValueHead(nn.Module): - """The CausalLMWithValueModel class implements a causal language model with - a secondary, scalar head. - """ - - def __init__(self, config: Union[transformers.PretrainedConfig, str]): - super().__init__() - if isinstance(config, str): - self.config = transformers.AutoConfig.from_pretrained(config) - self.base_model = transformers.AutoModelForCausalLM.from_pretrained(config) - else: - self.config = config - self.base_model = transformers.AutoModelForCausalLM.from_config(config) - - self.base_model.transformer = hf_get_causal_base_model(self.base_model) - self.base_model.lm_head = hf_get_lm_head(self.base_model) - dtype = next(self.base_model.lm_head.parameters()).dtype - self.v_head = make_head(hf_get_hidden_size(self.config), 1, dtype) - - # Cache `transformer.forward` args for general use (avoids incompatible args across architectures) - self.base_model_transformer_args = inspect.getfullargspec(self.base_model.transformer.forward).args - - def _get_compatible_forward_kwargs(self, **kwargs) -> Dict[str, Any]: - """Filter out arguments not supported by the specific instance of `base_model.transformer.forward`""" - return {k: v for k, v in kwargs.items() if k in self.base_model_transformer_args} - - def generate(self, input_ids, **kwargs): - return self.base_model.generate(input_ids, **kwargs) - - def forward( - self, - input_ids=None, - attention_mask=None, - past_key_values=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - return_dict=False, - ): - forward_kwargs = self._get_compatible_forward_kwargs( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - ) - transformer_outputs = self.base_model.transformer(**forward_kwargs) - last_hidden_state = transformer_outputs.last_hidden_state - lm_logits = self.base_model.lm_head(last_hidden_state) - value = self.v_head(last_hidden_state).squeeze(-1) - - if not return_dict: - outputs = (lm_logits,) + transformer_outputs[1:] + (value,) - return outputs - - return CausalLMOutputWithCrossAttentions( - loss=None, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - cross_attentions=transformer_outputs.cross_attentions, - value=value, - ) - - -class CausalLMHydraWithValueHead(nn.Module): - """The CausalLMHydraWithValueHead class implements a causal language model - with a secondary, scalar head. - """ - - def __init__( - self, - config: Union[transformers.PretrainedConfig, str], - num_layers_unfrozen: int = -1, - ): - super().__init__() - - if isinstance(config, str): - self.config = transformers.AutoConfig.from_pretrained(config) - self.base_model = transformers.AutoModelForCausalLM.from_pretrained(config) - else: - self.config = config - self.base_model = transformers.AutoModelForCausalLM.from_config(config) - - self.base_model.transformer = hf_get_causal_base_model(self.base_model) - self.base_model.lm_head = hf_get_lm_head(self.base_model) - dtype = next(self.base_model.lm_head.parameters()).dtype - self.v_head = make_head(hf_get_hidden_size(self.config), 1, dtype) - - self.num_layers_unfrozen = num_layers_unfrozen - if self.num_layers_unfrozen > 0: - transformer_blocks = list(hf_get_causal_hidden_layers(self.base_model)) - branch_class = hf_get_causal_lm_branch_class(self.config) - self.frozen_head = branch_class( - self.config, - transformer_blocks[-self.num_layers_unfrozen :], - final_norm=hf_get_causal_final_norm(self.base_model), - lm_head=self.base_model.lm_head, - ) - # Cache `transformer.forward` args for general use (avoids incompatible args across architectures) - self.base_model_transformer_args = inspect.getfullargspec(self.base_model.transformer.forward).args - - def _get_compatible_forward_kwargs(self, **kwargs) -> Dict[str, Any]: - """Filter out arguments not supported by the specific instance of `base_model.transformer.forward`""" - return {k: v for k, v in kwargs.items() if k in self.base_model_transformer_args} - - def generate(self, input_ids, **x): - return self.base_model.generate(input_ids, **x) - - def forward_hydra(self, input_ids, **forward_kwargs): - forward_kwargs = self._get_compatible_forward_kwargs(**forward_kwargs) - if forward_kwargs.get("return_dict") is not None: - return_dict = forward_kwargs["return_dict"] - else: - return_dict = True - forward_kwargs["return_dict"] = True - forward_kwargs["output_hidden_states"] = True - output = self.forward(input_ids, **forward_kwargs) - all_hidden_states = output.hidden_states - # Get output of last frozen hidden layer - # Select hidden state before first layer of branch. - input_hidden_state = all_hidden_states[-(self.num_layers_unfrozen + 1)] - # Get size of last hidden state - output_shape = all_hidden_states[-1].size() - outputs = self.frozen_head(input_hidden_state, output_shape, **forward_kwargs) - if not return_dict: - return outputs.logits - return outputs - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - output_attentions: Optional[bool] = False, - output_hidden_states: Optional[bool] = True, - return_dict: Optional[bool] = None, - ): - forward_kwargs = self._get_compatible_forward_kwargs( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - past_key_values=past_key_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - token_type_ids=token_type_ids, - ) - transformer_outputs = self.base_model.transformer(**forward_kwargs) - last_hidden_state = transformer_outputs.last_hidden_state - lm_logits = self.base_model.lm_head(last_hidden_state) - value = self.v_head(last_hidden_state).squeeze(-1) - - if not return_dict: - outputs = (lm_logits,) + transformer_outputs[1:] + (value,) - return outputs - - return CausalLMOutputWithCrossAttentions( - loss=None, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - cross_attentions=None, - value=value, - ) - - -@dataclass -class Seq2SeqLMOutput(ModelOutput): - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None - decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor]] = None - encoder_last_hidden_state: Optional[torch.FloatTensor] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None - value: Optional[torch.FloatTensor] = None - - -class Seq2SeqLMHydraWithValueHead(nn.Module): - def __init__( - self, - config: Union[transformers.PretrainedConfig, str], - num_layers_unfrozen: int = -1, - ): - super().__init__() - if isinstance(config, str): - self.config = transformers.AutoConfig.from_pretrained(config) - else: - self.config = config - self.base_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.config.name_or_path) - self.v_head = make_head(hf_get_hidden_size(self.config), 1) - - self.num_layers_unfrozen = num_layers_unfrozen - if self.num_layers_unfrozen > 0: - self.frozen_head = T5Branch(self.config, self.base_model, self.num_layers_unfrozen) - # Cache `transformer.forward` args for general use (avoids incompatible args across architectures) - self.base_model_args = inspect.getfullargspec(self.base_model.forward).args - - def _get_compatible_forward_kwargs(self, **kwargs) -> Dict[str, Any]: - """Filter out arguments not supported by the specific instance of `base_model.transformer.forward`""" - return {k: v for k, v in kwargs.items() if k in self.base_model_args} - - def generate(self, input_ids, **x): - return self.base_model.generate(input_ids, **x) - - def forward_hydra(self, input_ids, attention_mask, decoder_input_ids, **forward_kwargs): - forward_kwargs = self._get_compatible_forward_kwargs(**forward_kwargs) - forward_kwargs["return_dict"] = True - output = self.forward(input_ids, attention_mask, decoder_input_ids, **forward_kwargs) - all_hidden_states = output.decoder_hidden_states - # Get output of last frozen hidden layer - # Select hidden state before first layer of branch. - input_hidden_state = all_hidden_states[-(self.num_layers_unfrozen + 1)] - encoder_hidden_states = output.encoder_last_hidden_state - # Get size of last hidden state - outputs = self.frozen_head( - decoder_input_ids, - input_hidden_state, - encoder_hidden_states, - attention_mask, - False, - False, - ) - return outputs.logits - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.FloatTensor] = None, - encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - decoder_head_mask: Optional[torch.FloatTensor] = None, - cross_attn_head_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = True, - output_hidden_states: Optional[bool] = True, - return_dict: Optional[bool] = None, - ): - forward_kwargs = self._get_compatible_forward_kwargs( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - encoder_outputs=encoder_outputs, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - t5_outputs = self.base_model(**forward_kwargs) - lm_logits = t5_outputs.logits - last_hidden_state = t5_outputs.decoder_hidden_states[-1] - value = self.v_head(last_hidden_state).squeeze(-1) - - return Seq2SeqLMOutput( - loss=None, - logits=lm_logits, - decoder_hidden_states=t5_outputs.decoder_hidden_states, - decoder_attentions=t5_outputs.decoder_attentions, - cross_attentions=t5_outputs.cross_attentions, - encoder_last_hidden_state=t5_outputs.encoder_last_hidden_state, - encoder_hidden_states=t5_outputs.encoder_hidden_states, - encoder_attentions=t5_outputs.encoder_attentions, - past_key_values=t5_outputs.past_key_values, - value=value, - ) - - -class T5Branch(transformers.PreTrainedModel): - # Decoder branch only - def __init__( - self, - config: transformers.PretrainedConfig, - base_model: transformers.PreTrainedModel, - num_layers_unfrozen: int, - ): - super().__init__(config) - - # Defined by the main trunk - self.hidden_size = hf_get_hidden_size(config) - self.decoder = deepcopy(base_model.decoder) - self.decoder.block = nn.ModuleList(self.decoder.block[-num_layers_unfrozen:]) - self.lm_head = deepcopy(base_model.lm_head) - # Model parallel - self.model_parallel = False - self.device_map = None - self.last_device = None - self.gradient_checkpointing = False - - for parameter in self.parameters(): - parameter.requires_grad = False - - def forward( - self, - input_ids, - hidden_states, - encoder_hidden_states, - encoder_attention_mask, - use_cache: bool = False, - output_attentions: bool = False, - ): - input_shape = input_ids.size() - batch_size, seq_length = input_shape - - attention_mask = torch.ones(batch_size, seq_length, device=hidden_states.device) - - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) - position_bias = None - encoder_decoder_position_bias = None - - for i, layer_module in enumerate(self.decoder.block): - layer_outputs = layer_module( - hidden_states, # size: (batch_size, seq_length, hidden_size) - attention_mask=extended_attention_mask, # size: (batch_size, 1, seq_length, seq_length) - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - # layer_outputs is a tuple with: - # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), - # (cross-attention position bias), (cross-attention weights) - if use_cache is False: - layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - - hidden_states, present_key_value_state = layer_outputs[:2] - - # We share the position biases between the layers - the first layer store them - # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), - # (cross-attention position bias), (cross-attention weights) - position_bias = layer_outputs[2] - encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - - hidden_states = self.decoder.final_layer_norm(hidden_states) - hidden_states = self.decoder.dropout(hidden_states) - lm_logits = self.lm_head(hidden_states) - - return Seq2SeqLMOutput(logits=lm_logits) - - -class GPTModelBranch(transformers.PreTrainedModel): - """ - GPTModelBranch implements the frozen upper trunk of the reference model - used when computing the PPO KL-divergence penalty. Expects a list of - frozen transformer blocks and an lm_head from the base model. - """ - - def __init__( - self, - config: transformers.PretrainedConfig, - transformer_blocks: nn.ModuleList, - final_norm: nn.Module, - lm_head: nn.Module, - ): - super().__init__(config) - - # Defined by the main trunk - self.hidden_size = hf_get_hidden_size(config) - self.transformer_blocks = deepcopy(nn.ModuleList(transformer_blocks)) - self.final_norm = deepcopy(final_norm) - self.lm_head = deepcopy(lm_head) - - # Model parallel - self.model_parallel = False - self.device_map = None - self.gradient_checkpointing = False - - # Turning off grad saves memory - - for parameter in self.parameters(): - parameter.requires_grad_(False) - - def forward( # noqa: max-complexity - self, - hidden_states: torch.Tensor, # Takes as input hidden_states instead of input_ids - output_shape: torch.Tensor, # output_size given by main trunk - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = False, - position_ids: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: - batch_size = hidden_states.size()[0] - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - device = hidden_states.device - - if past_key_values is None: - past_key_values = tuple([None] * len(self.transformer_blocks)) - - # GPT2Attention mask. - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.add_cross_attention and encoder_hidden_states is not None: - ( - encoder_batch_size, - encoder_sequence_length, - _, - ) = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # head_mask has shape n_layer x batch x n_heads x N x N - head_mask = self.get_head_mask(head_mask, hf_get_num_hidden_layers(self.config)) - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - all_hidden_states = () if output_hidden_states else None - for i, (block, layer_past) in enumerate(zip(self.transformer_blocks, past_key_values)): - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure layer_past is on same device as hidden_states (might not be correct) - if layer_past is not None: - layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) - # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) - if isinstance(head_mask, torch.Tensor): - head_mask = head_mask.to(hidden_states.device) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - # Assumes we are never training the branch - block_params = inspect.getfullargspec(block.forward).args - if "encoder_hidden_states" in block_params: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) - - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) - - hidden_states = self.final_norm(hidden_states) - - hidden_states = hidden_states.view(output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - # last_hidden_state = hidden_states - # past_key_values = presents - # hidden_states = all_hidden_states - # attentions = all_self_attentions - # cross_attentions = all_cross_attentions - - # START OF CAUSAL HEAD # - # hidden_states = hidden_states.to(torch.float32) Present for gptj - - if self.model_parallel: - torch.cuda.set_device(self.transformer.first_device) - hidden_states = hidden_states.to(self.lm_head.weight.device) - - lm_logits = self.lm_head(hidden_states) - - if not return_dict: - outputs = (lm_logits,) + (None,) + (None,) - return outputs - - return CausalLMOutputWithCrossAttentions( - loss=None, - logits=lm_logits, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - value=None, - ) - - -class OPTModelBranch(transformers.PreTrainedModel): - """ - OPTModelBranch implements the frozen upper trunk of the reference model - used when computing the PPO KL-divergence penalty. Expects a list of - frozen transformer blocks and an lm_head from the base model. - """ - - def __init__( - self, - config: transformers.PretrainedConfig, - transformer_blocks: nn.ModuleList, - final_norm: nn.Module, - lm_head: nn.Module, - ): - super().__init__(config) - - # Defined by the main trunk - self.hidden_size = hf_get_hidden_size(config) - self.transformer_blocks = deepcopy(nn.ModuleList(transformer_blocks)) - self.final_norm = deepcopy(final_norm) - self.lm_head = deepcopy(lm_head) - - # Model parallel - self.model_parallel = False - self.device_map = None - self.gradient_checkpointing = False - - # Turning off grad saves memory - for parameter in self.parameters(): - parameter.requires_grad_(False) - - def forward( # noqa: max-complexity - self, - hidden_states: torch.Tensor, # Takes as input hidden_states instead of input_ids - output_shape: torch.Tensor, # output_size given by main trunk - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = False, - position_ids: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: - """Override OPTForCausalLM""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - ####################################################################### - # Modififed OPTDecoder.forward - ####################################################################### - - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - - if attention_mask is None: - attention_mask = torch.ones(hidden_states.shape[:2], dtype=torch.bool, device=hidden_states.device) - - input_shape = hidden_states.size()[:-1] - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = modeling_opt._make_causal_mask( - input_shape, - hidden_states.dtype, - past_key_values_length=past_key_values_length, - ).to(hidden_states.device) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = modeling_opt._expand_mask( - attention_mask, hidden_states.dtype, tgt_len=input_shape[-1] - ).to(hidden_states.device) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - attention_mask = combined_attention_mask - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - # check if head_mask has a correct number of layers specified if desired - for attn_mask, mask_name in zip([head_mask], ["head_mask"]): - if attn_mask is not None: - if attn_mask.size()[0] != (len(self.transformer_blocks)): - raise ValueError( - f"The `{mask_name}` should be specified for {len(self.transformer_blocks)} layers, but it is for" - f" {head_mask.size()[0]}." - ) - - for idx, decoder_layer in enumerate(self.transformer_blocks): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - layer_outputs = decoder_layer( - hidden_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if self.final_norm is not None: - hidden_states = self.final_norm(hidden_states) - - # TODO: Add output projection support - # https://github.com/huggingface/transformers/blob/699e90437f984d69ad3c9b891dd2e9d0fc2cffe4/src/transformers/models/opt/modeling_opt.py#L499 # noqa: E501 - # if self.project_out is not None: - # hidden_states = self.project_out(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - - ####################################################################### - # End of modified OPTDecoder.forward - ####################################################################### - - lm_logits = self.lm_head(hidden_states).contiguous() - - if not return_dict: - return tuple( - v - for v in [ - lm_logits, - hidden_states, - next_cache, - all_hidden_states, - all_self_attns, - ] - if v is not None - ) - - return CausalLMOutputWithCrossAttentions( - loss=None, - logits=lm_logits, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=None, - value=None, - ) - - -class BloomModelBranch(transformers.PreTrainedModel): - """ - BloomModelBranch implements the frozen upper trunk of the reference model - used when computing the PPO KL-divergence penalty. Expects a list of - frozen transformer blocks and an lm_head from the base model. - """ - - def __init__( - self, - config: transformers.PretrainedConfig, - transformer_blocks: nn.ModuleList, - final_norm: nn.Module, - lm_head: nn.Module, - ): - super().__init__(config) - - # Defined by the main trunk - self.hidden_size = hf_get_hidden_size(config) - self.transformer_blocks = deepcopy(nn.ModuleList(transformer_blocks)) - self.final_norm = deepcopy(final_norm) - self.lm_head = deepcopy(lm_head) - - # Model parallel - self.model_parallel = False - self.device_map = None - self.gradient_checkpointing = False - - # Turning off grad saves memory - for parameter in self.parameters(): - parameter.requires_grad_(False) - - def forward( # noqa: C901 - self, - hidden_states: torch.Tensor, # Takes as input hidden_states instead of input_ids - output_shape: torch.Tensor, # output_size given by main trunk - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = False, - position_ids: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - ####################################################################### - # Modififed BloomModel.forward - ####################################################################### - - batch_size, seq_length = hidden_states.shape[:2] - - if past_key_values is None: - past_key_values = tuple([None] * len(self.transformer_blocks)) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape batch_size x num_heads x N x N - # head_mask has shape n_layer x batch x num_heads x N x N - head_mask = self.get_head_mask(head_mask, hf_get_num_hidden_layers(self.config)) - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - # Compute alibi tensor: check modeling_bloom.build_alibi_tensor documentation - seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) - else: - attention_mask = attention_mask.to(hidden_states.device) - - alibi = modeling_bloom.build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype) - - # create causal mask - # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] - combined_attention_mask = None - device = attention_mask.device - input_shape = (batch_size, seq_length) - _, src_length = input_shape - - if src_length > 1: - combined_attention_mask = modeling_bloom._make_causal_mask( - input_shape, - device=device, - past_key_values_length=past_key_values_length, - ) - - # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] - expanded_attn_mask = modeling_bloom._expand_mask(attention_mask, tgt_length=src_length) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask - ) - causal_mask = combined_attention_mask - - for i, (block, layer_past) in enumerate(zip(self.transformer_blocks, past_key_values)): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=causal_mask, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - alibi=alibi, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - - # Add last hidden state - hidden_states = self.final_norm(hidden_states) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - ####################################################################### - # End of modified BloomModel.forward - ####################################################################### - - lm_logits = self.lm_head(hidden_states) - - if not return_dict: - return tuple( - v - for v in [ - lm_logits, - hidden_states, - presents, - all_hidden_states, - all_self_attentions, - ] - if v is not None - ) - - return CausalLMOutputWithCrossAttentions( - loss=None, - logits=lm_logits, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=None, - value=None, - ) - - -def hf_get_causal_lm_branch_class( - config: transformers.PretrainedConfig, -) -> "ModelBranch": - """Returns the CausalLM branch class for the given config.""" - gpt_branch_supported_archs = [ - "GPTJForCausalLM", - "GPT2LMHeadModel", - "GPTNeoForCausalLM", - "GPTNeoXForCausalLM", - ] - opt_branch_supported_archs = ["OPTForCausalLM"] - bloom_branch_supported_archs = ["BloomModel", "BloomForCausalLM"] - arch = config.architectures[0] - if arch in gpt_branch_supported_archs: - return GPTModelBranch - elif arch in opt_branch_supported_archs: - return OPTModelBranch - elif arch in bloom_branch_supported_archs: - return BloomModelBranch - else: - all_supported_archs = sum( - [ - gpt_branch_supported_archs, - opt_branch_supported_archs, - bloom_branch_supported_archs, - ], - [], - ) - raise ValueError( - f"Unsupported architecture: `{arch}`. The following architectures are " - f"available for model branching:\n{all_supported_archs}" - ) diff --git a/trlx/utils/__init__.py b/trlx/utils/__init__.py index 4f33d1502..803557724 100644 --- a/trlx/utils/__init__.py +++ b/trlx/utils/__init__.py @@ -29,7 +29,7 @@ def significant(x: Number, ndigits=2) -> Number: if isinstance(x, torch.Tensor): x = x.item() - if not isinstance(x, Number) or x == 0: + if not isinstance(x, Number) or math.isnan(x) or x == 0: return x return round(x, ndigits - int(math.floor(math.log10(abs(x))))) diff --git a/trlx/utils/modeling.py b/trlx/utils/modeling.py index 8aff33df0..d810dacc4 100644 --- a/trlx/utils/modeling.py +++ b/trlx/utils/modeling.py @@ -33,7 +33,7 @@ def make_head(n_embd: int, out: int, dtype: type = torch.float32) -> nn.Sequenti def freeze_bottom_causal_layers(model: nn.Module, num_layers_unfrozen: int = 0): """Freezes the bottom transformer block layers of the specified model.""" - hidden_layers = hf_get_causal_hidden_layers(model) + hidden_layers = hf_get_decoder_blocks(model) if num_layers_unfrozen == 0: hidden_layers_to_freeze = list(hidden_layers) elif num_layers_unfrozen > 0: @@ -102,7 +102,7 @@ def findattr(obj, attrs: Tuple[str]) -> Union[object, None]: raise ValueError(f"Could not find an attribute from `{attrs}` in `{obj}`") -def hf_get_causal_base_model(model: transformers.AutoModelForCausalLM) -> nn.Module: +def hf_get_decoder(model: nn.Module) -> nn.Module: """Returns the causal decoder backbone of the specified HuggingFace transformers model. NOTE: Different model configurations have different causal decoder attribute @@ -111,12 +111,12 @@ def hf_get_causal_base_model(model: transformers.AutoModelForCausalLM) -> nn.Mod - model.decoder: (OPTConfig, BloomConfig) - gpt_neox: (GPTNeoXConfig) """ - decoder_attrs = ("transformer", "model.decoder", "gpt_neox") + decoder_attrs = ("transformer", "model.decoder", "gpt_neox", "decoder") return findattr(model, decoder_attrs) -def hf_get_causal_final_norm(model: nn.Module) -> float: - """Returns the final (layer) norm of the specified model. +def hf_get_decoder_final_norm(model: nn.Module) -> float: + """Returns the final (layer) norm of the specified decoder. NOTE: Different model configurations have different final norm attribute names. - transformer.ln_f: (GPT2LMHeadModel, GPTJForCausalLM) - model.decoder.final_layer_norm: (OPTForCausalLM) @@ -125,17 +125,19 @@ def hf_get_causal_final_norm(model: nn.Module) -> float: norm_attrs = ( "transformer.ln_f", "model.decoder.final_layer_norm", + "decoder.final_layer_norm", "gpt_neox.final_layer_norm", ) return findattr(model, norm_attrs) -def hf_get_causal_hidden_layers(model: nn.Module) -> Tuple[nn.Module]: - """Returns the hidden layers of the specified model. +def hf_get_decoder_blocks(model: nn.Module) -> Tuple[nn.Module]: + """Returns the decoder hidden layers of the specified model. NOTE: Different model configurations have different hidden layer attribute names. - transformer.h: (BloomForCausalLM, GPT2LMHeadModel, GPTJForCausalLM) - model.decoder.layers: (OPTForCausalLM) - gpt_neox.layers: (GPTNeoXForCausalLM) + - decoder.block: (T5ForConditionalGeneration) """ hidden_layers_attrs = ( "h", @@ -144,11 +146,12 @@ def hf_get_causal_hidden_layers(model: nn.Module) -> Tuple[nn.Module]: "transformer.h", "model.decoder.layers", "gpt_neox.layers", + "decoder.block", ) return findattr(model, hidden_layers_attrs) -def hf_get_lm_head(model: transformers.AutoModelForCausalLM) -> nn.Module: +def hf_get_lm_head(model: nn.Module) -> nn.Module: """Returns the language modeling (lm) head of the specified HuggingFace transformers model. NOTE: Different model configurations have different `lm_head` attribute names. From 786709b7257218b1aa5246314fbab4ef0348c7aa Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Mon, 13 Mar 2023 17:31:32 +0200 Subject: [PATCH 34/57] merge(models): merge renaming of the directory --- trlx/data/default_configs.py | 119 +++ trlx/models/README.md | 344 +++++++++ trlx/models/__init__.py | 0 trlx/models/modeling_base.py | 223 ++++++ trlx/models/modeling_ilql.py | 488 +++++++++++++ trlx/models/modeling_nemo_ilql.py | 786 ++++++++++++++++++++ trlx/models/modeling_ppo.py | 1127 +++++++++++++++++++++++++++++ 7 files changed, 3087 insertions(+) create mode 100644 trlx/data/default_configs.py create mode 100644 trlx/models/README.md create mode 100644 trlx/models/__init__.py create mode 100644 trlx/models/modeling_base.py create mode 100644 trlx/models/modeling_ilql.py create mode 100644 trlx/models/modeling_nemo_ilql.py create mode 100644 trlx/models/modeling_ppo.py diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py new file mode 100644 index 000000000..1f9297db2 --- /dev/null +++ b/trlx/data/default_configs.py @@ -0,0 +1,119 @@ +from trlx.models.modeling_ilql import ILQLConfig +from trlx.models.modeling_ppo import PPOConfig +from trlx.trainer.accelerate_sft_trainer import SFTConfig + +from .configs import ( + ModelConfig, + OptimizerConfig, + SchedulerConfig, + TokenizerConfig, + TrainConfig, + TRLConfig, +) + + +def default_ppo_config(): + return TRLConfig( + train=TrainConfig( + seq_length=1024, + epochs=100, + total_steps=10000, + batch_size=32, + checkpoint_interval=10000, + eval_interval=100, + pipeline="PromptPipeline", + trainer="AcceleratePPOTrainer", + ), + model=ModelConfig(model_path="lvwerra/gpt2-imdb", num_layers_unfrozen=2), + tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), + optimizer=OptimizerConfig( + name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) + ), + scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1.0e-4)), + method=PPOConfig( + name="PPOConfig", + num_rollouts=128, + chunk_size=128, + ppo_epochs=4, + init_kl_coef=0.05, + target=6, + horizon=10000, + gamma=1, + lam=0.95, + cliprange=0.2, + cliprange_value=0.2, + vf_coef=1, + scale_reward="ignored", + ref_mean=None, + ref_std=None, + cliprange_reward=10, + gen_kwargs=dict( + max_new_tokens=40, + top_k=0, + top_p=1.0, + do_sample=True, + ), + ), + ) + + +def default_ilql_config(): + return TRLConfig( + train=TrainConfig( + seq_length=64, + batch_size=32, + epochs=100, + total_steps=1000, + checkpoint_interval=1000, + eval_interval=100, + pipeline="PromptPipeline", + trainer="AccelerateILQLTrainer", + ), + model=ModelConfig(model_path="gpt2", num_layers_unfrozen=-1), + tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), + optimizer=OptimizerConfig( + name="adamw", kwargs=dict(lr=5.0e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) + ), + scheduler=SchedulerConfig( + name="cosine_annealing", kwargs=dict(T_max=1000, eta_min=5.0e-5) # train.total_steps + ), + method=ILQLConfig( + name="ilqlconfig", + tau=0.7, + gamma=0.99, + cql_scale=0.1, + awac_scale=1, + alpha=0.001, + beta=0, + steps_for_target_q_sync=5, + two_qs=True, + gen_kwargs=dict(max_new_tokens=56, top_k=20, beta=4, temperature=1.0), + ), + ) + + +def default_sft_config(): + return TRLConfig( + train=TrainConfig( + seq_length=1024, + epochs=100, + total_steps=1000, + batch_size=8, + checkpoint_interval=10000, + eval_interval=100, + pipeline="PromptPipeline", + trainer="AccelerateSFTTrainer", + ), + model=ModelConfig(model_path="gpt2", num_layers_unfrozen=-1), + tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), + optimizer=OptimizerConfig( + name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) + ), + scheduler=SchedulerConfig( + name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1.0e-4) # train.total_steps + ), + method=SFTConfig( + name="sftconfig", + gen_kwargs=dict(max_new_tokens=40, top_k=0, top_p=1.0, do_sample=True), + ), + ) diff --git a/trlx/models/README.md b/trlx/models/README.md new file mode 100644 index 000000000..fdcb44597 --- /dev/null +++ b/trlx/models/README.md @@ -0,0 +1,344 @@ +## Using pretrained NeMo models +To use a NeMo models in `.nemo` format, like [NeMo Megatron-GPT-20B](https://huggingface.co/nvidia/nemo-megatron-gpt-20B), download and un-tar it: +``` +tar xvf nemo_gpt20B_bf16_tp4.nemo +``` +This will extract the model weights and the model config. + +Then set `train.trainer_kwargs.pretrained_model` to the path to the directory containing the parameters. The model hyperparameters in the `train.trainer_kwargs.megatron_cfg` should match the ones in the model config. + +## Inference ILQL trained NeMo models +To load a checkpoint, run +``` +python examples/nemo_ilql_inference.py configs/nemo_configs/megatron_20b.yaml "/path/to/ilql_sentiments_logs/checkpoints" +``` +To save checkpoints, ensure the following is set in the NeMo config: +``` +exp_manager: + explicit_log_dir: ilql_sentiments_logs + create_checkpoint_callback: True +``` + +## Resume Training +To resume training, ensure the following is set in the NeMo config: +``` +exp_manager: + resume_if_exists: True +``` + +## NeMo Megatron setup +Clone https://github.com/NVIDIA/NeMo/tree/r1.15.0 (currently only up to `r1.15.0` is supoprted) and apex from https://github.com/NVIDIA/apex/. + +1) install conda (or mamba/micromamba) + +2) srun into a compute node with a gpu (if running on HPC cluster) +``` +srun --pty bash -i +``` + +3) copy the conda env export below and change the name and prefix +``` +conda env create -f env.yaml +``` + +4) install nemo +``` +git clone https://github.com/NVIDIA/NeMo/ +cd NeMo && pip install '.[all]' +``` + +6) install apex (or clone the github) +``` +git clone https://github.com/NVIDIA/apex/ +cd apex +pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" --global-option="--distributed_adam" --global-option="--deprecated_fused_adam" ./ +``` + + +# conda env export +``` +name: nemo-113 +prefix: /mnt/nvme/jobs/nemo/nemo-source +channels: + - anaconda + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=2_gnu + - bzip2=1.0.8=h7f98852_4 + - c-ares=1.18.1=h7f8727e_0 + - ca-certificates=2022.9.24=ha878542_0 + - curl=7.84.0=h5eee18b_0 + - expat=2.4.4=h295c915_0 + - gettext=0.21.1=h27087fc_0 + - git=2.34.1=pl5262hc120c5b_0 + - krb5=1.19.2=hac12032_0 + - lame=3.100=h166bdaf_1003 + - ld_impl_linux-64=2.39=hcc3a1bd_1 + - libcurl=7.84.0=h91b91d3_0 + - libedit=3.1.20210910=h7f8727e_0 + - libev=4.33=h7f8727e_1 + - libffi=3.2.1=he1b5a44_1007 + - libflac=1.4.2=h27087fc_0 + - libgcc-ng=12.2.0=h65d4601_19 + - libgomp=12.2.0=h65d4601_19 + - libnghttp2=1.46.0=hce63b2e_0 + - libnsl=2.0.0=h7f98852_0 + - libogg=1.3.4=h7f98852_1 + - libopus=1.3.1=h7f98852_1 + - libsndfile=1.1.0=h27087fc_0 + - libsqlite=3.39.4=h753d276_0 + - libssh2=1.10.0=h8f2d780_0 + - libstdcxx-ng=12.2.0=h46fd767_19 + - libuuid=2.32.1=h7f98852_1000 + - libvorbis=1.3.7=h9c3ff4c_0 + - libzlib=1.2.12=h166bdaf_2 + - mpg123=1.30.2=h27087fc_1 + - ncurses=6.3=h27087fc_1 + - openssl=1.1.1q=h7f8727e_0 + - pcre2=10.37=he7ceb23_1 + - perl=5.26.2=h14c3975_0 + - pip=22.3.1=pyhd8ed1ab_0 + - python=3.8.2=he5300dc_7_cpython + - readline=8.1.2=h0f457ee_0 + - sqlite=3.39.4=h4ff8645_0 + - tk=8.6.12=h1ccaba5_0 + - wheel=0.38.4=pyhd8ed1ab_0 + - xz=5.2.6=h166bdaf_0 + - zlib=1.2.12=h7f8727e_2 + - pip: + - absl-py==1.3.0 + - aiohttp==3.8.3 + - aiosignal==1.3.1 + - alabaster==0.7.12 + - aniso8601==9.0.1 + - antlr4-python3-runtime==4.9.3 + - appdirs==1.4.4 + - asttokens==2.1.0 + - async-timeout==4.0.2 + - attrdict==2.0.1 + - attrs==22.1.0 + - audioread==3.0.0 + - babel==2.11.0 + - backcall==0.2.0 + - beautifulsoup4==4.11.1 + - black==19.10b0 + - boto3==1.26.13 + - botocore==1.29.13 + - braceexpand==0.1.7 + - cachetools==5.2.0 + - certifi==2022.9.24 + - cffi==1.15.1 + - charset-normalizer==2.1.1 + - click==8.0.2 + - colorama==0.4.6 + - commonmark==0.9.1 + - contourpy==1.0.6 + - cycler==0.11.0 + - cython==0.29.32 + - debugpy==1.6.3 + - decorator==5.1.1 + - distance==0.1.3 + - docker-pycreds==0.4.0 + - docopt==0.6.2 + - docutils==0.19 + - editdistance==0.6.1 + - einops==0.6.0 + - entrypoints==0.4 + - exceptiongroup==1.0.4 + - executing==1.2.0 + - faiss-cpu==1.7.3 + - fasttext==0.9.2 + - filelock==3.8.0 + - flask==2.2.2 + - flask-restful==0.3.9 + - fonttools==4.38.0 + - frozenlist==1.3.3 + - fsspec==2022.11.0 + - ftfy==6.1.1 + - g2p-en==2.1.0 + - gdown==4.5.3 + - gitdb==4.0.9 + - gitpython==3.1.29 + - google-auth==2.14.1 + - google-auth-oauthlib==0.4.6 + - grpcio==1.50.0 + - h5py==3.7.0 + - huggingface-hub==0.11.0 + - hydra-core==1.2.0 + - idna==3.4 + - ijson==3.1.4 + - imagesize==1.4.1 + - importlib-metadata==5.0.0 + - importlib-resources==5.10.0 + - inflect==6.0.2 + - iniconfig==1.1.1 + - ipadic==1.0.0 + - ipykernel==6.17.1 + - ipython==8.6.0 + - ipywidgets==8.0.2 + - isort==4.3.21 + - itsdangerous==2.1.2 + - jedi==0.18.1 + - jieba==0.42.1 + - jinja2==3.1.2 + - jiwer==2.5.1 + - jmespath==1.0.1 + - joblib==1.2.0 + - jupyter-client==7.4.7 + - jupyter-core==5.0.0 + - jupyterlab-widgets==3.0.3 + - kaldi-python-io==1.2.2 + - kaldiio==2.17.2 + - kiwisolver==1.4.4 + - latexcodec==2.0.1 + - levenshtein==0.20.2 + - librosa==0.9.2 + - llvmlite==0.39.1 + - loguru==0.6.0 + - lxml==4.9.1 + - markdown==3.4.1 + - markupsafe==2.1.1 + - marshmallow==3.19.0 + - matplotlib==3.6.2 + - matplotlib-inline==0.1.6 + - mecab-python3==1.0.5 + - mpmath==1.2.1 + - multidict==6.0.2 + - nest-asyncio==1.5.6 + - nltk==3.7 + - numba==0.56.4 + - numpy==1.23.4 + - nvidia-cublas-cu11==11.10.3.66 + - nvidia-cuda-nvrtc-cu11==11.7.99 + - nvidia-cuda-runtime-cu11==11.7.99 + - nvidia-cudnn-cu11==8.5.0.96 + - oauthlib==3.2.2 + - omegaconf==2.2.3 + - onnx==1.12.0 + - opencc==1.1.4 + - packaging==21.3 + - pandas==1.5.1 + - pangu==4.0.6.1 + - parameterized==0.8.1 + - parso==0.8.3 + - pathspec==0.10.2 + - pathtools==0.1.2 + - pesq==0.0.4 + - pexpect==4.8.0 + - pickleshare==0.7.5 + - pillow==9.3.0 + - pip-api==0.0.30 + - pipreqs==0.4.11 + - plac==1.3.5 + - platformdirs==2.5.4 + - pluggy==1.0.0 + - pooch==1.6.0 + - portalocker==2.6.0 + - progress==1.6 + - promise==2.3 + - prompt-toolkit==3.0.32 + - protobuf==3.20.1 + - psutil==5.9.4 + - ptyprocess==0.7.0 + - pure-eval==0.2.2 + - pyannote-core==4.5 + - pyannote-database==4.1.3 + - pyannote-metrics==3.2.1 + - pyasn1==0.4.8 + - pyasn1-modules==0.2.8 + - pybind11==2.10.1 + - pybtex==0.24.0 + - pybtex-docutils==1.0.2 + - pycparser==2.21 + - pydantic==1.10.2 + - pydeprecate==0.3.2 + - pydub==0.25.1 + - pygments==2.13.0 + - pynini==2.1.5 + - pyparsing==3.0.9 + - pypinyin==0.47.1 + - pysocks==1.7.1 + - pystoi==0.3.3 + - pytest==7.2.0 + - pytest-runner==6.0.0 + - python-dateutil==2.8.2 + - pytorch-lightning==1.7.7 + - pytz==2022.6 + - pyyaml==5.4.1 + - pyzmq==24.0.1 + - rapidfuzz==2.13.2 + - regex==2022.10.31 + - requests==2.28.1 + - requests-oauthlib==1.3.1 + - resampy==0.4.2 + - rich==12.6.0 + - rsa==4.9 + - ruamel-yaml==0.17.21 + - ruamel-yaml-clib==0.2.7 + - s3transfer==0.6.0 + - sacremoses==0.0.53 + - scikit-learn==1.1.3 + - scipy==1.9.3 + - sentence-transformers==2.2.2 + - sentencepiece==0.1.97 + - sentry-sdk==1.11.0 + - setproctitle==1.3.2 + - setuptools==59.5.0 + - shellingham==1.5.0 + - shortuuid==1.0.11 + - simplejson==3.18.0 + - six==1.16.0 + - smmap==5.0.0 + - snowballstemmer==2.2.0 + - sortedcontainers==2.4.0 + - soundfile==0.11.0 + - soupsieve==2.3.2.post1 + - sox==1.4.1 + - sphinx==5.3.0 + - sphinxcontrib-applehelp==1.0.2 + - sphinxcontrib-bibtex==2.5.0 + - sphinxcontrib-devhelp==1.0.2 + - sphinxcontrib-htmlhelp==2.0.0 + - sphinxcontrib-jsmath==1.0.1 + - sphinxcontrib-qthelp==1.0.3 + - sphinxcontrib-serializinghtml==1.1.5 + - stack-data==0.6.1 + - sympy==1.11.1 + - tabulate==0.9.0 + - tensorboard==2.11.0 + - tensorboard-data-server==0.6.1 + - tensorboard-plugin-wit==1.8.1 + - termcolor==2.1.0 + - text-unidecode==1.3 + - textdistance==4.5.0 + - texterrors==0.4.4 + - threadpoolctl==3.1.0 + - tokenizers==0.12.1 + - toml==0.10.2 + - tomli==2.0.1 + - torch==1.13.0 + - torchaudio==0.13.0 + - torchmetrics==0.10.3 + - torchvision==0.14.0 + - tornado==6.2 + - tqdm==4.64.1 + - traitlets==5.5.0 + - transformers==4.21.2 + - typed-ast==1.5.4 + - typer==0.7.0 + - typing-extensions==4.4.0 + - urllib3==1.26.12 + - wandb==0.13.5 + - wcwidth==0.2.5 + - webdataset==0.1.62 + - werkzeug==2.2.2 + - wget==3.2 + - widgetsnbextension==4.0.3 + - wrapt==1.14.1 + - yarg==0.1.9 + - yarl==1.8.1 + - youtokentome==1.0.6 + - zipp==3.10.0 +``` diff --git a/trlx/models/__init__.py b/trlx/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/trlx/models/modeling_base.py b/trlx/models/modeling_base.py new file mode 100644 index 000000000..6e6d10d1e --- /dev/null +++ b/trlx/models/modeling_base.py @@ -0,0 +1,223 @@ +# Copyright 2022 CarperAI & The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# NOTE: This file contains a modified version of the `PreTrainedModelWrapper` class from +# HuggingFace's `trl` library. The original source code can be found here: +# https://github.com/lvwerra/trl/blob/78c13226bf8ea1ccd9b1c091f03a938098521f6c/trl/models/modeling_base.py + +import inspect +import json +import os +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.nn as nn +import transformers +from huggingface_hub import hf_hub_download + + +class PreTrainedModelWrapper(nn.Module, transformers.utils.PushToHubMixin): + """A wrapper around `transformers.PreTrainedModel` + + Reference: @younesbelkada's `PreTrainedModelWrapper` + https://github.com/lvwerra/trl/blob/4f5c16fafde42d9aca971952bcdcc1f5a0a68cf0/trl/models/modeling_base.py#L2 + + Attributes: + _auto_model_parent_class (transformers.AutoModel): The `transformers.AutoModel` + type to base the wrapping behavior off of, e.g. `transformers.AutoModelForCausalLM`. + _supported_modules (List[str]): A list of attribute names for modules of + the underlying architecture model. This is used, for example, to save + and load any additional modules by manipulating the state dict. + _supported_args (List[str]): A list of arguments specific to the underlying + architecture to separate from arguments that are supported by the + parent `AutoModel` class. Any arguments that are not supported by the + underlying model will be passed to the parent `AutoModel` class. + """ + + _auto_model_parent_class: transformers.AutoModel = None + _supported_modules: List[str] = None + # TODO (jon-tow): Supported args should come from a `PretrainedConfig` of the + # specific underlying type similar to how config instances can be used to instantiate + # `transformers.PreTrainedModel`s. + _supported_args: List[str] = None + + def __init__(self, base_model: Optional[transformers.PreTrainedModel] = None, **kwargs): + super().__init__() + self.base_model = base_model + # cache `forward` args for general use (avoids incompatible args across architectures) + self.forward_kwargs = inspect.getfullargspec(self.base_model.forward).args + + @classmethod + def _split_kwargs(cls, kwargs: Dict[str, Any]): + """Separates the kwargs from the supported arguments within `supported_args` + and those that are not + """ + supported_kwargs = {} + unsupported_kwargs = {} + for key, value in kwargs.items(): + if key in cls._supported_args: + supported_kwargs[key] = value + else: + unsupported_kwargs[key] = value + return supported_kwargs, unsupported_kwargs + + @classmethod + def from_config(cls, config: transformers.PretrainedConfig, **kwargs): + """Instantiate the pretrained pytorch model from a configuration. + + Args: + config (transformers.PretrainedConfig): The configuration to use to + instantiate the base model. + + NOTE: Loading a model from its configuration file does **not** load the + model weights. It only affects the model's configuration. Use + `~transformers.AutoModel.from_pretrained` to load the model weights. + """ + if kwargs is not None: + wrapped_model_kwargs, from_config_kwargs = cls._split_kwargs(kwargs) + else: + from_config_kwargs = {} + wrapped_model_kwargs = {} + base_model = cls._auto_model_parent_class.from_config(config, **from_config_kwargs) + model = cls(base_model, **wrapped_model_kwargs) + return model + + @classmethod + def from_pretrained( # noqa: max-complexity + cls, + pretrained_model_name_or_path: Union[str, transformers.PreTrainedModel], + *model_args, + **kwargs, + ): + """Instantiate a pretrained pytorch model from a pretrained model configuration. + This method is a wrapper around `transformers.PreTrainedModel.from_pretrained`. + Please refer to the documentation of `transformers.PreTrainedModel.from_pretrained` + for more information. + + Args: + pretrained_model_name_or_path (str or `transformers.PreTrainedModel`): + The identifier of the pretrained model to load or the pretrained model itself. + *model_args (sequence of positional arguments, *optional*): + All remaining positional arguments will be passed to the `_auto_model_parent_class`. + **kwargs (dict, *optional*): + Dictionary of keyword arguments to pass to both the underlying `_auto_model_parent_class` + call (e.g. `transformers.AutoModelForCausalLM.from_pretrained`) and the specific + instance of the wrapped model. + + NOTE: You must pass in arguments specific to the wrapped model as keyword arguments. + """ + if kwargs is not None: + wrapped_model_kwargs, from_pretrained_kwargs = cls._split_kwargs(kwargs) + else: + from_pretrained_kwargs = {} + wrapped_model_kwargs = {} + + if isinstance(pretrained_model_name_or_path, str): + # Load the base model using the `transformers` AutoClass (e.g. AutoModelForCausalLM) + base_model = cls._auto_model_parent_class.from_pretrained( + pretrained_model_name_or_path, *model_args, **from_pretrained_kwargs + ) + elif isinstance(pretrained_model_name_or_path, transformers.PreTrainedModel): + base_model = pretrained_model_name_or_path + else: + raise ValueError( + f"Invalid type for `base_model_name_or_path`: {type(pretrained_model_name_or_path)}" + "Expected `str` or `transformers.PreTrainedModel`." + ) + + model = cls(base_model, **wrapped_model_kwargs) + + if isinstance(pretrained_model_name_or_path, str): + filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") + sharded_index_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json") + is_sharded = False + + if not os.path.exists(filename): + try: + filename = hf_hub_download(pretrained_model_name_or_path, "pytorch_model.bin") + # Sharded + except Exception: + if os.path.exists(sharded_index_filename): + index_file_name = sharded_index_filename + else: + index_file_name = hf_hub_download( + pretrained_model_name_or_path, + "pytorch_model.bin.index.json", + ) + with open(index_file_name, "r") as f: + index = json.load(f) + # Collect files containing weights from supported modules + files_to_download = set() + for k, v in index["weight_map"].items(): + if any([module in k for module in cls._supported_modules]): + files_to_download.add(v) + is_sharded = True + + if is_sharded: + # Merge each shard into a state dict + # TODO: Optimize this to avoid wasting RAM + state_dict = {} + for shard_file in files_to_download: + filename = os.path.join(pretrained_model_name_or_path, shard_file) + # Download if shard file doesn't exist locally + if not os.path.exists(filename): + filename = hf_hub_download(pretrained_model_name_or_path, shard_file) + state_dict.update(torch.load(filename, map_location="cpu")) + else: + state_dict = torch.load(filename, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path.state_dict() + + model.post_init(state_dict=state_dict) + return model + + def save_pretrained(self, *args, **kwargs): + """Save the pretrained model to a directory. This method is a wrapper + around `transformers.PreTrainedModel.save_pretrained`. Please refer to + the documentation of `transformers.PreTrainedModel.save_pretrained` for + more information. + + Args: + *args (`list`, *optional*): + Positional arguments passed along to the underlying model's + `save_pretrained` method. + **kwargs (`dict`, *optional*): + Keyword arguments passed along to the underlying model's + `save_pretrained` method. + """ + state_dict = kwargs.pop("state_dict", None) + if state_dict is None: + state_dict = self.state_dict() + kwargs["state_dict"] = state_dict + + return self.base_model.save_pretrained(*args, **kwargs) + + def state_dict(self, *args, **kwargs): + """Return the state_dict of the pretrained model.""" + raise NotImplementedError + + def post_init(self, *args, **kwargs): + """Post initialization method. This method is called after the model is + instantiated and loaded from a checkpoint. It can be used to perform + additional operations such as loading the state_dict. + """ + raise NotImplementedError + + def get_compatible_forward_kwargs(self, **kwargs) -> Dict[str, Any]: + """Filter out arguments not supported by the specific instance of + `base_model.transformer.forward` + """ + # FIXME: This is a hack to get around the fact that the `transformers` + # architectures we use don't have a consistent API for `forward` parameters. + return {k: v for k, v in kwargs.items() if k in self.forward_kwargs} diff --git a/trlx/models/modeling_ilql.py b/trlx/models/modeling_ilql.py new file mode 100644 index 000000000..d9e614a7f --- /dev/null +++ b/trlx/models/modeling_ilql.py @@ -0,0 +1,488 @@ +import gc +import os +from copy import deepcopy +from dataclasses import dataclass +from functools import reduce +from itertools import chain + +import deepspeed # type: ignore +import numpy as np +import torch +import torch.nn.functional as F +import transformers +from torch import nn +from torchtyping import TensorType + +from trlx.data.ilql_types import ILQLBatch +from trlx.data.method_configs import MethodConfig, register_method +from trlx.models.modeling_base import PreTrainedModelWrapper +from trlx.utils.modeling import ( + flatten_dict, + get_tensor_stats, + hf_get_hidden_size, + hf_get_lm_head, + make_head, +) + + +def topk_mask(xs: torch.FloatTensor, k: int): + if k > xs.shape[-1]: + return xs + mintop = torch.topk(xs, k)[0][:, -1].unsqueeze(-1) + return torch.where(xs < mintop, -np.inf * torch.ones_like(xs, dtype=xs.dtype), xs) + + +def batched_index_select( + x: TensorType["batch", "seq_len", "hidden"], + idxs: TensorType["batch", "index_len"], + dim: int, +) -> TensorType["batch", "index_len", "hidden"]: + """ + Gather vectors at idxs along dim from x + """ + idxs = idxs.unsqueeze(-1).expand(idxs.shape[0], idxs.shape[1], x.shape[-1]) + return x.gather(dim=dim, index=idxs) + + +@dataclass +@register_method +class ILQLConfig(MethodConfig): + tau: float + gamma: float + cql_scale: float + awac_scale: float + alpha: float + beta: float + steps_for_target_q_sync: float + two_qs: bool + gen_kwargs: dict + + def loss(self, outputs, labels): + logits, (qs, target_qs, vs) = outputs + terminal_mask = labels.dones[:, :-1] + n_nonterminal = max(1, terminal_mask.sum()) + # check type of labels + if isinstance(labels, ILQLBatch): + actions = labels.input_ids[:, 1:].gather(dim=1, index=labels.actions_ixs).unsqueeze(-1) + else: + actions = labels.decoder_input_ids[:, 1:].unsqueeze(-1) + nactions = actions.shape[1] + bsize, _, dsize = logits.shape + + Q = [q.gather(-1, actions).squeeze(-1) for q in qs] + targetQs = [q.gather(-1, actions).squeeze(-1).detach() for q in target_qs] + targetQ = reduce(torch.minimum, targetQs) + + # values of current states + V = vs[:, :-1].squeeze() + # values of next states + Vnext = vs[:, 1:].squeeze() * labels.dones[:, 1:] + # target to fit Q + Q_ = labels.rewards + self.gamma * Vnext.detach() + + loss_qs = [((Qi - Q_) * terminal_mask).pow(2).sum() / n_nonterminal for Qi in Q] + loss_q = sum(loss_qs) + + targetQ = targetQ.detach() + + loss_v = ( + ( + (targetQ >= V).int() * self.tau * (targetQ - V).pow(2) + + (targetQ < V).int() * (1 - self.tau) * (targetQ - V).pow(2) + ) + * terminal_mask + ).sum() / n_nonterminal + + def cql_loss(q): + loss = F.cross_entropy(q.reshape(-1, dsize), actions.reshape(-1), reduction="none") + loss = loss.reshape(bsize, nactions) * terminal_mask + loss = loss.sum() / n_nonterminal + return loss + + loss_cql = sum(cql_loss(q) for q in qs) + + # select logits from continuations + action_logits = batched_index_select(logits, labels.actions_ixs, dim=1) + cross_entropy = F.cross_entropy( + action_logits.reshape(-1, dsize), + actions.reshape(-1), + reduction="none", + ).reshape(bsize, nactions) + + with torch.no_grad(): + awac_weight = torch.exp(self.beta * (targetQ - V)) + + loss_awac = torch.sum(cross_entropy * awac_weight * terminal_mask) / n_nonterminal + loss = loss_q + loss_v + self.cql_scale * loss_cql + self.awac_scale * loss_awac + + stats = dict( + losses=dict( + loss=loss.item(), + loss_q=loss_q.item(), + loss_v=loss_v.item(), + loss_cql=loss_cql.item(), + loss_awac=loss_awac.item(), + ), + values=get_tensor_stats(V, terminal_mask, n_nonterminal), + qvalues={str(ix): get_tensor_stats(Q[ix], terminal_mask, n_nonterminal) for ix in range(len(Q))}, + awac_weight=get_tensor_stats(awac_weight, terminal_mask, n_nonterminal), + ) + + return loss, flatten_dict(stats) + + +class ILQLHeads(nn.Module): + def __init__( + self, + hidden_size: int, + vocab_size: int, + two_qs: bool, + alpha: float, + dtype: type, + ): + super().__init__() + + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.two_qs = two_qs + self.alpha = alpha + self.v_head = make_head(self.hidden_size, 1, dtype) + + n_qs = 2 if self.two_qs else 1 + self.q_heads = nn.ModuleList(make_head(self.hidden_size, self.vocab_size, dtype) for _ in range(n_qs)) + self.target_q_heads = nn.ModuleList(deepcopy(q_head) for q_head in self.q_heads) + + for target_q_head in self.target_q_heads: + target_q_head.requires_grad_(False) + + def forward( + self, + hs: torch.Tensor, + states_ixs: torch.Tensor = None, + actions_ixs: torch.Tensor = None, + **kwargs, + ): + if states_ixs is not None: + states_hs = batched_index_select(hs, states_ixs, 1) + actions_hs = batched_index_select(hs, actions_ixs, 1) + else: + states_hs = actions_hs = hs + + qs = tuple(q_head(actions_hs) for q_head in self.q_heads) + target_qs = tuple(q_head(actions_hs) for q_head in self.target_q_heads) + vs = self.v_head(states_hs) + + return qs, target_qs, vs + + def _sync_target_q_heads(self, alpha): + for target_q_head, q_head in zip(self.target_q_heads, self.q_heads): + for target_param, copy_param in zip(target_q_head.parameters(), q_head.parameters()): + target_param.data.copy_((alpha * copy_param.data) + (1.0 - alpha) * target_param.data) + + def sync_target_q_heads(self): + if os.environ.get("DEEPSPEED_ZERO_STAGE", "0") == "3": + params = chain( + chain(q_head.parameters() for q_head in self.q_heads), + chain(q_head.parameters() for q_head in self.target_q_heads), + ) + + with deepspeed.zero.GatheredParameters(list(params), modifier_rank=0): + if deepspeed.comm.get_rank() == 0: + self._sync_target_q_heads(self.alpha) + else: + self._sync_target_q_heads(self.alpha) + + +class AutoModelForCausalLMWithILQLHeads(PreTrainedModelWrapper): + """An `AutoModel` class wrapper for `transformers` causal models wtih a language + modeling head and ILQL heads. + + References: + [1] Snell et al., "Offline RL for Natural Language Generation with Implicit Language Q Learning", + https://arxiv.org/abs/2206.11871, 2022 + """ + + _auto_model_parent_class = transformers.AutoModelForCausalLM + _supported_modules = ["ilql_heads"] + _supported_args = ["two_qs", "alpha"] + + def __init__( + self, + base_model: transformers.PreTrainedModel, + *, + two_qs: bool = True, + alpha: float = 0.99, + ): + super().__init__(base_model) + hidden_size = hf_get_hidden_size(self.base_model.config) + vocab_size = self.base_model.config.vocab_size + dtype = next(hf_get_lm_head(self.base_model).parameters()).dtype + self.two_qs = two_qs + self.alpha = alpha + self.ilql_heads = ILQLHeads(hidden_size, vocab_size, self.two_qs, self.alpha, dtype=dtype) + + def forward( + self, + input_ids, + attention_mask=None, + position_ids=None, + past_key_values=None, + actions_ixs=None, + states_ixs=None, + ): + forward_kwargs = self.get_compatible_forward_kwargs( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + ) + forward_kwargs["output_hidden_states"] = True + + outputs = self.base_model(**forward_kwargs) + qs, target_qs, vs = self.ilql_heads(outputs.hidden_states[-1], states_ixs=states_ixs, actions_ixs=actions_ixs) + + return outputs.logits, qs, target_qs, vs, outputs.past_key_values + + def generate( + self, + input_ids, + attention_mask=None, + position_ids=None, + past_key_values=None, + beta=1, + max_new_tokens=32, + max_length=1024, + temperature=1, + top_k=20, + logit_mask=None, + pad_token_id=None, + eos_token_id=None, + ): + """ + Generates samples akin to hf's `.generate` but with custom logp prepossessing: + changing token probabilities as to how advantageous they would be + according to value functions estimations. + """ + pad_token_id = pad_token_id if pad_token_id is not None else self.base_model.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.base_model.config.eos_token_id + + if attention_mask is None: + attention_mask = input_ids.not_equal(pad_token_id) + + if position_ids is None: + position_ids = attention_mask.cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask.eq(0), 0) + + samples = input_ids.clone() + max_new_tokens = min(max_new_tokens, max_length - input_ids.shape[1]) + + finished = torch.zeros(input_ids.shape[0], 1, dtype=torch.long, device=input_ids.device) + for _ in range(max_new_tokens): + out = self.forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + ) + + logits, _, target_qs, vs, past_key_values = out + if self.two_qs: + qs = torch.minimum(target_qs[0][:, -1, :], target_qs[1][:, -1, :]) + else: + qs = target_qs[:, -1, :] + + logits = logits[:, -1, :] + vs = vs[:, -1, :] + + if logit_mask is not None: + mask = logit_mask[input_ids[:, -1].squeeze().to(logit_mask.device)] + logits[torch.where(mask)] = -np.inf + + adv = qs - vs + pi_beta = F.log_softmax(logits, -1) + pi_top_k = topk_mask(pi_beta + beta * adv, top_k) + pi = F.softmax(pi_top_k / temperature, -1) + + input_ids = torch.multinomial(pi, num_samples=1) + input_ids = (1 - finished) * input_ids + finished * eos_token_id + finished = (input_ids == eos_token_id).long() + + samples = torch.hstack((samples, input_ids)) + attention_mask = torch.hstack((attention_mask, (input_ids != eos_token_id).long())) + position_ids = (position_ids[:, -1] + 1).view(-1, 1) + + if torch.all(finished): + break + + return samples + + def sync_target_q_heads(self): + self.ilql_heads.sync_target_q_heads() + + def state_dict(self, *args, **kwargs): + """ + Returns the state dictionary of the model. We add the state dictionary of the ilql heads + to the state dictionary of the wrapped model by prepending the key with `ilql_heads.`. + """ + base_model_state_dict = self.base_model.state_dict(*args, **kwargs) + ilql_heads_state_dict = self.ilql_heads.state_dict(*args, **kwargs) + for k, v in ilql_heads_state_dict.items(): + base_model_state_dict[f"ilql_heads.{k}"] = v + return base_model_state_dict + + def post_init(self, state_dict): + """ + We add the state dictionary of the ilql heads to the state dictionary of the wrapped model + by preprending the key with `ilql_heads.`. This function removes the `ilql_heads.` prefix from the + keys of the value head state dictionary. + """ + for k in list(state_dict.keys()): + if "ilql_heads." in k: + state_dict[k.replace("ilql_heads.", "")] = state_dict.pop(k) + self.ilql_heads.load_state_dict(state_dict, strict=False) + del state_dict + gc.collect() + + +class AutoModelForSeq2SeqLMWithILQLHeads(PreTrainedModelWrapper): + """This is a wrapper around huggingface AutoModelForSeq2Seq with two additional scalar heads""" + + _auto_model_parent_class = transformers.AutoModelForSeq2SeqLM + _supported_modules = ["ilql_heads"] + _supported_args = ["two_qs", "alpha"] + + def __init__( + self, + base_model: transformers.PreTrainedModel, + *, + two_qs: bool = True, + alpha: float = 0.99, + ): + super().__init__(base_model) + hidden_size = hf_get_hidden_size(self.base_model.config) + vocab_size = self.base_model.config.vocab_size + dtype = next(hf_get_lm_head(self.base_model).parameters()).dtype + self.two_qs = two_qs + self.alpha = alpha + self.ilql_heads = ILQLHeads(hidden_size, vocab_size, self.two_qs, self.alpha, dtype=dtype) + + def sync_target_q_heads(self): + self.ilql_heads.sync_target_q_heads() + + def state_dict(self, *args, **kwargs): + """ + Returns the state dictionary of the model. We add the state dictionary of the ilql heads + to the state dictionary of the wrapped model by prepending the key with `ilql_heads.`. + """ + base_model_state_dict = self.base_model.state_dict(*args, **kwargs) + ilql_heads_state_dict = self.ilql_heads.state_dict(*args, **kwargs) + for k, v in ilql_heads_state_dict.items(): + base_model_state_dict[f"ilql_heads.{k}"] = v + return base_model_state_dict + + def post_init(self, state_dict): + """ + We add the state dictionary of the ilql heads to the state dictionary of the wrapped model + by preprending the key with `ilql_heads.`. This function removes the `ilql_heads.` prefix from the + keys of the value head state dictionary. + """ + for k in list(state_dict.keys()): + if "ilql_heads." in k: + state_dict[k.replace("ilql_heads.", "")] = state_dict.pop(k) + self.ilql_heads.load_state_dict(state_dict, strict=False) + del state_dict + gc.collect() + + def forward( + self, + input_ids, + attention_mask=None, + decoder_input_ids=None, + past_key_values=None, + encoder_outputs=None, + actions_ixs=None, + states_ixs=None, + output_attentions=True, + output_hidden_states=True, + ): + forward_kwargs = self.get_compatible_forward_kwargs( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + past_key_values=past_key_values, + encoder_outputs=encoder_outputs, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + out = self.base_model(**forward_kwargs) + + hs = out.decoder_hidden_states[-1] + + logits = self.base_model.lm_head(hs) + qs, target_qs, vs = self.ilql_heads(hs, states_ixs=states_ixs, actions_ixs=actions_ixs) + encoder_outputs = (out.encoder_last_hidden_state, out.encoder_hidden_states, out.encoder_attentions) + return logits, qs, target_qs, vs, out.past_key_values, encoder_outputs + + def generate( + self, + input_ids, + attention_mask=None, + decoder_input_ids=None, + past_key_values=None, + encoder_outputs=None, + beta=1, + max_new_tokens=32, + max_length=1024, + temperature=1, + top_k=20, + logit_mask=None, + pad_token_id=None, + eos_token_id=None, + ): + """ + Generates samples akin to hf's `.generate` but with custom logp prepossessing: + changing token probabilities as to how advantageous they would be + according to value functions estimations. + """ + + if eos_token_id is None or pad_token_id is None: + raise ValueError("eos_token_id and pad_token_id must be provided") + + if attention_mask is None: + attention_mask = input_ids.not_equal(pad_token_id) + + samples = input_ids.clone() + max_new_tokens = min(max_new_tokens, max_length - input_ids.shape[1]) + if decoder_input_ids is None: + decoder_input_ids = input_ids.new_zeros(input_ids.shape[0], 1) + + finished = torch.zeros(input_ids.shape[0], 1, dtype=torch.long, device=input_ids.device) + for _ in range(max_new_tokens): + out = self.forward( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids[:, -1].unsqueeze(-1), + past_key_values=past_key_values, + encoder_outputs=encoder_outputs, + ) + logits, _, target_qs, vs, past_key_values, encoder_outputs = out + if self.two_qs: + qs = torch.minimum(target_qs[0][:, -1, :], target_qs[1][:, -1, :]) + else: + qs = target_qs[:, -1, :] + + logits = logits[:, -1, :] + vs = vs[:, -1, :] + adv = qs - vs + pi_beta = F.log_softmax(logits, -1) + pi_top_k = topk_mask(pi_beta + beta * adv, top_k) + pi = F.softmax(pi_top_k / temperature, -1) + next_tokens = torch.multinomial(pi, num_samples=1) + next_tokens = (1 - finished) * next_tokens + finished * eos_token_id + finished = (next_tokens == eos_token_id).long() | (next_tokens == pad_token_id).long() + decoder_input_ids = torch.cat([decoder_input_ids, next_tokens], dim=-1) + samples = decoder_input_ids + if torch.all(finished): + break + + return samples diff --git a/trlx/models/modeling_nemo_ilql.py b/trlx/models/modeling_nemo_ilql.py new file mode 100644 index 000000000..31ac49a8a --- /dev/null +++ b/trlx/models/modeling_nemo_ilql.py @@ -0,0 +1,786 @@ +# Extensible version of the GPT model +import sys +from copy import deepcopy +from functools import partial, reduce +from math import sqrt +from pathlib import Path +from typing import List, Mapping, Optional, Tuple, Union + +import torch +import torch.distributed +import torch.nn as nn +import torch.nn.functional as F +from apex.transformer import parallel_state, tensor_parallel +from apex.transformer.tensor_parallel.mappings import ( + gather_from_sequence_parallel_region, +) +from einops import rearrange +from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import ( + MegatronPretrainingBatchSampler, +) +from nemo.collections.nlp.models.language_modeling.megatron.gpt_model import ( + post_language_model_processing, +) +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import ( + MegatronGPTModel, +) +from nemo.collections.nlp.modules.common.megatron.module import ( + Float16Module, + MegatronModule, +) +from nemo.collections.nlp.modules.common.megatron.utils import ( + average_losses_across_data_parallel_group, + get_ltor_masks_and_position_ids, +) +from nemo.collections.nlp.modules.common.transformer.text_generation import ( + LengthParam, + OutputType, + SamplingParam, +) +from nemo.collections.nlp.parts.utils_funcs import get_last_rank + +from trlx.data.ilql_types import ILQLBatch, unflatten_dataclass +from trlx.models.modeling_ilql import ILQLConfig, batched_index_select +from trlx.utils import to_device, tree_map + + +class ParallelLinear(nn.Module): + """Linear layer parallelized over the longer dimension.""" + + def __init__( + self, + in_size: int, + out_size: int, + init_method=partial(nn.init.kaiming_uniform_, a=sqrt(5), nonlinearity="relu"), + use_cpu_initialization=False, + bias=True, + sequence_parallel=False, + gradient_accumulation_fusion=False, + gather_output=True, + input_is_parallel=False, + ): + super().__init__() + + no_async_tensor_model_parallel_allreduce = ( + parallel_state.get_tensor_model_parallel_world_size() == 1 or sequence_parallel + ) + + if in_size < out_size: + self.layer = tensor_parallel.ColumnParallelLinear( + in_size, + out_size, + gather_output=gather_output, + init_method=init_method, + skip_bias_add=False, + use_cpu_initialization=use_cpu_initialization, + bias=bias, + sequence_parallel_enabled=sequence_parallel, + no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce, + gradient_accumulation_fusion=gradient_accumulation_fusion, + ) + else: + self.layer = tensor_parallel.RowParallelLinear( + in_size, + out_size, + input_is_parallel=input_is_parallel, + init_method=init_method, + skip_bias_add=False, + use_cpu_initialization=use_cpu_initialization, + bias=bias, + sequence_parallel_enabled=sequence_parallel, + gradient_accumulation_fusion=gradient_accumulation_fusion, + ) + + def forward(self, x): + output, bias = self.layer(x) + if bias is not None: + return output + bias + return output + + +def make_parallel_head(n_embd: int, out: int, sequence_parallel=False) -> nn.Sequential: + """Returns a generic sequential model parallel MLP head.""" + parallel_intermediate = out < (n_embd * 2) + return nn.Sequential( + ParallelLinear( + n_embd, + n_embd * 2, + sequence_parallel=sequence_parallel, + gather_output=not parallel_intermediate, + ), + nn.ReLU(), + ParallelLinear( + n_embd * 2, + out, + sequence_parallel=sequence_parallel, + input_is_parallel=parallel_intermediate, + ), + ) + + +class ParallelILQLHeads(nn.Module): + def __init__( + self, + config: ILQLConfig, + hidden_size: int, + vocab_size: int, + sequence_parallel=False, + ): + super().__init__() + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.v_head = make_parallel_head(hidden_size, 1, sequence_parallel=sequence_parallel) + self.config = config + + n_qs = 2 if self.config.two_qs else 1 + + self.q_heads = nn.ModuleList(make_parallel_head(self.hidden_size, self.vocab_size) for _ in range(n_qs)) + + self.target_q_heads = nn.ModuleList(deepcopy(q_head) for q_head in self.q_heads) + self.target_q_heads.requires_grad_(False) + + def forward(self, hidden_states): + qs = tuple(q_head(hidden_states) for q_head in self.q_heads) + target_qs = tuple(q_head(hidden_states) for q_head in self.target_q_heads) + vs = self.v_head(hidden_states) + + qs, target_qs, vs = tree_map(lambda t: rearrange(t, "T N ... -> N T ..."), (qs, target_qs, vs)) + + return qs, target_qs, vs + + def _sync_target_q_heads(self, alpha: float): + for target_q_head, q_head in zip(self.target_q_heads, self.q_heads): + for target_param, copy_param in zip(target_q_head.parameters(), q_head.parameters()): + target_param.data.copy_((alpha * copy_param.data) + (1.0 - alpha) * target_param.data) + + def sync_target_q_heads(self): + self._sync_target_q_heads(self.config.alpha) + + +class LMHeads(MegatronModule): + def __init__(self, language_model, other_heads): + super().__init__() + # must be this attribute name + self.pre_process = language_model.pre_process + self.post_process = language_model.post_process + self.language_model = language_model + + self.other_heads = other_heads + + if hasattr(language_model, "word_embeddings"): + self.word_embeddings = language_model.word_embeddings + + # The tensor from the previous pipeline rank arrives via this method + def set_input_tensor(self, input_tensor): + return self.language_model.set_input_tensor(input_tensor) + + def word_embeddings_weight(self): + return self.language_model.word_embeddings_weight() + + def load_state_dict(self, lm_state_dict, strict=True): + """Load GPTModel state dict.""" + self.language_model.language_model.load_state_dict(lm_state_dict, strict=strict) + + def forward( + self, + *args, + get_key_value=False, + forward_method_parallel_output=None, + **kwargs, + ): + lm_output = self.language_model(*args, get_key_value=get_key_value, **kwargs) + logits = post_language_model_processing( + lm_output, + labels=None, + logit_weights=self.language_model.word_embeddings_weight(), + get_key_value=get_key_value, + parallel_output=False, # self.language_model.parallel_output, + forward_method_parallel_output=forward_method_parallel_output, + fp16_lm_cross_entropy=self.language_model.fp16_lm_cross_entropy, + return_logits=True, + sequence_parallel=self.language_model.sequence_parallel, + gradient_accumulation_fusion=self.language_model.gradient_accumulation_fusion, + ) + + if get_key_value: + logits, presents = logits + lm_output, lm_output_presents = lm_output + + heads_output = self.other_heads(lm_output) + return logits, heads_output + + +def unwrap_float16_module(module): + if isinstance(module, Float16Module): + return module.module + return module + + +def reshard_for_pipeline_parallelism(num_layers, state_dict): + """Filter out the layers that are not in the current pipeline stage + and shift the layer ids to match the local stage layer ids.""" + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + pp_size = parallel_state.get_pipeline_model_parallel_world_size() + + stage_layers = num_layers // pp_size + pp_offset = pp_rank * stage_layers + + encoder_layers_key = "model.language_model.encoder.layers." + + def filter_in_pp_rank(key): + if key.startswith(encoder_layers_key): + layer_idx = int(key.split(".")[4]) + return pp_offset <= layer_idx < (pp_offset + stage_layers) + elif key.startswith("model.language_model.encoder.final_layernorm") and not pp_rank == (pp_size - 1): + return False + else: + return True + + def shift_layer_idx(key): + """If the key is for a transformer layer, shift down the layer index to select the + correct layer for this pipeline stage.""" + if key.startswith(encoder_layers_key): + layer_idx = int(key.split(".")[4]) + return f"{encoder_layers_key}{str(layer_idx - pp_offset)}.{'.'.join(key.split('.')[5:])}" + else: + return key + + state_dict = {shift_layer_idx(k): v for k, v in state_dict.items() if filter_in_pp_rank(k)} + + return state_dict + + +class ILQLGPT(MegatronGPTModel): + ilql_config: ILQLConfig + + def __init__(self, ilql_config, metric_fn=None, **kwargs): + self.ilql_config = ilql_config + self.metric_fn = metric_fn + super().__init__(**kwargs) + if len(list(self.parameters())) == 0: + raise ValueError("No parameters in model") + + self._ori_activations_checkpoint_granularity = self.cfg.get("activations_checkpoint_granularity", None) + self._ori_activations_checkpoint_method = self.cfg.get("activations_checkpoint_method", None) + self._ori_activations_checkpoint_num_layers = self.cfg.get("activations_checkpoint_num_layers", None) + + @classmethod + def list_available_models(cls) -> Optional[Mapping[str, str]]: + return None + + def build_train_valid_test_datasets(self): + pass + + def build_data_loader(self, dataset, collate_fn, consumed_samples=0): + dp_rank = parallel_state.get_data_parallel_rank() + dp_size = parallel_state.get_data_parallel_world_size() + print( + f"Building data loader for {type(dataset)=} {len(dataset)=} {dp_rank=} {dp_size=}", + file=sys.stderr, + ) + batch_sampler = MegatronPretrainingBatchSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=self.cfg.micro_batch_size, + global_batch_size=self.cfg.global_batch_size, + data_parallel_rank=dp_rank, + data_parallel_size=dp_size, + drop_last=True, + ) + return torch.utils.data.DataLoader( + dataset, + batch_sampler=batch_sampler, + # For some reason this causes a crash when using >0 workers + # with grad accumulation > 1 + num_workers=0, + pin_memory=True, + collate_fn=collate_fn, + ) + + def set_train_dataset(self, train_dataset, collate_fn): + self._train_dataset = train_dataset + self._train_collate_fn = collate_fn + + def set_valid_dataset(self, valid_dataset, collate_fn): + self._valid_dataset = valid_dataset + self._valid_collate_fn = collate_fn + + # Called by superclass to build data loaders + def setup_training_data(self, _): + if hasattr(self, "_train_dataset"): + self._train_dl = self.build_data_loader(self._train_dataset, self._train_collate_fn) + + def setup_validation_data(self, _): + if hasattr(self, "_valid_dataset"): + self._validation_dl = self.build_data_loader(self._valid_dataset, self._valid_collate_fn) + + def load_from_pretrained(self, checkpoint_dir): + mp_rank = parallel_state.get_tensor_model_parallel_rank() + rank_subfolder = f"mp_rank_{mp_rank:02d}" + rank_params = Path(checkpoint_dir) / rank_subfolder / "model_weights.ckpt" + print(f"Loading from {rank_params}") + state_dict = torch.load(rank_params) + + state_dict = reshard_for_pipeline_parallelism(self.cfg.num_layers, state_dict) + + def trim_key(key, prefix): + assert key.startswith(prefix), f"key {key} in state_dict does not start with {prefix}" + return key[len(prefix) :] + + lm_state_dict = {trim_key(k, "model.language_model."): v for k, v in state_dict.items()} + + encoder_state_dict = {trim_key(k, "encoder."): v for k, v in lm_state_dict.items() if k.startswith("encoder.")} + + lm_state_dict = {**lm_state_dict, "encoder": encoder_state_dict} + + unwrap_float16_module(self.model).load_state_dict(lm_state_dict, strict=True) + print(f"Loaded from pretrained {rank_params}") + + def model_provider_func(self, pre_process: bool, post_process: bool): + """ + Model construction for Apex Pipeline Parallelism. + Each rank will construct the model but inside the model, + only the relevant layers for that rank should be constructed. + On the first rank, pre_process will be True + On the last rank, post_process will be True + """ + gpt = super().model_provider_func(pre_process, post_process=post_process) + # This disables post-processing the lm output to the vocab + gpt.post_process = False + # This enables the final layernorm in the GPT model if there is one + gpt.language_model.post_process = post_process + # If running on the last pipeline stage, add the ILQL heads + if post_process: + parallel_ilql_heads = ParallelILQLHeads( + self.ilql_config, + self.cfg.hidden_size, + self.padded_vocab_size, + self.cfg.sequence_parallel, + ) + + return LMHeads( + gpt, + parallel_ilql_heads, + ) + else: + return gpt + + # Adapted from NeMo + # https://github.com/NVIDIA/NeMo/blob/r1.13.0/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L259 + def training_step(self, batch: ILQLBatch, batch_idx: int): # noqa: C901 + """ + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + Batch should be a list of microbatches and those microbatches should on CPU. + Microbatches are then moved to GPU during the pipeline. + The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + """ + # we zero grads here because we also call backward in the apex fwd/bwd functions + self._optimizer.zero_grad() + + if parallel_state.is_pipeline_first_stage(ignore_virtual=True) or parallel_state.is_pipeline_last_stage( + ignore_virtual=True + ): + # we prepare the micro batches for the apex fwd/bwd function + batch_for_pipeline = batch + else: + # The intermediate pipeline stages do not need any inputs from data loader + # GPT3 uses decoder with AttnMask:causal, thus doesn't need attention_mask + batch_for_pipeline = None + + # Pipeline stages will transfer this shape tensor to and from the + # previous and next stages + # The model must output a tensor of this shape if not the last pipeline + # stage. The model is given input of this shape if not the first pipeline + # stage via .set_input_tensor + tensor_shape = [ + self.cfg.encoder_seq_length, + self.cfg.micro_batch_size, + self.cfg.hidden_size, + ] + + # handle asynchronous grad reduction + if self.with_distributed_adam: + if self.megatron_amp_o2: + # copy grads to main grad + def custom_sync_context_handler(): + return self._optimizer.no_sync(greedy_grad_copy=True) + + else: + # keep grad tensors around + def custom_sync_context_handler(): + return self._optimizer.no_sync(greedy_grad_copy=False) + + else: + if self.megatron_amp_o2 and not self.cfg.get("sequence_parallel", False): + custom_sync_context_handler = self._optimizer.no_sync + else: + # TODO: enable async grad all reduce for O1/autocast mixed precision training + custom_sync_context_handler = None + + # run forward and backwards passes for an entire global batch + # we do this inside training_step to support pipeline parallelism + # This gets the correct fwd/bwd pipeline step depending on the pipeline + # parallelism configuration + fwd_bwd_function = self._get_fwd_bwd_function() + + last_stage_output = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(), + batch=batch_for_pipeline, + model=self.model, + forward_only=False, + tensor_shape=tensor_shape, + dtype=self.autocast_dtype, + grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, + custom_sync_context_handler=custom_sync_context_handler, + sequence_parallel_enabled=self.cfg.get("sequence_parallel", False), + sync_batch_comm=self.cfg.get("sync_batch_comm", False), + num_micro_batches_with_partial_activation_checkpoints=self.cfg.get( + "num_micro_batches_with_partial_activation_checkpoints", None + ), + ) + + # only the last stages of the pipeline return losses + if last_stage_output: + # average loss across micro batches + outputs = {k: [output[k] for output in last_stage_output] for k in last_stage_output[0].keys()} + outputs = {k: torch.concat([torch.as_tensor(vi).unsqueeze(0) for vi in v]) for k, v in outputs.items()} + + mean_outputs = {k: v.mean() for k, v in outputs.items()} + loss_mean = mean_outputs["avg_loss"] + else: + mean_outputs = {} + loss_mean = torch.tensor(0.0).cuda() + + # when using sequence parallelism, the sequence parallel layernorm grads must be all-reduced + if self.cfg.get("tensor_model_parallel_size", 1) > 1 and self.cfg.get("sequence_parallel", False): + self.allreduce_sequence_parallel_gradients() + if self.with_distributed_adam: + # launch grad reductions + # Note: grads in first pipeline stage have already been + # reduced + if not parallel_state.is_pipeline_first_stage(): + self.reduce_overlap_gradients() + elif self.megatron_amp_o2: + # when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously) + if self.cfg.get("pipeline_model_parallel_size", 1) > 1 or self.cfg.get("sequence_parallel", False): + # main grads are stored in the MainParamsOptimizer wrapper + self._optimizer.allreduce_main_grads() + else: + # async grad allreduce is not currently implemented for O1/autocasting mixed precision training + # so we all-reduce gradients after the pipeline + self.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf) + + if self.cfg.get("pipeline_model_parallel_size", 1) > 1: + # when using pipeline parallelism the first and last stage must keep embeddings in sync + self.allreduce_first_last_embeddings() + + # we can only log on one rank if it is rank zero so we broadcast from last rank + # we can avoid this broadcast by updating the PTL log function to accept specific ranks + torch.distributed.broadcast(loss_mean, get_last_rank()) + + if self.cfg.precision == 16: + loss_scale = self.trainer.precision_plugin.scaler._scale + if loss_scale is not None: + self.log("loss_scale", loss_scale) + + self.log( + "reduced_train_loss", + loss_mean, + prog_bar=True, + rank_zero_only=True, + ) + + for k, v in mean_outputs.items(): + if k != "avg_loss": + self.log(k, v) + + self.log( + "global_step", + float(self.trainer.global_step), + prog_bar=True, + rank_zero_only=True, + ) + + if self.trainer.global_step % self.ilql_config.steps_for_target_q_sync == 0 and self.trainer.global_step > 0: + if parallel_state.is_pipeline_last_stage(): + unwrap_float16_module(self.model).other_heads.sync_target_q_heads() + + return loss_mean + + def activation_checkpointing_(self, enable: bool): + def toggle_checkpointing(module): + if hasattr(module, "activations_checkpoint_granularity"): + if enable: + module.activations_checkpoint_granularity = self._ori_activations_checkpoint_granularity + else: + module.activations_checkpoint_granularity = None + + if hasattr(module, "activations_checkpoint_method"): + if enable: + module.activations_checkpoint_method = self._ori_activations_checkpoint_method + else: + module.activations_checkpoint_method = None + + if hasattr(module, "activations_checkpoint_num_layers"): + if enable: + module.activations_checkpoint_num_layers = self._ori_activations_checkpoint_num_layers + else: + module.activations_checkpoint_num_layers = None + + self.model.apply(toggle_checkpointing) + + if enable: + self.cfg.activations_checkpoint_granularity = self._ori_activations_checkpoint_granularity + self.cfg.activations_checkpoint_method = self._ori_activations_checkpoint_method + self.cfg.activations_checkpoint_num_layers = self._ori_activations_checkpoint_num_layers + else: + self.cfg.activations_checkpoint_granularity = None + self.cfg.activations_checkpoint_method = None + self.cfg.activations_checkpoint_num_layers = None + + # TODO: replace this with less magical code + def sequence_parallel_(self, enabled: bool): + self.cfg.sequence_parallel = enabled + + def toggle_sp(m): + if hasattr(m, "sequence_parallel"): + m.sequence_parallel = enabled + + # for the Row/ColumnParallelLinear layers + if hasattr(m, "sequence_parallel_enabled"): + if hasattr(m, "input_is_parallel"): + m.sequence_parallel_enabled = enabled and m.input_is_parallel + elif hasattr(m, "gather_output"): + m.sequence_parallel_enabled = enabled and not m.gather_output + else: + m.sequence_parallel_enabled = enabled + + self.model.apply(toggle_sp) + + def validation_step(self, batch: Tuple[List[int], List[int]], batch_idx: int): + if self.metric_fn is None: + raise ValueError("Must set metric_fn to use validation") + + sp_was_enabled = self.cfg.get("sequence_parallel", False) + if sp_was_enabled: + self.sequence_parallel_(False) + + activations_checkpointing_was_enabled = self.cfg.get("activations_checkpoint_granularity", None) is not None + + if activations_checkpointing_was_enabled: + self.activation_checkpointing_(False) + + input_ids, lengths = batch + input_ids, lengths = torch.as_tensor(input_ids), torch.as_tensor(lengths) + + input_ids, lengths = to_device((input_ids, lengths), torch.cuda.current_device(), non_blocking=True) + + max_new_tokens = self.ilql_config.gen_kwargs.get("max_new_tokens", 64) + + gen = self.generate((input_ids, lengths), dict(max_length=max_new_tokens, min_length=0)) + + metrics = self.metric_fn(gen["sentences"]) + + metric_keys, metric_values = zip(*metrics.items()) + + columns = ["sentences", *metric_keys] + rows = list(zip(gen["sentences"], *metric_values)) + + avg_metrics = {f"avg_{k}": torch.as_tensor(v).mean() for k, v in metrics.items()} + + if activations_checkpointing_was_enabled: + self.activation_checkpointing_(True) + + if sp_was_enabled: + self.sequence_parallel_(True) + + # NeMo generate resets the microbatch calculator + from apex.transformer.pipeline_parallel.utils import ( + _reconfigure_microbatch_calculator, + ) + from nemo.utils import AppState + + _reconfigure_microbatch_calculator( + rank=AppState().global_rank, + rampup_batch_size=None, + global_batch_size=self.cfg.global_batch_size, + micro_batch_size=self.cfg.micro_batch_size, + data_parallel_size=AppState().data_parallel_size, + ) + + return avg_metrics, (rows, columns) + + def validation_epoch_end(self, outputs: List[Tuple[dict, Tuple[List[str], List[str]]]]): + metrics, tables = zip(*outputs) + _, columns = tables[0] + rows = [r for trows, _ in tables for r in trows] + + self.logger.log_text(key="samples", columns=columns, data=rows) + + outputs_soa = {k: torch.as_tensor([d[k] for d in metrics]) for k in metrics[0].keys()} + # this assumes all validation microbatches are the same size + avg_outputs = {k: v.mean() for k, v in outputs_soa.items()} + for k, v in avg_outputs.items(): + self.log( + f"val_metrics/{k}", + v, + prog_bar=True, + rank_zero_only=True, + sync_dist=True, + ) + + # Need to override this otherwise distributed fused adam won't work + # with frozen layers + def parameters(self): + return (p for p in self.model.parameters() if p.requires_grad) + + def get_forward_output_and_loss_func(self, validation_step=False): + def fwd_output_and_loss_func(batch: List[torch.Tensor], model, checkpoint_activations_all_layers=None): + # On first and last pipeline stages, the input data is passed in + if batch is not None: + batch = unflatten_dataclass(ILQLBatch)(batch) + batch = to_device(batch, torch.cuda.current_device(), non_blocking=True) + + inputs = batch.input_ids + pad_by = self.cfg.encoder_seq_length - inputs.shape[1] + inputs = torch.nn.functional.pad(inputs, (0, pad_by), value=self.tokenizer.eos_id) + + ( + attention_mask, + loss_mask, + position_ids, + ) = get_ltor_masks_and_position_ids( + data=inputs, + eod_token=self.tokenizer.eos_id, + reset_position_ids=False, + reset_attention_mask=False, + eod_mask_loss=False, + ) + + model_output = model( + input_ids=inputs, + position_ids=position_ids.long(), + attention_mask=attention_mask, + ) + else: + # In-between stages are given data via the pipeline engine + # Still need to specify thes arguments to avoid errors + model_output = model(input_ids=None, position_ids=None, attention_mask=None) + + def gather_ntc(t: torch.Tensor): + """Gather sequence parallel tensor [batch, seq, hidden]""" + t = rearrange(t, "N T ... -> T N ...") + t = gather_from_sequence_parallel_region(t, to_model_parallel=False) + t = rearrange(t, "T N ... -> N T ...") + return t + + def loss_func(model_output): + # # TODO: implement this in a sequence parallel way + logits, (qs, target_qs, vs) = model_output + + if self.cfg.sequence_parallel: + qs, target_qs, vs = tree_map(gather_ntc, (qs, target_qs, vs)) + + qs = tree_map( + lambda t: batched_index_select(t, batch.actions_ixs, 1), + qs, + ) + + target_qs = tree_map( + lambda t: batched_index_select(t, batch.actions_ixs, 1), + target_qs, + ) + + vs = batched_index_select(vs, batch.states_ixs, 1) + + model_output = (logits, (qs, target_qs, vs)) + loss_for_mb, stats = self.ilql_config.loss(model_output, batch) + + reduced_loss = average_losses_across_data_parallel_group([loss_for_mb]) + + # TODO: figure out why this sync is needed (crashes otherwise) + torch.cuda.synchronize() + + return loss_for_mb, {"avg_loss": reduced_loss, **stats} + + return model_output, loss_func + + return fwd_output_and_loss_func + + def get_forward_output_only_func( + self, + set_inference_key_value_memory=False, + inference_max_sequence_len=None, + checkpoint_activations_all_layers=None, + ): + def fwd_output_only_func( + batch: torch.Tensor, + model, + ): + if batch is not None: + batch = to_device(batch, torch.cuda.current_device(), non_blocking=True) + + extra_arg = {} + + if len(batch) == 3: + tokens, attention_mask, position_ids = batch + else: + ( + tokens, + attention_mask, + position_ids, + set_inference_key_value_memory, + inference_max_sequence_len, + ) = batch + + extra_arg["set_inference_key_value_memory"] = set_inference_key_value_memory[0].item() + extra_arg["inference_max_sequence_len"] = inference_max_sequence_len[0].item() + + model_output = model( + input_ids=tokens, + position_ids=position_ids.long(), + attention_mask=attention_mask, + **extra_arg, + ) + else: + model_output = model(input_ids=None, position_ids=None, attention_mask=None) + + def ilql_postprocess(model_output): + model_output = tree_map(lambda t: t.float(), model_output) + + logits, (_, target_qs, vs) = model_output + + target_q = reduce(torch.minimum, target_qs) + advantage = target_q - vs + pi_beta = F.log_softmax(logits, -1) + beta = self.ilql_config.gen_kwargs.get("beta", 1.0) + + logits = pi_beta + beta * advantage + + return logits, {"logits": logits} + + return model_output, ilql_postprocess + + return fwd_output_only_func + + def generate( + self, + inputs: Union[List[str], torch.Tensor, List[dict]], + length_params: LengthParam, + sampling_params: SamplingParam = None, + ) -> OutputType: + if sampling_params is None: + sampling_params = { + "use_greedy": False, + "temperature": self.ilql_config.gen_kwargs.get("temperature", 1.0), + "top_k": self.ilql_config.gen_kwargs.get("top_k", 0), + "top_p": 0.9, + "repetition_penalty": 1.2, + "add_BOS": False, + "all_probs": False, + "compute_logprob": False, + } + + return super().generate(inputs, length_params, sampling_params) diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py new file mode 100644 index 000000000..787286123 --- /dev/null +++ b/trlx/models/modeling_ppo.py @@ -0,0 +1,1127 @@ +import gc +import inspect +from copy import deepcopy +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import transformers +from torchtyping import TensorType +from transformers.modeling_outputs import ModelOutput +from transformers.models.bloom import modeling_bloom +from transformers.models.opt import modeling_opt + +from trlx.data.method_configs import MethodConfig, register_method +from trlx.models.modeling_base import PreTrainedModelWrapper +from trlx.utils.modeling import ( + flatten_dict, + get_tensor_stats, + hf_get_decoder, + hf_get_decoder_blocks, + hf_get_decoder_final_norm, + hf_get_hidden_size, + hf_get_lm_head, + hf_get_num_hidden_layers, + make_head, + whiten, +) + +# KL Controllers + + +class AdaptiveKLController: + """Adaptive KL Controller as described in Ziegler et al. "Fine-Tuning Language Models from Human Preferences" + Reference: Section 2.2 https://arxiv.org/pdf/1909.08593.pdf#page=2 + Source: https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_policy.py + """ + + def __init__(self, init_kl_coef: float, target: float, horizon: int): + self.value = init_kl_coef + self.target = target + self.horizon = horizon + + def update(self, current: float, n_steps: int): + """Returns adaptively updated KL coefficient, βₜ₊₁. + Arguments: + current: The current KL value between the newest policy and the initial policy. + """ + proportional_error = np.clip(current / self.target - 1, -0.2, 0.2) # ϵₜ + mult = 1 + proportional_error * n_steps / self.horizon + self.value *= mult # βₜ₊₁ + + +class FixedKLController: + """Fixed KL controller.""" + + def __init__(self, kl_coef): + self.value = kl_coef + + def update(self, current: float, n_steps: int): + """Returns updated KL coefficient, βₜ₊₁. + Arguments: + current: The current KL value between the newest policy and the initial policy. + """ + pass + + +# PPO Configs + + +@dataclass +@register_method +class PPOConfig(MethodConfig): + """ + Config for PPO method + + :param ppo_epochs: Number of updates per batch + :type ppo_epochs: int + + :param num_rollouts: Number of experiences to observe before learning + :type num_rollouts: int + + :param init_kl_coef: Initial value for KL coefficient + :type init_kl_coef: float + + :param target: Target value for KL coefficient + :type target: float + + :param horizon: Number of steps for KL coefficient to reach target + :type horizon: int + + :param gamma: Discount factor + :type gamma: float + + :param lam: GAE lambda + :type lam: float + + :param cliprange: Clipping range for PPO policy loss (1 - cliprange, 1 + cliprange) + :type cliprange: float + + :param cliprange_value: Clipping range for predicted values + (observed values - cliprange_value, observed values + cliprange_value) + :type cliprange_value: float + + :param vf_coef: Value loss scale w.r.t policy loss + :type vf_coef: float + + :param gen_kwargs: Additioanl kwargs for the generation + :type gen_kwargs: Dict[str, Any] + + :param gen_experience_kwargs: if this is not None, then the experience is generated using this + :type gen_experience_kwargs: Dict[str, Any] + """ + + ppo_epochs: int + num_rollouts: int + chunk_size: int + init_kl_coef: float + target: float + horizon: int + gamma: float + lam: float + cliprange: float + cliprange_value: float + vf_coef: float + scale_reward: Optional[str] + ref_mean: Optional[float] + ref_std: Optional[float] + cliprange_reward: float + gen_kwargs: dict + gen_experience_kwargs: Optional[dict] = None + + def get_advantages_and_returns( + self, + values: TensorType["batch_size", "response_size"], + rewards: TensorType["batch_size", "response_size"], + response_length: int, + use_whitening: Optional[bool] = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Function that computes advantages and returns from rewards and values. + Calculated as in the original PPO paper: https://arxiv.org/abs/1707.06347 + Note that rewards may include a KL divergence loss term. + + Advantages looks like this: + Adv1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ... + - V1 + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ... + + Returns looks like this: + Ret1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ... + + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ... + + Args: + values: Tensor of shape (batch_size, response_size) + rewards: Tensor of shape (batch_size, response_size) + response_length: Length of the response sequence + use_whitening: Whether to use whitening (ie. normalize advantages) or not + """ + lastgaelam = 0 + advantages_reversed = [] + for t in reversed(range(response_length)): + nextvalues = values[:, t + 1] if t < response_length - 1 else 0.0 + delta = rewards[:, t] + self.gamma * nextvalues - values[:, t] + lastgaelam = delta + self.gamma * self.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], dim=1) + returns = advantages + values + if use_whitening: + advantages = whiten(advantages) + return advantages.detach(), returns + + def loss( + self, + logprobs: TensorType["batch_size", "response_size"], + values: TensorType["batch_size", "response_size"], + old_logprobs: TensorType["batch_size", "response_size"], + old_values: TensorType["batch_size", "response_size"], + advantages: TensorType["batch_size", "response_size"], + returns: TensorType["batch_size", "response_size"], + mask: TensorType["batch_size", "response_size"], + ): + """PPO objective function. + References: + - https://stable-baselines.readthedocs.io/en/master/modules/ppo2.html + """ + values_clipped = torch.clamp( + values, + old_values - self.cliprange_value, + old_values + self.cliprange_value, + ) + n = mask.sum() + + vf_loss1 = (values - returns) ** 2 + vf_loss2 = (values_clipped - returns) ** 2 + vf_loss = 0.5 * torch.sum(torch.max(vf_loss1, vf_loss2) * mask) / n + vf_clipfrac = torch.sum((vf_loss2 > vf_loss1).float() * mask) / n + + log_ratio = (logprobs - old_logprobs) * mask + ratio = torch.exp(log_ratio) + # Unbiased KL-div estimates (`k3`). Ref: http://joschu.net/blog/kl-approx.html + with torch.no_grad(): + approx_kl = torch.mean((ratio - 1) - log_ratio) + + pg_loss1 = -advantages * ratio + pg_loss2 = -advantages * torch.clamp( + ratio, + 1.0 - self.cliprange, + 1.0 + self.cliprange, + ) + pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / n + pg_clipfrac = torch.sum((pg_loss2 > pg_loss1).float() * mask) / n + + loss = pg_loss + self.vf_coef * vf_loss + + stats = dict( + losses=dict( + total_loss=loss.item(), + policy_loss=pg_loss.item(), + value_loss=vf_loss.item(), + ), + values=dict( + get_tensor_stats(values, mask, n), + values_error=torch.sum(((values - returns) * mask) ** 2) / n, + clipfrac=vf_clipfrac, + ), + old_values=get_tensor_stats(old_values, mask, n), + returns=get_tensor_stats(returns, mask, n), + policy=dict(approx_kl=approx_kl.item(), clipfrac=pg_clipfrac.item()), + ratio=(ratio * mask).sum() / n, + padding_percentage=n / mask.numel(), + ) + + return loss, flatten_dict(stats) + + +# CausalLM architectures + + +@dataclass +class CausalLMOutputWithValue(ModelOutput): + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + value: Optional[torch.FloatTensor] = None + + +class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper): + """An `AutoModel` class wrapper for `transformers` causal models that have a + language modeling head and a value head + """ + + _auto_model_parent_class = transformers.AutoModelForCausalLM + _supported_modules = ["v_head"] + _supported_args = [] + + def __init__( + self, + base_model: transformers.PreTrainedModel, + ): + super().__init__(base_model) + self.v_head = make_head(hf_get_hidden_size(self.base_model.config), 1) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + position_ids: Optional[List[torch.FloatTensor]] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithValue]: + forward_kwargs = self.get_compatible_forward_kwargs( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + forward_kwargs["output_hidden_states"] = True + forward_kwargs["return_dict"] = True + + outputs = self.base_model(**forward_kwargs) + value = self.v_head(outputs.hidden_states[-1]).squeeze(-1) + + if not return_dict: + outputs = (outputs.logits,) + outputs[1:] + (value,) + return outputs + + return CausalLMOutputWithValue(**outputs, value=value) + + def generate(self, *args, **kwargs) -> Union[ModelOutput, torch.LongTensor]: + return self.base_model.generate(*args, **kwargs) + + def state_dict(self, *args, **kwargs): + """ + Returns the state dictionary of the model. We add the state dictionary of the value head + to the state dictionary of the wrapped model by prepending the key with `v_head.`. + """ + base_model_state_dict = self.base_model.state_dict(*args, **kwargs) + v_head_state_dict = self.v_head.state_dict(*args, **kwargs) + for k, v in v_head_state_dict.items(): + base_model_state_dict[f"v_head.{k}"] = v + return base_model_state_dict + + def post_init(self, state_dict): + """ + Adds the state dictionary of the value head to the state dictionary of the wrapped model + by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the + keys of the value head state dictionary. + """ + for k in list(state_dict.keys()): + if "v_head." in k: + state_dict[k.replace("v_head.", "")] = state_dict.pop(k) + self.v_head.load_state_dict(state_dict, strict=False) + del state_dict + gc.collect() # noqa: E702 + + +class AutoModelForCausalLMWithHydraValueHead(AutoModelForCausalLMWithValueHead): + _supported_modules = ["v_head", "frozen_head"] + _supported_args = ["num_layers_unfrozen"] + + def __init__( + self, + base_model: transformers.PreTrainedModel, + *, + num_layers_unfrozen: int = -1, + ): + super().__init__(base_model) + self.num_layers_unfrozen = num_layers_unfrozen + if self.num_layers_unfrozen > 0: + config = self.base_model.config + branch_class = hf_get_branch_class(config) + self.frozen_head = branch_class( + self.base_model, + num_layers_unfrozen=self.num_layers_unfrozen, + ).eval() + + def forward_hydra( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + position_ids: Optional[List[torch.FloatTensor]] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[torch.FloatTensor, CausalLMOutputWithValue]: + forward_kwargs = self.get_compatible_forward_kwargs( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return_dict = forward_kwargs.get("return_dict", True) + forward_kwargs["return_dict"] = True + forward_kwargs["output_hidden_states"] = True + + outputs = self.forward(**forward_kwargs) + # Select the hidden state before the first branching layer + input_hidden_state = outputs.hidden_states[-(self.num_layers_unfrozen + 1)] + + output_shape = outputs.hidden_states[-1].size() + forward_kwargs.pop("input_ids", None) # Ignore `input_ids` for branch head + forward_kwargs.pop("inputs_embeds", None) # Ignore `inputs_embeds` for branch head + hydra_outputs = self.frozen_head(input_hidden_state, output_shape, **forward_kwargs) + + if not return_dict: + return hydra_outputs.logits + return hydra_outputs + + +class ModelBranch(transformers.PreTrainedModel): + """Implements the frozen upper trunk of the pretrained reference model used + when computing the PPO KL-divergence penalty. + """ + + def __init__( + self, + base_model: transformers.PreTrainedModel, + *, + num_layers_unfrozen: int, + ): + """ + Args: + base_model (transformers.PreTrainedModel): The pretrained model to extract upper trunk from + num_layers_unfrozen (int): The number of trainable layers + """ + super().__init__(base_model.config) + + # The branch is defined by the last `num_layers_unfrozen` layers of the pretrained model + decoder_blocks = deepcopy(hf_get_decoder_blocks(base_model)) + self.decoder_blocks = nn.ModuleList(list(decoder_blocks)[-num_layers_unfrozen:]) + self.final_norm = deepcopy(hf_get_decoder_final_norm(base_model)) + self.lm_head = deepcopy(hf_get_lm_head(base_model)) + + self.hidden_size = hf_get_hidden_size(self.config) + self.model_parallel = False + self.device_map = None + self.last_device = None + self.gradient_checkpointing = False + + # Freeze the entire branch + for parameter in self.parameters(): + parameter.requires_grad_(False) + + +class GPTModelBranch(ModelBranch): + def forward( # noqa: max-complexity + self, + hidden_states: torch.Tensor, # Takes as input hidden_states instead of input_ids + output_shape: torch.Tensor, # output_size given by main trunk + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = False, + ) -> Union[Tuple, CausalLMOutputWithValue]: + """Reference: + https://github.com/huggingface/transformers/blob/2411f0e465e761790879e605a4256f3d4afb7f82/src/transformers/models/gpt2/modeling_gpt2.py#L743 # noqa: E501 + """ + batch_size = hidden_states.size()[0] + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + device = hidden_states.device + + if past_key_values is None: + past_key_values = tuple([None] * len(self.decoder_blocks)) + + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = attention_mask[:, None, None, :] + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + if self.config.add_cross_attention and encoder_hidden_states is not None: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + head_mask = self.get_head_mask(head_mask, hf_get_num_hidden_layers(self.config)) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.decoder_blocks, past_key_values)): + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Assumes we are never training the branch + block_params = inspect.getfullargspec(block.forward).args + if "encoder_hidden_states" in block_params: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_norm(hidden_states) + + hidden_states = hidden_states.view(output_shape) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + outputs = (lm_logits,) + (None,) + (None,) + return outputs + + return CausalLMOutputWithValue( + logits=lm_logits, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class OPTModelBranch(ModelBranch): + def forward( # noqa: max-complexity + self, + hidden_states: torch.Tensor, + output_shape: torch.Tensor, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = False, + ) -> Union[Tuple, CausalLMOutputWithValue]: + """Reference: + https://github.com/huggingface/transformers/blob/bdb84e2bada3658f99c6a81c963ec562f8485151/src/transformers/models/opt/modeling_opt.py#L840 # noqa: E501 + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(hidden_states.shape[:2], dtype=torch.bool, device=hidden_states.device) + + input_shape = hidden_states.size()[:-1] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = modeling_opt._make_causal_mask( + input_shape, + hidden_states.dtype, + past_key_values_length=past_key_values_length, + ).to(hidden_states.device) + + if attention_mask is not None: + expanded_attn_mask = modeling_opt._expand_mask( + attention_mask, hidden_states.dtype, tgt_len=input_shape[-1] + ).to(hidden_states.device) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + attention_mask = combined_attention_mask + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for attn_mask, mask_name in zip([head_mask], ["head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.decoder_blocks)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.decoder_blocks)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.decoder_blocks): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if self.final_norm is not None: + hidden_states = self.final_norm(hidden_states) + + # TODO: Add output projection support + # https://github.com/huggingface/transformers/blob/699e90437f984d69ad3c9b891dd2e9d0fc2cffe4/src/transformers/models/opt/modeling_opt.py#L499 # noqa: E501 + # if self.project_out is not None: + # hidden_states = self.project_out(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + lm_logits = self.lm_head(hidden_states).contiguous() + + if not return_dict: + return tuple( + v + for v in [ + lm_logits, + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + ] + if v is not None + ) + + return CausalLMOutputWithValue( + logits=lm_logits, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class BloomModelBranch(ModelBranch): + def forward( # noqa: max-complexity + self, + hidden_states: torch.Tensor, # Takes as input hidden_states instead of input_ids + output_shape: torch.Tensor, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = False, + ) -> Union[Tuple, CausalLMOutputWithValue]: + """Reference: + https://github.com/huggingface/transformers/blob/2411f0e465e761790879e605a4256f3d4afb7f82/src/transformers/models/bloom/modeling_bloom.py#L623 # noqa: E501 + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, seq_length = hidden_states.shape[:2] + + if past_key_values is None: + past_key_values = tuple([None] * len(self.decoder_blocks)) + + head_mask = self.get_head_mask(head_mask, hf_get_num_hidden_layers(self.config)) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + alibi = modeling_bloom.build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype) + + combined_attention_mask = None + device = attention_mask.device + input_shape = (batch_size, seq_length) + _, src_length = input_shape + + if src_length > 1: + combined_attention_mask = modeling_bloom._make_causal_mask( + input_shape, + device=device, + past_key_values_length=past_key_values_length, + ) + + expanded_attn_mask = modeling_bloom._expand_mask(attention_mask, tgt_length=src_length) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask + ) + causal_mask = combined_attention_mask + + for i, (block, layer_past) in enumerate(zip(self.decoder_blocks, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + hidden_states = self.final_norm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return tuple( + v + for v in [ + lm_logits, + hidden_states, + presents, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) + + return CausalLMOutputWithValue( + logits=lm_logits, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +# Seq2Seq architectures + + +@dataclass +class Seq2SeqLMOutputWithValue(ModelOutput): + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + value: Optional[torch.FloatTensor] = None + + +class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper): + """An `AutoModel` class wrapper for `transformers` sequence-to-sequence + models that have a language modeling head and a value head + """ + + _auto_model_parent_class = transformers.AutoModelForSeq2SeqLM + _supported_modules = ["v_head"] + _supported_args = [] + + def __init__( + self, + base_model: transformers.PreTrainedModel, + ): + super().__init__(base_model) + self.v_head = make_head(hf_get_hidden_size(self.base_model.config), 1) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.FloatTensor] = None, + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = True, + output_hidden_states: Optional[bool] = True, + return_dict: Optional[bool] = None, + ) -> Seq2SeqLMOutputWithValue: + forward_kwargs = self.get_compatible_forward_kwargs( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + forward_kwargs["output_hidden_states"] = True + forward_kwargs["return_dict"] = True + + outputs = self.base_model(**forward_kwargs) + last_hidden_state = outputs.decoder_hidden_states[-1] + value = self.v_head(last_hidden_state).squeeze(-1) + + return Seq2SeqLMOutputWithValue(**outputs, value=value) + + def generate(self, *args, **kwargs) -> Union[ModelOutput, torch.LongTensor]: + return self.base_model.generate(*args, **kwargs) + + def state_dict(self, *args, **kwargs): + """ + Returns the state dictionary of the model. We add the state dictionary of the value head + to the state dictionary of the wrapped model by prepending the key with `v_head.`. + """ + base_model_state_dict = self.base_model.state_dict(*args, **kwargs) + v_head_state_dict = self.v_head.state_dict(*args, **kwargs) + for k, v in v_head_state_dict.items(): + base_model_state_dict[f"v_head.{k}"] = v + return base_model_state_dict + + def post_init(self, state_dict): + """ + We add the state dictionary of the value head to the state dictionary of the wrapped model + by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the + keys of the value head state dictionary. + """ + for k in list(state_dict.keys()): + if "v_head." in k: + state_dict[k.replace("v_head.", "")] = state_dict.pop(k) + self.v_head.load_state_dict(state_dict, strict=False) + del state_dict + gc.collect() # noqa: E702 + + +class AutoModelForSeq2SeqLMWithHydraValueHead(AutoModelForSeq2SeqLMWithValueHead): + _supported_modules = ["v_head", "frozen_head"] + _supported_args = ["num_layers_unfrozen"] + + def __init__( + self, + base_model: transformers.PreTrainedModel, + *, + num_layers_unfrozen: int = -1, + ): + super().__init__(base_model) + self.num_layers_unfrozen = num_layers_unfrozen + if self.num_layers_unfrozen > 0: + branch_class = T5Branch # TODO: Add support for other model branches + self.frozen_head = branch_class( + self.base_model, + num_layers_unfrozen=self.num_layers_unfrozen, + ).eval() + + def forward_hydra( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.FloatTensor] = None, + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Seq2SeqLMOutputWithValue: + forward_kwargs = self.get_compatible_forward_kwargs( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return_dict = forward_kwargs.get("return_dict", True) + forward_kwargs["output_hidden_states"] = True + forward_kwargs["return_dict"] = True + + outputs = self.forward(**forward_kwargs) + # Select the hidden state before the first branching layer + input_hidden_state = outputs.decoder_hidden_states[-(self.num_layers_unfrozen + 1)] + hydra_outputs = self.frozen_head( + hidden_states=input_hidden_state, + attention_mask=decoder_attention_mask, + encoder_hidden_states=outputs.encoder_last_hidden_state, + encoder_attention_mask=attention_mask, + use_cache=False, + output_attentions=False, + output_hidden_states=True, + return_dict=return_dict, + ) + + if not return_dict: + return hydra_outputs.logits + return hydra_outputs + + +class T5Branch(ModelBranch): + """Decoder only T5 branch""" + + def __init__( + self, + base_model: transformers.PreTrainedModel, + *, + num_layers_unfrozen: int, + ): + super().__init__(base_model, num_layers_unfrozen=num_layers_unfrozen) + self.dropout = hf_get_decoder(base_model).dropout + self.is_decoder = True + + def forward( # noqa: max-complexity + self, + hidden_states: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqLMOutputWithValue]: + """Reference: + https://github.com/huggingface/transformers/blob/bc21aaca789f1a366c05e8b5e111632944886393/src/transformers/models/t5/modeling_t5.py#L899 # noqa: E501 + """ + batch_size, seq_length = hidden_states.shape[:2] + input_shape = (batch_size, seq_length) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if attention_mask is None: + attention_mask = torch.ones(batch_size, seq_length, device=hidden_states.device) + if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones( + batch_size, encoder_seq_length, device=hidden_states.device, dtype=torch.long + ) + + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=hidden_states.device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + position_bias = None + encoder_decoder_position_bias = None + + for _, layer_module in enumerate(self.decoder_blocks): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + + hidden_states = self.final_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + sequence_output = hidden_states + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 # noqa: E501 + sequence_output = sequence_output * (self.config.d_model**-0.5) + + lm_logits = self.lm_head(sequence_output) + + if not return_dict: + return (lm_logits,) + + return Seq2SeqLMOutputWithValue( + logits=lm_logits, + decoder_hidden_states=all_hidden_states, + decoder_attentions=all_attentions, + ) + + +# Branch class utils + + +def hf_get_branch_class( + config: transformers.PretrainedConfig, +) -> "ModelBranch": + """Returns the model branch class for the given config.""" + gpt_branch_supported_archs = [ + "GPTJForCausalLM", + "GPT2LMHeadModel", + "GPTNeoForCausalLM", + "GPTNeoXForCausalLM", + ] + opt_branch_supported_archs = ["OPTForCausalLM"] + bloom_branch_supported_archs = ["BloomModel", "BloomForCausalLM"] + arch = config.architectures[0] + if arch in gpt_branch_supported_archs: + return GPTModelBranch + elif arch in opt_branch_supported_archs: + return OPTModelBranch + elif arch in bloom_branch_supported_archs: + return BloomModelBranch + else: + all_supported_archs = sum( + [ + gpt_branch_supported_archs, + opt_branch_supported_archs, + bloom_branch_supported_archs, + ], + [], + ) + raise ValueError( + f"Unsupported architecture: `{arch}`. The following architectures are " + f"available for model branching:\n{all_supported_archs}" + ) From 3377293010ace8d7ffddfbecebfd80fec206eaf2 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Tue, 14 Mar 2023 16:26:39 +0200 Subject: [PATCH 35/57] feat(base_trainer): enable w&b logging under ray --- trlx/trainer/accelerate_base_trainer.py | 12 +++++------- trlx/trainer/accelerate_ppo_trainer.py | 4 +--- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 1526cde41..9d74c96cf 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -80,7 +80,7 @@ def __init__(self, config, **kwargs): # noqa: C901 run_name = "/".join([script_name, model_name, num_gpus]) + f":{branch}" - if self.accelerator.is_main_process and not ray.is_initialized(): + if self.accelerator.is_main_process: config_dict = self.config.to_dict() dist_config = get_distributed_config(self.accelerator) config_dict["distributed"] = dist_config @@ -413,11 +413,10 @@ def evaluate(self): # noqa: C901 rich_table.add_row(*[str(significant(x)) for x in rows[ix]]) Console().print(rich_table) - if not ray.is_initialized(): - if self.config.train.tracker == "wandb": - import wandb + if self.config.train.tracker == "wandb": + import wandb - stats["samples"] = wandb.Table(columns, rows) + stats["samples"] = wandb.Table(columns, rows) self.nth_evaluation += 1 return stats @@ -527,8 +526,7 @@ def learn(self): # noqa: C901 checkpoint = Checkpoint.from_directory("state") session.report(filter_non_scalars(stats), checkpoint=checkpoint) - if not ray.is_initialized(): - self.accelerator.log(stats, step=self.iter_count) + self.accelerator.log(stats, step=self.iter_count) desc = " | ".join(f"{k}: {v:.2f}" for k, v in stats.items() if k.startswith("loss")) tbar.set_description(f"[{desc}]") diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 7d1c34f45..63572d978 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -4,7 +4,6 @@ from time import time from typing import Callable, List -import ray import torch import torch.nn.functional as F import transformers @@ -505,8 +504,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq stats["kl_ctl_value"] = self.kl_ctl.value stats["time/exp"] = exp_time - if not ray.is_initialized(): - self.accelerator.log(stats, step=iter_count) + self.accelerator.log(stats, step=iter_count) # Push samples and rewards to trainer's rollout storage self.push_to_store(ppo_rl_elements) From 0cf350ed5bbc36060b7d4153828729ea9a8c11e5 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Tue, 14 Mar 2023 16:27:07 +0200 Subject: [PATCH 36/57] feat(sweep): remove `default_config` from argparse --- trlx/sweep.py | 32 ++++++++++---------------------- 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/trlx/sweep.py b/trlx/sweep.py index d5ea40048..b2ec34374 100644 --- a/trlx/sweep.py +++ b/trlx/sweep.py @@ -13,7 +13,6 @@ from ray.tune.logger import CSVLoggerCallback from trlx.ray_train.accelerate_trainer import AccelerateTrainer -from trlx.utils import get_git_tag def get_param_space(config: dict): # noqa: C901 @@ -248,13 +247,13 @@ def create_report(target_metric, column_names, entity_name, project_name, group_ if best_config: best_config = best_config["train_loop_config"] - config = best_config.pop("default_config") + config = {} for name, value in best_config.items(): *layers, var = name.split(".") if layers: - d = config[layers[0]] + d = config.setdefault(layers[0], {}) for layer in layers[1:]: - d = d[layer] + d = d.setdefault(layer, {}) d[var] = value report.blocks = report.blocks + [ @@ -275,12 +274,7 @@ def create_report(target_metric, column_names, entity_name, project_name, group_ required=True, help="The config file defining the param_space.", ) - parser.add_argument( - "--default_config", - type=str, - required=True, - help="The default config file for the script.", - ) + parser.add_argument( "--accelerate_config", type=str, @@ -304,8 +298,8 @@ def create_report(target_metric, column_names, entity_name, project_name, group_ config = yaml.safe_load(f) tune_config = get_tune_config(config.pop("tune_config")) param_space = get_param_space(config) - with open(args.default_config) as f: - default_config = yaml.safe_load(f) + column_names = list(param_space.keys()) + target_metric = tune_config["metric"] if args.server_address: ray.init(address=f"ray://{args.server_address}") @@ -325,10 +319,8 @@ def create_report(target_metric, column_names, entity_name, project_name, group_ script = importlib.import_module(script_path) project_name = "sweep_" + script_path.split(".")[-1] - default_config["train"]["project_name"] = project_name - default_config["train"]["group_name"] = datetime.now().replace(microsecond=0).isoformat() - param_space["default_config"] = default_config.copy() - param_space["default_config"]["train"]["git_tag"] = get_git_tag() + param_space["train.project_name"] = project_name + param_space["train.group_name"] = datetime.now().replace(microsecond=0).isoformat() param_space_train = {"train_loop_config": param_space} tuner = tune.Tuner( @@ -349,12 +341,8 @@ def create_report(target_metric, column_names, entity_name, project_name, group_ ) results = tuner.fit() - group_name = default_config["train"]["group_name"] - entity_name = default_config["train"].get("entity_name", None) - - column_names = param_space.pop("default_config") - column_names = param_space.keys() - target_metric = tune_config["metric"] + group_name = param_space["train.group_name"] + entity_name = param_space.get("train.entity_name", None) create_report(target_metric, column_names, entity_name, project_name, group_name, results.get_best_result().config) From 722be0d9fcc45db66844643b8f45a6c0124883cc Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Tue, 14 Mar 2023 22:30:53 +0200 Subject: [PATCH 37/57] fix(default_configs): disable schedulers --- trlx/data/default_configs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py index 1f9297db2..3fa1b5c2b 100644 --- a/trlx/data/default_configs.py +++ b/trlx/data/default_configs.py @@ -29,7 +29,7 @@ def default_ppo_config(): optimizer=OptimizerConfig( name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) ), - scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1.0e-4)), + scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4)), method=PPOConfig( name="PPOConfig", num_rollouts=128, @@ -75,7 +75,7 @@ def default_ilql_config(): name="adamw", kwargs=dict(lr=5.0e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) ), scheduler=SchedulerConfig( - name="cosine_annealing", kwargs=dict(T_max=1000, eta_min=5.0e-5) # train.total_steps + name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=5.0e-5) # train.total_steps ), method=ILQLConfig( name="ilqlconfig", @@ -110,7 +110,7 @@ def default_sft_config(): name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) ), scheduler=SchedulerConfig( - name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1.0e-4) # train.total_steps + name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4) # train.total_steps ), method=SFTConfig( name="sftconfig", From 4cd3e6136363520672daabde0bedbf74cef3902a Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Tue, 14 Mar 2023 22:57:56 +0200 Subject: [PATCH 38/57] feat(setup.cfg): pin `ray` wheel, update `accelerate` `deepspeed` --- setup.cfg | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index 9c0ef02f2..1116f5008 100644 --- a/setup.cfg +++ b/setup.cfg @@ -11,9 +11,9 @@ license = MIT [options] packages = find: install_requires = - accelerate>=0.12.0 + accelerate>=0.17.1 datasets - deepspeed>=0.7.3 + deepspeed>=0.8.1 einops>=0.4.1 numpy>=1.23.2 torchtyping @@ -21,10 +21,10 @@ install_requires = tqdm rich wandb>=0.13.5 - ray>=2.0.1 tabulate>=0.9.0 networkx tritonclient + ray@https://ray-ci-artifact-pr-public.s3.amazonaws.com/5cceaa7c1216fea7de79c5bd0c84c96265dd1a7a/tmp/artifacts/.whl/ray-3.0.0.dev0-cp39-cp39-manylinux2014_x86_64.whl [options.extras_require] bnb = bitsandbytes From a7e7bb4f984dc1f0810c21be70a5dc382802258d Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Tue, 14 Mar 2023 23:08:03 +0200 Subject: [PATCH 39/57] merge: revert to upstream changes --- examples/notebooks/trlx_simulacra.ipynb | 3 +- examples/simulacra.py | 2 + .../configs/ppo_config_cnn_daily.yml | 61 -------------- .../t5_summarize_daily_cnn.py | 79 ++++++++++++++++- setup.cfg | 2 +- tests/test_ppo.py | 84 ------------------- tests/test_utils.py | 28 ++++++- trlx/trlx.py | 13 +-- 8 files changed, 113 insertions(+), 159 deletions(-) delete mode 100755 examples/summarize_daily_cnn/configs/ppo_config_cnn_daily.yml delete mode 100644 tests/test_ppo.py diff --git a/examples/notebooks/trlx_simulacra.ipynb b/examples/notebooks/trlx_simulacra.ipynb index 28e36aaa1..407d624f1 100644 --- a/examples/notebooks/trlx_simulacra.ipynb +++ b/examples/notebooks/trlx_simulacra.ipynb @@ -1158,7 +1158,8 @@ "source": [ "trlx.train(\n", " \"gpt2\",\n", - " dataset=(prompts, ratings),\n", + " samples=prompts,\n", + " rewards=ratings,\n", " eval_prompts=[\"Hatsune Miku, Red Dress\"] * 64,\n", ")" ] diff --git a/examples/simulacra.py b/examples/simulacra.py index cc28520d6..f4d6f82d8 100644 --- a/examples/simulacra.py +++ b/examples/simulacra.py @@ -6,6 +6,7 @@ from urllib.request import urlretrieve import trlx +from trlx.data.default_configs import default_ilql_config url = "https://raw.githubusercontent.com/JD-P/simulacra-aesthetic-captions/main/sac_public_2022_06_29.sqlite" dbpath = "sac_public_2022_06_29.sqlite" @@ -26,6 +27,7 @@ prompts, ratings = tuple(map(list, zip(*c.fetchall()))) trlx.train( + config=default_ilql_config(), samples=prompts, rewards=ratings, eval_prompts=["Hatsune Miku, Red Dress"] * 64, diff --git a/examples/summarize_daily_cnn/configs/ppo_config_cnn_daily.yml b/examples/summarize_daily_cnn/configs/ppo_config_cnn_daily.yml deleted file mode 100755 index 2134beadd..000000000 --- a/examples/summarize_daily_cnn/configs/ppo_config_cnn_daily.yml +++ /dev/null @@ -1,61 +0,0 @@ -train: - seq_length: 612 - epochs: 100 - total_steps: 100000 - batch_size: 12 - - checkpoint_interval: 10000 - eval_interval: 500 - save_best: False - - pipeline: "PromptPipeline" - trainer: "AcceleratePPOTrainer" - -model: - model_path: "google/flan-t5-large" - model_arch_type: "seq2seq" - num_layers_unfrozen: 2 - -tokenizer: - tokenizer_path: "google/flan-t5-large" - truncation_side: "right" - -optimizer: - name: "adamw" - kwargs: - lr: 1.0e-5 - betas: [0.9, 0.999] - eps: 1.0e-8 - weight_decay: 1.0e-6 - -scheduler: - name: "cosine_annealing" - kwargs: - T_max: 10000 - eta_min: 1.0e-6 - -method: - name: "ppoconfig" - num_rollouts: 512 - chunk_size: 12 - ppo_epochs: 4 - init_kl_coef: 0.05 - target: 6 - horizon: 10000 - gamma: 0.99 - lam: 0.95 - cliprange: 0.2 - cliprange_value: 0.2 - vf_coef: 1.0 - scale_reward: False - ref_mean: null - ref_std: null - cliprange_reward: 10 - gen_kwargs: - max_new_tokens: 100 - gen_experience_kwargs: - max_new_tokens: 100 - do_sample: True - temperature: 1.0 - top_k: 50 - top_p: 0.95 diff --git a/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py b/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py index 67863bf7d..4c3a56758 100755 --- a/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py +++ b/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py @@ -1,4 +1,3 @@ -import pathlib from typing import List from datasets import load_dataset @@ -6,7 +5,15 @@ from transformers import AutoTokenizer import trlx -from trlx.data.configs import TRLConfig +from trlx.data.configs import ( + ModelConfig, + OptimizerConfig, + SchedulerConfig, + TokenizerConfig, + TrainConfig, + TRLConfig, +) +from trlx.models.modeling_ppo import PPOConfig try: import evaluate @@ -15,8 +22,72 @@ "To run this example, please install the `evaluate` and `nltk` packages" "by running `pip install evaluate`" ) -config_path = pathlib.Path(__file__).parent / "configs/ppo_config_cnn_daily.yml" -config = TRLConfig.load_yaml(config_path) +config = TRLConfig( + train=TrainConfig( + seq_length=612, + epochs=100, + total_steps=100000, + batch_size=12, + checkpoint_interval=10000, + eval_interval=500, + pipeline="PromptPipeline", + trainer="AcceleratePPOTrainer", + ), + model=ModelConfig( + model_path="google/flan-t5-large", + model_arch_type="seq2seq", + num_layers_unfrozen=2, + ), + tokenizer=TokenizerConfig( + tokenizer_path="google/flan-t5-large", + truncation_side="right", + ), + optimizer=OptimizerConfig( + name="adamw", + kwargs={ + "lr": 1.0e-5, + "betas": [0.9, 0.999], + "eps": 1.0e-8, + "weight_decay": 1.0e-6, + }, + ), + scheduler=SchedulerConfig( + name="cosine_annealing", + kwargs={ + "T_max": 10000, + "eta_min": 1.0e-6, + }, + ), + method=PPOConfig( + name="PPOConfig", + num_rollouts=512, + chunk_size=12, + ppo_epochs=4, + init_kl_coef=0.05, + target=6, + horizon=10000, + gamma=0.99, + lam=0.95, + cliprange=0.2, + cliprange_value=0.2, + vf_coef=1.0, + scale_reward=None, + ref_mean=None, + ref_std=None, + cliprange_reward=10, + gen_kwargs={ + "max_new_tokens": 100, + }, + gen_experience_kwargs={ + "max_new_tokens": 100, + "do_sample": True, + "temperature": 1.0, + "top_k": 50, + "top_p": 0.95, + }, + ), +) + meteor = evaluate.load("meteor") # use meteor as the reward function diff --git a/setup.cfg b/setup.cfg index 1116f5008..d1ab5d199 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,7 @@ [metadata] name = trlx author = Alex Havrilla -version = 0.3.0 +version = 0.5.0 url = https://github.com/CarperAI/trlx description = A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF) long_description = file: README.md diff --git a/tests/test_ppo.py b/tests/test_ppo.py deleted file mode 100644 index d5be5fdbe..000000000 --- a/tests/test_ppo.py +++ /dev/null @@ -1,84 +0,0 @@ -import unittest - -import torch -from transformers import AutoTokenizer - -from trlx.data.configs import TRLConfig -from trlx.trainer.nn.ppo_models import CausalLMHydraWithValueHead -from trlx.utils.modeling import RunningMoments - - -# Note tests must start with "test_" -class TestHydraHead(unittest.TestCase): - @classmethod - def setUpClass(cls): - print("Testing Hydra model...") - config = TRLConfig.load_yaml("configs/test_config.yml") - cls.hydra_model = CausalLMHydraWithValueHead(config.model.model_path, config.model.num_layers_unfrozen) - - tokenizer = AutoTokenizer.from_pretrained(config.tokenizer.tokenizer_path) - tokenizer.pad_token = tokenizer.eos_token - tokenizer.padding_side = "left" - - cls.dummy_inputs = tokenizer( - "Once upon a time there was a happy goose named Louis. He liked to eat bananas.", - truncation=True, - padding="max_length", - max_length=4, - return_tensors="pt", - ) - - def test_lm_heads(self): - with torch.no_grad(): - unfrozen_outputs = TestHydraHead.hydra_model( - **TestHydraHead.dummy_inputs, return_dict=True, output_hidden_states=True - ) - unfrozen_logits = unfrozen_outputs.logits - last_hidden_states = unfrozen_outputs.hidden_states[-1].to(torch.float32) - frozen_logits = TestHydraHead.hydra_model.frozen_head.lm_head(last_hidden_states) - diff = torch.sum(unfrozen_logits - frozen_logits).item() - self.assertEqual(diff, 0) - - def test_frozen_head(self): - # Ensure that all parameters of the `hydra_model.frozen_head` are actually frozen - for parameter in TestHydraHead.hydra_model.frozen_head.parameters(): - self.assertTrue(parameter.requires_grad is False) - - def test_forward(self): - with torch.no_grad(): - unfrozen_outputs = TestHydraHead.hydra_model( - **TestHydraHead.dummy_inputs, return_dict=True, output_hidden_states=True - ) - unfrozen_last_hidden_states = unfrozen_outputs.hidden_states[-1] - unfrozen_logits = unfrozen_outputs.logits - - frozen_outputs = TestHydraHead.hydra_model.forward_hydra( - **TestHydraHead.dummy_inputs, return_dict=True, output_hidden_states=True - ) - frozen_last_hidden_states = frozen_outputs.hidden_states[-1] - frozen_logits = frozen_outputs.logits - - hs_diff = torch.sum(unfrozen_last_hidden_states - frozen_last_hidden_states).item() - logits_diff = torch.sum(unfrozen_logits - frozen_logits).item() - self.assertEqual(hs_diff, 0) - self.assertEqual(logits_diff, 0) - - -class TestStatistics(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.m = RunningMoments() - cls.a1 = torch.arange(100, dtype=float) - cls.a2 = torch.ones(100, dtype=float) - cls.a3 = torch.exp(torch.arange(10, dtype=float)) - cls.a4 = torch.tensor([-10, -1, 0, 1, 10], dtype=float) - - def test_running_moments(self): - assert torch.isclose(self.m.update(self.a1)[1], self.a1.std(unbiased=True), atol=1e-6) - assert torch.isclose(self.m.update(self.a2)[1], self.a2.std(unbiased=True), atol=1e-6) - assert torch.isclose(self.m.update(self.a3)[1], self.a3.std(unbiased=True), atol=1e-6) - assert torch.isclose(self.m.update(self.a4)[1], self.a4.std(unbiased=True), atol=1e-6) - - a = torch.hstack((self.a1, self.a2, self.a3, self.a4)) - assert torch.isclose(self.m.mean, a.mean(), atol=1e-6) - assert torch.isclose(self.m.std, a.std(unbiased=True), atol=1e-6) diff --git a/tests/test_utils.py b/tests/test_utils.py index 7a0af9959..f3c09c23b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,3 +1,5 @@ +import unittest + import accelerate import pytest import torch @@ -68,9 +70,9 @@ def test_hf_attr_getters(model_name: str): arch = transformers.AutoModelForCausalLM.from_config(config) arch_getters = [ - modeling_utils.hf_get_causal_base_model, - modeling_utils.hf_get_causal_final_norm, - modeling_utils.hf_get_causal_hidden_layers, + modeling_utils.hf_get_decoder, + modeling_utils.hf_get_decoder_final_norm, + modeling_utils.hf_get_decoder_blocks, modeling_utils.hf_get_lm_head, ] for get in arch_getters: @@ -125,3 +127,23 @@ def test_parse_delta_kwargs(model_name): ) for kwarg_mod in delta_kwargs["modified_modules"]: assert kwarg_mod.endswith("a") or kwarg_mod.endswith("b"), "Parsed modified module should contain ['a', 'b']" + + +class TestStatistics(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.m = modeling_utils.RunningMoments() + cls.a1 = torch.arange(100, dtype=float) + cls.a2 = torch.ones(100, dtype=float) + cls.a3 = torch.exp(torch.arange(10, dtype=float)) + cls.a4 = torch.tensor([-10, -1, 0, 1, 10], dtype=float) + + def test_running_moments(self): + assert torch.isclose(self.m.update(self.a1)[1], self.a1.std(unbiased=True), atol=1e-6) + assert torch.isclose(self.m.update(self.a2)[1], self.a2.std(unbiased=True), atol=1e-6) + assert torch.isclose(self.m.update(self.a3)[1], self.a3.std(unbiased=True), atol=1e-6) + assert torch.isclose(self.m.update(self.a4)[1], self.a4.std(unbiased=True), atol=1e-6) + + a = torch.hstack((self.a1, self.a2, self.a3, self.a4)) + assert torch.isclose(self.m.mean, a.mean(), atol=1e-6) + assert torch.isclose(self.m.std, a.std(unbiased=True), atol=1e-6) diff --git a/trlx/trlx.py b/trlx/trlx.py index fadbc2b83..f50753d14 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -3,6 +3,11 @@ from typing import Callable, Dict, Iterable, List, Optional, Tuple from trlx.data.configs import TRLConfig +from trlx.data.default_configs import ( + default_ilql_config, + default_ppo_config, + default_sft_config, +) from trlx.utils import set_seed from trlx.utils.loading import get_pipeline, get_trainer @@ -17,7 +22,6 @@ def train( # noqa: C901 eval_prompts: Optional[List[str]] = None, metric_fn: Optional[Callable[[List[str], List[str], List[str]], Dict[str, List[float]]]] = None, config: Optional[TRLConfig] = None, - logit_mask: Optional[List[List[bool]]] = None, stop_sequences: Optional[List[str]] = [], ): """ @@ -45,7 +49,6 @@ def train( # noqa: C901 Function to compute statistics on batches of generated samples. Its arguments are the same as in `reward_fn` (`samples`, `prompts`, `outputs`) but the return is dictionary with keys as metric's name and values and lists of numeric values per each sample in batch - logit_mask (Optional[List]): Bigram masking matrix stop_sequences (Optional[List[str]]): String sequences to trim generations (both for generating of experience and evaluation) up to its encounter in them. Generations will not contain them and also will also be right-stripped @@ -56,11 +59,11 @@ def train( # noqa: C901 "adapt some from `trlx/data/default_configs.py` instead" ) if reward_fn: - config = TRLConfig.load_yaml("configs/ppo_config.yml") + config = default_ppo_config() elif rewards: - config = TRLConfig.load_yaml("configs/ilql_config.yml") + config = default_ilql_config() else: - config = TRLConfig.load_yaml("configs/sft_config.yml") + config = default_sft_config() set_seed(config.train.seed) From 6fd5aefc0e9fa7a8cb0a869ee5176effcf25df63 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Wed, 15 Mar 2023 10:28:16 +0200 Subject: [PATCH 40/57] fix(scripts/sweep): remove `default_config` --- scripts/sweep-cw.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/sweep-cw.sh b/scripts/sweep-cw.sh index e0c8881e5..5e573a839 100644 --- a/scripts/sweep-cw.sh +++ b/scripts/sweep-cw.sh @@ -2,7 +2,7 @@ #SBATCH --job-name=trlx-sweep #SBATCH --account=trlx #SBATCH --partition=a100-cu117 -#SBATCH --nodes=4 +#SBATCH --nodes=2 #SBATCH --ntasks-per-node=1 #SBATCH --mem=0 #SBATCH --output=%j @@ -22,7 +22,7 @@ export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) cd $TRLX -source $TRLX/.env/bin/activate +source $TRLX/venv-with-pinned-ray/bin/activate ray start --head --port=6379 & @@ -36,5 +36,5 @@ sleep 10 ray status NUM_GPUS=16 -python -m trlx.sweep -y --config configs/sweeps/ppo_sweep.yml --default_config configs/ppo_config.yml --accelerate_config configs/accelerate/zero2-bf16.yaml --num_gpus $NUM_GPUS examples/ppo_sentiments.py +python -m trlx.sweep -y --config configs/sweeps/ppo_sweep.yml --accelerate_config configs/accelerate/zero2-bf16.yaml --num_gpus $NUM_GPUS examples/ppo_sentiments.py # python -m trlx.sweep -y --config configs/sweeps/ilql_sweep.yml --default_config configs/ilql_config.yml --accelerate_config configs/accelerate/zero2-bf16.yaml --num_gpus $NUM_GPUS examples/ilql_sentiments.py From ab5a86074d95def4d8c12b9b11f9e6e3a7f6e373 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Wed, 15 Mar 2023 10:29:18 +0200 Subject: [PATCH 41/57] merge(configs): remove yml files --- configs/ilql_config.yml | 50 -------------------------------- configs/nemo_ilql_config.yml | 52 --------------------------------- configs/ppo_config.yml | 56 ------------------------------------ configs/ppo_gptj.yml | 56 ------------------------------------ configs/sft_config.yml | 41 -------------------------- docs/source/trainer.rst | 15 ---------- 6 files changed, 270 deletions(-) delete mode 100644 configs/ilql_config.yml delete mode 100644 configs/nemo_ilql_config.yml delete mode 100644 configs/ppo_config.yml delete mode 100644 configs/ppo_gptj.yml delete mode 100644 configs/sft_config.yml diff --git a/configs/ilql_config.yml b/configs/ilql_config.yml deleted file mode 100644 index 4a1f4706a..000000000 --- a/configs/ilql_config.yml +++ /dev/null @@ -1,50 +0,0 @@ -train: - seq_length: 64 - batch_size: 128 - epochs: 100 - total_steps: 1000 - - checkpoint_interval: 1000 - eval_interval: 100 - - pipeline: "PromptPipeline" - trainer: "AccelerateILQLTrainer" - seed: 1000 - -model: - model_path: "gpt2" - num_layers_unfrozen: -1 - -tokenizer: - tokenizer_path: "gpt2" - truncation_side: "right" - -optimizer: - name: "adamw" - kwargs: - lr: 5.0e-5 - betas: [0.9, 0.95] - eps: 1.0e-8 - weight_decay: 1.0e-6 - -scheduler: - name: "cosine_annealing" - kwargs: - T_max: 100000000000 - eta_min: 0 - -method: - name: "ilqlconfig" - tau: 0.7 - gamma: 0.99 - cql_scale: 0.1 - awac_scale: 1 - alpha: 0.001 - beta: 0 - steps_for_target_q_sync: 5 - two_qs: true - gen_kwargs: - max_new_tokens: 56 - top_k: 20 - beta: 4 - temperature: 1.0 diff --git a/configs/nemo_ilql_config.yml b/configs/nemo_ilql_config.yml deleted file mode 100644 index 1d4cc71e2..000000000 --- a/configs/nemo_ilql_config.yml +++ /dev/null @@ -1,52 +0,0 @@ -train: - seq_length: 1024 - batch_size: 512 - epochs: 100 - total_steps: 200 - checkpoint_interval: 200 - eval_interval: 20 - - pipeline: "PromptPipeline" - trainer: "NeMoILQLTrainer" - trainer_kwargs: - pretrained_model: "/mnt/nvme/home/uwu/nemo-megatron-gpt-20B/" - megatron_cfg: "megatron_20b.yaml" - seed: 1000 - -model: - model_path: "gpt2" - num_layers_unfrozen: -1 - -tokenizer: - tokenizer_path: "gpt2" - truncation_side: "right" - -optimizer: - name: "adamw" - kwargs: - lr: 5.0e-5 - betas: [0.9, 0.95] - eps: 1.0e-8 - weight_decay: 1.0e-6 - -scheduler: - name: "cosine_annealing" - kwargs: - T_max: 2000 # train.total_steps - eta_min: 1.0e-6 - -method: - name: "ilqlconfig" - tau: 0.7 - gamma: 0.99 - cql_scale: 0.1 - awac_scale: 1 - alpha: 0.001 - beta: 0 - steps_for_target_q_sync: 5 - two_qs: True - gen_kwargs: - max_new_tokens: 56 - top_k: 20 - beta: 2 - temperature: 0.9 diff --git a/configs/ppo_config.yml b/configs/ppo_config.yml deleted file mode 100644 index aea895ffc..000000000 --- a/configs/ppo_config.yml +++ /dev/null @@ -1,56 +0,0 @@ -train: - seq_length: 1024 - epochs: 100 - total_steps: 10000 - batch_size: 64 - - checkpoint_interval: 10000 - eval_interval: 100 - - pipeline: "PromptPipeline" - trainer: "AcceleratePPOTrainer" - -model: - model_path: "lvwerra/gpt2-imdb" - num_layers_unfrozen: 2 - -tokenizer: - tokenizer_path: "gpt2" - truncation_side: "right" - -optimizer: - name: "adamw" - kwargs: - lr: 1.0e-4 - betas: [0.9, 0.95] - eps: 1.0e-8 - weight_decay: 1.0e-6 - -scheduler: - name: "cosine_annealing" - kwargs: - T_max: 100000000000 - eta_min: 0 - -method: - name: "ppoconfig" - num_rollouts: 128 - chunk_size: 128 - ppo_epochs: 4 - init_kl_coef: 0.05 - target: 6 - horizon: 10000 - gamma: 1 - lam: 0.95 - cliprange: 0.2 - cliprange_value: 0.2 - vf_coef: 1 - scale_reward: False - ref_mean: null - ref_std: null - cliprange_reward: 10 - gen_kwargs: - max_new_tokens: 40 - top_k: 0 - top_p: 1.0 - do_sample: True diff --git a/configs/ppo_gptj.yml b/configs/ppo_gptj.yml deleted file mode 100644 index 0595f7ded..000000000 --- a/configs/ppo_gptj.yml +++ /dev/null @@ -1,56 +0,0 @@ -train: - seq_length: 48 - epochs: 10 - total_steps: 80000 - batch_size: 8 - - checkpoint_interval: 1000000 - eval_interval: 16 - - pipeline: "PromptPipeline" - trainer: "AcceleratePPOTrainer" - -model: - model_path: "EleutherAI/gpt-j-6B" - num_layers_unfrozen: 2 - -tokenizer: - tokenizer_path: "gpt2" - -optimizer: - name: "adamw" - kwargs: - lr: 1.412e-4 - betas: [0.9, 0.95] - eps: 1.0e-8 - weight_decay: 1.0e-6 - -scheduler: - name: "cosine_annealing" - kwargs: - T_max: 80000 # train.total_steps - eta_min: 1.412e-4 - -method: - name: "ppoconfig" - num_rollouts: 8 - chunk_size: 8 - ppo_epochs: 4 - init_kl_coef: 0.2 - target: 6 - horizon: 10000 - gamma: 1 - lam: 0.95 - cliprange: 0.2 - cliprange_value: 0.2 - vf_coef: 0.2 - scale_reward: False - ref_mean: null - ref_std: null - cliprange_reward: 10 - gen_kwargs: - max_new_tokens: 48 - top_k: 0.0 - top_p: 0.7 - do_sample: True - temperature: 0.5 diff --git a/configs/sft_config.yml b/configs/sft_config.yml deleted file mode 100644 index 710c4c1b9..000000000 --- a/configs/sft_config.yml +++ /dev/null @@ -1,41 +0,0 @@ -train: - seq_length: 1024 - epochs: 100 - total_steps: 1000 - batch_size: 8 - - checkpoint_interval: 10000 - eval_interval: 100 - - pipeline: "PromptPipeline" - trainer: "AccelerateSFTTrainer" - -model: - model_path: "gpt2" - num_layers_unfrozen: -1 - -tokenizer: - tokenizer_path: "gpt2" - truncation_side: "right" - -optimizer: - name: "adamw" - kwargs: - lr: 1.0e-4 - betas: [0.9, 0.95] - eps: 1.0e-8 - weight_decay: 1.0e-6 - -scheduler: - name: "cosine_annealing" - kwargs: - T_max: 100000000000 - eta_min: 0 - -method: - name: "sftconfig" - gen_kwargs: - max_new_tokens: 40 - top_k: 0 - top_p: 1.0 - do_sample: True diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst index 0972cc5ff..6259c8b21 100644 --- a/docs/source/trainer.rst +++ b/docs/source/trainer.rst @@ -19,22 +19,7 @@ Note that new trainers must be registered with ``trlx.trainer.register_trainer`` .. autoclass:: trlx.trainer.accelerate_ppo_trainer.AcceleratePPOTrainer :members: -.. autoclass:: trlx.trainer.nn.ppo_models.CausalLMWithValueHead - :members: - -.. autoclass:: trlx.trainer.nn.ppo_models.GPTModelBranch - :members: - -.. autoclass:: trlx.trainer.nn.ppo_models.OPTModelBranch - :members: - -.. autoclass:: trlx.trainer.nn.ppo_models.CausalLMHydraWithValueHead - :members: - **ILQL** .. autoclass:: trlx.trainer.accelerate_ilql_trainer.AccelerateILQLTrainer :members: - -.. autoclass:: trlx.trainer.nn.ilql_models.CausalLMWithValueHeads - :members: From 9b262a3352db8345ec24891164c876bac1c9c51c Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Wed, 15 Mar 2023 10:39:44 +0200 Subject: [PATCH 42/57] merge(examples): upstream config usage --- README.md | 21 ++++++-- examples/architext.py | 14 ++--- examples/nemo_ilql_inference.py | 20 +++++-- examples/nemo_ilql_sentiments.py | 25 ++++++--- .../configs/ppo_config_summ_gptj.yml | 53 ------------------- 5 files changed, 54 insertions(+), 79 deletions(-) delete mode 100755 examples/summarize_rlhf/configs/ppo_config_summ_gptj.yml diff --git a/README.md b/README.md index da9ba405d..e140c21fb 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ trainer = trlx.train('gpt2', reward_fn=lambda samples, **kwargs: [sample.count(' #### Using a reward-labeled dataset ```python -trainer = trlx.train('EleutherAI/gpt-j-6B', dataset=[('dolphins', 'geese'), (1.0, 100.0)]) +trainer = trlx.train('EleutherAI/gpt-j-6B', samples=['dolphins', 'geese'], rewards=[1.0, 100.0]) ``` #### Trainers provide a wrapper over their underlying model @@ -57,14 +57,25 @@ trainer = trlx.train('EleutherAI/gpt-j-6B', dataset=[('dolphins', 'geese'), (1.0 trainer.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'), do_sample=True) ``` +#### Configure Hyperparameters + +```python +from trlx.data.default_configs import default_ppo_config, TrainConfig + +config = default_ppo_config() +config.model.model_path = 'EleutherAI/gpt-neox-20b' +config.train.seq_length = 32 +config.train.batch_size = 16 + +trainer = trlx.train(config=config, reward_fn=lambda samples, **kwargs: [float(int(sample)) for sample in samples]) +``` + #### Save the resulting model to a Hugging Face pretrained language model. (Ready to upload to the Hub!) ```python trainer.save_pretrained('/path/to/output/folder/') ``` -🩹 Warning: Only the `AcceleratePPOTrainer` can write HuggingFace transformers to disk with `save_pretrained` at the moment, as ILQL trainers require inference behavior currently unsupported by available `transformers` architectures. - #### Use 🤗 Accelerate to launch distributed training ```bash @@ -74,13 +85,13 @@ accelerate launch examples/simulacra.py #### Use NeMo-Megatron to launch distributed training -Follow the setup instructions in the [NeMo README](./trlx/trainer/nemo). +Follow the setup instructions in the [NeMo README](./trlx/models/). ```bash python examples/nemo_ilql_sentiments.py ``` -For more usage see the [NeMo README](./trlx/trainer/nemo) +For more usage see the [NeMo README](./trlx/models) #### Use Ray Tune to launch hyperparameter sweep diff --git a/examples/architext.py b/examples/architext.py index d854c4858..6e31f3497 100644 --- a/examples/architext.py +++ b/examples/architext.py @@ -1,11 +1,7 @@ # Toy example of optimizing textual interior designs to output the least number of rooms # Also see https://architext.design/ -import pathlib - -import yaml - import trlx -from trlx.data.configs import TRLConfig +from trlx.data.default_configs import default_ppo_config def reward_fn(samples, **kwargs): @@ -30,13 +26,9 @@ def reward_fn(samples, **kwargs): "[prompt] the kitchen is not adjacent to the bathroom [layout]", ] -config_path = pathlib.Path(__file__).parent.joinpath("../configs/ppo_config.yml") -with config_path.open() as f: - default_config = yaml.safe_load(f) - -def main(hparams={}): - config = TRLConfig.update(default_config, hparams) +def main(): + config = default_ppo_config() trlx.train(model_path="architext/gptj-162M", reward_fn=reward_fn, prompts=prompts, config=config) diff --git a/examples/nemo_ilql_inference.py b/examples/nemo_ilql_inference.py index 425a8cdb2..f172f6fbb 100644 --- a/examples/nemo_ilql_inference.py +++ b/examples/nemo_ilql_inference.py @@ -2,7 +2,6 @@ import sys from glob import glob -import yaml from nemo.collections.nlp.modules.common.megatron.megatron_init import ( fake_initialize_model_parallel, ) @@ -10,10 +9,24 @@ from nemo.utils.model_utils import inject_model_parallel_rank from omegaconf.omegaconf import OmegaConf -from trlx.data.configs import TRLConfig +from trlx.data.configs import TrainConfig +from trlx.data.default_configs import default_ilql_config from trlx.trainer.nemo_ilql_trainer import ILQLGPT, megatron_trainer -default_config = yaml.safe_load(open(os.path.dirname(__file__) + "/../configs/nemo_ilql_config.yml")) +default_config = default_ilql_config() + +trl_config = default_config.evolve( + train=TrainConfig( + **dict( + default_config.train.__dict__, + trainer="NeMoILQLTrainer", + trainer_kwargs=dict( + pretrained_model="/mnt/nvme/home/uwu/nemo-megatron-gpt-20B/", + megatron_cfg="megatron_20b.yaml", + ), + ), + ) +) def find_checkpoints(checkpoint_dir): @@ -23,7 +36,6 @@ def find_checkpoints(checkpoint_dir): def main(megatron_cfg_path, checkpoint_path): - trl_config = TRLConfig.update(default_config, {}) ilql_config = trl_config.method megatron_cfg = OmegaConf.load(megatron_cfg_path) diff --git a/examples/nemo_ilql_sentiments.py b/examples/nemo_ilql_sentiments.py index 82abe3b7b..34044622a 100644 --- a/examples/nemo_ilql_sentiments.py +++ b/examples/nemo_ilql_sentiments.py @@ -1,12 +1,10 @@ -import os from typing import Dict, List -import yaml from datasets import load_dataset from transformers import pipeline import trlx -from trlx.data.configs import TRLConfig +from trlx.data.default_configs import default_ilql_config def get_positive_score(scores): @@ -14,11 +12,25 @@ def get_positive_score(scores): return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] -default_config = yaml.safe_load(open(os.path.dirname(__file__) + "/../configs/nemo_ilql_config.yml")) +default_config = default_ilql_config() def main(hparams={}): - config = TRLConfig.update(default_config, hparams) + # Merge sweep config with default config if given + + config = default_config.evolve( + train=dict( + seq_length=1024, + batch_size=512, + total_steps=200, + trainer="NeMoILQLTrainer", + trainer_kwargs=dict( + pretrained_model="/mnt/nvme/home/uwu/nemo-megatron-gpt-20B/", + megatron_cfg="megatron_20b.yaml", + ), + ) + ) + config = config.evolve(**hparams) sentiment_fn = pipeline( "sentiment-analysis", @@ -36,7 +48,8 @@ def metric_fn(samples: List[str], **kwargs) -> Dict[str, List[float]]: imdb = load_dataset("imdb", split="train+test") trlx.train( - dataset=(imdb["text"], imdb["label"]), + samples=imdb["text"], + rewards=imdb["label"], eval_prompts=["I don't know much about Hungarian underground"] * 128, metric_fn=metric_fn, config=config, diff --git a/examples/summarize_rlhf/configs/ppo_config_summ_gptj.yml b/examples/summarize_rlhf/configs/ppo_config_summ_gptj.yml deleted file mode 100755 index 8055a49b5..000000000 --- a/examples/summarize_rlhf/configs/ppo_config_summ_gptj.yml +++ /dev/null @@ -1,53 +0,0 @@ -train: - seq_length: 550 - epochs: 50 - total_steps: 100000 - batch_size: 4 - - checkpoint_interval: 10000 - eval_interval: 200 - - pipeline: "PromptPipeline" - trainer: "AcceleratePPOTrainer" - -model: - model_path: "CarperAI/openai_summarize_tldr_sft" - num_layers_unfrozen: 8 - -tokenizer: - tokenizer_path: "gpt2" - truncation_side: "right" - -optimizer: - name: "adamw" - kwargs: - lr: 5.0e-6 - betas: [0.9, 0.999] - eps: 1.0e-8 - weight_decay: 0.01 - -scheduler: - name: "cosine_annealing" - kwargs: - T_max: 100000 - eta_min: 5.0e-6 - -method: - name: "ppoconfig" - num_rollouts: 128 - chunk_size: 16 - ppo_epochs: 4 - init_kl_coef: 0.1 - target: 6 - horizon: 10000 - gamma: 1 - lam: 0.95 - cliprange: 0.2 - cliprange_value: 0.2 - vf_coef: 0.2 - scale_reward: False - ref_mean: null - ref_std: null - cliprange_reward: 10 - gen_kwargs: - max_new_tokens: 50 From c7ac679b357e12f4ac9c2968d2c3ba4a534f948b Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Wed, 15 Mar 2023 10:48:29 +0200 Subject: [PATCH 43/57] fix(setup.cfg): condition ray's pinned wheel --- setup.cfg | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index d1ab5d199..0078d0095 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,7 +24,9 @@ install_requires = tabulate>=0.9.0 networkx tritonclient - ray@https://ray-ci-artifact-pr-public.s3.amazonaws.com/5cceaa7c1216fea7de79c5bd0c84c96265dd1a7a/tmp/artifacts/.whl/ray-3.0.0.dev0-cp39-cp39-manylinux2014_x86_64.whl + ray@https://ray-ci-artifact-pr-public.s3.amazonaws.com/5cceaa7c1216fea7de79c5bd0c84c96265dd1a7a/tmp/artifacts/.whl/ray-3.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl ; python_version=="3.8" + ray@https://ray-ci-artifact-pr-public.s3.amazonaws.com/5cceaa7c1216fea7de79c5bd0c84c96265dd1a7a/tmp/artifacts/.whl/ray-3.0.0.dev0-cp39-cp39-manylinux2014_x86_64.whl ; python_version=="3.9" + ray@https://ray-ci-artifact-pr-public.s3.amazonaws.com/5cceaa7c1216fea7de79c5bd0c84c96265dd1a7a/tmp/artifacts/.whl/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl ; python_version=="3.10" [options.extras_require] bnb = bitsandbytes From f9875f028088c1f3498607fc14f247d77361bfea Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 23 Mar 2023 14:59:36 -0700 Subject: [PATCH 44/57] Use `AccelerateTrainer` from Ray (#386) * Use `AccelerateTrainer` from Ray Signed-off-by: Antoni Baum * fix(sweep): `accelerate_config_path` -> `accelerate_config` --------- Signed-off-by: Antoni Baum Co-authored-by: reciprocated <56548574+reciprocated@users.noreply.github.com> --- setup.cfg | 6 +- trlx/ray_train/__init__.py | 0 trlx/ray_train/accelerate_trainer.py | 177 --------------------------- trlx/ray_train/launch.py | 94 -------------- trlx/sweep.py | 5 +- 5 files changed, 5 insertions(+), 277 deletions(-) delete mode 100644 trlx/ray_train/__init__.py delete mode 100644 trlx/ray_train/accelerate_trainer.py delete mode 100644 trlx/ray_train/launch.py diff --git a/setup.cfg b/setup.cfg index 0078d0095..b0b9b12f5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,9 +24,9 @@ install_requires = tabulate>=0.9.0 networkx tritonclient - ray@https://ray-ci-artifact-pr-public.s3.amazonaws.com/5cceaa7c1216fea7de79c5bd0c84c96265dd1a7a/tmp/artifacts/.whl/ray-3.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl ; python_version=="3.8" - ray@https://ray-ci-artifact-pr-public.s3.amazonaws.com/5cceaa7c1216fea7de79c5bd0c84c96265dd1a7a/tmp/artifacts/.whl/ray-3.0.0.dev0-cp39-cp39-manylinux2014_x86_64.whl ; python_version=="3.9" - ray@https://ray-ci-artifact-pr-public.s3.amazonaws.com/5cceaa7c1216fea7de79c5bd0c84c96265dd1a7a/tmp/artifacts/.whl/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl ; python_version=="3.10" + ray@https://ray-ci-artifact-branch-public.s3.amazonaws.com/42bb0357a6fb13e4994789c824f3623f32869ad8/tmp/artifacts/.whl/ray-3.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl ; python_version=="3.8" + ray@https://ray-ci-artifact-branch-public.s3.amazonaws.com/42bb0357a6fb13e4994789c824f3623f32869ad8/tmp/artifacts/.whl/ray-3.0.0.dev0-cp39-cp39-manylinux2014_x86_64.whl ; python_version=="3.9" + ray@https://ray-ci-artifact-branch-public.s3.amazonaws.com/42bb0357a6fb13e4994789c824f3623f32869ad8/tmp/artifacts/.whl/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl ; python_version=="3.10" [options.extras_require] bnb = bitsandbytes diff --git a/trlx/ray_train/__init__.py b/trlx/ray_train/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/trlx/ray_train/accelerate_trainer.py b/trlx/ray_train/accelerate_trainer.py deleted file mode 100644 index e1003a68f..000000000 --- a/trlx/ray_train/accelerate_trainer.py +++ /dev/null @@ -1,177 +0,0 @@ -import os -import tempfile -from argparse import Namespace -from functools import wraps -from pathlib import Path -from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type, Union - -from ray.air import session -from ray.air.checkpoint import Checkpoint -from ray.air.config import DatasetConfig, RunConfig, ScalingConfig -from ray.train.torch import TorchConfig -from ray.train.trainer import GenDataset - -if TYPE_CHECKING: - from ray.data.preprocessor import Preprocessor - from ray.tune.trainable import Trainable - -from accelerate.commands.config import default_config_file, load_config_from_file -from ray.train.torch import TorchTrainer, get_device - -from .launch import launch_command, launch_command_parser - - -class _AccelerateDefaultNamespace(Namespace): - @property - def parser(self): - return launch_command_parser() - - def __getattr__(self, name: str): - return self.parser.get_default(name) - - -class _AccelerateConfigWrapper: - """ - Lets Trainables know to treat this as already loaded file content instead of path. - """ - - def __init__(self, config_raw: str, deepspeed_config_raw: Optional[str] = None) -> None: - self.config_raw = config_raw - self.deepspeed_config_raw = deepspeed_config_raw - - def __bool__(self) -> bool: - return bool(self.config_raw) - - -class AccelerateTrainer(TorchTrainer): - def __init__( - self, - train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]], - *, - accelerate_config_path: Union[str, Path, os.PathLike], - train_loop_config: Optional[Dict] = None, - torch_config: Optional[TorchConfig] = None, - scaling_config: Optional[ScalingConfig] = None, - dataset_config: Optional[Dict[str, DatasetConfig]] = None, - run_config: Optional[RunConfig] = None, - datasets: Optional[Dict[str, GenDataset]] = None, - preprocessor: Optional["Preprocessor"] = None, - resume_from_checkpoint: Optional[Checkpoint] = None, - ): - self.accelerate_config_path = accelerate_config_path or default_config_file - if isinstance(self.accelerate_config_path, _AccelerateConfigWrapper): - self._accelerate_config_raw = self.accelerate_config_path.config_raw - self._deepspeed_config_file_raw = self.accelerate_config_path.deepspeed_config_raw - else: - ( - self._accelerate_config_raw, - self._deepspeed_config_file_raw, - ) = self._load_accelerate_config() - super().__init__( - train_loop_per_worker, - train_loop_config=train_loop_config, - torch_config=torch_config, - scaling_config=scaling_config, - dataset_config=dataset_config, - run_config=run_config, - datasets=datasets, - preprocessor=preprocessor, - resume_from_checkpoint=resume_from_checkpoint, - ) - - def training_loop(self) -> None: - old_train_loop_per_worker = self._train_loop_per_worker - self._train_loop_per_worker = self._wrap_train_loop_per_worker( - self._train_loop_per_worker, - self._accelerate_config_raw, - self._deepspeed_config_file_raw, - ) - try: - ret = super().training_loop() - finally: - self._train_loop_per_worker = old_train_loop_per_worker - return ret - - def as_trainable(self) -> Type["Trainable"]: - # We want to load the config when the Trainer is first instantiated, - # and share the contents with the Trainables (which may be on different) - # nodes - old_accelerate_config_path = self._param_dict["accelerate_config_path"] - self._param_dict["accelerate_config_path"] = _AccelerateConfigWrapper( - self._accelerate_config_raw, self._deepspeed_config_file_raw - ) - try: - ret = super().as_trainable() - finally: - self._param_dict["accelerate_config_path"] = old_accelerate_config_path - return ret - - def _load_accelerate_config(self) -> Tuple[str, Optional[str]]: - # We only load config to dict to obtain the deepspeed_config_file - config = load_config_from_file(self.accelerate_config_path) - deepspeed_config_file = getattr(config, "deepspeed_config_file", None) - deepspeed_config_file_raw = None - - if deepspeed_config_file: - with open(deepspeed_config_file, "r") as f: - deepspeed_config_file_raw = f.read() - - # Otherwise, we want to pass raw contents to Trainables for maximum - # compatibility. - with open(self.accelerate_config_path, "r") as f: - return f.read(), deepspeed_config_file_raw - - @classmethod - def _wrap_train_loop_per_worker( - cls, - train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]], - accelerate_config_raw: str, - deepspeed_config_file_raw: str, - ): - """Wrap around train_loop_per_worker to set necessary Accelerate env vars.""" - - @wraps(train_loop_per_worker) - def wrapped_train_loop_per_worker(*args, **kwargs): - with tempfile.TemporaryDirectory() as tempdir: - temp_config_file = os.path.join(tempdir, "default_config.yaml") - with open(temp_config_file, "w") as f: - f.write(accelerate_config_raw) - - os.environ["ACCELERATE_TORCH_DEVICE"] = str(get_device()) - - # Set by TorchBackend - master_addr = os.environ["MASTER_ADDR"] - master_port = os.environ["MASTER_PORT"] - - namespace = _AccelerateDefaultNamespace() - namespace.config_file = temp_config_file - namespace.num_processes = 1 - namespace.num_machines = session.get_world_size() - namespace.machine_rank = session.get_world_rank() - namespace.num_cpu_threads_per_process = session.get_trial_resources().bundles[-1]["CPU"] - namespace.gpu_ids = None - namespace.main_process_ip = master_addr - namespace.main_process_port = master_port - - if deepspeed_config_file_raw: - deepspeed_config_file = os.path.join(tempdir, "deepspeed_config.json") - with open(deepspeed_config_file, "w") as f: - f.write(deepspeed_config_file_raw) - namespace.deepspeed_config_file = deepspeed_config_file - - launch_command(namespace) - - os.environ["MASTER_ADDR"] = master_addr - os.environ["MASTER_PORT"] = master_port - os.environ["RANK"] = str(session.get_world_rank()) - os.environ["WORLD_RANK"] = str(session.get_world_rank()) - os.environ["CROSS_RANK"] = str(session.get_world_rank()) - os.environ["CROSS_SIZE"] = str(session.get_world_size()) - os.environ["WORLD_SIZE"] = str(session.get_world_size()) - os.environ["LOCAL_RANK"] = str(session.get_local_rank()) - os.environ["LOCAL_WORLD_SIZE"] = str(session.get_local_world_size()) - os.environ["LOCAL_SIZE"] = str(session.get_local_world_size()) - - return train_loop_per_worker(*args, **kwargs) - - return wrapped_train_loop_per_worker diff --git a/trlx/ray_train/launch.py b/trlx/ray_train/launch.py deleted file mode 100644 index 137673e70..000000000 --- a/trlx/ray_train/launch.py +++ /dev/null @@ -1,94 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2021 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os - -try: - from packaging.version import Version -except ImportError: - from distutils.version import LooseVersion as Version - -import accelerate - -if Version(accelerate.__version__) < Version("0.17.0.dev0"): - raise RuntimeError(f"AccelerateTrainer requires accelerate>=0.17.0, got {accelerate.__version__}") - -from accelerate.commands.launch import ( - ComputeEnvironment, - _validate_launch_command, - launch_command_parser, - prepare_deepspeed_cmd_env, - prepare_multi_gpu_env, - prepare_simple_launcher_cmd_env, -) -from accelerate.utils import is_deepspeed_available - -logger = logging.getLogger(__name__) - - -def simple_launcher(args): - _, current_env = prepare_simple_launcher_cmd_env(args) - - os.environ.update(current_env) - - -def multi_gpu_launcher(args): - current_env = prepare_multi_gpu_env(args) - os.environ.update(current_env) - - -def deepspeed_launcher(args): - if not is_deepspeed_available(): - raise ImportError("DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source.") - - _, current_env = prepare_deepspeed_cmd_env(args) - - os.environ.update(current_env) - - -def launch_command(args): - args, defaults, mp_from_config_flag = _validate_launch_command(args) - - # Use the proper launcher - if args.use_deepspeed and not args.cpu: - args.deepspeed_fields_from_accelerate_config = list(defaults.deepspeed_config.keys()) if defaults else [] - if mp_from_config_flag: - args.deepspeed_fields_from_accelerate_config.append("mixed_precision") - args.deepspeed_fields_from_accelerate_config = ",".join(args.deepspeed_fields_from_accelerate_config) - deepspeed_launcher(args) - elif args.use_fsdp and not args.cpu: - multi_gpu_launcher(args) - elif args.use_megatron_lm and not args.cpu: - multi_gpu_launcher(args) - elif args.multi_gpu and not args.cpu: - multi_gpu_launcher(args) - elif args.tpu and not args.cpu: - raise NotImplementedError() - elif defaults is not None and defaults.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER: - raise NotImplementedError() - else: - simple_launcher(args) - - -def main(): - parser = launch_command_parser() - args = parser.parse_args() - launch_command(args) - - -if __name__ == "__main__": - main() diff --git a/trlx/sweep.py b/trlx/sweep.py index b2ec34374..c9172dd07 100644 --- a/trlx/sweep.py +++ b/trlx/sweep.py @@ -10,10 +10,9 @@ import yaml from ray import tune from ray.air import ScalingConfig +from ray.train.huggingface.accelerate import AccelerateTrainer from ray.tune.logger import CSVLoggerCallback -from trlx.ray_train.accelerate_trainer import AccelerateTrainer - def get_param_space(config: dict): # noqa: C901 """Get the param space from the config file.""" @@ -327,7 +326,7 @@ def create_report(target_metric, column_names, entity_name, project_name, group_ AccelerateTrainer( script.main, # Mandatory arg. None means use Accelerate default path - accelerate_config_path=args.accelerate_config, + accelerate_config=args.accelerate_config, scaling_config=ScalingConfig( trainer_resources={"CPU": 0}, num_workers=args.num_gpus, From ee63dd98b832cd11148075ab5401da56db3eb304 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Thu, 30 Mar 2023 16:46:28 +0300 Subject: [PATCH 45/57] chore(sweep): explicitly pin a GPU per worker --- trlx/sweep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trlx/sweep.py b/trlx/sweep.py index c9172dd07..615cb7361 100644 --- a/trlx/sweep.py +++ b/trlx/sweep.py @@ -331,7 +331,7 @@ def create_report(target_metric, column_names, entity_name, project_name, group_ trainer_resources={"CPU": 0}, num_workers=args.num_gpus, use_gpu=True, - resources_per_worker={"CPU": args.num_cpus}, + resources_per_worker={"CPU": args.num_cpus, "GPU": 1}, ), ), param_space=param_space_train, From a677e256dce94e4c1f646011b245eab88dddb42a Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Thu, 30 Mar 2023 22:29:35 +0300 Subject: [PATCH 46/57] fix(base_trainer): remove checkpointing while under ray --- trlx/trainer/accelerate_base_trainer.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index ec8f5922a..612d1a2a9 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -508,6 +508,8 @@ def learn(self): # noqa: C901 if self.iter_count % self.config.train.eval_interval == 0: results = self.evaluate() stats.update(results) + if ray.is_initialized(): + session.report(filter_non_scalars(stats), checkpoint=checkpoint) # always save checkpoint with the greatest mean reward if self.config.train.save_best: @@ -528,14 +530,6 @@ def learn(self): # noqa: C901 logger.info(f"Saving the best state so far into {best_path}") self.save(best_path) - # Report the metrics to Ray Tune. - if ray.is_initialized(): - self.save("state") - with open("state/state.json", "w") as f: - json.dump(dict(iter_count=self.iter_count), f) - checkpoint = Checkpoint.from_directory("state") - session.report(filter_non_scalars(stats), checkpoint=checkpoint) - self.accelerator.log(stats, step=self.iter_count) desc = " | ".join(f"{k}: {v:.2f}" for k, v in stats.items() if k.startswith("loss")) From 23d2c696fc4afccf7df47fba69fddc1f3cc80a28 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Thu, 30 Mar 2023 22:41:47 +0300 Subject: [PATCH 47/57] chore(README): update sweep instructions --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 11963c314..93a62b499 100644 --- a/README.md +++ b/README.md @@ -104,7 +104,7 @@ For more usage see the [NeMo README](./trlx/models) #### Use Ray Tune to launch hyperparameter sweep ```bash -python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml examples/ppo_sentiments.py +python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml --accelerate_config configs/accelerate/ddp.yaml --num_gpus 4 examples/ppo_sentiments.py ``` #### Benchmark your trlX fork against trlX's `main` branch From c69166a26863bc472134254251eceba34e78e597 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Thu, 30 Mar 2023 22:45:25 +0300 Subject: [PATCH 48/57] feat(configs/sweeps): update with more values --- configs/sweeps/ilql_sweep.yml | 16 +++++++++++++++- configs/sweeps/ppo_sweep.yml | 29 +++++++++++++++++++++++------ 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/configs/sweeps/ilql_sweep.yml b/configs/sweeps/ilql_sweep.yml index fa23296a6..d45e4d1e5 100644 --- a/configs/sweeps/ilql_sweep.yml +++ b/configs/sweeps/ilql_sweep.yml @@ -3,7 +3,7 @@ tune_config: metric: "metrics/sentiments" search_alg: "random" scheduler: "fifo" - num_samples: 8 + num_samples: 64 # https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs optimizer.kwargs.lr: @@ -18,3 +18,17 @@ method.steps_for_target_q_sync: method.alpha: strategy: "loguniform" values: [0.001, 1.0] + +# disable checkpointing for storage sake +train.checkpoint_interval: + strategy: "choice" + values: [10000000] +train.save_best: + strategy: "choice" + values: [false] + + + + + + diff --git a/configs/sweeps/ppo_sweep.yml b/configs/sweeps/ppo_sweep.yml index 84fa627d9..fa9ec9696 100644 --- a/configs/sweeps/ppo_sweep.yml +++ b/configs/sweeps/ppo_sweep.yml @@ -3,18 +3,35 @@ tune_config: metric: "reward/mean" search_alg: "random" scheduler: "fifo" - num_samples: 8 + num_samples: 32 # https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs optimizer.kwargs.lr: strategy: "loguniform" values: [0.000001, 0.001] method.init_kl_coef: - strategy: "uniform" - values: [0, 0.2] -method.vf_coef: - strategy: "uniform" - values: [0.25, 1.5] + strategy: "loguniform" + values: [0.0001, 0.2] model.num_layers_unfrozen: strategy: "choice" values: [-1, 2, 6] +method.num_rollouts: + strategy: "choice" + values: [32, 128, 512] +method.target: + strategy: "choice" + values: [null, 1] + +# disable checkpointing for storage sake +train.checkpoint_interval: + strategy: "choice" + values: [10000000] +train.save_best: + strategy: "choice" + values: [false] + + + + + + From e14488b98c9ad695193f50aef8e63b9503e52c25 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Thu, 30 Mar 2023 22:53:59 +0300 Subject: [PATCH 49/57] style: satisfy flake --- configs/sweeps/ilql_sweep.yml | 6 ------ configs/sweeps/ppo_sweep.yml | 6 ------ 2 files changed, 12 deletions(-) diff --git a/configs/sweeps/ilql_sweep.yml b/configs/sweeps/ilql_sweep.yml index d45e4d1e5..b63319961 100644 --- a/configs/sweeps/ilql_sweep.yml +++ b/configs/sweeps/ilql_sweep.yml @@ -26,9 +26,3 @@ train.checkpoint_interval: train.save_best: strategy: "choice" values: [false] - - - - - - diff --git a/configs/sweeps/ppo_sweep.yml b/configs/sweeps/ppo_sweep.yml index fa9ec9696..3fc0b07b5 100644 --- a/configs/sweeps/ppo_sweep.yml +++ b/configs/sweeps/ppo_sweep.yml @@ -29,9 +29,3 @@ train.checkpoint_interval: train.save_best: strategy: "choice" values: [false] - - - - - - From 81fd0e79fdcabe33edf84dad0ab060fbf56be4c6 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Thu, 30 Mar 2023 23:00:30 +0300 Subject: [PATCH 50/57] style: satisfy flake --- trlx/trainer/accelerate_base_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 612d1a2a9..164b42d59 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -9,7 +9,6 @@ import torch from accelerate import Accelerator # type: ignore from ray.air import session -from ray.air.checkpoint import Checkpoint from rich.console import Console from rich.table import Table from transformers import AutoTokenizer From 55a5aa97a31a102e0423380ee545639fed77788b Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Thu, 30 Mar 2023 23:05:02 +0300 Subject: [PATCH 51/57] revert(examples): remove redundant device selection --- examples/hh/ppo_hh.py | 4 +--- examples/hh/to_triton.py | 5 +---- examples/summarize_rlhf/trlx_inference_gptj.py | 4 +--- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/examples/hh/ppo_hh.py b/examples/hh/ppo_hh.py index acfec4d8e..6c71ce94b 100644 --- a/examples/hh/ppo_hh.py +++ b/examples/hh/ppo_hh.py @@ -164,9 +164,7 @@ def forward(self, input_ids): reward_model.load_state_dict(torch.load(checkpoint)) reward_model.eval() reward_model.requires_grad_(False) - device = os.environ.get("ACCELERATE_TORCH_DEVICE", None) - if device is None: - device = torch.cuda.device_count() - 1 + device = torch.cuda.device_count() - 1 reward_model = reward_model.half().to(device) def reward_fn(samples, prompts, outputs): diff --git a/examples/hh/to_triton.py b/examples/hh/to_triton.py index c2a541517..6559e0998 100644 --- a/examples/hh/to_triton.py +++ b/examples/hh/to_triton.py @@ -24,10 +24,7 @@ args = parser.parse_args() model_name = args.checkpoint.split("/")[-1] -device = os.environ.get("ACCELERATE_TORCH_DEVICE", None) -if device is None: - device = torch.device(args.device) - +device = torch.device(args.device) class RewardModel(nn.Module): def __init__(self, checkpoint_path, eos_token_id): diff --git a/examples/summarize_rlhf/trlx_inference_gptj.py b/examples/summarize_rlhf/trlx_inference_gptj.py index 3032289e3..f5a54365a 100644 --- a/examples/summarize_rlhf/trlx_inference_gptj.py +++ b/examples/summarize_rlhf/trlx_inference_gptj.py @@ -32,9 +32,7 @@ def load_model(path): rw_model.load_state_dict(torch.load(REWARD_CHECKPOINT_PATH)) rw_model.half() rw_model.eval() -rw_device = os.environ.get("ACCELERATE_TORCH_DEVICE", None) -if rw_device is None: - rw_device = torch.device("cuda:{}".format(1)) +rw_device = torch.device("cuda:{}".format(1)) rw_model.to(rw_device) From 962dfa32f3ee3e18e48fc17a7d1e429d4eb0ed87 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Thu, 30 Mar 2023 23:13:09 +0300 Subject: [PATCH 52/57] style: satisfy black --- examples/hh/to_triton.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/hh/to_triton.py b/examples/hh/to_triton.py index 6559e0998..ad26009d3 100644 --- a/examples/hh/to_triton.py +++ b/examples/hh/to_triton.py @@ -26,6 +26,7 @@ model_name = args.checkpoint.split("/")[-1] device = torch.device(args.device) + class RewardModel(nn.Module): def __init__(self, checkpoint_path, eos_token_id): super().__init__() From fd6f9c18a368b02f3d340a1d5a59036aa78e88c0 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Fri, 31 Mar 2023 13:38:43 +0300 Subject: [PATCH 53/57] fix(base_trainer): report evaluation stats at the end --- trlx/trainer/accelerate_base_trainer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 164b42d59..27a834a9b 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -538,8 +538,13 @@ def learn(self): # noqa: C901 if self.iter_count >= self.total_steps: subfolder = f"checkpoint_{self.iter_count:0{len(str(self.total_steps))}d}" directory = os.path.join(self.config.train.checkpoint_dir, subfolder) + results = self.evaluate() + stats.update(results) + + if ray.is_initialized(): + session.report(filter_non_scalars(stats), checkpoint=checkpoint) self.save(directory) - return self.evaluate() + return results self.post_backward_callback() From 69b0bb87364d2bb3b7f91fbfd9e7905212f7af2e Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Fri, 31 Mar 2023 19:01:28 +0300 Subject: [PATCH 54/57] fix(base_trainer): log final stats for w&b --- trlx/trainer/accelerate_base_trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 27a834a9b..8151cac9d 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -529,7 +529,6 @@ def learn(self): # noqa: C901 logger.info(f"Saving the best state so far into {best_path}") self.save(best_path) - self.accelerator.log(stats, step=self.iter_count) desc = " | ".join(f"{k}: {v:.2f}" for k, v in stats.items() if k.startswith("loss")) tbar.set_description(f"[{desc}]") @@ -543,9 +542,13 @@ def learn(self): # noqa: C901 if ray.is_initialized(): session.report(filter_non_scalars(stats), checkpoint=checkpoint) + self.accelerator.log(stats, step=self.iter_count) + self.save(directory) return results + self.accelerator.log(stats, step=self.iter_count) + self.post_backward_callback() self.post_epoch_callback() From 45feae6b06dac5e47ed04a0aa293250defd80387 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Fri, 31 Mar 2023 19:03:07 +0300 Subject: [PATCH 55/57] fix(examples/sentiments): improve hyperparameters --- examples/ilql_sentiments.py | 2 +- examples/ppo_sentiments.py | 2 +- trlx/data/default_configs.py | 12 ++++++------ 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/ilql_sentiments.py b/examples/ilql_sentiments.py index d8737a87f..931010c59 100644 --- a/examples/ilql_sentiments.py +++ b/examples/ilql_sentiments.py @@ -37,7 +37,7 @@ def metric_fn(samples: List[str], **kwargs) -> Dict[str, List[float]]: trlx.train( samples=imdb["text"], rewards=imdb["label"], - eval_prompts=["I don't know much about Hungarian underground"] * 64, + eval_prompts=["I don't know much about Hungarian underground"] * 256, metric_fn=metric_fn, config=config, ) diff --git a/examples/ppo_sentiments.py b/examples/ppo_sentiments.py index 74508540d..47b764a91 100644 --- a/examples/ppo_sentiments.py +++ b/examples/ppo_sentiments.py @@ -47,7 +47,7 @@ def reward_fn(samples: List[str], **kwargs) -> List[float]: trlx.train( reward_fn=reward_fn, prompts=prompts, - eval_prompts=["I don't know much about Hungarian underground"] * 64, + eval_prompts=["I don't know much about Hungarian underground"] * 256, config=config, ) diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py index 3fa1b5c2b..b8ab22c87 100644 --- a/trlx/data/default_configs.py +++ b/trlx/data/default_configs.py @@ -27,16 +27,16 @@ def default_ppo_config(): model=ModelConfig(model_path="lvwerra/gpt2-imdb", num_layers_unfrozen=2), tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), optimizer=OptimizerConfig( - name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) + name="adamw", kwargs=dict(lr=3e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) ), - scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4)), + scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=3e-5)), method=PPOConfig( name="PPOConfig", num_rollouts=128, chunk_size=128, ppo_epochs=4, - init_kl_coef=0.05, - target=6, + init_kl_coef=0.001, + target=None, horizon=10000, gamma=1, lam=0.95, @@ -61,7 +61,7 @@ def default_ilql_config(): return TRLConfig( train=TrainConfig( seq_length=64, - batch_size=32, + batch_size=128, epochs=100, total_steps=1000, checkpoint_interval=1000, @@ -87,7 +87,7 @@ def default_ilql_config(): beta=0, steps_for_target_q_sync=5, two_qs=True, - gen_kwargs=dict(max_new_tokens=56, top_k=20, beta=4, temperature=1.0), + gen_kwargs=dict(max_new_tokens=56, top_k=20, beta=1, temperature=1.0), ), ) From ec638c3a93fb67093c27512895b13e9067f17541 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Fri, 31 Mar 2023 19:06:24 +0300 Subject: [PATCH 56/57] style: satisfy black --- trlx/trainer/accelerate_base_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 8151cac9d..251a1943d 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -529,7 +529,6 @@ def learn(self): # noqa: C901 logger.info(f"Saving the best state so far into {best_path}") self.save(best_path) - desc = " | ".join(f"{k}: {v:.2f}" for k, v in stats.items() if k.startswith("loss")) tbar.set_description(f"[{desc}]") tbar.update() From 6c7e6f9272d59d3145eba466ca480afd3852204c Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Fri, 31 Mar 2023 20:00:53 +0300 Subject: [PATCH 57/57] fix(README): add ray cluster manual creation instruction --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 93a62b499..1ede7bd3a 100644 --- a/README.md +++ b/README.md @@ -104,6 +104,7 @@ For more usage see the [NeMo README](./trlx/models) #### Use Ray Tune to launch hyperparameter sweep ```bash +ray start --head --port=6379 python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml --accelerate_config configs/accelerate/ddp.yaml --num_gpus 4 examples/ppo_sentiments.py ```