From f02f2272a73b871b2feb53b30aefbcfd58d4bb1a Mon Sep 17 00:00:00 2001 From: Sid Sahai Date: Tue, 7 Feb 2023 03:26:04 -0800 Subject: [PATCH] Add gpt-neo-x support (#745) * add gpt-neo-x configs * fixes * name fix --- .../exporters/onnx/package_reference/configuration.mdx | 1 + optimum/exporters/onnx/model_configs.py | 5 +++++ optimum/exporters/tasks.py | 7 +++++++ optimum/utils/normalized_config.py | 1 + tests/exporters/exporters_utils.py | 2 ++ tests/onnxruntime/test_modeling.py | 2 ++ 6 files changed, 18 insertions(+) diff --git a/docs/source/exporters/onnx/package_reference/configuration.mdx b/docs/source/exporters/onnx/package_reference/configuration.mdx index f17b66701fa..5417ed95583 100644 --- a/docs/source/exporters/onnx/package_reference/configuration.mdx +++ b/docs/source/exporters/onnx/package_reference/configuration.mdx @@ -92,6 +92,7 @@ They specify which input generators should be used for the dummy inputs, but rem - GPT-2 - GPT-J - GPT-Neo +- GPT-NeoX - GroupVit - Hubert - IBert diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 5f107a4816e..c91447ea768 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -191,6 +191,11 @@ class GPTNeoOnnxConfig(TextDecoderOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_attention_heads="num_heads") +class GPTNeoXOnnxConfig(TextDecoderOnnxConfig): + DEFAULT_ONNX_OPSET = 13 + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + class BloomDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): def generate(self, input_name: str, framework: str = "pt"): past_key_shape = ( diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index ebd8a414e19..d994177baf8 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -383,6 +383,13 @@ class TasksManager: "sequence-classification", onnx="GPTNeoOnnxConfig", ), + "gpt-neox": supported_tasks_mapping( + "default", + "default-with-past", + "causal-lm", + "causal-lm-with-past", + onnx="GPTNeoXOnnxConfig", + ), "groupvit": supported_tasks_mapping( "default", onnx="GroupViTOnnxConfig", diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 6d8385a3613..efb5ee6589c 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -193,6 +193,7 @@ class NormalizedConfigManager: "electra": NormalizedTextConfig, "gpt2": GPT2LikeNormalizedTextConfig, "gpt_neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"), + "gpt_neox": NormalizedTextConfig, "gptj": GPT2LikeNormalizedTextConfig, "longt5": T5LikeNormalizedTextConfig, "marian": BartLikeNormalizedTextConfig, diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index c93ca2ef249..4819707746e 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -53,6 +53,7 @@ "flaubert": "hf-internal-testing/tiny-random-flaubert", "gpt2": "hf-internal-testing/tiny-random-gpt2", "gpt-neo": "hf-internal-testing/tiny-random-GPTNeoModel", + "gpt-neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", "gptj": "hf-internal-testing/tiny-random-GPTJModel", "groupvit": "hf-internal-testing/tiny-random-groupvit", "ibert": "hf-internal-testing/tiny-random-IBertModel", @@ -151,6 +152,7 @@ "flaubert": "hf-internal-testing/tiny-random-flaubert", # TODO "gpt2": "gpt2", "gpt-neo": "EleutherAI/gpt-neo-125M", + "gpt-neox": "EleutherAI/gpt-neox-20b", "gptj": "anton-l/gpt-j-tiny-random", # TODO "groupvit": "nvidia/groupvit-gcc-yfcc", "ibert": "kssteven/ibert-roberta-base", diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 4705b4d5953..b84183f01da 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -111,6 +111,7 @@ "flaubert": "hf-internal-testing/tiny-random-flaubert", "gpt2": "hf-internal-testing/tiny-random-gpt2", "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", + "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", "gptj": "hf-internal-testing/tiny-random-GPTJModel", "groupvit": "hf-internal-testing/tiny-random-groupvit", "ibert": "hf-internal-testing/tiny-random-IBertModel", @@ -1731,6 +1732,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): "codegen", "gpt2", "gpt_neo", + "gpt_neox", "gptj", ]