Skip to content

Commit

Permalink
avoid using chunk to avoid bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
bicycle1885 committed Nov 8, 2023
1 parent 00f84e8 commit 041110d
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions src/MessagePassingIPA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,22 @@ function (ipa::InvariantPointAttention)(
bias = ipa.map_pairs(z)

# split into queries, keys and values
nodes_q, nodes_k, nodes_v = chunk(nodes, size=[c, c, c], dims=2)
points_q, points_k, points_v = chunk(
points,
size=n_heads * [n_query_points, n_query_points, n_point_values],
dims=2,
)
# NOTE: workaround to avoid bugs associated with the chunk function
#nodes_q, nodes_k, nodes_v = chunk(nodes, size=[c, c, c], dims=2)
i = firstindex(nodes, 2)
nodes_q = nodes[:,i:i+c-1,:]; i += size(nodes_q, 2)
nodes_k = nodes[:,i:i+c-1,:]; i += size(nodes_k, 2)
nodes_v = nodes[:,i:i+c-1,:]
#points_q, points_k, points_v = chunk(
# points,
# size=n_heads * [n_query_points, n_query_points, n_point_values],
# dims=2,
#)
i = firstindex(points, 2)
points_q = points[:,i:i+n_heads*n_query_points-1,:]; i += size(points_q, 2)
points_k = points[:,i:i+n_heads*n_query_points-1,:]; i += size(points_k, 2)
points_v = points[:,i:i+n_heads*n_point_values-1,:]

points_q = reshape(points_q, 3, n_heads, :, n_residues)
points_k = reshape(points_k, 3, n_heads, :, n_residues)
points_v = reshape(points_v, 3, n_heads, :, n_residues)
Expand Down

0 comments on commit 041110d

Please sign in to comment.