diff --git a/nemo/collections/nlp/modules/common/text_generation_utils.py b/nemo/collections/nlp/modules/common/text_generation_utils.py index 1bd5b618de35..d4b67d3e3783 100644 --- a/nemo/collections/nlp/modules/common/text_generation_utils.py +++ b/nemo/collections/nlp/modules/common/text_generation_utils.py @@ -794,6 +794,9 @@ def generate( if random_seed is not None: seed_everything(random_seed) + if hasattr(model, 'get_attention_mask_from_fusion') and model.get_attention_mask_from_fusion: + compute_attention_mask = False + output = synced_generate( model, inference_strategy,