Skip to content

Commit

Permalink
align config between dy/st and add memory profiler for dy (PaddlePadd…
Browse files Browse the repository at this point in the history
…le#6982)

* align config between dy and st

* add memory profiler for dy

* comment tiny fix

* delete config 'vocab_size_divisible_unit' when get_model
  • Loading branch information
Wennie396 authored Sep 20, 2023
1 parent c6e0f78 commit f18dcc8
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Engine:
Model:
module: "GPTModuleAuto"
name: "GPT"
vocab_size_divisible_unit: 128
fuse_attn_qkv: True
scale_qk_by_layer_num: True
fused_softmax_with_triangular: True
Expand Down
17 changes: 15 additions & 2 deletions model_zoo/gpt-3/ppfleetx/core/engine/eager_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,11 @@ def configure_optimizers(self):
self.profiler.start()
logger.warning("Profiler is enabled, do not enable it in production.")

# Profiler_pretrain configs
self.memory_stats = configs.get("Profiler_pretrain", {}).get("memory_stats", False)
self.nvprof_start = configs.get("Profiler_pretrain", {}).get("nvprof_start", -1)
self.nvprof_end = configs.get("Profiler_pretrain", {}).get("nvprof_end", -1)

def _wrap_with_fleet(self):
if self._sharding_stage in [2, 3]:
assert self._pp_degree == 1, "sharding stage2/3 will support pipeline parallel later"
Expand Down Expand Up @@ -317,8 +322,9 @@ def _train_one_epoch(self, epoch_index, train_data_loader=None, valid_data_loade
if epoch_index == self._load_recovery["epoch"]:
if step < self._load_recovery["step"]:
continue

loss = self._fit_impl(batch)

with paddle.profiler.utils._nvprof_range(iter_id=step, start=self.nvprof_start, end=self.nvprof_end):
loss = self._fit_impl(batch)
train_losses.append(loss)

if self._lr_scheduler is not None and self._lr_scheduler_mode == "step":
Expand All @@ -341,6 +347,13 @@ def _train_one_epoch(self, epoch_index, train_data_loader=None, valid_data_loade
}
if self._amp_enable:
log_dict["loss_scale"] = self._scaler._scale.numpy()[0]
if self.memory_stats:
# convert from Byte to MB
log_dict["max_memory_allocated"] = paddle.device.cuda.max_memory_allocated() / (1024**2)
log_dict["max_memory_reserved"] = paddle.device.cuda.max_memory_reserved() / (1024**2)
log_dict["memory_allocated"] = paddle.device.cuda.memory_allocated() / (1024**2)
log_dict["memory_reserved"] = paddle.device.cuda.memory_reserved() / (1024**2)

self._module.training_step_end(log_dict)

train_step_start = get_timestamp()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ppfleetx.core.module.basic_module import BasicModule
from ppfleetx.data.tokenizers import GPTTokenizer
from ppfleetx.utils.log import logger
from ppfleetx.models.language_model.language_module import vocab_size_with_padding

from paddlenlp.transformers.gpt.tokenizer import GPTChineseTokenizer

Expand Down Expand Up @@ -124,6 +125,12 @@ def get_model(self):
tokenizer_class, pretrained_name = MODEL_CLASSES[model_name]
self.tokenizer = tokenizer_class.from_pretrained(pretrained_name)

model_setting["vocab_size"] = vocab_size_with_padding(
model_setting.get("vocab_size", self.tokenizer.vocab_size),
model_setting.pop("vocab_size_divisible_unit", 128),
self.configs.Distributed.get("mp_degree", 1),
)

l = model_setting["num_layers"]
h = model_setting["hidden_size"]
v = model_setting["vocab_size"]
Expand Down Expand Up @@ -157,6 +164,12 @@ def get_model(self):
tokenizer_class, pretrained_name = MODEL_CLASSES[model_name]
self.tokenizer = tokenizer_class.from_pretrained(pretrained_name)

model_setting["vocab_size"] = vocab_size_with_padding(
model_setting.get("vocab_size", self.tokenizer.vocab_size),
model_setting.pop("vocab_size_divisible_unit", 128),
self.configs.Distributed.get("mp_degree", 1),
)

with LazyGuard():
model = gpt.GPTForGenerationAuto(gpt.GPTModelAuto(**model_setting), self.generation_cfgs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,14 @@ def training_step_end(self, log_dict):
loss_scale_str = (
"loss_scale: %.9f," % (log_dict["loss_scale"]) if log_dict.get("loss_scale", None) is not None else ""
)
memort_str=(
", max_memory_allocated: %.1f MB, max_memory_reserved: %.1f MB, " \
"memory_allocated: %.1f MB, memory_reserved: %.1f MB" \
% (log_dict["max_memory_allocated"], log_dict["max_memory_reserved"],log_dict["memory_allocated"], log_dict["memory_reserved"]) if "max_memory_allocated" in log_dict else ""
)
logger.info(
"[train] epoch: [%d/%d], batch: [%d/%d], loss: %.9f, avg_batch_cost: %.5f sec, speed: %.2f step/s, "
"ips_total: %.0f tokens/s, ips: %.0f tokens/s, %s learning rate: %.5e, found_inf: %.0f"
"ips_total: %.0f tokens/s, ips: %.0f tokens/s, %s learning rate: %.5e, found_inf: %.0f %s"
% (
log_dict["epoch"],
log_dict["total_epoch"],
Expand All @@ -119,6 +124,7 @@ def training_step_end(self, log_dict):
loss_scale_str,
log_dict["lr"],
log_dict["found_inf"],
memort_str,
)
)

Expand Down
2 changes: 1 addition & 1 deletion model_zoo/gpt-3/ppfleetx/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def process_global_configs(config):
pp_degree = config["Distributed"]["pp_degree"]
sharding_degree = config["Distributed"]["sharding"]["sharding_degree"]

config["Global"]["enable_partial_send_recv"] = True
config["Global"]["enable_partial_send_recv"] = config["Global"]["enable_partial_send_recv"] if "enable_partial_send_recv" in config["Global"] else True
if "sequence_parallel" in config["Model"] and pp_degree > 1:
if config["Model"]["sequence_parallel"]:
config["Global"]["enable_partial_send_recv"] = False
Expand Down

0 comments on commit f18dcc8

Please sign in to comment.