Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ONNX Runtime cache usage for decoders, add relevant tests #756

Merged
merged 3 commits into from
Feb 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/test_onnxruntime.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@ jobs:
- name: Test with pytest
working-directory: tests
run: |
python -m pytest -n auto -m "not run_in_series" onnxruntime
python -m pytest -m "run_in_series" onnxruntime
pytest -n auto -m "not run_in_series" --durations=0 onnxruntime
pytest -m "run_in_series" --durations=0 onnxruntime
4 changes: 2 additions & 2 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,15 +566,15 @@ def forward(
)

# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly

attention_mask = kwargs.get("attention_mask", None) # input_ids.new_ones(input_ids.shape)
use_cache = kwargs.get("use_cache", None)

return {
"input_ids": input_ids,
"past_key_values": past,
"past_key_values": past_key_values,
"use_cache": use_cache,
"position_ids": None,
"attention_mask": attention_mask,
Expand Down
12 changes: 6 additions & 6 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,7 @@ def forward(
def prepare_inputs_for_generation(
self,
input_ids,
past=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
Expand All @@ -914,7 +914,7 @@ def prepare_inputs_for_generation(

return {
"decoder_input_ids": input_ids,
"past_key_values": past,
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"head_mask": head_mask,
Expand Down Expand Up @@ -1009,7 +1009,7 @@ def forward(
def prepare_inputs_for_generation(
self,
input_ids,
past=None,
past_key_values=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
Expand All @@ -1020,7 +1020,7 @@ def prepare_inputs_for_generation(

return {
"decoder_input_ids": input_ids,
"past_key_values": past,
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
Expand Down Expand Up @@ -1137,7 +1137,7 @@ def forward(
def prepare_inputs_for_generation(
self,
input_ids,
past=None,
past_key_values=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
Expand All @@ -1148,7 +1148,7 @@ def prepare_inputs_for_generation(

return {
"decoder_input_ids": input_ids,
"past_key_values": past,
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
Expand Down
116 changes: 103 additions & 13 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import json
import os
import shutil
import subprocess
import tempfile
import time
import unittest
from typing import Dict

Expand Down Expand Up @@ -84,6 +84,16 @@

logger = logging.get_logger()


class Timer(object):
def __enter__(self):
self.elapsed = time.perf_counter()
return self

def __exit__(self, type, value, traceback):
self.elapsed = (time.perf_counter() - self.elapsed) * 1e3


MODEL_NAMES = {
"albert": "hf-internal-testing/tiny-random-AlbertModel",
"beit": "hf-internal-testing/tiny-random-BeitForImageClassification",
Expand Down Expand Up @@ -1742,6 +1752,9 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
ORTMODEL_CLASS = ORTModelForCausalLM
TASK = "causal-lm"

GENERATION_LENGTH = 100
SPEEDUP_CACHE = 1.2

def test_load_vanilla_transformers_which_is_not_supported(self):
with self.assertRaises(Exception) as context:
_ = ORTModelForCausalLM.from_pretrained(MODEL_NAMES["vit"], from_transformers=True)
Expand Down Expand Up @@ -1898,7 +1911,7 @@ def test_pipeline_on_trt_execution_provider(self, test_name: str, model_arch: st
gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_with_and_without_past_key_values_model_outputs(self, model_arch):
def test_compare_with_and_without_past_key_values(self, model_arch):
model_args = {"test_name": model_arch + "_False", "model_arch": model_arch, "use_cache": False}
self._setup(model_args)
model_args = {"test_name": model_arch + "_True", "model_arch": model_arch, "use_cache": True}
Expand All @@ -1908,15 +1921,33 @@ def test_compare_with_and_without_past_key_values_model_outputs(self, model_arch
tokenizer = get_preprocessor(model_id)
text = "My Name is Philipp and i live"
tokens = tokenizer(text, return_tensors="pt")

model_with_pkv = ORTModelForCausalLM.from_pretrained(
self.onnx_model_dirs[model_arch + "_True"], use_cache=True
)
outputs_model_with_pkv = model_with_pkv.generate(**tokens)
_ = model_with_pkv.generate(**tokens) # warmup
with Timer() as with_pkv_timer:
outputs_model_with_pkv = model_with_pkv.generate(
**tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1
)

model_without_pkv = ORTModelForCausalLM.from_pretrained(
self.onnx_model_dirs[model_arch + "_False"], use_cache=False
)
outputs_model_without_pkv = model_without_pkv.generate(**tokens)
_ = model_without_pkv.generate(**tokens) # warmup
with Timer() as without_pkv_timer:
outputs_model_without_pkv = model_without_pkv.generate(
**tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1
)

self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv))
self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH)
self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH)
self.assertTrue(
without_pkv_timer.elapsed / with_pkv_timer.elapsed > self.SPEEDUP_CACHE,
f"With pkv latency: {with_pkv_timer.elapsed:.3f} ms, without pkv latency: {without_pkv_timer.elapsed:.3f} ms,"
f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}",
)

@parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]}))
@require_torch_gpu
Expand Down Expand Up @@ -2274,6 +2305,9 @@ class ORTModelForSeq2SeqLMIntegrationTest(ORTModelTestMixin):
ORTMODEL_CLASS = ORTModelForSeq2SeqLM
TASK = "seq2seq-lm"

GENERATION_LENGTH = 100
SPEEDUP_CACHE = 1.2

def test_load_vanilla_transformers_which_is_not_supported(self):
with self.assertRaises(Exception) as context:
_ = ORTModelForSeq2SeqLM.from_pretrained(MODEL_NAMES["bert"], from_transformers=True)
Expand Down Expand Up @@ -2459,7 +2493,7 @@ def test_pipeline_on_trt_execution_provider(self, test_name: str, model_arch: st
gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_with_and_without_past_key_values_model_outputs(self, model_arch: str):
def test_compare_with_and_without_past_key_values(self, model_arch: str):
if model_arch == "m2m_100":
return # TODO: this test is failing for m2m_100
model_args = {"test_name": model_arch + "_False", "model_arch": model_arch, "use_cache": False}
Expand All @@ -2474,12 +2508,30 @@ def test_compare_with_and_without_past_key_values_model_outputs(self, model_arch
model_with_pkv = ORTModelForSeq2SeqLM.from_pretrained(
self.onnx_model_dirs[model_arch + "_True"], use_cache=True
)
outputs_model_with_pkv = model_with_pkv.generate(**tokens)

_ = model_with_pkv.generate(**tokens) # warmup
with Timer() as with_pkv_timer:
outputs_model_with_pkv = model_with_pkv.generate(
**tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1
)

model_without_pkv = ORTModelForSeq2SeqLM.from_pretrained(
self.onnx_model_dirs[model_arch + "_False"], use_cache=False
)
outputs_model_without_pkv = model_without_pkv.generate(**tokens)
_ = model_without_pkv.generate(**tokens) # warmup
with Timer() as without_pkv_timer:
outputs_model_without_pkv = model_without_pkv.generate(
**tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1
)

self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv))
self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH)
self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH)
self.assertTrue(
without_pkv_timer.elapsed / with_pkv_timer.elapsed > self.SPEEDUP_CACHE,
f"With pkv latency: {with_pkv_timer.elapsed:.3f} ms, without pkv latency: {without_pkv_timer.elapsed:.3f} ms,"
f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}",
)

@parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]}))
@require_torch_gpu
Expand Down Expand Up @@ -2548,6 +2600,9 @@ class ORTModelForSpeechSeq2SeqIntegrationTest(ORTModelTestMixin):
ORTMODEL_CLASS = ORTModelForSpeechSeq2Seq
TASK = "speech2seq-lm"

GENERATION_LENGTH = 100
SPEEDUP_CACHE = 1.2

def _generate_random_audio_data(self):
np.random.seed(10)
t = np.linspace(0, 5.0, int(5.0 * 22050), endpoint=False)
Expand Down Expand Up @@ -2663,7 +2718,7 @@ def test_pipeline_on_gpu(self, test_name: str, model_arch: str, use_cache: bool)
self.assertTrue(isinstance(outputs["text"], str))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_with_and_without_past_key_values_model_outputs(self, model_arch: str):
def test_compare_with_and_without_past_key_values(self, model_arch: str):
model_args = {"test_name": model_arch + "_False", "model_arch": model_arch, "use_cache": False}
self._setup(model_args)
model_args = {"test_name": model_arch + "_True", "model_arch": model_arch, "use_cache": True}
Expand All @@ -2678,13 +2733,29 @@ def test_compare_with_and_without_past_key_values_model_outputs(self, model_arch
model_with_pkv = ORTModelForSpeechSeq2Seq.from_pretrained(
self.onnx_model_dirs[model_arch + "_True"], use_cache=True
)
outputs_model_with_pkv = model_with_pkv.generate(**features)
_ = model_with_pkv.generate(**features) # warpup
with Timer() as with_pkv_timer:
outputs_model_with_pkv = model_with_pkv.generate(
**features, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1
)

model_without_pkv = ORTModelForSpeechSeq2Seq.from_pretrained(
self.onnx_model_dirs[model_arch + "_False"], use_cache=False
)
outputs_model_without_pkv = model_without_pkv.generate(**features)
_ = model_without_pkv.generate(**features) # warpup
with Timer() as without_pkv_timer:
outputs_model_without_pkv = model_without_pkv.generate(
**features, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1
)

self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv))
self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH)
self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH)
self.assertTrue(
without_pkv_timer.elapsed / with_pkv_timer.elapsed > self.SPEEDUP_CACHE,
f"With pkv latency: {with_pkv_timer.elapsed:.3f} ms, without pkv latency: {without_pkv_timer.elapsed:.3f} ms,"
f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}",
)

@parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]}))
@require_torch_gpu
Expand Down Expand Up @@ -2760,6 +2831,9 @@ class ORTModelForVision2SeqIntegrationTest(ORTModelTestMixin):

TASK = "vision2seq-lm"

GENERATION_LENGTH = 100
SPEEDUP_CACHE = 1.2

def exclude_trocr_with_cache(params):
if params[0] == "trocr" and params[1] == True:
return None
Expand Down Expand Up @@ -2905,7 +2979,7 @@ def test_pipeline_on_gpu(self, test_name: str, model_arch: str, use_cache: bool)
self.assertTrue(isinstance(outputs[0]["generated_text"], str))

@parameterized.expand(SUPPORTED_ARCHITECTURES[:1])
def test_compare_with_and_without_past_key_values_model_outputs(self, model_arch: str):
def test_compare_with_and_without_past_key_values(self, model_arch: str):
model_args = {"test_name": model_arch + "_False", "model_arch": model_arch, "use_cache": False}
self._setup(model_args)
model_args = {"test_name": model_arch + "_True", "model_arch": model_arch, "use_cache": True}
Expand All @@ -2920,13 +2994,29 @@ def test_compare_with_and_without_past_key_values_model_outputs(self, model_arch
model_with_pkv = ORTModelForVision2Seq.from_pretrained(
self.onnx_model_dirs[model_arch + "_True"], use_cache=True
)
outputs_model_with_pkv = model_with_pkv.generate(**features)
_ = model_with_pkv.generate(**features) # warmup
with Timer() as with_pkv_timer:
outputs_model_with_pkv = model_with_pkv.generate(
**features, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1
)

model_without_pkv = ORTModelForVision2Seq.from_pretrained(
self.onnx_model_dirs[model_arch + "_False"], use_cache=False
)
outputs_model_without_pkv = model_without_pkv.generate(**features)
_ = model_without_pkv.generate(**features) # warmup
with Timer() as without_pkv_timer:
outputs_model_without_pkv = model_without_pkv.generate(
**features, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1
)

self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv))
self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH)
self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH)
self.assertTrue(
without_pkv_timer.elapsed / with_pkv_timer.elapsed > self.SPEEDUP_CACHE,
f"With pkv latency: {with_pkv_timer.elapsed:.3f} ms, without pkv latency: {without_pkv_timer.elapsed:.3f} ms,"
f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}",
)


class ORTModelForCustomTasksIntegrationTest(unittest.TestCase):
Expand Down