From 51081f9ae7e60d8dd4ab95c5f7a477c894768d13 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 27 Mar 2024 11:35:06 +0100 Subject: [PATCH 01/10] Add multiple features --- examples/language-modeling/run_clm.py | 6 +++--- optimum/neuron/trainers.py | 10 ++++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index a31b2456a..bedf48ec9 100755 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -466,9 +466,9 @@ def main(): # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # on a small vocab and want a smaller embedding size, remove this test. - embedding_size = model.get_input_embeddings().weight.shape[0] - if len(tokenizer) > embedding_size: - model.resize_token_embeddings(len(tokenizer)) + # embedding_size = model.get_input_embeddings().weight.shape[0] + # if len(tokenizer) > embedding_size: + # model.resize_token_embeddings(len(tokenizer)) # Preprocessing the datasets. # First we tokenize all the texts. diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index fe2c305aa..476f8c5ba 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -397,11 +397,11 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for get_pipeline_model_parallel_size, ) - if self.args.mp_plugin.should_parallelize: - dp_size = get_data_parallel_size() - pp_size = get_pipeline_model_parallel_size() - pp_rank = get_pipeline_model_parallel_rank() + dp_size = get_data_parallel_size() + pp_size = get_pipeline_model_parallel_size() + pp_rank = get_pipeline_model_parallel_rank() + if self.args.mp_plugin.should_parallelize: tr_loss_div = tr_loss / dp_size if pp_size > 1 and pp_rank == pp_size - 1: @@ -432,6 +432,8 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for if is_main_worker_for_metrics(): self.log(logs) + if pp_size > 1: + xm.rendezvous("waiting_after_log_metrics") metrics = None if self.control.should_evaluate: From a543a9cf960afb6e0e77e3dca269ae6d7bfa29e3 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 27 Mar 2024 14:59:57 +0100 Subject: [PATCH 02/10] Fix --- optimum/neuron/trainers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 476f8c5ba..e3c0cc72a 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -432,8 +432,6 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for if is_main_worker_for_metrics(): self.log(logs) - if pp_size > 1: - xm.rendezvous("waiting_after_log_metrics") metrics = None if self.control.should_evaluate: From 1f1d229f07be2b8f55556ec0f14aa00e0212a966 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 27 Mar 2024 15:42:33 +0100 Subject: [PATCH 03/10] Fix typo --- optimum/neuron/distributed/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index 3d4d6df27..b3de2f2dc 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -522,7 +522,9 @@ def compute_query_indices_for_rank( queries_indices = [torch.arange(query_group_size_per_rank) for _ in range(num_key_value_heads_per_rank)] keys_indices = torch.arange(num_key_value_heads).repeat(kv_size_multiplier) - keys_indices = torch.repeat_interleave(keys_indices, num_attention_heads_per_rank // num_key_value_heads_per_rank) + keys_indices = torch.repeat_interleave( + keys_indices, num_attention_heads_per_rank // num_key_value_heads_per_rank + ) keys_indices = torch.chunk(keys_indices, tp_size) shift_per_key = torch.arange(0, num_attention_heads, query_group_size) From e11c5def3880873a524bf8fb60a7150d0eb906c5 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 27 Mar 2024 15:52:13 +0100 Subject: [PATCH 04/10] Rename variables --- optimum/neuron/distributed/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index b3de2f2dc..3d4d6df27 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -522,9 +522,7 @@ def compute_query_indices_for_rank( queries_indices = [torch.arange(query_group_size_per_rank) for _ in range(num_key_value_heads_per_rank)] keys_indices = torch.arange(num_key_value_heads).repeat(kv_size_multiplier) - keys_indices = torch.repeat_interleave( - keys_indices, num_attention_heads_per_rank // num_key_value_heads_per_rank - ) + keys_indices = torch.repeat_interleave(keys_indices, num_attention_heads_per_rank // num_key_value_heads_per_rank) keys_indices = torch.chunk(keys_indices, tp_size) shift_per_key = torch.arange(0, num_attention_heads, query_group_size) From 818f5745609ecb7b25f19b59ad2b46b6e76cd734 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 27 Mar 2024 17:35:43 +0100 Subject: [PATCH 05/10] Disable logging when doing precompilation --- optimum/neuron/trainers.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index e3c0cc72a..09f7d22bd 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -118,6 +118,13 @@ else: IS_SAGEMAKER_MP_POST_1_10 = False + +# `neuron_parallel_compile` relies on the logs to retrieve the HLO graphs to compile. +# For some reason, the logger logs strange characters that make `neuron_parallel_compile` fail when it tries to load +# the log file to extract the graphs to compile. To avoid that, we disable logging when doing precompilation. +if is_precompilation(): + logging.logging.disable(sys.maxsize) + logger = logging.get_logger("transformers.trainer") KEEP_HF_HUB_PROGRESS_BARS = os.environ.get("KEEP_HF_HUB_PROGRESS_BARS") From 7095426fa0203cfdb39e61814f35c558820d8e3f Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 28 Mar 2024 14:42:10 +0100 Subject: [PATCH 06/10] Fix communication --- optimum/neuron/trainers.py | 20 +++++++++++--------- optimum/neuron/utils/training_utils.py | 2 -- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 09f7d22bd..83a7bd3f9 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -175,6 +175,13 @@ def __init__(self, *args, **kwargs): prepare_environment_for_neuron() super().__init__(*args, **kwargs) + # We need to specify that the world process is the main worker for metrics otherwise callbacks, such as the + # WandbCallback will not have access to the loss logs when doing PP. + self.state = TrainerState( + is_local_process_zero=self.is_local_process_zero(), + is_world_process_zero=is_main_worker_for_metrics(), + ) + # That's the case for Transformers < 4.30.0 if not hasattr(self, "is_fsdp_enabled"): self.is_fsdp_enabled = False @@ -411,18 +418,11 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for if self.args.mp_plugin.should_parallelize: tr_loss_div = tr_loss / dp_size - if pp_size > 1 and pp_rank == pp_size - 1: + if pp_size == 1 or (pp_size > 1 and pp_rank == pp_size - 1): tr_loss_div = xm.all_reduce( xm.REDUCE_SUM, tr_loss_div, groups=get_data_parallel_group(as_list=True) ) tr_loss_scalar = tr_loss_div.detach().item() - else: - tr_loss_scalar = xm.all_reduce( - xm.REDUCE_SUM, - tr_loss_div, - groups=get_data_parallel_group(as_list=True), - ) - tr_loss_scalar = tr_loss_scalar.detach().item() else: # all_gather + mean() to get average loss over all processes tr_loss_scalar = self._nested_gather(tr_loss).mean().item() @@ -875,7 +875,9 @@ def _inner_training_loop( self.state.max_steps = max_steps self.state.num_train_epochs = num_train_epochs self.state.is_local_process_zero = self.is_local_process_zero() - self.state.is_world_process_zero = self.is_world_process_zero() + # We need to specify that the world process is the main worker for metrics otherwise callbacks, such as the + # WandbCallback will not have access to the loss logs when doing PP. + self.state.is_world_process_zero = is_main_worker_for_metrics() # tr_loss is a tensor to avoid synchronization of TPUs through .item() tr_loss = torch.tensor(0.0).to(args.device) diff --git a/optimum/neuron/utils/training_utils.py b/optimum/neuron/utils/training_utils.py index b2877fab1..81cbe15e4 100644 --- a/optimum/neuron/utils/training_utils.py +++ b/optimum/neuron/utils/training_utils.py @@ -16,7 +16,6 @@ import os import re -from functools import lru_cache from typing import TYPE_CHECKING, Dict, List, Optional, Union import torch @@ -433,7 +432,6 @@ def numel(parameter_name, parameter) -> int: return param_count -@lru_cache @requires_neuronx_distributed def is_main_worker_for_metrics() -> bool: from neuronx_distributed.parallel_layers.parallel_state import ( From eeb85c53b8cbf9e7e870d4adaffd203c6f0dd86f Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 28 Mar 2024 16:52:06 +0100 Subject: [PATCH 07/10] Restore run_clm.py --- examples/language-modeling/run_clm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index bedf48ec9..a31b2456a 100755 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -466,9 +466,9 @@ def main(): # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # on a small vocab and want a smaller embedding size, remove this test. - # embedding_size = model.get_input_embeddings().weight.shape[0] - # if len(tokenizer) > embedding_size: - # model.resize_token_embeddings(len(tokenizer)) + embedding_size = model.get_input_embeddings().weight.shape[0] + if len(tokenizer) > embedding_size: + model.resize_token_embeddings(len(tokenizer)) # Preprocessing the datasets. # First we tokenize all the texts. From a0e38289e567ba48d2938af5b3254a456076d0c1 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 28 Mar 2024 16:53:48 +0100 Subject: [PATCH 08/10] Fix --- optimum/neuron/trainers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 83a7bd3f9..020086722 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -411,11 +411,11 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for get_pipeline_model_parallel_size, ) - dp_size = get_data_parallel_size() - pp_size = get_pipeline_model_parallel_size() - pp_rank = get_pipeline_model_parallel_rank() - if self.args.mp_plugin.should_parallelize: + dp_size = get_data_parallel_size() + pp_size = get_pipeline_model_parallel_size() + pp_rank = get_pipeline_model_parallel_rank() + tr_loss_div = tr_loss / dp_size if pp_size == 1 or (pp_size > 1 and pp_rank == pp_size - 1): From 84ef3bb297def1b7447e68a8da11dc528387e4bd Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 28 Mar 2024 17:17:16 +0100 Subject: [PATCH 09/10] Fix --- optimum/neuron/trainers.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 020086722..512180675 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -407,22 +407,17 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for from neuronx_distributed.parallel_layers.parallel_state import ( get_data_parallel_group, get_data_parallel_size, - get_pipeline_model_parallel_rank, - get_pipeline_model_parallel_size, ) if self.args.mp_plugin.should_parallelize: dp_size = get_data_parallel_size() - pp_size = get_pipeline_model_parallel_size() - pp_rank = get_pipeline_model_parallel_rank() tr_loss_div = tr_loss / dp_size - if pp_size == 1 or (pp_size > 1 and pp_rank == pp_size - 1): - tr_loss_div = xm.all_reduce( - xm.REDUCE_SUM, tr_loss_div, groups=get_data_parallel_group(as_list=True) - ) - tr_loss_scalar = tr_loss_div.detach().item() + # It works even for PP because under PP we make it so that the main process to log for callbacks is + # the one on dp_rank = 0, pp_rank = pp_size -1. + tr_loss_div = xm.all_reduce(xm.REDUCE_SUM, tr_loss_div, groups=get_data_parallel_group(as_list=True)) + tr_loss_scalar = tr_loss_div.detach().item() else: # all_gather + mean() to get average loss over all processes tr_loss_scalar = self._nested_gather(tr_loss).mean().item() From fcb5f6ee2994be2a1de6da652b3b4db447475c8f Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 29 Mar 2024 11:23:48 +0100 Subject: [PATCH 10/10] Update comments --- optimum/neuron/trainers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 512180675..5faac8c80 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -175,8 +175,8 @@ def __init__(self, *args, **kwargs): prepare_environment_for_neuron() super().__init__(*args, **kwargs) - # We need to specify that the world process is the main worker for metrics otherwise callbacks, such as the - # WandbCallback will not have access to the loss logs when doing PP. + # We need to change which process can be seen as "world process zero" to make sure the proper metrics + # (eg.g loss) are logged and sent to the callbacks (for instance WandbCallback). self.state = TrainerState( is_local_process_zero=self.is_local_process_zero(), is_world_process_zero=is_main_worker_for_metrics(), @@ -415,7 +415,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for tr_loss_div = tr_loss / dp_size # It works even for PP because under PP we make it so that the main process to log for callbacks is - # the one on dp_rank = 0, pp_rank = pp_size -1. + # the one on dp_rank = tp_rank = 0 and pp_rank = pp_size -1. tr_loss_div = xm.all_reduce(xm.REDUCE_SUM, tr_loss_div, groups=get_data_parallel_group(as_list=True)) tr_loss_scalar = tr_loss_div.detach().item() else: @@ -870,8 +870,8 @@ def _inner_training_loop( self.state.max_steps = max_steps self.state.num_train_epochs = num_train_epochs self.state.is_local_process_zero = self.is_local_process_zero() - # We need to specify that the world process is the main worker for metrics otherwise callbacks, such as the - # WandbCallback will not have access to the loss logs when doing PP. + # We need to change which process can be seen as "world process zero" to make sure the proper metrics + # (eg.g loss) are logged and sent to the callbacks (for instance WandbCallback). self.state.is_world_process_zero = is_main_worker_for_metrics() # tr_loss is a tensor to avoid synchronization of TPUs through .item()