From 1191f0bda1b8920f8e282698f7cac6c1b61e66ef Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 20 Mar 2024 17:30:15 +0100 Subject: [PATCH 1/4] Load parallel linears directly on device --- optimum/neuron/accelerate/accelerator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/optimum/neuron/accelerate/accelerator.py b/optimum/neuron/accelerate/accelerator.py index 06e4c3660..2330aa5f3 100644 --- a/optimum/neuron/accelerate/accelerator.py +++ b/optimum/neuron/accelerate/accelerator.py @@ -422,8 +422,7 @@ def _prepare_model_for_mp( cpu_ids = {name: id(param) for name, param in model.named_parameters()} tied_parameters_dict = get_tied_parameters_dict(model) model_main_input_name = getattr(model, "main_input_name", None) - # TODO: use self.device. - model = self.state.mp_plugin.parallelize_model(model, device=None) + model = self.state.mp_plugin.parallelize_model(model, device=self.device) if model_main_input_name is not None: setattr(model, "main_input_name", model_main_input_name) From 9c99e5114595d9f5b44a692f23cc13e1a7742eb3 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 2 Apr 2024 16:29:22 +0200 Subject: [PATCH 2/4] Fix initialization issue --- optimum/neuron/distributed/base.py | 7 ++++++- optimum/neuron/distributed/utils.py | 24 +++++++++++++++++++----- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/optimum/neuron/distributed/base.py b/optimum/neuron/distributed/base.py index 39f9a39b4..bc02a4efd 100644 --- a/optimum/neuron/distributed/base.py +++ b/optimum/neuron/distributed/base.py @@ -407,7 +407,8 @@ def _initialize_or_load_weights( continue else: # This means that there is no information about where to find the weights for this parameter. - device = torch.device("cpu") if device is None else device + # We first create the module on CPU, initialize it and then move it on device if needed. + device = torch.device("cpu") new_parameter = torch.nn.Parameter(torch.empty_like(parameter, device=device)) modules_to_initialize[module].append(attribute_name) @@ -500,6 +501,10 @@ def initialize(mod: GQAQKVColumnParallelLinear, proj_name: str, output_size: int if left_uninitialized and hasattr(mod, "reset_parameters"): initialize_torch_nn_module(mod, parameter_names) + if device is not None: + mod.to(device) + gc.collect() + @classmethod @requires_neuronx_distributed def _initialize_for_precompilation( diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index 3d4d6df27..cfef542d9 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -1100,17 +1100,28 @@ def try_to_hf_initialize( `model._init_weights` method. It returns the names of the parameters that were left uninitialized. """ - cached_params_data = {name: param.data.detach().clone().to("cpu") for name, param in mod.named_parameters()} + device = torch.device("cpu") + for name in parameter_names: + param_device = getattr(mod, name).device + if param_device != torch.device("meta"): + device = param_device + + mod.to("cpu") + + cached_params_data = {name: param.data.detach().clone() for name, param in mod.named_parameters()} + + # We initialize on cpu to have the same RNG state (mostly useful for tests). model._init_weights(mod) if parameter_names_mapping is None: parameter_names_mapping = {} + reverse_parameter_names_mapping = {v: k for k, v in parameter_names_mapping.items()} def name_in_mod(name: str): return parameter_names_mapping.get(name, name) - dummy_mod = copy.deepcopy(mod).to("cpu") + dummy_mod = copy.deepcopy(mod) for name in parameter_names: getattr(dummy_mod, name_in_mod(name)).random_() model._init_weights(dummy_mod) @@ -1120,15 +1131,15 @@ def name_in_mod(name: str): for param_name in parameter_names: name = name_in_mod(param_name) # The parameter was left unchanged. - param_on_cpu = getattr(mod, name).data.to("cpu") - if torch.all(param_on_cpu == cached_params_data[name]): + param = getattr(mod, name).data + if torch.all(param == cached_params_data[name]): # There are two possible reasons: # 1. The model cannot initialize the module that owns the parameter. # 2. The parameter already had the proper value. # We check if a dummy copy of the module, filled with random values is modified to know if the model # can initialize the module. - dummy_param_was_changed = torch.all(getattr(dummy_mod, name).data == param_on_cpu) + dummy_param_was_changed = torch.all(getattr(dummy_mod, name).data == param) if not dummy_param_was_changed: left_uninitialized.append(param_name) @@ -1138,6 +1149,9 @@ def name_in_mod(name: str): param = getattr(mod, name) param.data = cached_data + # We restore the module back to its original device. + mod.to(device) + return left_uninitialized From bc419ef5b3d0fed8a153929d48facd6ae91ada84 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 2 Apr 2024 16:47:22 +0200 Subject: [PATCH 3/4] Add moving to device --- optimum/neuron/distributed/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/optimum/neuron/distributed/base.py b/optimum/neuron/distributed/base.py index bc02a4efd..e09ba43bc 100644 --- a/optimum/neuron/distributed/base.py +++ b/optimum/neuron/distributed/base.py @@ -396,7 +396,8 @@ def _initialize_or_load_weights( slices = None new_parameter = torch.nn.Parameter( - load_tensor_for_weight(weight_info, tensor_slices=slices).to(parameter.dtype) + load_tensor_for_weight(weight_info, tensor_slices=slices).to(parameter.dtype), + device=device, ) elif parameter.device != torch.device("meta") and ( was_already_initialized_during_parallelization(parameter) From ab2654a638071dd44b140f91b1c47dc3c9f094e7 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 2 Apr 2024 16:53:50 +0200 Subject: [PATCH 4/4] Fix tiny error --- optimum/neuron/distributed/base.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/optimum/neuron/distributed/base.py b/optimum/neuron/distributed/base.py index e09ba43bc..d0c73ce4c 100644 --- a/optimum/neuron/distributed/base.py +++ b/optimum/neuron/distributed/base.py @@ -394,11 +394,10 @@ def _initialize_or_load_weights( continue else: slices = None - - new_parameter = torch.nn.Parameter( - load_tensor_for_weight(weight_info, tensor_slices=slices).to(parameter.dtype), - device=device, - ) + weight_data = load_tensor_for_weight(weight_info, tensor_slices=slices).to(parameter.dtype) + if device is not None: + weight_data = weight_data.to(device) + new_parameter = torch.nn.Parameter(weight_data) elif parameter.device != torch.device("meta") and ( was_already_initialized_during_parallelization(parameter) or not parameter_can_be_initialized(model, module, attribute_name) @@ -420,6 +419,7 @@ def _initialize_or_load_weights( ) tied_weights[parameter] = new_parameter new_parameters.add(new_parameter) + gc.collect() for mod, parameter_names in modules_to_initialize.items(): if isinstance(mod, torch.nn.Embedding):