Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mup for the layers with AttentionLayerMup #494

Merged
merged 8 commits into from
Dec 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions graphium/nn/architectures/global_architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,12 @@ def _recursive_divide_dim(x: collections.abc.Mapping):
_recursive_divide_dim(v)
elif k in ["in_dim", "out_dim", "in_dim_edges", "out_dim_edges"]:
x[k] = round(v / divide_factor)
elif k in ["embed_dim"]:
num_heads = x.get("num_heads", 1)
x[k] = round(v / divide_factor)
assert (
x[k] % num_heads == 0
), f"embed_dim={x[k]} is not divisible by num_heads={num_heads}"
Comment on lines +1330 to +1332
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's needed since there's another assertion in AttentionLayerMup. @maciej-sypetkowski can you check if we remove that part when scaling by a factor that's not divisible by num_heads, does it still work?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's not needed


_recursive_divide_dim(kwargs["layer_kwargs"])

Expand Down
34 changes: 23 additions & 11 deletions graphium/nn/pyg_layers/gps_pyg.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
precision: str = "32",
biased_attention_key: Optional[str] = None,
attn_kwargs=None,
force_consistent_in_dim: bool = True,
droppath_rate_attn: float = 0.0,
droppath_rate_ffn: float = 0.0,
hidden_dim_scaling: float = 4.0,
Expand All @@ -79,12 +80,6 @@ def __init__(
out_dim:
Output node feature dimensions of the layer

in_dim:
Input edge feature dimensions of the layer

out_dim:
Output edge feature dimensions of the layer

in_dim_edges:
input edge-feature dimensions of the layer

Expand Down Expand Up @@ -120,6 +115,11 @@ def __init__(
attn_kwargs:
kwargs for attention layer

force_consistent_in_dim:
whether to force the `embed_dim` to be the same as the `in_dim` for the attention and mpnn.
The argument is only valid if `attn_type` is not None. If `embed_dim` is not provided,
it will be set to `in_dim` by default, so this parameter won't have an effect.

droppath_rate_attn:
stochastic depth drop rate for attention layer https://arxiv.org/abs/1603.09382

Expand Down Expand Up @@ -194,7 +194,9 @@ def __init__(
self.biased_attention_key = biased_attention_key
# Initialize the MPNN and Attention layers
self.mpnn = self._parse_mpnn_layer(mpnn_type, mpnn_kwargs)
self.attn_layer = self._parse_attn_layer(attn_type, self.biased_attention_key, attn_kwargs)
self.attn_layer = self._parse_attn_layer(
attn_type, self.biased_attention_key, attn_kwargs, force_consistent_in_dim=force_consistent_in_dim
)

self.output_scale = output_scale
self.use_edges = True if self.in_dim_edges is not None else False
Expand Down Expand Up @@ -237,8 +239,6 @@ def forward(self, batch: Batch) -> Batch:
"""
# pe, feat, edge_index, edge_feat = batch.pos_enc_feats_sign_flip, batch.feat, batch.edge_index, batch.edge_feat
feat = batch.feat
if self.use_edges:
edges_feat_in = batch.edge_feat

feat_in = feat # for first residual connection

Expand Down Expand Up @@ -309,26 +309,38 @@ def _parse_mpnn_layer(self, mpnn_type, mpnn_kwargs: Dict[str, Any]) -> Optional[
return mpnn_layer

def _parse_attn_layer(
self, attn_type, biased_attention_key: str, attn_kwargs: Dict[str, Any]
self,
attn_type,
biased_attention_key: str,
attn_kwargs: Dict[str, Any],
force_consistent_in_dim: bool = True,
) -> Optional[Module]:
"""
parse the input attention layer and check if it is valid
Parameters:
attn_type: type of the attention layer
biased_attention_key: key for the attenion bias
attn_kwargs: kwargs for the attention layer
force_consistent_in_dim: whether to force the `embed_dim` to be the same as the `in_dim`

Returns:
attn_layer: the attention layer
"""

# Set the default values for the Attention layer
if attn_kwargs is None:
attn_kwargs = {}
attn_kwargs.setdefault("embed_dim", self.in_dim)
attn_kwargs.setdefault("num_heads", 1)
attn_kwargs.setdefault("dropout", self.dropout)
attn_kwargs.setdefault("batch_first", True)
self.attn_kwargs = attn_kwargs

# Force the `embed_dim` to be the same as the `in_dim`
attn_kwargs.setdefault("embed_dim", self.in_dim)
if force_consistent_in_dim:
embed_dim = attn_kwargs["embed_dim"]
assert embed_dim == self.in_dim, f"embed_dim={embed_dim} must be equal to in_dim={self.in_dim}"

# Initialize the Attention layer
attn_layer, attn_class = None, None
if attn_type is not None:
Expand Down