Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable logging during precompilation #539

Merged
merged 10 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 21 additions & 17 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@
else:
IS_SAGEMAKER_MP_POST_1_10 = False


# `neuron_parallel_compile` relies on the logs to retrieve the HLO graphs to compile.
# For some reason, the logger logs strange characters that make `neuron_parallel_compile` fail when it tries to load
# the log file to extract the graphs to compile. To avoid that, we disable logging when doing precompilation.
if is_precompilation():
logging.logging.disable(sys.maxsize)

logger = logging.get_logger("transformers.trainer")

KEEP_HF_HUB_PROGRESS_BARS = os.environ.get("KEEP_HF_HUB_PROGRESS_BARS")
Expand Down Expand Up @@ -168,6 +175,13 @@ def __init__(self, *args, **kwargs):
prepare_environment_for_neuron()
super().__init__(*args, **kwargs)

# We need to change which process can be seen as "world process zero" to make sure the proper metrics
# (eg.g loss) are logged and sent to the callbacks (for instance WandbCallback).
self.state = TrainerState(
is_local_process_zero=self.is_local_process_zero(),
is_world_process_zero=is_main_worker_for_metrics(),
)

# That's the case for Transformers < 4.30.0
if not hasattr(self, "is_fsdp_enabled"):
self.is_fsdp_enabled = False
Expand Down Expand Up @@ -393,29 +407,17 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for
from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_group,
get_data_parallel_size,
get_pipeline_model_parallel_rank,
get_pipeline_model_parallel_size,
)

if self.args.mp_plugin.should_parallelize:
dp_size = get_data_parallel_size()
pp_size = get_pipeline_model_parallel_size()
pp_rank = get_pipeline_model_parallel_rank()

tr_loss_div = tr_loss / dp_size

if pp_size > 1 and pp_rank == pp_size - 1:
tr_loss_div = xm.all_reduce(
xm.REDUCE_SUM, tr_loss_div, groups=get_data_parallel_group(as_list=True)
)
tr_loss_scalar = tr_loss_div.detach().item()
else:
tr_loss_scalar = xm.all_reduce(
xm.REDUCE_SUM,
tr_loss_div,
groups=get_data_parallel_group(as_list=True),
)
tr_loss_scalar = tr_loss_scalar.detach().item()
# It works even for PP because under PP we make it so that the main process to log for callbacks is
# the one on dp_rank = tp_rank = 0 and pp_rank = pp_size -1.
tr_loss_div = xm.all_reduce(xm.REDUCE_SUM, tr_loss_div, groups=get_data_parallel_group(as_list=True))
tr_loss_scalar = tr_loss_div.detach().item()
else:
# all_gather + mean() to get average loss over all processes
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
Expand Down Expand Up @@ -868,7 +870,9 @@ def _inner_training_loop(
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero()
# We need to change which process can be seen as "world process zero" to make sure the proper metrics
# (eg.g loss) are logged and sent to the callbacks (for instance WandbCallback).
self.state.is_world_process_zero = is_main_worker_for_metrics()

# tr_loss is a tensor to avoid synchronization of TPUs through .item()
tr_loss = torch.tensor(0.0).to(args.device)
Expand Down
2 changes: 0 additions & 2 deletions optimum/neuron/utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import os
import re
from functools import lru_cache
from typing import TYPE_CHECKING, Dict, List, Optional, Union

import torch
Expand Down Expand Up @@ -433,7 +432,6 @@ def numel(parameter_name, parameter) -> int:
return param_count


@lru_cache
@requires_neuronx_distributed
def is_main_worker_for_metrics() -> bool:
from neuronx_distributed.parallel_layers.parallel_state import (
Expand Down
Loading