Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for nn.Linear layers in StandardParametrizator #277

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/nanotron/scaling/parametrization.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,20 @@ def __init__(self, config: ModelArgs):
TensorParallelRowLinear: self._parametrize_row_linear,
TritonRMSNorm: self._parametrize_layer_norm,
TensorParallelEmbedding: self._parametrize_embedding,
nn.Linear: self._parametrize_nn_linear,
}

self.std = config.init_method.std
self.num_layers = config.model_config.num_hidden_layers

def _parametrize_nn_linear(self, param_name: str, module: nn.Module):
assert param_name in ["weight", "bias"]

if param_name == "weight":
init.normal_(module.weight, mean=0.0, std=self.std)
elif param_name == "bias":
module.bias.zero_()

def _parametrize_column_linear(self, param_name: str, module: nn.Module):
assert param_name in ["weight", "bias"]

Expand Down Expand Up @@ -89,6 +98,7 @@ def __init__(self, config: ModelArgs):
TensorParallelRowLinear: self._parametrize_mup_weight,
TritonRMSNorm: self._parametrize_layer_norm,
TensorParallelEmbedding: self._parametrize_embedding,
nn.Linear: self._parametrize_nn_linear,
}
self.std = 1.0

Expand All @@ -102,6 +112,15 @@ def _compute_spectral_std(std: float, fan_in: int, fan_out: int):
"""
return (std / math.sqrt(fan_in)) * min(1, math.sqrt(fan_out / fan_in))

def _parametrize_nn_linear(self, param_name: str, module: nn.Module):
assert param_name in ["weight", "bias"]

data = module.weight if param_name == "weight" else module.bias
fan_in, fan_out = init._calculate_fan_in_and_fan_out(data)

std = SpectralMupParametrizator._compute_spectral_std(std=self.std, fan_in=fan_in, fan_out=fan_out)
init.normal_(data, mean=0.0, std=std)

def _parametrize_mup_weight(self, param_name: str, module: nn.Module):
assert param_name in ["weight", "bias"]

Expand Down