From 508865d63ec222ae3cda25f3ff11b790d8a55f40 Mon Sep 17 00:00:00 2001 From: WolframRhodium Date: Tue, 7 Feb 2023 17:10:25 +0800 Subject: [PATCH] scripts/vsmlrt.py: add param `num_threads` to the `OV_CPU` backend --- scripts/vsmlrt.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/scripts/vsmlrt.py b/scripts/vsmlrt.py index 76c6f8f..150def5 100644 --- a/scripts/vsmlrt.py +++ b/scripts/vsmlrt.py @@ -1,4 +1,4 @@ -__version__ = "3.15.8" +__version__ = "3.15.9" __all__ = [ "Backend", "BackendV2", @@ -96,6 +96,7 @@ class OV_CPU: bind_thread: bool = True fp16_blacklist_ops: typing.Optional[typing.Sequence[str]] = None bf16: bool = False + num_threads: int = 0 # internal backend attributes supports_onnx_serialization: bool = True @@ -1323,6 +1324,7 @@ def _inference( config = lambda: dict( CPU_THROUGHPUT_STREAMS=backend.num_streams, CPU_BIND_THREAD="YES" if backend.bind_thread else "NO", + CPU_THREADS_NUM=backend.num_threads, ENFORCE_BF16="YES" if backend.bf16 else "NO" ) @@ -1570,12 +1572,14 @@ def OV_CPU(*, num_streams: typing.Union[int, str] = 1, bf16: bool = False, bind_thread: bool = True, + num_threads: int = 0, **kwargs ) -> Backend.OV_CPU: return Backend.OV_CPU( num_streams=num_streams, bf16=bf16, bind_thread=bind_thread, + num_threads=num_threads, **kwargs )