From e722dbdb22a443327111a5df496d237b5b661e33 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Tue, 20 Dec 2022 12:06:51 +0100 Subject: [PATCH] add pegasus exporters --- .../exporters/onnx/package_reference/configuration.mdx | 1 + optimum/exporters/onnx/model_configs.py | 4 ++++ optimum/exporters/tasks.py | 9 +++++++++ optimum/utils/normalized_config.py | 1 + tests/exporters/exporters_utils.py | 1 + 5 files changed, 16 insertions(+) diff --git a/docs/source/exporters/onnx/package_reference/configuration.mdx b/docs/source/exporters/onnx/package_reference/configuration.mdx index 462e9560026..76721c66959 100644 --- a/docs/source/exporters/onnx/package_reference/configuration.mdx +++ b/docs/source/exporters/onnx/package_reference/configuration.mdx @@ -103,6 +103,7 @@ They specify which input generators should be used for the dummy inputs, but rem - MobileBert - MobileVit - OwlVit +- Pegasus - Perceiver - PoolFormer - ResNet diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index b6fc6b57e44..23b516349d3 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -447,6 +447,10 @@ def generate_dummy_inputs_for_validation(self, reference_model_inputs: Mapping[s return super().generate_dummy_inputs_for_validation(reference_model_inputs) +class PegasusOnnxConfig(BartOnnxConfig): + pass + + class MarianOnnxConfig(BartOnnxConfig): pass diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 0d9f54c06f6..29742e67f47 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -467,6 +467,15 @@ class TasksManager: # "zero-shot-object-detection", # onnx="OwlViTOnnxConfig", # ), + "pegasus": supported_tasks_mapping( + "default", + "default-with-past", + "causal-lm", + "causal-lm-with-past", + "seq2seq-lm", + "seq2seq-lm-with-past", + onnx="PegasusOnnxConfig", + ), "perceiver": supported_tasks_mapping( "masked-lm", "image-classification", diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 8c57f1163d2..4bf3579da9e 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -192,6 +192,7 @@ class NormalizedConfigManager: "mbart": BartLikeNormalizedTextConfig, "mt5": T5LikeNormalizedTextConfig, "m2m_100": BartLikeNormalizedTextConfig, + "pegasus": BartLikeNormalizedTextConfig, "poolformer": NormalizedVisionConfig, "resnet": NormalizedVisionConfig, "roberta": NormalizedTextConfig, diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index e0739c1fb5e..b8ff222032f 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -69,6 +69,7 @@ "mobilevit": "hf-internal-testing/tiny-random-mobilevit", "mt5": "lewtun/tiny-random-mt5", # "owlvit": "google/owlvit-base-patch32", + "pegasus": "hf-internal-testing/tiny-random-PegasusModel", "perceiver": { "hf-internal-testing/tiny-random-language_perceiver": ["masked-lm", "sequence-classification"], "hf-internal-testing/tiny-random-vision_perceiver_conv": ["image-classification"],