Skip to content

Commit

Permalink
Merge pull request #501 from datamol-io/Patch-for-mup
Browse files Browse the repository at this point in the history
Update global_architectures.py
  • Loading branch information
DomInvivo authored Feb 16, 2024
2 parents 7771164 + 0e6c932 commit 48d26fd
Showing 1 changed file with 0 additions and 4 deletions.
4 changes: 0 additions & 4 deletions graphium/nn/architectures/global_architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,11 +1339,7 @@ def _recursive_divide_dim(x: collections.abc.Mapping):
elif k in ["in_dim", "out_dim", "in_dim_edges", "out_dim_edges"]:
x[k] = round(v / divide_factor)
elif k in ["embed_dim"]:
num_heads = x.get("num_heads", 1)
x[k] = round(v / divide_factor)
assert (
x[k] % num_heads == 0
), f"embed_dim={x[k]} is not divisible by num_heads={num_heads}"

_recursive_divide_dim(kwargs["layer_kwargs"])

Expand Down

0 comments on commit 48d26fd

Please sign in to comment.