From 2ca73435745c3308f187a7f2294ffb736882038a Mon Sep 17 00:00:00 2001 From: Cheng-Ping Hsieh Date: Fri, 12 Jul 2024 13:49:17 -0700 Subject: [PATCH 1/2] Remove mask if use fusion mask Signed-off-by: Cheng-Ping Hsieh --- nemo/collections/nlp/modules/common/text_generation_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nemo/collections/nlp/modules/common/text_generation_utils.py b/nemo/collections/nlp/modules/common/text_generation_utils.py index 1bd5b618de35..148f6d5783fc 100644 --- a/nemo/collections/nlp/modules/common/text_generation_utils.py +++ b/nemo/collections/nlp/modules/common/text_generation_utils.py @@ -793,6 +793,9 @@ def generate( if random_seed is not None: seed_everything(random_seed) + + if hasattr(model, 'get_attention_mask_from_fusion') and model.get_attention_mask_from_fusion: + compute_attention_mask = False output = synced_generate( model, From cf3b35820988d76f72ba0521e6793ec1e090d4e5 Mon Sep 17 00:00:00 2001 From: hsiehjackson Date: Fri, 12 Jul 2024 20:53:17 +0000 Subject: [PATCH 2/2] Apply isort and black reformatting Signed-off-by: hsiehjackson --- nemo/collections/nlp/modules/common/text_generation_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/modules/common/text_generation_utils.py b/nemo/collections/nlp/modules/common/text_generation_utils.py index 148f6d5783fc..d4b67d3e3783 100644 --- a/nemo/collections/nlp/modules/common/text_generation_utils.py +++ b/nemo/collections/nlp/modules/common/text_generation_utils.py @@ -793,7 +793,7 @@ def generate( if random_seed is not None: seed_everything(random_seed) - + if hasattr(model, 'get_attention_mask_from_fusion') and model.get_attention_mask_from_fusion: compute_attention_mask = False