Skip to content

Commit

Permalink
Merge branch 'master' into issue_1361
Browse files Browse the repository at this point in the history
  • Loading branch information
lxning authored Dec 14, 2021
2 parents 45c0777 + c87bfec commit 6c1c0ae
Show file tree
Hide file tree
Showing 4 changed files with 290 additions and 1 deletion.
225 changes: 225 additions & 0 deletions examples/intel_extension_for_pytorch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# TorchServe with Intel® Extension for PyTorch*

TorchServe can be used with Intel® Extension for PyTorch* (IPEX) to give performance boost on Intel hardware<sup>1</sup>.
Here we show how to use TorchServe with IPEX.

<sup>1. While IPEX benefits all platforms, plaforms with AVX512 benefit the most. </sup>

## Contents of this Document
* [Install Intel Extension for PyTorch](#install-intel-extension-for-pytorch)
* [Serving model with Intel Extension for PyTorch](#serving-model-with-intel-extension-for-pytorch)
* [TorchServe with Launcher](#torchserve-with-launcher)
* [Creating and Exporting INT8 model for IPEX](#creating-and-exporting-int8-model-for-ipex)
* [Benchmarking with Launcher](#benchmarking-with-launcher)


## Install Intel Extension for PyTorch
Refer to the documentation [here](https://github.com/intel/intel-extension-for-pytorch#installation).

## Serving model with Intel Extension for PyTorch
After installation, all it needs to be done to use TorchServe with IPEX is to enable it in `config.properties`.
```
ipex_enable=true
```
Once IPEX is enabled, deploying PyTorch model follows the same procedure shown [here](https://pytorch.org/serve/use_cases.html). TorchServe with IPEX can deploy any model and do inference.

## TorchServe with Launcher
Launcher is a script to automate the process of tunining configuration setting on intel hardware to boost performance. Tuning configurations such as OMP_NUM_THREADS, thread affininty, memory allocator can have a dramatic effect on performance. Please refer to [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/tuning_guide.md) and [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md) for details on performance tuning with launcher.

All it needs to be done to use TorchServe with launcher is to set its configuration in `config.properties`.

Add the following lines in `config.properties` to use launcher with its default configuration.
```
ipex_enable=true
cpu_launcher_enable=true
```

Launcher by default uses `numactl` if its installed to ensure socket is pinned and thus memory is allocated from local numa node. To use launcher without numactl, add the following lines in `config.properties`.
```
ipex_enable=true
cpu_launcher_enable=true
cpu_launcher_args=--disable_numactl
```

Launcher by default uses only non-hyperthreaded cores if hyperthreading is present to avoid core compute resource sharing. To use launcher with all cores, both physical and logical, add the following lines in `config.properties`.
```
ipex_enable=true
cpu_launcher_enable=true
cpu_launcher_args=--use_logical_core
```

Below is an example of passing multiple args to `cpu_launcher_args`.
```
ipex_enable=true
cpu_launcher_enable=true
cpu_launcher_args=--use_logical_core --disable_numactl
```

Some useful `cpu_launcher_args` to note are:
1. Memory Allocator: [ PTMalloc `--use_default_allocator` | *TCMalloc `--enable_tcmalloc`* | JeMalloc `--enable_jemalloc`]
* PyTorch by defualt uses PTMalloc. TCMalloc/JeMalloc generally gives better performance.
2. OpenMP library: [GNU OpenMP `--disable_iomp` | *Intel OpenMP*]
* PyTorch by default uses GNU OpenMP. Launcher by default uses Intel OpenMP. Intel OpenMP library generally gives better performance.
3. Socket id: [`--socket_id`]
* Launcher by default uses all physical cores. Limit memory access to local memories on the Nth socket to avoid Non-Uniform Memory Access (NUMA).

Please refer to [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md) for a full list of tunable configuration of launcher.


## Creating and Exporting INT8 model for IPEX
Intel Extension for PyTorch supports both eager and torchscript mode. In this section, we show how to deploy INT8 model for IPEX.

### 1. Creating a serialized file
First create `.pt` serialized file using IPEX INT8 inference. Here we show two examples with BERT and ResNet50.

#### BERT

```
import torch
import intel_extension_for_pytorch as ipex
import transformers
from transformers import AutoModelForSequenceClassification, AutoConfig
# load the model
config = AutoConfig.from_pretrained(
"bert-base-uncased", return_dict=False, torchscript=True, num_labels=2)
model = AutoModelForSequenceClassification.from_pretrained(
"bert-base-uncased", config=config)
model = model.eval()
# define dummy input tensor to use for the model's forward call to record operations in the model for tracing
N, max_length = 1, 384
dummy_tensor = torch.ones((N, max_length), dtype=torch.long)
# calibration
# ipex supports two quantization schemes to be used for activation: torch.per_tensor_affine and torch.per_tensor_symmetric
# default qscheme is torch.per_tensor_affine
conf = ipex.quantization.QuantConf(qscheme=torch.per_tensor_affine)
n_iter = 100
with torch.no_grad():
for i in range(n_iter):
with ipex.quantization.calibrate(conf):
model(dummy_tensor, dummy_tensor, dummy_tensor)
# optionally save the configuraiton for later use
# save:
# conf.save("model_conf.json")
# load:
# conf = ipex.quantization.QuantConf("model_conf.json")
# conversion
jit_inputs = (dummy_tensor, dummy_tensor, dummy_tensor)
model = ipex.quantization.convert(model, conf, jit_inputs)
# enable fusion path work(need to run forward propagation twice)
with torch.no_grad():
y = model(dummy_tensor,dummy_tensor,dummy_tensor)
y = model(dummy_tensor,dummy_tensor,dummy_tensor)
# save to .pt
torch.jit.save(model, 'bert_int8_jit.pt')
```

#### ResNet50

```
import torch
import torch.fx.experimental.optimization as optimization
import intel_extension_for_pytorch as ipex
import torchvision.models as models
# load the model
model = models.resnet50(pretrained=True)
model = model.eval()
model = optimization.fuse(model)
# define dummy input tensor to use for the model's forward call to record operations in the model for tracing
N, C, H, W = 1, 3, 224, 224
dummy_tensor = torch.randn(N, C, H, W).contiguous(memory_format=torch.channels_last)
# calibration
# ipex supports two quantization schemes to be used for activation: torch.per_tensor_affine and torch.per_tensor_symmetric
# default qscheme is torch.per_tensor_affine
conf = ipex.quantization.QuantConf(qscheme=torch.per_tensor_symmetric)
n_iter = 100
with torch.no_grad():
for i in range(n_iter):
with ipex.quantization.calibrate(conf):
model(dummy_tensor)
# optionally save the configuraiton for later use
# save:
# conf.save("model_conf.json")
# load:
# conf = ipex.quantization.QuantConf("model_conf.json")
# conversion
jit_inputs = (dummy_tensor)
model = ipex.quantization.convert(model, conf, jit_inputs)
# enable fusion path work(need to run two iterations)
with torch.no_grad():
y = model(dummy_tensor)
y = model(dummy_tensor)
# save to .pt
torch.jit.save(model, 'rn50_int8_jit.pt')
```

### 2. Creating a Model Archive
Once the serialized file ( `.pt`) is created, it can be used with `torch-model-archiver` as ususal. Use the following command to package the model.
```
torch-model-archiver --model-name rn50_ipex_int8 --version 1.0 --serialized-file rn50_int8_jit.pt --handler image_classifier
```
### 3. Start TorchServe to serve the model
Make sure to set `ipex_enable=true` in `config.properties`. Use the following command to start TorchServe with IPEX.
```
torchserve --start --ncs --model-store model_store --ts-config config.properties
```

### 4. Registering and Deploying model
Registering and deploying the model follows the same steps shown [here](https://pytorch.org/serve/use_cases.html).

## Benchmarking with Launcher
Launcher can be used with TorchServe official [benchmark](https://github.com/pytorch/serve/tree/master/benchmarks) to launch server and benchmark requests with optimal configuration on Intel hardware.

In this section we provide examples of benchmarking with launcher with its default configuration.

Add the following lines to `config.properties` in the benchmark directory to use launcher with its default setting.
```
ipex_enable=true
cpu_launcher_enable=true
```

The rest of the steps for benchmarking follows the same steps shown [here](https://github.com/pytorch/serve/tree/master/benchmarks).

`model_log.log` contains information and command that were used for this execution launch.


CPU usage on a machine with Intel(R) Xeon(R) Platinum 8180 CPU, 2 sockets, 28 cores per socket, 2 threads per core is shown as below:
![launcher_default_2sockets](https://user-images.githubusercontent.com/93151422/144373537-07787510-039d-44c4-8cfd-6afeeb64ac78.gif)

```
$ cat logs/model_log.log
2021-12-01 21:22:40,096 - __main__ - WARNING - Both TCMalloc and JeMalloc are not found in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or /home/<user>/.local/lib/ so the LD_PRELOAD environment variable will not be set. This may drop the performance
2021-12-01 21:22:40,096 - __main__ - INFO - OMP_NUM_THREADS=56
2021-12-01 21:22:40,096 - __main__ - INFO - Using Intel OpenMP
2021-12-01 21:22:40,096 - __main__ - INFO - KMP_AFFINITY=granularity=fine,compact,1,0
2021-12-01 21:22:40,096 - __main__ - INFO - KMP_BLOCKTIME=1
2021-12-01 21:22:40,096 - __main__ - INFO - LD_PRELOAD=<VIRTUAL_ENV>/lib/libiomp5.so
2021-12-01 21:22:40,096 - __main__ - WARNING - Numa Aware: cores:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55] in different NUMA node
```

CPU usage on a machine with Intel(R) Xeon(R) Platinum 8375C CPU, 1 socket, 2 cores per socket, 2 threads per socket is shown as below:
![launcher_default_1socket](https://user-images.githubusercontent.com/93151422/144372993-92b2ca96-f309-41e2-a5c8-bf2143815c93.gif)

```
$ cat logs/model_log.log
2021-12-02 06:15:03,981 - __main__ - WARNING - Both TCMalloc and JeMalloc are not found in $CONDA_PREFIX/lib or $VIRTUAL_ENV/lib or /.local/lib/ or /usr/local/lib/ or /usr/local/lib64/ or /usr/lib or /usr/lib64 or /home/<user>/.local/lib/ so the LD_PRELOAD environment variable will not be set. This may drop the performance
2021-12-02 06:15:03,981 - __main__ - INFO - OMP_NUM_THREADS=2
2021-12-02 06:15:03,982 - __main__ - INFO - Using Intel OpenMP
2021-12-02 06:15:03,982 - __main__ - INFO - KMP_AFFINITY=granularity=fine,compact,1,0
2021-12-02 06:15:03,982 - __main__ - INFO - KMP_BLOCKTIME=1
2021-12-02 06:15:03,982 - __main__ - INFO - LD_PRELOAD=<VIRTUAL_ENV>/lib/libiomp5.so
```
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ public final class ConfigManager {

// IPEX config option that can be set at config.properties
private static final String TS_IPEX_ENABLE = "ipex_enable";
private static final String TS_CPU_LAUNCHER_ENABLE = "cpu_launcher_enable";
private static final String TS_CPU_LAUNCHER_ARGS = "cpu_launcher_args";

private static final String TS_ASYNC_LOGGING = "async_logging";
private static final String TS_CORS_ALLOWED_ORIGIN = "cors_allowed_origin";
Expand Down Expand Up @@ -337,6 +339,14 @@ public boolean isMetricApiEnable() {
return Boolean.parseBoolean(getProperty(TS_ENABLE_METRICS_API, "true"));
}

public boolean isCPULauncherEnabled() {
return Boolean.parseBoolean(getProperty(TS_CPU_LAUNCHER_ENABLE, "false"));
}

public String getCPULauncherArgs() {
return getProperty(TS_CPU_LAUNCHER_ARGS, null);
}

public int getNettyThreads() {
return getIntProperty(TS_NUMBER_OF_NETTY_THREADS, 0);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public class WorkerLifeCycle {
private Connector connector;
private ReaderThread errReader;
private ReaderThread outReader;
private String launcherArgs;

public WorkerLifeCycle(ConfigManager configManager, Model model) {
this.configManager = configManager;
Expand All @@ -39,6 +40,46 @@ public Process getProcess() {
return process;
}

public ArrayList<String> launcherArgsToList() {
ArrayList<String> arrlist = new ArrayList<String>();
arrlist.add("-m");
arrlist.add("intel_extension_for_pytorch.cpu.launch");
arrlist.add("--ninstance");
arrlist.add("1");
if (launcherArgs != null && launcherArgs.length() > 1) {
String[] argarray = launcherArgs.split(" ");
for (int i = 0; i < argarray.length; i++) {
arrlist.add(argarray[i]);
}
}
return arrlist;
}

public boolean isLauncherAvailable()
throws WorkerInitializationException, InterruptedException {
boolean launcherAvailable = false;
try {
ArrayList<String> cmd = new ArrayList<String>();
cmd.add("python");
ArrayList<String> args = launcherArgsToList();
cmd.addAll(args);
cmd.add("--no_python");
// try launching dummy command to check launcher availability
String dummyCmd = "hostname";
cmd.add(dummyCmd);

String[] cmd_ = new String[cmd.size()];
cmd_ = cmd.toArray(cmd_);

Process process = Runtime.getRuntime().exec(cmd_);
int ret = process.waitFor();
launcherAvailable = (ret == 0);
} catch (IOException | InterruptedException e) {
throw new WorkerInitializationException("Failed to start launcher", e);
}
return launcherAvailable;
}

public void startWorker(int port) throws WorkerInitializationException, InterruptedException {
File workingDir = new File(configManager.getModelServerHome());
File modelPath;
Expand All @@ -51,6 +92,19 @@ public void startWorker(int port) throws WorkerInitializationException, Interrup

ArrayList<String> argl = new ArrayList<String>();
argl.add(EnvironmentUtils.getPythonRunTime(model));

if (configManager.isCPULauncherEnabled()) {
launcherArgs = configManager.getCPULauncherArgs();
boolean launcherAvailable = isLauncherAvailable();
if (launcherAvailable) {
ArrayList<String> args = launcherArgsToList();
argl.addAll(args);
} else {
logger.warn(
"CPU launcher is enabled but launcher is not available. Proceeding without launcher.");
}
}

argl.add(new File(workingDir, "ts/model_service_worker.py").getAbsolutePath());
argl.add("--sock-type");
argl.add(connector.getSocketType());
Expand Down
2 changes: 1 addition & 1 deletion ts/torch_handler/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import intel_extension_for_pytorch as ipex
ipex_enabled = True
except ImportError as error:
logger.warning("IPEX was not installed. Please install IPEX if wanted.")
logger.warning("IPEX is enabled but intel-extension-for-pytorch is not installed. Proceeding without IPEX.")

class BaseHandler(abc.ABC):
"""
Expand Down

0 comments on commit 6c1c0ae

Please sign in to comment.