Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add llama onnx export & onnxruntime support #975

Merged
merged 14 commits into from
Apr 17, 2023
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Supported architectures:
- LayoutLM-v3
- Levit
- LongT5
- Llama
- M2-M100
- Marian
- MBart
Expand Down
5 changes: 5 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,11 @@ class OPTOnnxConfig(TextDecoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


class LlamaOnnxConfig(TextDecoderOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


class BloomDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def generate(self, input_name: str, framework: str = "pt"):
past_key_shape = (
Expand Down
8 changes: 8 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,14 @@ class TasksManager:
"text-classification",
onnx="OPTOnnxConfig",
),
"llama": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
onnx="LlamaOnnxConfig",
),
"pegasus": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down
1 change: 0 additions & 1 deletion optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,6 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
"use_cache": use_cache,
"position_ids": None,
"attention_mask": attention_mask,
"token_type_ids": None,
}

# Copied from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
Expand Down
1 change: 1 addition & 0 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class ORTConfigManager:
"gptj": "gpt2",
# longt5 with O4 results in segmentation fault
"longt5": "bert",
"llama": "gpt2",
"marian": "bart",
"mbart": "bart",
"mt5": "bart",
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ class NormalizedConfigManager:
"gpt2": GPT2LikeNormalizedTextConfig,
"gpt_neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"),
"gpt_neox": NormalizedTextConfig,
"llama": NormalizedTextConfig,
"gptj": GPT2LikeNormalizedTextConfig,
"imagegpt": GPT2LikeNormalizedTextConfig,
"longt5": T5LikeNormalizedTextConfig,
Expand Down
2 changes: 2 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"levit": "hf-internal-testing/tiny-random-LevitModel",
"layoutlm": "hf-internal-testing/tiny-random-LayoutLMModel",
"layoutlmv3": "hf-internal-testing/tiny-random-LayoutLMv3Model",
"llama": "fxmarty/tiny-llama-fast-tokenizer",
"longt5": "fxmarty/tiny-random-working-LongT5Model",
# "longformer": "allenai/longformer-base-4096",
"m2m-100": "hf-internal-testing/tiny-random-m2m_100",
Expand Down Expand Up @@ -171,6 +172,7 @@
"levit": "facebook/levit-128S",
"layoutlm": "microsoft/layoutlm-base-uncased",
"layoutlmv3": "microsoft/layoutlmv3-base",
"llama": "decapoda-research/llama-65b-hf",
"longt5": "fxmarty/tiny-random-working-LongT5Model", # Not using google/long-t5-local-base because it takes too much time for testing.
# "longformer": "allenai/longformer-base-4096",
"m2m-100": "hf-internal-testing/tiny-random-m2m_100", # Not using facebook/m2m100_418M because it takes too much time for testing.
Expand Down
35 changes: 26 additions & 9 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1970,6 +1970,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
"gpt_neo",
"gpt_neox",
"gptj",
"llama",
]

FULL_GRID = {
Expand Down Expand Up @@ -2021,7 +2022,7 @@ def test_generate_utils(self, test_name: str, model_arch: str, use_cache: str):
model = ORTModelForCausalLM.from_pretrained(self.onnx_model_dirs[test_name])
tokenizer = get_preprocessor(model_id)
text = "This is a sample output"
tokens = tokenizer(text, return_tensors="pt")
tokens = tokenizer(text, return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None)

# General case
outputs = model.generate(**tokens)
Expand All @@ -2030,7 +2031,7 @@ def test_generate_utils(self, test_name: str, model_arch: str, use_cache: str):
self.assertTrue(len(res[0]) > len(text))

# With input ids
tokens = tokenizer(text, return_tensors="pt")
tokens = tokenizer(text, return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None)
outputs = model.generate(input_ids=tokens["input_ids"])
res = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertIsInstance(res[0], str)
Expand Down Expand Up @@ -2118,7 +2119,11 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach
set_seed(SEED)
transformers_model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = get_preprocessor(model_id)
tokens = tokenizer("This is a sample output", return_tensors="pt")
tokens = tokenizer(
"This is a sample output",
return_tensors="pt",
return_token_type_ids=False if model_arch == "llama" else None,
)
onnx_outputs = onnx_model(**tokens)

self.assertTrue("logits" in onnx_outputs)
Expand Down Expand Up @@ -2217,12 +2222,16 @@ def test_pipeline_on_trt_execution_provider(self, test_name: str, model_arch: st

# build engine for a short sequence
text = ["short"]
encoded_input = tokenizer(text, return_tensors="pt").to("cuda")
encoded_input = tokenizer(
text, return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None
).to("cuda")
_ = onnx_model(**encoded_input)

# build engine for a long sequence
text = [" a very long input just for demo purpose, this is very long" * 10]
encoded_input = tokenizer(text, return_tensors="pt").to("cuda")
encoded_input = tokenizer(
text, return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None
).to("cuda")
_ = onnx_model(**encoded_input)

pipe = pipeline("text-generation", model=onnx_model, tokenizer=tokenizer, device=0)
Expand All @@ -2235,7 +2244,11 @@ def test_pipeline_on_trt_execution_provider(self, test_name: str, model_arch: st
self.assertTrue(isinstance(outputs[0]["generated_text"], str))
self.assertTrue(len(outputs[0]["generated_text"]) > len(text))

encoded_input = tokenizer(["Replace me by any text you'd like."], return_tensors="pt").to("cuda")
encoded_input = tokenizer(
["Replace me by any text you'd like."],
return_tensors="pt",
return_token_type_ids=False if model_arch == "llama" else None,
).to("cuda")
_ = onnx_model.generate(**encoded_input)

gc.collect()
Expand All @@ -2251,7 +2264,7 @@ def test_compare_with_and_without_past_key_values(self, model_arch):
model_id = MODEL_NAMES[model_arch]
tokenizer = get_preprocessor(model_id)
text = "My Name is Philipp and i live"
tokens = tokenizer(text, return_tensors="pt")
tokens = tokenizer(text, return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None)

model_with_pkv = ORTModelForCausalLM.from_pretrained(
self.onnx_model_dirs[model_arch + "_True"], use_cache=True
Expand Down Expand Up @@ -2302,7 +2315,7 @@ def test_compare_merged_and_not_merged_models_outputs(self, test_name: str, mode
model_id = MODEL_NAMES[model_arch]
tokenizer = get_preprocessor(model_id)
text = "My Name is Philipp and i live"
tokens = tokenizer(text, return_tensors="pt")
tokens = tokenizer(text, return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None)

model_not_merged_dir = self.onnx_model_dirs[test_name + "_False"]
model_merged_dir = self.onnx_model_dirs[test_name + "_True"]
Expand Down Expand Up @@ -2372,7 +2385,11 @@ def test_compare_generation_to_io_binding(self, test_name: str, model_arch: str,
io_model = ORTModelForCausalLM.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=True).to("cuda")

tokenizer = get_preprocessor(model_id)
tokens = tokenizer("This is a sample output", return_tensors="pt").to("cuda")
tokens = tokenizer(
"This is a sample output",
return_tensors="pt",
return_token_type_ids=False if model_arch == "llama" else None,
).to("cuda")
onnx_outputs = onnx_model.generate(**tokens)
io_outputs = io_model.generate(**tokens)

Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/utils_onnxruntime_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"layoutlm": "hf-internal-testing/tiny-random-LayoutLMModel",
"layoutlmv3": "hf-internal-testing/tiny-random-LayoutLMv3Model",
"longt5": "hf-internal-testing/tiny-random-LongT5Model",
"llama": "fxmarty/tiny-llama-fast-tokenizer",
"m2m_100": "hf-internal-testing/tiny-random-m2m_100",
"marian": "sshleifer/tiny-marian-en-de", # hf-internal-testing ones are broken
"mbart": "hf-internal-testing/tiny-random-mbart",
Expand Down