From 7bfa99d6866785071474438582e922de11f2dba1 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 21 Sep 2023 16:07:53 +0200 Subject: [PATCH 1/3] Fix weight loading after lazy loading --- optimum/neuron/distributed/base.py | 104 +++++++++++++++------------- optimum/neuron/distributed/utils.py | 2 +- 2 files changed, 58 insertions(+), 48 deletions(-) diff --git a/optimum/neuron/distributed/base.py b/optimum/neuron/distributed/base.py index 9e1abba45..61cbebea1 100644 --- a/optimum/neuron/distributed/base.py +++ b/optimum/neuron/distributed/base.py @@ -149,56 +149,66 @@ def parallelize( model = cls._parallelize( model, orig_to_parallel=orig_to_parallel, device=device, parallelize_embeddings=parallelize_embeddings ) - weight_map = getattr(model, "_weight_map", {}) + weight_map = getattr(model, "_weight_map", None) + + # The model was not loaded lazily, it is already ready. + if weight_map is None: + return model + with torch.no_grad(): modules_to_initialize = [] for name, parameter in model.named_parameters(): - # This must be either a torch.nn.Embedding or a torch.nn.Linear since those are the only - # classes that we initialize on the `meta` device. - if parameter.device == torch.device("meta"): - if weight_map is None: - raise ValueError( - f"The parameter called {name} of the model is on the `meta` device and no weight map is " - "attached to the model to load the proper weights from file." - ) - split = name.rsplit(".", maxsplit=1) - module = model.get_submodule(split[0]) - attribute_name = split[1] - current_weight = getattr(module, attribute_name) - try: - weight_info = WeightInformation(weight_map[name], name, device=device) - # The weight might have been parallelized, in which case we must load the proper slice. - if getattr(current_weight, "tensor_model_parallel", False): - num_dims = current_weight.dim() - partition_dim = getattr(current_weight, "partition_dim") - tp_rank = parallel_layers.parallel_state.get_tensor_model_parallel_rank() - size_per_rank = current_weight.size(partition_dim) - slices = [ - None - if idx != partition_dim - else (size_per_rank * tp_rank, size_per_rank * (tp_rank + 1)) - for idx in range(num_dims) - ] - else: - slices = None - setattr( - module, - attribute_name, - torch.nn.Parameter(load_tensor_for_weight(weight_info, tensor_slices=slices)), - ) - except KeyError: - # This means that there is no information about where to find the weights for this parameter. - device = torch.device("cpu") if device is None else device - setattr( - module, - attribute_name, - torch.nn.Parameter(torch.empty_like(current_weight, device=device)), - ) - modules_to_initialize.append(module) - for mod in modules_to_initialize: - # This module has not pre-trained weights, it must be fine-tuned, we initialize it with the - # `reset_parameters()` method. - mod.reset_parameters() + split = name.rsplit(".", maxsplit=1) + module = model.get_submodule(split[0]) + attribute_name = split[1] + current_weight = getattr(module, attribute_name) + try: + weight_info = WeightInformation(weight_map[name], name, device=device) + except KeyError: + weight_info = None + + if weight_info is not None: + if getattr(current_weight, "tensor_model_parallel", False) and parameter.device == torch.device( + "meta" + ): + # This must either be a torch.nn.Embedding or a torch.nn.Linear that was not handled during + # parallelization since those are the only classes that we initialize on the `meta` device. + # We only load weights for the parameters that are still on the meta device because other + # parallel layers were handled during parallelization. + num_dims = current_weight.dim() + partition_dim = getattr(current_weight, "partition_dim") + tp_rank = parallel_layers.parallel_state.get_tensor_model_parallel_rank() + size_per_rank = current_weight.size(partition_dim) + slices = [ + None if idx != partition_dim else (size_per_rank * tp_rank, size_per_rank * (tp_rank + 1)) + for idx in range(num_dims) + ] + else: + slices = None + + setattr( + module, + attribute_name, + torch.nn.Parameter(load_tensor_for_weight(weight_info, tensor_slices=slices)), + ) + else: + # This means that there is no information about where to find the weights for this parameter. + device = torch.device("cpu") if device is None else device + setattr( + module, + attribute_name, + torch.nn.Parameter(torch.empty_like(current_weight, device=device)), + ) + modules_to_initialize.append(module) + for mod in modules_to_initialize: + # This module has not pre-trained weights, it must be fine-tuned, we initialize it with the + # `reset_parameters()` method. + mod.reset_parameters() + for name, mod in model.named_parameters(): + if name in weight_map: + weight_info = WeightInformation(weight_map[name], name, device=device) + tensor = load_tensor_for_weight(weight_info) + print(mod, tensor) return model @classmethod diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index 64cbe9ba0..d10480ecf 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -458,8 +458,8 @@ def gqa_key_value_slicing_when_tp_size_greater_than_num_key_value_heads( return sliced_linear_layer -@requires_torch_xla @classmethod +@requires_torch_xla def from_pretrained_for_tp( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], From 162fcbf8945dcff31ab810fccc893929a386dfd7 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 21 Sep 2023 16:31:19 +0200 Subject: [PATCH 2/3] Fix weight loading after lazy loading --- optimum/neuron/distributed/base.py | 33 ++++++++++++++++-------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/optimum/neuron/distributed/base.py b/optimum/neuron/distributed/base.py index 61cbebea1..a4e52d5b3 100644 --- a/optimum/neuron/distributed/base.py +++ b/optimum/neuron/distributed/base.py @@ -168,21 +168,24 @@ def parallelize( weight_info = None if weight_info is not None: - if getattr(current_weight, "tensor_model_parallel", False) and parameter.device == torch.device( - "meta" - ): - # This must either be a torch.nn.Embedding or a torch.nn.Linear that was not handled during - # parallelization since those are the only classes that we initialize on the `meta` device. - # We only load weights for the parameters that are still on the meta device because other - # parallel layers were handled during parallelization. - num_dims = current_weight.dim() - partition_dim = getattr(current_weight, "partition_dim") - tp_rank = parallel_layers.parallel_state.get_tensor_model_parallel_rank() - size_per_rank = current_weight.size(partition_dim) - slices = [ - None if idx != partition_dim else (size_per_rank * tp_rank, size_per_rank * (tp_rank + 1)) - for idx in range(num_dims) - ] + if getattr(current_weight, "tensor_model_parallel", False): + if parameter.device == torch.device("meta"): + # This must either be a torch.nn.Embedding or a torch.nn.Linear that was not handled during + # parallelization since those are the only classes that we initialize on the `meta` device. + num_dims = current_weight.dim() + partition_dim = getattr(current_weight, "partition_dim") + tp_rank = parallel_layers.parallel_state.get_tensor_model_parallel_rank() + size_per_rank = current_weight.size(partition_dim) + slices = [ + None + if idx != partition_dim + else (size_per_rank * tp_rank, size_per_rank * (tp_rank + 1)) + for idx in range(num_dims) + ] + else: + # The parameter is not on the `meta` device, it has been loaded from a checkpoint during + # parallelization, we can skip. + continue else: slices = None From a2d349c7c3256ff02188f02e115999c2e5d30ae3 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 21 Sep 2023 16:34:59 +0200 Subject: [PATCH 3/3] Remove print --- optimum/neuron/distributed/base.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/optimum/neuron/distributed/base.py b/optimum/neuron/distributed/base.py index a4e52d5b3..bd2759e39 100644 --- a/optimum/neuron/distributed/base.py +++ b/optimum/neuron/distributed/base.py @@ -207,11 +207,6 @@ def parallelize( # This module has not pre-trained weights, it must be fine-tuned, we initialize it with the # `reset_parameters()` method. mod.reset_parameters() - for name, mod in model.named_parameters(): - if name in weight_map: - weight_info = WeightInformation(weight_map[name], name, device=device) - tensor = load_tensor_for_weight(weight_info) - print(mod, tensor) return model @classmethod