From 732d6833ccc1a519ce01fa2769e4735784581609 Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Mon, 8 Nov 2021 14:49:17 -0800 Subject: [PATCH] update util to work for additive attention mask --- xformers/components/attention/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/xformers/components/attention/utils.py b/xformers/components/attention/utils.py index 476204bc89..9952a5a530 100644 --- a/xformers/components/attention/utils.py +++ b/xformers/components/attention/utils.py @@ -49,7 +49,10 @@ def maybe_merge_masks( if att_mask is None: att_mask = key_padding_mask # Assumption is that False means to mask. - att_mask = att_mask.logical_and(key_padding_mask) + elif att_mask.dtype == torch.bool: + att_mask = att_mask.logical_and(key_padding_mask) + else: + att_mask = att_mask.masked_fill(~key_padding_mask, float("-inf")) return att_mask