diff --git a/src/MessagePassingIPA.jl b/src/MessagePassingIPA.jl index b8ffef7..482935b 100644 --- a/src/MessagePassingIPA.jl +++ b/src/MessagePassingIPA.jl @@ -176,12 +176,22 @@ function (ipa::InvariantPointAttention)( bias = ipa.map_pairs(z) # split into queries, keys and values - nodes_q, nodes_k, nodes_v = chunk(nodes, size=[c, c, c], dims=2) - points_q, points_k, points_v = chunk( - points, - size=n_heads * [n_query_points, n_query_points, n_point_values], - dims=2, - ) + # NOTE: workaround to avoid bugs associated with the chunk function + #nodes_q, nodes_k, nodes_v = chunk(nodes, size=[c, c, c], dims=2) + i = firstindex(nodes, 2) + nodes_q = nodes[:,i:i+c-1,:]; i += size(nodes_q, 2) + nodes_k = nodes[:,i:i+c-1,:]; i += size(nodes_k, 2) + nodes_v = nodes[:,i:i+c-1,:] + #points_q, points_k, points_v = chunk( + # points, + # size=n_heads * [n_query_points, n_query_points, n_point_values], + # dims=2, + #) + i = firstindex(points, 2) + points_q = points[:,i:i+n_heads*n_query_points-1,:]; i += size(points_q, 2) + points_k = points[:,i:i+n_heads*n_query_points-1,:]; i += size(points_k, 2) + points_v = points[:,i:i+n_heads*n_point_values-1,:] + points_q = reshape(points_q, 3, n_heads, :, n_residues) points_k = reshape(points_k, 3, n_heads, :, n_residues) points_v = reshape(points_v, 3, n_heads, :, n_residues)