Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Init on the xla device #521

Merged
merged 5 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,7 @@ def _prepare_model_for_mp(
cpu_ids = {name: id(param) for name, param in model.named_parameters()}
tied_parameters_dict = get_tied_parameters_dict(model)
model_main_input_name = getattr(model, "main_input_name", None)
# TODO: use self.device.
model = self.state.mp_plugin.parallelize_model(model, device=None)
model = self.state.mp_plugin.parallelize_model(model, device=self.device)

if model_main_input_name is not None:
setattr(model, "main_input_name", model_main_input_name)
Expand Down
16 changes: 11 additions & 5 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,10 +394,10 @@ def _initialize_or_load_weights(
continue
else:
slices = None

new_parameter = torch.nn.Parameter(
load_tensor_for_weight(weight_info, tensor_slices=slices).to(parameter.dtype)
)
weight_data = load_tensor_for_weight(weight_info, tensor_slices=slices).to(parameter.dtype)
if device is not None:
weight_data = weight_data.to(device)
new_parameter = torch.nn.Parameter(weight_data)
elif parameter.device != torch.device("meta") and (
was_already_initialized_during_parallelization(parameter)
or not parameter_can_be_initialized(model, module, attribute_name)
Expand All @@ -407,7 +407,8 @@ def _initialize_or_load_weights(
continue
else:
# This means that there is no information about where to find the weights for this parameter.
device = torch.device("cpu") if device is None else device
# We first create the module on CPU, initialize it and then move it on device if needed.
device = torch.device("cpu")
new_parameter = torch.nn.Parameter(torch.empty_like(parameter, device=device))
modules_to_initialize[module].append(attribute_name)

Expand All @@ -418,6 +419,7 @@ def _initialize_or_load_weights(
)
tied_weights[parameter] = new_parameter
new_parameters.add(new_parameter)
gc.collect()

for mod, parameter_names in modules_to_initialize.items():
if isinstance(mod, torch.nn.Embedding):
Expand Down Expand Up @@ -500,6 +502,10 @@ def initialize(mod: GQAQKVColumnParallelLinear, proj_name: str, output_size: int
if left_uninitialized and hasattr(mod, "reset_parameters"):
initialize_torch_nn_module(mod, parameter_names)

if device is not None:
mod.to(device)
gc.collect()

@classmethod
@requires_neuronx_distributed
def _initialize_for_precompilation(
Expand Down
24 changes: 19 additions & 5 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,17 +1100,28 @@ def try_to_hf_initialize(
`model._init_weights` method. It returns the names of the parameters that were left uninitialized.

"""
cached_params_data = {name: param.data.detach().clone().to("cpu") for name, param in mod.named_parameters()}
device = torch.device("cpu")
for name in parameter_names:
param_device = getattr(mod, name).device
if param_device != torch.device("meta"):
device = param_device

mod.to("cpu")

cached_params_data = {name: param.data.detach().clone() for name, param in mod.named_parameters()}

# We initialize on cpu to have the same RNG state (mostly useful for tests).
model._init_weights(mod)

if parameter_names_mapping is None:
parameter_names_mapping = {}

reverse_parameter_names_mapping = {v: k for k, v in parameter_names_mapping.items()}

def name_in_mod(name: str):
return parameter_names_mapping.get(name, name)

dummy_mod = copy.deepcopy(mod).to("cpu")
dummy_mod = copy.deepcopy(mod)
for name in parameter_names:
getattr(dummy_mod, name_in_mod(name)).random_()
model._init_weights(dummy_mod)
Expand All @@ -1120,15 +1131,15 @@ def name_in_mod(name: str):
for param_name in parameter_names:
name = name_in_mod(param_name)
# The parameter was left unchanged.
param_on_cpu = getattr(mod, name).data.to("cpu")
if torch.all(param_on_cpu == cached_params_data[name]):
param = getattr(mod, name).data
if torch.all(param == cached_params_data[name]):
# There are two possible reasons:
# 1. The model cannot initialize the module that owns the parameter.
# 2. The parameter already had the proper value.

# We check if a dummy copy of the module, filled with random values is modified to know if the model
# can initialize the module.
dummy_param_was_changed = torch.all(getattr(dummy_mod, name).data == param_on_cpu)
dummy_param_was_changed = torch.all(getattr(dummy_mod, name).data == param)
if not dummy_param_was_changed:
left_uninitialized.append(param_name)

Expand All @@ -1138,6 +1149,9 @@ def name_in_mod(name: str):
param = getattr(mod, name)
param.data = cached_data

# We restore the module back to its original device.
mod.to(device)

return left_uninitialized


Expand Down
Loading