From cf46d438b835ba287556ca2bbca0a249a898042e Mon Sep 17 00:00:00 2001 From: dianaml0 <82468439+dianaml0@users.noreply.github.com> Date: Mon, 8 Nov 2021 19:07:43 -0500 Subject: [PATCH] update util to work for additive attention mask (#86) Co-authored-by: Diana Liskovich --- xformers/components/attention/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/xformers/components/attention/utils.py b/xformers/components/attention/utils.py index 476204bc89..9952a5a530 100644 --- a/xformers/components/attention/utils.py +++ b/xformers/components/attention/utils.py @@ -49,7 +49,10 @@ def maybe_merge_masks( if att_mask is None: att_mask = key_padding_mask # Assumption is that False means to mask. - att_mask = att_mask.logical_and(key_padding_mask) + elif att_mask.dtype == torch.bool: + att_mask = att_mask.logical_and(key_padding_mask) + else: + att_mask = att_mask.masked_fill(~key_padding_mask, float("-inf")) return att_mask