Skip to content

Commit

Permalink
Update global_architectures.py
Browse files Browse the repository at this point in the history
Removed double check of embed_dim/num_heads, discussed in PR #494
  • Loading branch information
DomInvivo authored Dec 24, 2023
1 parent f698df4 commit b69aced
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions graphium/nn/architectures/global_architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,12 +1339,8 @@ def _recursive_divide_dim(x: collections.abc.Mapping):
elif k in ["in_dim", "out_dim", "in_dim_edges", "out_dim_edges"]:
x[k] = round(v / divide_factor)
elif k in ["embed_dim"]:
num_heads = x.get("num_heads", 1)
x[k] = round(v / divide_factor)
assert (
x[k] % num_heads == 0
), f"embed_dim={x[k]} is not divisible by num_heads={num_heads}"


_recursive_divide_dim(kwargs["layer_kwargs"])

return kwargs
Expand Down

0 comments on commit b69aced

Please sign in to comment.