Skip to content

Commit

Permalink
[TRL] Rename sft trainer. (#9292)
Browse files Browse the repository at this point in the history
* rename sft trainer.
  • Loading branch information
ZHUI authored Oct 25, 2024
1 parent 75f44ef commit 2bf3d7f
Show file tree
Hide file tree
Showing 10 changed files with 171 additions and 143 deletions.
2 changes: 1 addition & 1 deletion llm/alignment/dpo/run_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
preference_collate_fn,
preprocess_preference_data,
)
from paddlenlp.utils.llm_utils import get_lora_target_modules
from paddlenlp.trl.llm_utils import get_lora_target_modules
from paddlenlp.utils.log import logger

flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe]
Expand Down
2 changes: 1 addition & 1 deletion llm/alignment/ppo/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, config, model: PretrainedModel = None, tokenizer: PretrainedT
def create_predictor(trainer):
from predictor import PdArgumentParser, PredictorArgument

from paddlenlp.utils.llm_utils import get_model_max_position_embeddings
from paddlenlp.trl.llm_utils import get_model_max_position_embeddings

# create infer model
# NOTE: infer model use static name param_attr to create and cannot be
Expand Down
2 changes: 1 addition & 1 deletion llm/predict/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from paddle.distributed import fleet

from paddlenlp.trainer import PdArgumentParser
from paddlenlp.utils import llm_utils
from paddlenlp.trl import llm_utils

from .predictor import ModelArgument, PredictorArgument, create_predictor

Expand Down
2 changes: 1 addition & 1 deletion llm/predict/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
PretrainedModel,
PretrainedTokenizer,
)
from paddlenlp.utils import llm_utils
from paddlenlp.trl import llm_utils
from paddlenlp.utils.import_utils import is_paddlenlp_ops_available
from paddlenlp.utils.log import logger

Expand Down
6 changes: 3 additions & 3 deletions llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@
register_sequence_parallel_allreduce_hooks,
)
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
from paddlenlp.utils.llm_utils import (
CausalLMTrainer,
from paddlenlp.trl import SFTTrainer
from paddlenlp.trl.llm_utils import (
ZeroPaddingIterDatasetCallback,
compute_metrics,
get_lora_target_modules,
Expand Down Expand Up @@ -541,7 +541,7 @@ def compute_metrics_do_generation(eval_preds):
else:
metrics = compute_metrics

trainer = CausalLMTrainer(
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_ds,
Expand Down
2 changes: 1 addition & 1 deletion llm/tools/split_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from paddlenlp.generation import GenerationConfig
from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from paddlenlp.transformers.model_utils import load_tp_checkpoint
from paddlenlp.utils import llm_utils
from paddlenlp.trl import llm_utils


def parse_arguments():
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@

from .dpo_criterion import DPOCriterion
from .dpo_trainer import DPOTrainer
from .sft_trainer import *
from .trl_data import *
from .trl_utils import *
137 changes: 3 additions & 134 deletions paddlenlp/utils/llm_utils.py → paddlenlp/trl/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,20 @@
import math
import os
import struct
from typing import Dict, List, Optional
from typing import List, Optional

import numpy as np
import paddle
import paddle.distributed as dist
import paddle.distributed.fleet.base.topology as tp
import paddle.incubate.multiprocessing as mp
from paddle.distributed import fleet
from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler
from sklearn.metrics import accuracy_score

from paddlenlp.datasets import ZeroPaddingIterableDataset
from paddlenlp.generation import GenerationConfig
from paddlenlp.trainer import Trainer, TrainerCallback
from paddlenlp.trainer.trainer_utils import IterableDatasetShard, has_length
from paddlenlp.trainer import TrainerCallback
from paddlenlp.trainer.trainer_utils import IterableDatasetShard
from paddlenlp.transformers import (
AutoTokenizer,
ChatGLMv2Tokenizer,
Expand Down Expand Up @@ -260,136 +259,6 @@ def on_step_end(self, args, state, control, **kwargs):
state.trial_params["zero_padding_global_step"] = dataset.zero_padding_global_step


class CausalLMTrainer(Trainer):
def __init__(self, do_generation: bool, gen_args, data_args, **kwargs):
super().__init__(**kwargs)
self.do_generation = do_generation
self.gen_args = gen_args
self.data_args = data_args

def prediction_step(
self,
model,
inputs,
prediction_loss_only: bool,
ignore_keys=None,
):
if prediction_loss_only or self.args.pipeline_parallel_degree > 1:
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
elif not self.do_generation:
loss, logits, labels = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
# argmax here to avoid gather all logits, which is too memory-consuming.
# keepdim in order to maintain the same shape as logits
if isinstance(logits, (list, tuple)):
logits = logits[0]
# all gather logits when enabling tensor_parallel_output
if self.args.tensor_parallel_degree > 1 and getattr(self.args, "tensor_parallel_output", False):
hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group()
gathered_logits = []
dist.all_gather(gathered_logits, logits, group=model_parallel_group)
logits = paddle.concat(gathered_logits, axis=-1)
return (loss, logits.argmax(axis=-1, keepdim=True), labels)

loss = None

model.eval()
with paddle.no_grad():
generated_tokens = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"] if "attention_mask" in inputs else None,
position_ids=inputs["position_ids"] if "position_ids" in inputs else None,
max_length=max(self.data_args.max_length - inputs["input_ids"].shape[-1], 1),
decode_strategy="sampling",
top_k=self.gen_args.top_k,
top_p=self.gen_args.top_p,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=True,
)[0]
all_preds = []
for pred_tokens in generated_tokens:
pred_tokens = pred_tokens.numpy()
pred_tokens = pred_tokens[pred_tokens != self.tokenizer.pad_token_id].tolist()
all_preds.append(pred_tokens)
max_pred_length = max([len(x) for x in all_preds])
for index, preds in enumerate(all_preds):
all_preds[index] = preds + [-100] * (max_pred_length - len(preds))
all_preds = paddle.to_tensor(all_preds)

if "labels" in inputs:
all_labels = paddle.to_tensor(inputs["labels"])
else:
all_labels = None

return (loss, all_preds, all_labels)

def log(self, logs: Dict[str, float], **kwargs) -> None:
if "loss" in logs:
logs["ppl"] = np.exp(logs["loss"])
if "eval_loss" in logs:
logs["eval_ppl"] = np.exp(logs["eval_loss"])

super(CausalLMTrainer, self).log(logs, **kwargs)

def get_ptq_dataloader(self, ptq_ds):
if self.args.world_size <= 1:
ptq_sampler = BatchSampler(
dataset=ptq_ds,
shuffle=True,
batch_size=self.args.per_device_train_batch_size,
drop_last=self.args.dataloader_drop_last,
)
else:
ptq_sampler = DistributedBatchSampler(
self.train_dataset,
batch_size=self.args.per_device_train_batch_size,
shuffle=True,
num_replicas=self.args.dataset_world_size,
rank=self.args.dataset_rank,
drop_last=self.args.dataloader_drop_last,
)
ptq_dataloader = DataLoader(
ptq_ds,
batch_sampler=ptq_sampler,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
)
return ptq_dataloader

def ptq_loop(
self,
dataloader: DataLoader,
description: str,
max_eval_iters: Optional[int] = -1,
):
if isinstance(dataloader, paddle.io.DataLoader):
batch_size = dataloader.batch_sampler.batch_size
else:
raise ValueError("Only support for paddle.io.DataLoader")

if has_length(dataloader):
logger.info(f" Num examples = {self.num_examples(dataloader)}")
if max_eval_iters > 0:
logger.info(f" Total {description} steps = {max_eval_iters}")
else:
logger.info(f" Total {description} steps = {len(dataloader)}")
else:
logger.info(" Num examples: Unknown")
if max_eval_iters > 0:
logger.info(f" Total {description} steps = {max_eval_iters}")

logger.info(f" Pre device batch size = {batch_size}")
logger.info(f" Total Batch size = {batch_size * self.args.dataset_world_size}")
self.model.eval()
with paddle.no_grad():
for step, inputs in enumerate(dataloader):
self.prediction_step(model=self.model, inputs=inputs, prediction_loss_only=True, ignore_keys=None)
if max_eval_iters > 0 and step >= max_eval_iters - 1:
break


def get_infer_model_path(input_dir, model_prefix):
if dist.get_world_size() > 1:
local_rank = dist.get_rank()
Expand Down
158 changes: 158 additions & 0 deletions paddlenlp/trl/sft_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import Dict, Optional

import numpy as np
import paddle
import paddle.distributed as dist
from paddle.distributed import fleet
from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler

from paddlenlp.trainer import Trainer
from paddlenlp.trainer.trainer_utils import has_length
from paddlenlp.utils.log import logger

__all__ = ["SFTTrainer"]


class SFTTrainer(Trainer):
def __init__(self, do_generation: bool, gen_args, data_args, **kwargs):
super().__init__(**kwargs)
self.do_generation = do_generation
self.gen_args = gen_args
self.data_args = data_args

def prediction_step(
self,
model,
inputs,
prediction_loss_only: bool,
ignore_keys=None,
):
if prediction_loss_only or self.args.pipeline_parallel_degree > 1:
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
elif not self.do_generation:
loss, logits, labels = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
# argmax here to avoid gather all logits, which is too memory-consuming.
# keepdim in order to maintain the same shape as logits
if isinstance(logits, (list, tuple)):
logits = logits[0]
# all gather logits when enabling tensor_parallel_output
if self.args.tensor_parallel_degree > 1 and getattr(self.args, "tensor_parallel_output", False):
hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group()
gathered_logits = []
dist.all_gather(gathered_logits, logits, group=model_parallel_group)
logits = paddle.concat(gathered_logits, axis=-1)
return (loss, logits.argmax(axis=-1, keepdim=True), labels)

loss = None

model.eval()
with paddle.no_grad():
generated_tokens = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"] if "attention_mask" in inputs else None,
position_ids=inputs["position_ids"] if "position_ids" in inputs else None,
max_length=max(self.data_args.max_length - inputs["input_ids"].shape[-1], 1),
decode_strategy="sampling",
top_k=self.gen_args.top_k,
top_p=self.gen_args.top_p,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=True,
)[0]
all_preds = []
for pred_tokens in generated_tokens:
pred_tokens = pred_tokens.numpy()
pred_tokens = pred_tokens[pred_tokens != self.tokenizer.pad_token_id].tolist()
all_preds.append(pred_tokens)
max_pred_length = max([len(x) for x in all_preds])
for index, preds in enumerate(all_preds):
all_preds[index] = preds + [-100] * (max_pred_length - len(preds))
all_preds = paddle.to_tensor(all_preds)

if "labels" in inputs:
all_labels = paddle.to_tensor(inputs["labels"])
else:
all_labels = None

return (loss, all_preds, all_labels)

def log(self, logs: Dict[str, float], **kwargs) -> None:
if "loss" in logs:
logs["ppl"] = np.exp(logs["loss"])
if "eval_loss" in logs:
logs["eval_ppl"] = np.exp(logs["eval_loss"])

super(SFTTrainer, self).log(logs, **kwargs)

def get_ptq_dataloader(self, ptq_ds):
if self.args.world_size <= 1:
ptq_sampler = BatchSampler(
dataset=ptq_ds,
shuffle=True,
batch_size=self.args.per_device_train_batch_size,
drop_last=self.args.dataloader_drop_last,
)
else:
ptq_sampler = DistributedBatchSampler(
self.train_dataset,
batch_size=self.args.per_device_train_batch_size,
shuffle=True,
num_replicas=self.args.dataset_world_size,
rank=self.args.dataset_rank,
drop_last=self.args.dataloader_drop_last,
)
ptq_dataloader = DataLoader(
ptq_ds,
batch_sampler=ptq_sampler,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
)
return ptq_dataloader

def ptq_loop(
self,
dataloader: DataLoader,
description: str,
max_eval_iters: Optional[int] = -1,
):
if isinstance(dataloader, paddle.io.DataLoader):
batch_size = dataloader.batch_sampler.batch_size
else:
raise ValueError("Only support for paddle.io.DataLoader")

if has_length(dataloader):
logger.info(f" Num examples = {self.num_examples(dataloader)}")
if max_eval_iters > 0:
logger.info(f" Total {description} steps = {max_eval_iters}")
else:
logger.info(f" Total {description} steps = {len(dataloader)}")
else:
logger.info(" Num examples: Unknown")
if max_eval_iters > 0:
logger.info(f" Total {description} steps = {max_eval_iters}")

logger.info(f" Pre device batch size = {batch_size}")
logger.info(f" Total Batch size = {batch_size * self.args.dataset_world_size}")
self.model.eval()
with paddle.no_grad():
for step, inputs in enumerate(dataloader):
self.prediction_step(model=self.model, inputs=inputs, prediction_loss_only=True, ignore_keys=None)
if max_eval_iters > 0 and step >= max_eval_iters - 1:
break
Loading

0 comments on commit 2bf3d7f

Please sign in to comment.