From c0b9ca3031bb45aa68bfea3f727fa70c8d0ca064 Mon Sep 17 00:00:00 2001 From: Dawid Motyka Date: Mon, 23 Dec 2024 17:16:03 +0100 Subject: [PATCH 1/3] Fix calculations of steps, episodes and epochs --- trl/trainer/rloo_trainer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 23ea1ca21f..6284b7464d 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -120,12 +120,12 @@ def __init__( # calculate various batch sizes ######### if args.total_episodes is None: # allow the users to define episodes in terms of epochs. - args.total_episodes = int(args.num_train_epochs * self.train_dataset_len) + args.total_episodes = int(args.num_train_epochs * self.train_dataset_len * args.rloo_k) accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) self.accelerator = accelerator args.world_size = accelerator.num_processes args.local_batch_size = ( - args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches + args.per_device_train_batch_size * args.gradient_accumulation_steps ) args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) args.batch_size = int(args.local_batch_size * args.world_size) @@ -275,8 +275,8 @@ def repeat_generator(): # trainer state initialization self.state.global_step = 0 self.state.episode = 0 - self.state.max_steps = (args.num_total_batches * args.num_mini_batches) // 2 - self.state.num_train_epochs = args.total_episodes / self.train_dataset_len + self.state.max_steps = args.num_total_batches * args.num_mini_batches * args.num_ppo_epochs + self.state.num_train_epochs = (args.total_episodes / args.rloo_k) / self.train_dataset_len # Compute absolute values for logging, eval, and save if given as ratio if args.logging_steps is not None: if args.logging_steps < 1: @@ -480,7 +480,7 @@ def repeat_generator(): del kl, mean_kl, mean_entropy, scores self.lr_scheduler.step() - self.state.global_step += 1 + self.state.global_step += args.num_ppo_epochs * args.num_mini_batches self.control = self.callback_handler.on_step_end(args, self.state, self.control) if self.control.should_save: self._save_checkpoint(model, trial=None) From d4c2dc90ce338e0f4c32404e5556a8657fa7cfdc Mon Sep 17 00:00:00 2001 From: Dawid Motyka Date: Wed, 8 Jan 2025 20:27:19 +0100 Subject: [PATCH 2/3] fix: don't multiply steps by (num_mini_batches * num_ppo_epochs) --- trl/trainer/rloo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 6284b7464d..6fdcf28609 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -275,7 +275,7 @@ def repeat_generator(): # trainer state initialization self.state.global_step = 0 self.state.episode = 0 - self.state.max_steps = args.num_total_batches * args.num_mini_batches * args.num_ppo_epochs + self.state.max_steps = args.num_total_batches self.state.num_train_epochs = (args.total_episodes / args.rloo_k) / self.train_dataset_len # Compute absolute values for logging, eval, and save if given as ratio if args.logging_steps is not None: @@ -480,7 +480,7 @@ def repeat_generator(): del kl, mean_kl, mean_entropy, scores self.lr_scheduler.step() - self.state.global_step += args.num_ppo_epochs * args.num_mini_batches + self.state.global_step += 1 self.control = self.callback_handler.on_step_end(args, self.state, self.control) if self.control.should_save: self._save_checkpoint(model, trial=None) From 4eb48180deb47a4b5045ad1021b77cfbb93d3308 Mon Sep 17 00:00:00 2001 From: Dawid Motyka Date: Wed, 8 Jan 2025 20:59:17 +0100 Subject: [PATCH 3/3] update documentation: logged episodes are not global steps --- docs/source/rloo_trainer.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/rloo_trainer.md b/docs/source/rloo_trainer.md index 127f297321..b3b4663416 100644 --- a/docs/source/rloo_trainer.md +++ b/docs/source/rloo_trainer.md @@ -52,7 +52,7 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an * `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes. * `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses. * `lr`: lr: The current learning rate used by the optimizer. -* `episode`: episode: The current global step or episode count in the training process. +* `episode`: episode: The current episode count in the training process. ## Cookbook