diff --git a/test/benchmarks/conftest.py b/test/benchmarks/conftest.py index d58e9ce55..16d7c3491 100644 --- a/test/benchmarks/conftest.py +++ b/test/benchmarks/conftest.py @@ -7,17 +7,18 @@ @pytest.fixture(scope="session") -def onnx_adaptive_model_qa(use_gpu, num_processes): - model_name_or_path = "deepset/bert-base-cased-squad2" - onnx_model_export_path = Path("benchmarks/onnx-export") - if not (onnx_model_export_path / "model.onnx").is_file(): +def onnx_adaptive_model_qa(use_gpu, num_processes, model_name_or_path="deepset/bert-base-cased-squad2"): + if (Path(model_name_or_path) / "model.onnx").is_file(): # load model directly if in ONNX format + onnx_model_path = model_name_or_path + else: # convert to ONNX format + onnx_model_path = Path("benchmarks/onnx-export") model = AdaptiveModel.convert_from_transformers( model_name_or_path, device="cpu", task_type="question_answering" ) - model.convert_to_onnx(onnx_model_export_path) + model.convert_to_onnx(onnx_model_path) model = Inferencer.load( - onnx_model_export_path, task_type="question_answering", batch_size=1, num_processes=num_processes, gpu=use_gpu + onnx_model_path, task_type="question_answering", batch_size=1, num_processes=num_processes, gpu=use_gpu ) return model