From fe0e7f4e1118cd8788fccdf33973dfa1d601bd6d Mon Sep 17 00:00:00 2001 From: Alexander Kozlov Date: Mon, 2 Jan 2023 17:10:54 +0400 Subject: [PATCH] Added support of OpenVINO export (#214) * Added support of OpenVINO export * Cleaned up the code * Added notebook. Fixed device issues * Some fixes * Reused ONNX code for OpenVINO export * Fixed and moved test for OpenVINO * Fixed style * Fixed tests * Fixed environment issue * Applied comments --- notebooks/openvino_inference.ipynb | 242 +++++++++++++++++++++++++++++ setup.py | 5 +- src/setfit/exporters/onnx.py | 26 +++- src/setfit/exporters/openvino.py | 47 ++++++ tests/exporters/test_openvino.py | 51 ++++++ 5 files changed, 368 insertions(+), 3 deletions(-) create mode 100644 notebooks/openvino_inference.ipynb create mode 100644 src/setfit/exporters/openvino.py create mode 100644 tests/exporters/test_openvino.py diff --git a/notebooks/openvino_inference.ipynb b/notebooks/openvino_inference.ipynb new file mode 100644 index 00000000..14b3dca8 --- /dev/null +++ b/notebooks/openvino_inference.ipynb @@ -0,0 +1,242 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "552ec552", + "metadata": {}, + "source": [ + "# Run SetFit model with OpenVINO" + ] + }, + { + "cell_type": "markdown", + "id": "3af7f258-5aaf-47c2-b81e-2f10fc349812", + "metadata": {}, + "source": [ + "In this notebook, we'll learn how to export SetFit model to OpenVINO and run the inference." + ] + }, + { + "cell_type": "markdown", + "id": "c5604f73-f395-42cb-8082-9974a87ef9e9", + "metadata": { + "tags": [] + }, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "markdown", + "id": "26f09e23-2e1f-41f6-bb40-a30d447a0541", + "metadata": {}, + "source": [ + "If you're running this Notebook on Colab or some other cloud platform, you will need to install the `setfit` library. Uncomment the following cell and run it:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "712b96e8", + "metadata": {}, + "outputs": [], + "source": [ + "#%pip install setfit[openvino] ipywidgets widgetsnbextension" + ] + }, + { + "cell_type": "markdown", + "id": "059bd547-0e39-43ab-adf0-5f5509217020", + "metadata": {}, + "source": [ + "## Load pretrained SetFit model from the Hub" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "135ba8d2-ac13-4329-946a-04226a253d83", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-01-02 10:18:34.394949: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA\n", + "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "2023-01-02 10:18:34.511144: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2023-01-02 10:18:35.007140: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n", + "2023-01-02 10:18:35.007192: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n", + "2023-01-02 10:18:35.007213: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n", + "/home/alex/virt_envs/setfit/lib/python3.8/site-packages/scikit_learn-1.1.3-py3.8-linux-x86_64.egg/sklearn/base.py:329: UserWarning: Trying to unpickle estimator LogisticRegression from version 1.1.1 when using version 1.1.3. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n", + "https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/plain": [ + "(array([1, 0]), '1.74 seconds')" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from time import perf_counter\n", + "from setfit import SetFitModel\n", + "\n", + "NUM_ITERS=100\n", + "\n", + "model_id = \"lewtun/my-awesome-setfit-model\"\n", + "model = SetFitModel.from_pretrained(\"lewtun/my-awesome-setfit-model\")\n", + "\n", + "# Run inference\n", + "input_text = [\"i loved the spiderman movie!\", \"pineapple on pizza is the worst 🤮\"]\n", + "\n", + "start = perf_counter()\n", + "for i in range(NUM_ITERS):\n", + " preds = model(input_text)\n", + "end = perf_counter()\n", + "\n", + "preds, f\"{end-start:.2f} seconds\"" + ] + }, + { + "cell_type": "markdown", + "id": "12ae661d-2236-4eb3-8c52-a8fab714beb9", + "metadata": {}, + "source": [ + "## Export to OpenVINO\n", + "\n", + "At this step, the model is exported to OpenVINO Intermediate Representation (IR). `setfit` also provides a pure PyTorch implementation of `SetFitModel`, where the head is a dense layer instead of a classifier from `scikit-learn`. This allows one to do backprop end-to-end and have more fine-grained control over the training process. The model body (Transformer) and head are merged into the one model during the export.\n", + "\n", + "To use the PyTorch model, we load a pretrained model with `use_differentiable_head=True` and specify the number of classes to include in the head." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "58a69d72-edae-45a8-a423-702062ce75a9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/alex/work/experimental/setfit/src/setfit/exporters/onnx.py:256: UserWarning: sklearn onnx max opset is 11 requested opset 13 using opset 11 for compatibility.\n", + " warnings.warn(\n", + "/home/alex/virt_envs/setfit/lib/python3.8/site-packages/torch/onnx/_internal/jit_utils.py:258: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)\n", + " _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)\n", + "/home/alex/virt_envs/setfit/lib/python3.8/site-packages/torch/onnx/utils.py:687: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)\n", + " _C._jit_pass_onnx_graph_shape_type_inference(\n", + "/home/alex/virt_envs/setfit/lib/python3.8/site-packages/torch/onnx/utils.py:1178: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)\n", + " _C._jit_pass_onnx_graph_shape_type_inference(\n" + ] + } + ], + "source": [ + "from setfit import SetFitModel\n", + "from setfit.exporters.openvino import export_to_openvino\n", + "\n", + "output_path = \"model.xml\"\n", + "export_to_openvino(model, output_path)" + ] + }, + { + "cell_type": "markdown", + "id": "2959a761-24cc-4d65-a90c-8a12b61d383c", + "metadata": {}, + "source": [ + "## Run model with OpenVINO\n", + "Now, we run the model wit OpenVINO Python* API to compare results. The elapsed inference time is computed over `NUM_ITERS` to compare with PyTorch.\n", + "\n", + ">**Note:** Text tokenization is a required extra step that should be done before the inference" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "e3d4038c-4426-45d0-b7a4-d1e89019df7a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([1, 0]), '1.25 seconds')" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from transformers import AutoTokenizer\n", + "import openvino.runtime as ov\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", + "inputs = tokenizer(\n", + " input_text,\n", + " padding=True,\n", + " truncation=True,\n", + " return_attention_mask=True,\n", + " return_token_type_ids=True,\n", + " return_tensors=\"np\",\n", + ")\n", + "\n", + "inputs_dict = dict(inputs)\n", + "\n", + "core = ov.Core()\n", + "ov_model = core.read_model(output_path)\n", + "compiled_model = core.compile_model(ov_model, \"CPU\")\n", + "output = compiled_model.output(0)\n", + "\n", + "start = perf_counter()\n", + "for i in range(NUM_ITERS):\n", + " ov_preds = compiled_model(inputs_dict)[output]\n", + "end = perf_counter()\n", + "\n", + "ov_preds, f\"{end-start:.2f} seconds\"\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.10 ('setfit')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "vscode": { + "interpreter": { + "hash": "1a0b4a52b05e9a0d1fb6e7e4631f54aec2fa21faaaad909c085a4d70095afda4" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/setup.py b/setup.py index fd78c9b0..de78d342 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,9 @@ ONNX_REQUIRE = ["onnxruntime", "onnx", "skl2onnx"] -TESTS_REQUIRE = ["pytest", "pytest-cov"] + ONNX_REQUIRE +OPENVINO_REQUIRE = ["hummingbird-ml", "openvino>=2022.3"] + +TESTS_REQUIRE = ["pytest", "pytest-cov"] + ONNX_REQUIRE + OPENVINO_REQUIRE COMPAT_TESTS_REQUIRE = [requirement.replace(">=", "==") for requirement in REQUIRED_PKGS] + TESTS_REQUIRE @@ -26,6 +28,7 @@ "quality": QUALITY_REQUIRE, "tests": TESTS_REQUIRE, "onnx": ONNX_REQUIRE, + "openvino": ONNX_REQUIRE + OPENVINO_REQUIRE, "compat_tests": COMPAT_TESTS_REQUIRE, } diff --git a/src/setfit/exporters/onnx.py b/src/setfit/exporters/onnx.py index 44413297..b5d18df4 100644 --- a/src/setfit/exporters/onnx.py +++ b/src/setfit/exporters/onnx.py @@ -1,3 +1,4 @@ +import copy import warnings from typing import Callable, Optional, Union @@ -164,12 +165,24 @@ def export_sklearn_head_to_onnx(model_head: LogisticRegression, opset: int) -> o return onnx_model +def hummingbird_export(model, data_sample): + try: + from hummingbird.ml import convert + except ImportError: + raise ImportError( + "Hummingbird-ML library is not installed." "Run 'pip install hummingbird-ml' to use this type of export." + ) + onnx_model = convert(model, "onnx", data_sample) + return onnx_model._model + + def export_onnx( model_body: SentenceTransformer, model_head: Union[torch.nn.Module, LogisticRegression], opset: int, output_path: str = "model.onnx", ignore_ir_version: bool = True, + use_hummingbird: bool = False, ) -> None: """Export a PyTorch backed SetFit model to ONNX Intermediate Representation. @@ -228,7 +241,15 @@ def export_onnx( # Export the sklearn head first to get the minimum opset. sklearn is behind # in supported opsets. - onnx_head = export_sklearn_head_to_onnx(model_head, opset) + # Hummingbird-ML can be used as an option to export to standard opset + if use_hummingbird: + with torch.no_grad(): + test_input = copy.deepcopy(dummy_inputs) + head_input = model_body(test_input)["sentence_embedding"] + onnx_head = hummingbird_export(model_head, head_input.detach().numpy()) + else: + onnx_head = export_sklearn_head_to_onnx(model_head, opset) + max_opset = max([x.version for x in onnx_head.opset_import]) if max_opset != opset: @@ -256,10 +277,11 @@ def export_onnx( raise ValueError(msg) # Combine the onnx body and head by mapping the pooled output to the input of the sklearn model. + head_input_name = next(iter(onnx_head.graph.input)).name onnx_setfit_model = onnx.compose.merge_models( onnx_body, onnx_head, - io_map=[("logits", "model_head")], + io_map=[("logits", head_input_name)], ) # Save the final model. diff --git a/src/setfit/exporters/openvino.py b/src/setfit/exporters/openvino.py new file mode 100644 index 00000000..cc0157e0 --- /dev/null +++ b/src/setfit/exporters/openvino.py @@ -0,0 +1,47 @@ +import os + +import openvino.runtime as ov + +from setfit import SetFitModel +from setfit.exporters.onnx import export_onnx + + +def export_to_openvino( + model: SetFitModel, + output_path: str = "model.xml", +) -> None: + """Export a PyTorch backed SetFit model to OpenVINO Intermediate Representation. + + Args: + model_body (`SentenceTransformer`): The model_body from a SetFit model body. This should be a + SentenceTransformer. + model_head (`torch.nn.Module` or `LogisticRegression`): The SetFit model head. This can be either a + dense layer SetFitHead or a Sklearn estimator. + output_path (`str`): The path where will be stored the generated OpenVINO model. At a minimum it needs to contain + the name of the final file. + ignore_ir_version (`bool`): Whether to ignore the IR version used in sklearn. The version is often missmatched + with the transformer models. Setting this to true coerces the versions to be the same. This might + cause errors but in practice works. If this is set to False you need to ensure that the IR versions + align between the transformer and the sklearn onnx representation. + """ + + # Load the model and get all of the parts. + OPENVINO_SUPPORTED_OPSET = 13 + + model.model_body.cpu() + onnx_path = output_path.replace(".xml", ".onnx") + + export_onnx( + model.model_body, + model.model_head, + opset=OPENVINO_SUPPORTED_OPSET, + output_path=onnx_path, + ignore_ir_version=True, + use_hummingbird=True, + ) + + # Save the final model. + ov_model = ov.Core().read_model(onnx_path) + ov.serialize(ov_model, output_path) + + os.remove(onnx_path) diff --git a/tests/exporters/test_openvino.py b/tests/exporters/test_openvino.py new file mode 100644 index 00000000..0a846c86 --- /dev/null +++ b/tests/exporters/test_openvino.py @@ -0,0 +1,51 @@ +import os + +import numpy as np +import openvino.runtime as ov +from transformers import AutoTokenizer + +from setfit import SetFitModel +from setfit.exporters.openvino import export_to_openvino + + +def test_export_to_openvino(): + """Test that the exported `OpenVINO` model returns the same predictions as the original model.""" + model_path = "lewtun/my-awesome-setfit-model" + model = SetFitModel.from_pretrained(model_path) + + # Export the sklearn based model + output_path = "model.xml" + export_to_openvino(model, output_path=output_path) + + # Check that the model was saved. + assert output_path in os.listdir(), "Model not saved to output_path" + + # Run inference using the original model. + input_text = ["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"] + pytorch_preds = model(input_text) + + # Run inference using the exported OpenVINO model. + tokenizer = AutoTokenizer.from_pretrained(model_path) + inputs = tokenizer( + input_text, + padding=True, + truncation=True, + return_attention_mask=True, + return_token_type_ids=True, + return_tensors="np", + ) + + inputs_dict = dict(inputs) + + core = ov.Core() + ov_model = core.read_model(output_path) + compiled_model = core.compile_model(ov_model, "CPU") + + ov_preds = compiled_model(inputs_dict)[compiled_model.outputs[0]] + + # Compare the results and ensure that we get the same predictions. + assert np.array_equal(ov_preds, pytorch_preds) + + # Cleanup the model. + os.remove(output_path) + os.remove(output_path.replace(".xml", ".bin"))