diff --git a/.github/workflows/test_onnxruntime.yml b/.github/workflows/test_onnxruntime.yml index 4bd41f57a6b..da0700d62b9 100644 --- a/.github/workflows/test_onnxruntime.yml +++ b/.github/workflows/test_onnxruntime.yml @@ -33,5 +33,5 @@ jobs: - name: Test with pytest working-directory: tests run: | - python -m pytest -n auto -m "not run_in_series" onnxruntime - python -m pytest -m "run_in_series" onnxruntime + pytest -n auto -m "not run_in_series" --durations=0 onnxruntime + pytest -m "run_in_series" --durations=0 onnxruntime diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index eedaf420cf9..fa39a5484af 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -566,7 +566,7 @@ def forward( ) # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly attention_mask = kwargs.get("attention_mask", None) # input_ids.new_ones(input_ids.shape) @@ -574,7 +574,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): return { "input_ids": input_ids, - "past_key_values": past, + "past_key_values": past_key_values, "use_cache": use_cache, "position_ids": None, "attention_mask": attention_mask, diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index f0744abd411..bdd2367128e 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -902,7 +902,7 @@ def forward( def prepare_inputs_for_generation( self, input_ids, - past=None, + past_key_values=None, attention_mask=None, head_mask=None, decoder_head_mask=None, @@ -914,7 +914,7 @@ def prepare_inputs_for_generation( return { "decoder_input_ids": input_ids, - "past_key_values": past, + "past_key_values": past_key_values, "encoder_outputs": encoder_outputs, "attention_mask": attention_mask, "head_mask": head_mask, @@ -1009,7 +1009,7 @@ def forward( def prepare_inputs_for_generation( self, input_ids, - past=None, + past_key_values=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, @@ -1020,7 +1020,7 @@ def prepare_inputs_for_generation( return { "decoder_input_ids": input_ids, - "past_key_values": past, + "past_key_values": past_key_values, "encoder_outputs": encoder_outputs, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, @@ -1137,7 +1137,7 @@ def forward( def prepare_inputs_for_generation( self, input_ids, - past=None, + past_key_values=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, @@ -1148,7 +1148,7 @@ def prepare_inputs_for_generation( return { "decoder_input_ids": input_ids, - "past_key_values": past, + "past_key_values": past_key_values, "encoder_outputs": encoder_outputs, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 4705b4d5953..35312d7a02c 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import gc -import json import os import shutil import subprocess import tempfile +import time import unittest from typing import Dict @@ -84,6 +84,16 @@ logger = logging.get_logger() + +class Timer(object): + def __enter__(self): + self.elapsed = time.perf_counter() + return self + + def __exit__(self, type, value, traceback): + self.elapsed = (time.perf_counter() - self.elapsed) * 1e3 + + MODEL_NAMES = { "albert": "hf-internal-testing/tiny-random-AlbertModel", "beit": "hf-internal-testing/tiny-random-BeitForImageClassification", @@ -1742,6 +1752,9 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): ORTMODEL_CLASS = ORTModelForCausalLM TASK = "causal-lm" + GENERATION_LENGTH = 100 + SPEEDUP_CACHE = 1.2 + def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: _ = ORTModelForCausalLM.from_pretrained(MODEL_NAMES["vit"], from_transformers=True) @@ -1898,7 +1911,7 @@ def test_pipeline_on_trt_execution_provider(self, test_name: str, model_arch: st gc.collect() @parameterized.expand(SUPPORTED_ARCHITECTURES) - def test_compare_with_and_without_past_key_values_model_outputs(self, model_arch): + def test_compare_with_and_without_past_key_values(self, model_arch): model_args = {"test_name": model_arch + "_False", "model_arch": model_arch, "use_cache": False} self._setup(model_args) model_args = {"test_name": model_arch + "_True", "model_arch": model_arch, "use_cache": True} @@ -1908,15 +1921,33 @@ def test_compare_with_and_without_past_key_values_model_outputs(self, model_arch tokenizer = get_preprocessor(model_id) text = "My Name is Philipp and i live" tokens = tokenizer(text, return_tensors="pt") + model_with_pkv = ORTModelForCausalLM.from_pretrained( self.onnx_model_dirs[model_arch + "_True"], use_cache=True ) - outputs_model_with_pkv = model_with_pkv.generate(**tokens) + _ = model_with_pkv.generate(**tokens) # warmup + with Timer() as with_pkv_timer: + outputs_model_with_pkv = model_with_pkv.generate( + **tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 + ) + model_without_pkv = ORTModelForCausalLM.from_pretrained( self.onnx_model_dirs[model_arch + "_False"], use_cache=False ) - outputs_model_without_pkv = model_without_pkv.generate(**tokens) + _ = model_without_pkv.generate(**tokens) # warmup + with Timer() as without_pkv_timer: + outputs_model_without_pkv = model_without_pkv.generate( + **tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 + ) + self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) + self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH) + self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH) + self.assertTrue( + without_pkv_timer.elapsed / with_pkv_timer.elapsed > self.SPEEDUP_CACHE, + f"With pkv latency: {with_pkv_timer.elapsed:.3f} ms, without pkv latency: {without_pkv_timer.elapsed:.3f} ms," + f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}", + ) @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]})) @require_torch_gpu @@ -2274,6 +2305,9 @@ class ORTModelForSeq2SeqLMIntegrationTest(ORTModelTestMixin): ORTMODEL_CLASS = ORTModelForSeq2SeqLM TASK = "seq2seq-lm" + GENERATION_LENGTH = 100 + SPEEDUP_CACHE = 1.2 + def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: _ = ORTModelForSeq2SeqLM.from_pretrained(MODEL_NAMES["bert"], from_transformers=True) @@ -2459,7 +2493,7 @@ def test_pipeline_on_trt_execution_provider(self, test_name: str, model_arch: st gc.collect() @parameterized.expand(SUPPORTED_ARCHITECTURES) - def test_compare_with_and_without_past_key_values_model_outputs(self, model_arch: str): + def test_compare_with_and_without_past_key_values(self, model_arch: str): if model_arch == "m2m_100": return # TODO: this test is failing for m2m_100 model_args = {"test_name": model_arch + "_False", "model_arch": model_arch, "use_cache": False} @@ -2474,12 +2508,30 @@ def test_compare_with_and_without_past_key_values_model_outputs(self, model_arch model_with_pkv = ORTModelForSeq2SeqLM.from_pretrained( self.onnx_model_dirs[model_arch + "_True"], use_cache=True ) - outputs_model_with_pkv = model_with_pkv.generate(**tokens) + + _ = model_with_pkv.generate(**tokens) # warmup + with Timer() as with_pkv_timer: + outputs_model_with_pkv = model_with_pkv.generate( + **tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 + ) + model_without_pkv = ORTModelForSeq2SeqLM.from_pretrained( self.onnx_model_dirs[model_arch + "_False"], use_cache=False ) - outputs_model_without_pkv = model_without_pkv.generate(**tokens) + _ = model_without_pkv.generate(**tokens) # warmup + with Timer() as without_pkv_timer: + outputs_model_without_pkv = model_without_pkv.generate( + **tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 + ) + self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) + self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH) + self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH) + self.assertTrue( + without_pkv_timer.elapsed / with_pkv_timer.elapsed > self.SPEEDUP_CACHE, + f"With pkv latency: {with_pkv_timer.elapsed:.3f} ms, without pkv latency: {without_pkv_timer.elapsed:.3f} ms," + f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}", + ) @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]})) @require_torch_gpu @@ -2548,6 +2600,9 @@ class ORTModelForSpeechSeq2SeqIntegrationTest(ORTModelTestMixin): ORTMODEL_CLASS = ORTModelForSpeechSeq2Seq TASK = "speech2seq-lm" + GENERATION_LENGTH = 100 + SPEEDUP_CACHE = 1.2 + def _generate_random_audio_data(self): np.random.seed(10) t = np.linspace(0, 5.0, int(5.0 * 22050), endpoint=False) @@ -2663,7 +2718,7 @@ def test_pipeline_on_gpu(self, test_name: str, model_arch: str, use_cache: bool) self.assertTrue(isinstance(outputs["text"], str)) @parameterized.expand(SUPPORTED_ARCHITECTURES) - def test_compare_with_and_without_past_key_values_model_outputs(self, model_arch: str): + def test_compare_with_and_without_past_key_values(self, model_arch: str): model_args = {"test_name": model_arch + "_False", "model_arch": model_arch, "use_cache": False} self._setup(model_args) model_args = {"test_name": model_arch + "_True", "model_arch": model_arch, "use_cache": True} @@ -2678,13 +2733,29 @@ def test_compare_with_and_without_past_key_values_model_outputs(self, model_arch model_with_pkv = ORTModelForSpeechSeq2Seq.from_pretrained( self.onnx_model_dirs[model_arch + "_True"], use_cache=True ) - outputs_model_with_pkv = model_with_pkv.generate(**features) + _ = model_with_pkv.generate(**features) # warpup + with Timer() as with_pkv_timer: + outputs_model_with_pkv = model_with_pkv.generate( + **features, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 + ) + model_without_pkv = ORTModelForSpeechSeq2Seq.from_pretrained( self.onnx_model_dirs[model_arch + "_False"], use_cache=False ) - outputs_model_without_pkv = model_without_pkv.generate(**features) + _ = model_without_pkv.generate(**features) # warpup + with Timer() as without_pkv_timer: + outputs_model_without_pkv = model_without_pkv.generate( + **features, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 + ) self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) + self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH) + self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH) + self.assertTrue( + without_pkv_timer.elapsed / with_pkv_timer.elapsed > self.SPEEDUP_CACHE, + f"With pkv latency: {with_pkv_timer.elapsed:.3f} ms, without pkv latency: {without_pkv_timer.elapsed:.3f} ms," + f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}", + ) @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]})) @require_torch_gpu @@ -2760,6 +2831,9 @@ class ORTModelForVision2SeqIntegrationTest(ORTModelTestMixin): TASK = "vision2seq-lm" + GENERATION_LENGTH = 100 + SPEEDUP_CACHE = 1.2 + def exclude_trocr_with_cache(params): if params[0] == "trocr" and params[1] == True: return None @@ -2905,7 +2979,7 @@ def test_pipeline_on_gpu(self, test_name: str, model_arch: str, use_cache: bool) self.assertTrue(isinstance(outputs[0]["generated_text"], str)) @parameterized.expand(SUPPORTED_ARCHITECTURES[:1]) - def test_compare_with_and_without_past_key_values_model_outputs(self, model_arch: str): + def test_compare_with_and_without_past_key_values(self, model_arch: str): model_args = {"test_name": model_arch + "_False", "model_arch": model_arch, "use_cache": False} self._setup(model_args) model_args = {"test_name": model_arch + "_True", "model_arch": model_arch, "use_cache": True} @@ -2920,13 +2994,29 @@ def test_compare_with_and_without_past_key_values_model_outputs(self, model_arch model_with_pkv = ORTModelForVision2Seq.from_pretrained( self.onnx_model_dirs[model_arch + "_True"], use_cache=True ) - outputs_model_with_pkv = model_with_pkv.generate(**features) + _ = model_with_pkv.generate(**features) # warmup + with Timer() as with_pkv_timer: + outputs_model_with_pkv = model_with_pkv.generate( + **features, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 + ) + model_without_pkv = ORTModelForVision2Seq.from_pretrained( self.onnx_model_dirs[model_arch + "_False"], use_cache=False ) - outputs_model_without_pkv = model_without_pkv.generate(**features) + _ = model_without_pkv.generate(**features) # warmup + with Timer() as without_pkv_timer: + outputs_model_without_pkv = model_without_pkv.generate( + **features, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 + ) self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) + self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH) + self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH) + self.assertTrue( + without_pkv_timer.elapsed / with_pkv_timer.elapsed > self.SPEEDUP_CACHE, + f"With pkv latency: {with_pkv_timer.elapsed:.3f} ms, without pkv latency: {without_pkv_timer.elapsed:.3f} ms," + f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}", + ) class ORTModelForCustomTasksIntegrationTest(unittest.TestCase):