Skip to content

Commit

Permalink
test_custom_scale K!=Kv
Browse files Browse the repository at this point in the history
ghstack-source-id: 6c4893ac3f375b5984b4e380c74f3ef5a1433318
Pull Request resolved: fairinternal/xformers#958

__original_commit__ = fairinternal/xformers@4766126
  • Loading branch information
bottler authored and xFormers Bot committed Dec 6, 2023
1 parent af2f04b commit 1254a16
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,12 +1116,12 @@ def test_custom_scale(opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv):
device,
dtype,
_,
_,
B,
q_len,
kv_len,
_,
H,
k,
_,
Kv,
) = opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv
torch.manual_seed(q_len + kv_len + k)
if device != "cuda":
Expand All @@ -1134,7 +1134,7 @@ def test_custom_scale(opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv):
query=query, key=key, value=value, attn_bias=attn_bias, scale=scale
)
op_fw = sample_random_supported_fw(inputs, seed=q_len * k + kv_len * k)
grad_out = torch.ones_like(query)
grad_out = query.new_ones(B * H, q_len, Kv)
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
Expand Down

0 comments on commit 1254a16

Please sign in to comment.