Skip to content

Commit

Permalink
[Trainer] Support skip data intervals (PaddlePaddle#8989)
Browse files Browse the repository at this point in the history
* support skip data intervals

* add debug_data arg

* fix loss compute

* remove callback while skip data

* remove debug data

* add callback_handler

* remove debug_data

* fix conflict
  • Loading branch information
greycooker authored Sep 23, 2024
1 parent 353d278 commit ad14dc4
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 18 deletions.
20 changes: 18 additions & 2 deletions paddlenlp/trainer/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,17 @@
from enum import Enum
from inspect import isclass
from pathlib import Path
from typing import Any, Dict, Iterable, NewType, Optional, Tuple, Union, get_type_hints
from typing import (
Any,
Dict,
Iterable,
NewType,
Optional,
Tuple,
Union,
get_args,
get_type_hints,
)

DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)
Expand Down Expand Up @@ -129,7 +139,13 @@ def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field):
# This is the value that will get picked if we do --field_name (without value)
kwargs["const"] = True
elif isclass(origin_type) and issubclass(origin_type, list):
kwargs["type"] = field.type.__args__[0]
# supprt one dimension list and two dimension list
if hasattr(get_args(field.type)[0], "__args__"):
kwargs["type"] = field.type.__args__[0].__args__[0]
kwargs["action"] = "append"
else:
kwargs["type"] = field.type.__args__[0]

kwargs["nargs"] = "+"
if field.default_factory is not dataclasses.MISSING:
kwargs["default"] = field.default_factory()
Expand Down
93 changes: 77 additions & 16 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
PREFIX_CHECKPOINT_DIR,
EvalLoopOutput,
EvalPrediction,
IntervalStrategy,
IterableDatasetShard,
OptimizerNames,
PredictionOutput,
Expand All @@ -139,6 +140,7 @@
get_scheduler,
has_length,
set_seed,
should_skip_data,
speed_metrics,
)
from .training_args import TrainingArguments
Expand Down Expand Up @@ -287,9 +289,16 @@ def __init__(

# Seed must be set before instantiating the model when using model
set_seed(seed=self.args.seed)

self._skip_global_steps = 0 # total skip global steps
self._skip_steps_since_last_logged = 0 # skip steps since last logged
if model is None:
raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
logger.warning("Model is None.")
self.model = None
self.train_dataset = train_dataset
self.tokenizer = tokenizer
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
self.data_collator = data_collator if data_collator is not None else default_collator
return

if self.args.to_static:
model = paddle.jit.to_static(model)
Expand Down Expand Up @@ -945,6 +954,7 @@ def _inner_training_loop(
step_control = 0 # used in loop control, reset to 0 after every step
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)

step = -1
for step, inputs in enumerate(epoch_iterator):
if self.args.use_hybrid_parallel and self.args.sep_parallel_degree > 1:
inputs = split_inputs_sequence_dim(inputs)
Expand Down Expand Up @@ -981,6 +991,44 @@ def _inner_training_loop(
steps_trained_progress_bar.close()
steps_trained_progress_bar = None

if should_skip_data(self.state.global_step, self.args.skip_data_intervals):
# skip this step

if (step_control + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
steps_in_epoch <= args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch
):
# update current global step and skip step
self.state.global_step += 1
self._skip_global_steps += 1
self._skip_steps_since_last_logged += 1

self.state.epoch = epoch + (step + 1) / steps_in_epoch

if self.state.global_step == 1 and self.args.logging_first_step:
self.control.should_log = True
if (
self.args.logging_strategy == IntervalStrategy.STEPS
and self.state.global_step % self.args.logging_steps == 0
):
self.control.should_log = True

self.control.should_evaluate = False
self.control.should_save = False

# log loss and memeory usage
self._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval, inputs=inputs)
self._print_timer()
step_control = 0
else:
step_control += 1
if self.state.global_step >= self.state.max_steps:
break

self.timers and self.timers("read-data").start()
continue

if step_control % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
self.timers and self.timers("forward-backward").start()
Expand Down Expand Up @@ -1202,7 +1250,13 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
)

self._total_loss_scalar += tr_loss.item()
train_loss = self._total_loss_scalar / self.state.global_step

# In case all steps were skipped, the total loss is set to 0.
if self.state.global_step == self._skip_global_steps:
logger.info("All steps were skipped, the total loss is set to 0.")
train_loss = 0.0
else:
train_loss = self._total_loss_scalar / (self.state.global_step - self._skip_global_steps)

metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)

Expand Down Expand Up @@ -1321,15 +1375,20 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
if self.control.should_log:

logs: Dict[str, float] = {}

num_steps = self.state.global_step - self._globalstep_last_logged - self._skip_steps_since_last_logged
self._skip_steps_since_last_logged = 0
# all_gather + mean() to get average loss over all processes
avg_loss = self._nested_gather(tr_loss).mean()
tr_loss_scalar = self._get_item_from_loss(avg_loss)

# reset tr_loss to zero
tr_loss.subtract_(tr_loss)
# set loss to zero if all steps are skipped since last log
if num_steps == 0:
logs["loss"] = 0.0
else:
logs["loss"] = round(tr_loss_scalar / num_steps, 8)

logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 8)
logs["learning_rate"] = float("{0:.3e}".format(self._get_learning_rate()))
logs["global_step"] = int(self.state.global_step)
if in_auto_parallel_align_mode():
Expand All @@ -1352,7 +1411,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
total_train_batch_size = (
self.args.train_batch_size * self.args.gradient_accumulation_steps * self.args.dataset_world_size
)
num_steps = self.state.global_step - self._globalstep_last_logged

seq_length = None
model_flops = None
if getattr(self, "is_pretraining", False) and hasattr(self.model, "config"):
Expand All @@ -1362,16 +1421,18 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
except NotImplementedError:
model_flops = None

logs.update(
speed_metrics(
"interval",
self._globalstep_last_start_time,
num_samples=total_train_batch_size * num_steps,
num_steps=num_steps,
seq_length=seq_length,
model_flops=model_flops,
# Do not log speed metrics if all steps are skipped since last log.
if num_steps > 0:
logs.update(
speed_metrics(
"interval",
self._globalstep_last_start_time,
num_samples=total_train_batch_size * num_steps,
num_steps=num_steps,
seq_length=seq_length,
model_flops=model_flops,
)
)
)

self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
Expand Down Expand Up @@ -3255,7 +3316,7 @@ def _set_signature_columns_if_needed(self):
self._signature_columns += list(set(["label", "label_ids"] + self.label_names))

def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns:
if not self.args.remove_unused_columns or self.model is None:
return dataset
if self._signature_columns is None:
# Inspect model forward signature to keep only the arguments it accepts.
Expand Down
17 changes: 17 additions & 0 deletions paddlenlp/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,3 +1105,20 @@ def set_hyrbid_parallel_seed(basic_seed, dataset_rank, tp_rank, pp_rank=0):
tracker.add("global_seed", global_seed)
if "local_seed" not in tracker.states_ and local_seed not in tracker.seeds_:
tracker.add("local_seed", local_seed)


def should_skip_data(global_step, skip_data_intervals):
"""Whether to skip current step data"""

if skip_data_intervals is None:
return False
skip_flag = False
for interval in skip_data_intervals:
if len(interval) != 2 or interval[0] > interval[1] or interval[0] <= 0:
raise ValueError(f"Please check your skip interval {interval}")
start_global_step, end_global_step = interval[0], interval[1]
# start_global_step and end_global_step start from 1, while global_step start from 0
if start_global_step <= global_step + 1 <= end_global_step:
skip_flag = True
break
return skip_flag
4 changes: 4 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,10 @@ class TrainingArguments:
release_grads: Optional[bool] = field(
default=False, metadata={"help": "Whether to release gradients during training. Default is `False`."}
)
skip_data_intervals: Optional[List[List[int]]] = field(
default=None,
metadata={"help": "The intervals to skip, pass start global step and end global step at each interval"},
)

def __post_init__(self):
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
Expand Down

0 comments on commit ad14dc4

Please sign in to comment.