Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix model to save in ppov2 #1776

Merged
merged 6 commits into from
Aug 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 15 additions & 25 deletions trl/trainer/ppov2_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import math
import os
import time
from collections import OrderedDict, defaultdict
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -142,6 +142,7 @@ def __init__(
if args.stop_token and args.stop_token == "eos":
args.stop_token_id = tokenizer.eos_token_id
self.model = PolicyAndValueWrapper(policy, value_model)
self.model.config = policy.config # needed for pushing to hub
self.create_optimizer_and_scheduler(
num_training_steps=args.num_total_batches
) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
Expand Down Expand Up @@ -170,7 +171,6 @@ def __init__(
self.init_hf_repo()
if self.args.should_save:
os.makedirs(self.args.output_dir, exist_ok=True)
self.backup_model = None

#########
### setup dataloader
Expand Down Expand Up @@ -213,30 +213,20 @@ def get_train_dataloader(self) -> DataLoader:
def get_eval_dataloader(self) -> DataLoader:
return self.eval_dataloader

def push_to_hub(self, **kwargs):
"""Modified from `Trainer.save_model` to only save the policy and not the value network."""
self.backup_model = self.model
self.model = self.accelerator.unwrap_model(self.model).policy # save only the policy
super().push_to_hub(**kwargs)
self.model = self.backup_model

def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
"""Modified from `Trainer.save_model` to only save the policy and not the value network."""
if not _internal_call: # `push_to_hub` already swaps out the self.model with policy
self.backup_model = self.model
self.model = self.accelerator.unwrap_model(self.model).policy # save only the policy
if output_dir is None:
output_dir = self.args.output_dir
state_dict = self.accelerator.get_state_dict(self.backup_model)
policy_state_dict = state_dict
if self.accelerator.is_main_process:
policy_state_dict = OrderedDict(
{k[len("policy.") :]: v for k, v in state_dict.items() if k.startswith("policy.")}
)
if self.args.should_save:
self._save(output_dir, state_dict=policy_state_dict)
if not _internal_call:
self.model = self.backup_model
backup_model = self.model
self.model = self.model.policy # save only the policy

if self.is_deepspeed_enabled:
backup_deepspeed = self.deepspeed
self.deepspeed = self.model

super().save_model(output_dir, _internal_call)

self.model = backup_model

if self.is_deepspeed_enabled:
self.deepspeed = backup_deepspeed

def train(self):
args = self.args
Expand Down
Loading