Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weight loading after lazy loading fix #238

Merged
merged 3 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 44 additions & 36 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,26 +149,29 @@ def parallelize(
model = cls._parallelize(
model, orig_to_parallel=orig_to_parallel, device=device, parallelize_embeddings=parallelize_embeddings
)
weight_map = getattr(model, "_weight_map", {})
weight_map = getattr(model, "_weight_map", None)

# The model was not loaded lazily, it is already ready.
if weight_map is None:
return model

with torch.no_grad():
modules_to_initialize = []
for name, parameter in model.named_parameters():
# This must be either a torch.nn.Embedding or a torch.nn.Linear since those are the only
# classes that we initialize on the `meta` device.
if parameter.device == torch.device("meta"):
if weight_map is None:
raise ValueError(
f"The parameter called {name} of the model is on the `meta` device and no weight map is "
"attached to the model to load the proper weights from file."
)
split = name.rsplit(".", maxsplit=1)
module = model.get_submodule(split[0])
attribute_name = split[1]
current_weight = getattr(module, attribute_name)
try:
weight_info = WeightInformation(weight_map[name], name, device=device)
# The weight might have been parallelized, in which case we must load the proper slice.
if getattr(current_weight, "tensor_model_parallel", False):
split = name.rsplit(".", maxsplit=1)
module = model.get_submodule(split[0])
attribute_name = split[1]
current_weight = getattr(module, attribute_name)
try:
weight_info = WeightInformation(weight_map[name], name, device=device)
except KeyError:
weight_info = None

if weight_info is not None:
if getattr(current_weight, "tensor_model_parallel", False):
if parameter.device == torch.device("meta"):
# This must either be a torch.nn.Embedding or a torch.nn.Linear that was not handled during
# parallelization since those are the only classes that we initialize on the `meta` device.
num_dims = current_weight.dim()
partition_dim = getattr(current_weight, "partition_dim")
tp_rank = parallel_layers.parallel_state.get_tensor_model_parallel_rank()
Expand All @@ -180,25 +183,30 @@ def parallelize(
for idx in range(num_dims)
]
else:
slices = None
setattr(
module,
attribute_name,
torch.nn.Parameter(load_tensor_for_weight(weight_info, tensor_slices=slices)),
)
except KeyError:
# This means that there is no information about where to find the weights for this parameter.
device = torch.device("cpu") if device is None else device
setattr(
module,
attribute_name,
torch.nn.Parameter(torch.empty_like(current_weight, device=device)),
)
modules_to_initialize.append(module)
for mod in modules_to_initialize:
# This module has not pre-trained weights, it must be fine-tuned, we initialize it with the
# `reset_parameters()` method.
mod.reset_parameters()
# The parameter is not on the `meta` device, it has been loaded from a checkpoint during
# parallelization, we can skip.
continue
else:
slices = None

setattr(
module,
attribute_name,
torch.nn.Parameter(load_tensor_for_weight(weight_info, tensor_slices=slices)),
)
else:
# This means that there is no information about where to find the weights for this parameter.
device = torch.device("cpu") if device is None else device
setattr(
module,
attribute_name,
torch.nn.Parameter(torch.empty_like(current_weight, device=device)),
)
modules_to_initialize.append(module)
for mod in modules_to_initialize:
# This module has not pre-trained weights, it must be fine-tuned, we initialize it with the
# `reset_parameters()` method.
mod.reset_parameters()
return model

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,8 +458,8 @@ def gqa_key_value_slicing_when_tp_size_greater_than_num_key_value_heads(
return sliced_linear_layer


@requires_torch_xla
@classmethod
@requires_torch_xla
def from_pretrained_for_tp(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
Expand Down
Loading