Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support of OpenVINO export #214

Merged
merged 12 commits into from
Jan 2, 2023
225 changes: 225 additions & 0 deletions notebooks/openvino_inference.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
{
AlexKoff88 marked this conversation as resolved.
Show resolved Hide resolved
AlexKoff88 marked this conversation as resolved.
Show resolved Hide resolved
AlexKoff88 marked this conversation as resolved.
Show resolved Hide resolved
AlexKoff88 marked this conversation as resolved.
Show resolved Hide resolved
"cells": [
{
"cell_type": "markdown",
"id": "552ec552",
"metadata": {},
"source": [
"# Run SetFit model with OpenVINO"
]
},
{
"cell_type": "markdown",
"id": "3af7f258-5aaf-47c2-b81e-2f10fc349812",
"metadata": {},
"source": [
"In this notebook, we'll learn how to export SetFit model to OpenVINO and run the inference."
]
},
{
"cell_type": "markdown",
"id": "c5604f73-f395-42cb-8082-9974a87ef9e9",
"metadata": {
"tags": []
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"id": "26f09e23-2e1f-41f6-bb40-a30d447a0541",
"metadata": {},
"source": [
"If you're running this Notebook on Colab or some other cloud platform, you will need to install the `setfit` library. Uncomment the following cell and run it:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "712b96e8",
"metadata": {},
"outputs": [],
"source": [
"# %pip install setfit[openvino]"
]
},
{
"cell_type": "markdown",
"id": "059bd547-0e39-43ab-adf0-5f5509217020",
"metadata": {},
"source": [
"## Load pretrained SetFit model from the Hub"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "135ba8d2-ac13-4329-946a-04226a253d83",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/alex/virt_envs/setfit/lib/python3.8/site-packages/tqdm-4.64.1-py3.8.egg/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"/home/alex/virt_envs/setfit/lib/python3.8/site-packages/scikit_learn-1.1.3-py3.8-linux-x86_64.egg/sklearn/base.py:329: UserWarning: Trying to unpickle estimator LogisticRegression from version 1.1.1 when using version 1.1.3. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n",
"https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"(array([1, 0]), '1.87 seconds')"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from time import perf_counter\n",
"from setfit import SetFitModel\n",
"\n",
"NUM_ITERS=100\n",
"\n",
"model_id = \"lewtun/my-awesome-setfit-model\"\n",
"model = SetFitModel.from_pretrained(\"lewtun/my-awesome-setfit-model\")\n",
"\n",
"# Run inference\n",
"input_text = [\"i loved the spiderman movie!\", \"pineapple on pizza is the worst 🤮\"]\n",
"\n",
"start = perf_counter()\n",
"for i in range(NUM_ITERS):\n",
" preds = model(input_text)\n",
"end = perf_counter()\n",
"\n",
"preds, f\"{end-start:.2f} seconds\""
]
},
{
"cell_type": "markdown",
"id": "12ae661d-2236-4eb3-8c52-a8fab714beb9",
"metadata": {},
"source": [
"## Export to OpenVINO\n",
"\n",
"At this step, the model is exported to OpenVINO Intermediate Representation (IR). `setfit` also provides a pure PyTorch implementation of `SetFitModel`, where the head is a dense layer instead of a classifier from `scikit-learn`. This allows one to do backprop end-to-end and have more fine-grained control over the training process. The model body (Transformer) and head are merged into the one model during the export.\n",
"\n",
"To use the PyTorch model, we load a pretrained model with `use_differentiable_head=True` and specify the number of classes to include in the head:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "58a69d72-edae-45a8-a423-702062ce75a9",
"metadata": {},
"outputs": [
{
"ename": "UnboundLocalError",
"evalue": "local variable 'onnx_path' referenced before assignment",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mUnboundLocalError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[4], line 5\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39msetfit\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mexporters\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mopenvino\u001b[39;00m \u001b[39mimport\u001b[39;00m export_to_openvino\n\u001b[1;32m 4\u001b[0m output_path \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mmodel.xml\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[0;32m----> 5\u001b[0m export_to_openvino(model, output_path\u001b[39m=\u001b[39;49moutput_path)\n",
"File \u001b[0;32m~/work/experimental/setfit/src/setfit/exporters/openvino.py:275\u001b[0m, in \u001b[0;36mexport_to_openvino\u001b[0;34m(model, output_path, ignore_ir_version)\u001b[0m\n\u001b[1;32m 268\u001b[0m transformer \u001b[39m=\u001b[39m model_body_module\u001b[39m.\u001b[39mauto_model\n\u001b[1;32m 269\u001b[0m transformer\u001b[39m.\u001b[39meval()\n\u001b[1;32m 271\u001b[0m export_onnx(\n\u001b[1;32m 272\u001b[0m model\u001b[39m.\u001b[39mmodel_body, \n\u001b[1;32m 273\u001b[0m model\u001b[39m.\u001b[39mmodel_head,\n\u001b[1;32m 274\u001b[0m opset\u001b[39m=\u001b[39mOPENVINO_SUPPORTED_OPSET,\n\u001b[0;32m--> 275\u001b[0m output_path\u001b[39m=\u001b[39monnx_path,\n\u001b[1;32m 276\u001b[0m use_hummingbird\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m\n\u001b[1;32m 277\u001b[0m )\n\u001b[1;32m 279\u001b[0m \u001b[39m# Save the final model.\u001b[39;00m\n\u001b[1;32m 280\u001b[0m onnx_path \u001b[39m=\u001b[39m output_path\u001b[39m.\u001b[39mreplace(\u001b[39m\"\u001b[39m\u001b[39m.xml\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39m.onnx\u001b[39m\u001b[39m\"\u001b[39m)\n",
"\u001b[0;31mUnboundLocalError\u001b[0m: local variable 'onnx_path' referenced before assignment"
]
}
],
"source": [
"from setfit import SetFitModel\n",
"from setfit.exporters.openvino import export_to_openvino\n",
"\n",
"output_path = \"model.xml\"\n",
"export_to_openvino(model, output_path)"
]
},
{
"cell_type": "markdown",
"id": "2959a761-24cc-4d65-a90c-8a12b61d383c",
"metadata": {},
"source": [
"## Run model with OpenVINO\n",
"Now, we run the model wit OpenVINO Python* API to compare results. The elapsed inference time is computed over `NUM_ITERS` to compare with PyTorch.\n",
"\n",
">**Note:** Text tokenization is a required extra step that should be done before the inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e3d4038c-4426-45d0-b7a4-d1e89019df7a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([1, 0]), '1.26 seconds')"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from transformers import AutoTokenizer\n",
"import openvino.runtime as ov\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
"inputs = tokenizer(\n",
" input_text,\n",
" padding=True,\n",
" truncation=True,\n",
" return_attention_mask=True,\n",
" return_token_type_ids=True,\n",
" return_tensors=\"np\",\n",
")\n",
"\n",
"inputs_dict = dict(inputs)\n",
"\n",
"core = ov.Core()\n",
"ov_model = core.read_model(output_path)\n",
"compiled_model = core.compile_model(ov_model, \"CPU\")\n",
"\n",
"start = perf_counter()\n",
"for i in range(NUM_ITERS):\n",
" ov_preds = compiled_model(inputs_dict)[compiled_model.outputs[0]]\n",
"end = perf_counter()\n",
"\n",
"ov_preds, f\"{end-start:.2f} seconds\"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.10 ('setfit')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"vscode": {
"interpreter": {
"hash": "1a0b4a52b05e9a0d1fb6e7e4631f54aec2fa21faaaad909c085a4d70095afda4"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

ONNX_REQUIRE = ["onnxruntime", "onnx", "skl2onnx"]

TESTS_REQUIRE = ["pytest", "pytest-cov"] + ONNX_REQUIRE
OPENVINO_REQUIRE = ["hummingbird-ml", "openvino>=2022.3"]

TESTS_REQUIRE = ["pytest", "pytest-cov"] + ONNX_REQUIRE + OPENVINO_REQUIRE

COMPAT_TESTS_REQUIRE = [requirement.replace(">=", "==") for requirement in REQUIRED_PKGS] + TESTS_REQUIRE

Expand All @@ -23,6 +25,7 @@
"quality": QUALITY_REQUIRE,
"tests": TESTS_REQUIRE,
"onnx": ONNX_REQUIRE,
"openvino": ONNX_REQUIRE + OPENVINO_REQUIRE,
"compat_tests": COMPAT_TESTS_REQUIRE,
}

Expand Down
26 changes: 24 additions & 2 deletions src/setfit/exporters/onnx.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import warnings
from typing import Callable, Optional, Union

Expand Down Expand Up @@ -164,12 +165,24 @@ def export_sklearn_head_to_onnx(model_head: LogisticRegression, opset: int) -> o
return onnx_model


def hummingbird_export(model, data_sample):
try:
from hummingbird.ml import convert
except ImportError:
raise ImportError(
"Hummingbird-ML library is not installed." "Run 'pip install hummingbird-ml' to use this type of export."
)
onnx_model = convert(model, "onnx", data_sample)
return onnx_model._model


def export_onnx(
model_body: SentenceTransformer,
model_head: Union[torch.nn.Module, LogisticRegression],
opset: int,
output_path: str = "model.onnx",
ignore_ir_version: bool = True,
use_hummingbird: bool = False,
) -> None:
"""Export a PyTorch backed SetFit model to ONNX Intermediate Representation.

Expand Down Expand Up @@ -228,7 +241,15 @@ def export_onnx(

# Export the sklearn head first to get the minimum opset. sklearn is behind
# in supported opsets.
onnx_head = export_sklearn_head_to_onnx(model_head, opset)
# Hummingbird-ML can be used as an option to export to standard opset
if use_hummingbird:
with torch.no_grad():
test_input = copy.deepcopy(dummy_inputs)
head_input = model_body(test_input)["sentence_embedding"]
onnx_head = hummingbird_export(model_head, head_input.detach().numpy())
else:
onnx_head = export_sklearn_head_to_onnx(model_head, opset)

max_opset = max([x.version for x in onnx_head.opset_import])

if max_opset != opset:
Expand Down Expand Up @@ -256,10 +277,11 @@ def export_onnx(
raise ValueError(msg)

# Combine the onnx body and head by mapping the pooled output to the input of the sklearn model.
head_input_name = next(iter(onnx_head.graph.input)).name
onnx_setfit_model = onnx.compose.merge_models(
onnx_body,
onnx_head,
io_map=[("logits", "model_head")],
io_map=[("logits", head_input_name)],
)

# Save the final model.
Expand Down
47 changes: 47 additions & 0 deletions src/setfit/exporters/openvino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os

import openvino.runtime as ov

from setfit import SetFitModel
from setfit.exporters.onnx import export_onnx


def export_to_openvino(
model: SetFitModel,
output_path: str = "model.xml",
) -> None:
"""Export a PyTorch backed SetFit model to OpenVINO Intermediate Representation.

Args:
model_body (`SentenceTransformer`): The model_body from a SetFit model body. This should be a
SentenceTransformer.
model_head (`torch.nn.Module` or `LogisticRegression`): The SetFit model head. This can be either a
dense layer SetFitHead or a Sklearn estimator.
output_path (`str`): The path where will be stored the generated OpenVINO model. At a minimum it needs to contain
the name of the final file.
ignore_ir_version (`bool`): Whether to ignore the IR version used in sklearn. The version is often missmatched
with the transformer models. Setting this to true coerces the versions to be the same. This might
cause errors but in practice works. If this is set to False you need to ensure that the IR versions
align between the transformer and the sklearn onnx representation.
"""

# Load the model and get all of the parts.
OPENVINO_SUPPORTED_OPSET = 13

model.model_body.cpu()
onnx_path = output_path.replace(".xml", ".onnx")

export_onnx(
model.model_body,
model.model_head,
opset=OPENVINO_SUPPORTED_OPSET,
output_path=onnx_path,
ignore_ir_version=True,
use_hummingbird=True,
)

# Save the final model.
ov_model = ov.Core().read_model(onnx_path)
ov.serialize(ov_model, output_path)

os.remove(onnx_path)
Loading