Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Micro batching example #2210

Merged
merged 27 commits into from
May 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
466526b
Add unit test for micro batching
mreso Mar 13, 2023
c3c82a4
Adds simple implementation for microbatching with coroutines
mreso Mar 14, 2023
06b69ca
Added testing to microbatching
mreso Mar 21, 2023
c39ae60
Created MBHandler
mreso Mar 22, 2023
9f5d17a
Added more tests to microbatching handler
mreso Mar 22, 2023
07eecb5
Adds configurable parallelism
mreso Mar 22, 2023
169099e
Enables loading of micro batching parameters through config file
mreso Mar 22, 2023
7373866
Moved microbatching into example
mreso Mar 29, 2023
26a7319
Moved micro batching test into test/pytest folder
mreso Mar 29, 2023
061abb6
Rewrote micro batching to use threading
mreso Mar 30, 2023
3645c9a
Implemented method to update parallelism
mreso Mar 31, 2023
9aaec3c
Fix and test spin up spin down fo threads
mreso Mar 31, 2023
88022d2
Clean up and comments
mreso Mar 31, 2023
6309ffb
More comments
mreso Mar 31, 2023
866b2ce
Adds README to micro batching example
mreso Mar 31, 2023
565ccf1
Refined readme + added config.yaml
mreso Mar 31, 2023
6caf30d
Add config_file
mreso Mar 31, 2023
3d5d76d
Fix linting error
mreso Apr 1, 2023
c5fd370
Fix spell check error
mreso Apr 1, 2023
45aa9c2
Fix linting error
mreso Apr 3, 2023
9d06fe6
Move micro_batching.py into ts.utils and use model_yaml_config for co…
mreso Apr 5, 2023
91e523b
Fix links in README
mreso Apr 5, 2023
de05ebf
Merge branch 'master' into feature/microbatching
mreso Apr 21, 2023
2c95274
Merge branch 'master' into feature/microbatching
mreso Apr 21, 2023
5af65df
Merge branch 'master' into feature/microbatching
mreso May 23, 2023
15295bc
Moved to handler_utils
mreso May 24, 2023
b3b487a
remove __all__
mreso May 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions examples/micro_batching/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Micro Batching
Accelerators like GPUs can be used most cost efficiently for inference if they are steadily fed with incoming data.
TorchServe currently allows a single batch to be processed per backend worker.
In each worker the three computation steps (preprocess, inference, postprocess) are executed sequentially.
Because pre- and postprocessing are often carried out on the CPU the GPU sits idle until the two CPU bound steps are executed and the worker receives a new batch.
The following example will show how to make better use of an accelerator in high load scenarios.

For this we are going to assume that there are a lot of incoming client requests and we can potentially fill a bigger batch size within the batch delay time frame where the frontend collects requests for the next batch.
Given this precondition we are going to increase the batch size which the backend worker receives and subsequently split the big batch up into smaller *micro* batches to perform the processing.
We can then perform the computation on the micro batches in parallel as more than one batches are available to the worker.
This way we can already process a micro batch on the GPU while the preprocessing is applied to remaining micro batches.
The pros and cons of this approach are as follow:

Pros:

* Higher throughput by better utilizing the available accelerator
* Lower overall latency when enough requests are available for computation

Cons:

* Potentially higher latency and throughput if not enough requests are available

## Implementation
This example implements micro batching using a custom handler which overwrites the *handle* method with a MicroBatching object defined in __ts.utils.micro_batching__.
```python
class MicroBatchingHandler(ImageClassifier):
def __init__(self):
mb_handle = MicroBatching(self)
self.handle = mb_handle
```
The MicroBatching object takes the custom handler as an input and spins up a number of threads.
Each thread will work on one of the processing steps (preprocess, inference, postprocess) of the custom handler while multiple threads can be assigned to process the same step in parallel.
The number of threads as well as the size of the micro batch size is configurable through the [model yaml config](config.yaml):
```yaml
batchSize: 32
micro_batching:
micro_batch_size: 4
parallelism:
preprocess: 2
inference: 1
postprocess: 2
```
Each number in the *parallelism* dictionary represents the number of threads created for the respective step on initialization.
The micro_batch_size parameter should be chosen much smaller than the batch size configured through the TorchServe API (e.g. 64 vs 4)

## Example
The following example will take a ResNet18 image classification model and run the pre- and postprocessing in parallel which includes resizing and cropping the image.

First, we need to download the model weights:
```bash
$ cd <TorchServe main folder>
$ wget https://download.pytorch.org/models/resnet18-f37072fd.pth
```
Second, we create the MAR file while including the necessary source and config files as additional files:
```bash
$ torch-model-archiver --model-name resnet-18_mb --version 1.0 --model-file ./examples/image_classifier/resnet_18/model.py --serialized-file resnet18-f37072fd.pth --handler examples/micro_batching/micro_batching_handler.py --extra-files ./examples/image_classifier/index_to_name.json --config-file examples/micro_batching/config.yaml
```
Our MicroBatchingHandler defined in [micro_batching_handler.py](micro_batching_handler.py) inherits from ImageClassifier which already defines the necessary pre- and postprocessing.

Third, we move the MAR file to our model_store and start TorchServe.
```bash
$ mkdir model_store
$ mv resnet-18_mb.mar model_store/
$ torchserve --start --ncs --model-store model_store --models resnet-18_mb.mar
```

Finally, we test the registered model with a request:
```bash
$ curl http://127.0.0.1:8080/predictions/resnet-18_mb -T ./examples/image_classifier/kitten.jpg
```
In the next section we will have a look at how the throughput and latency of the model behave by benchmarking it with TorchServe's benchmark tool.

## Results
For the following benchmark we use [benchmark-ab.py](../../benchmarks/benchmark-ab.py) and a ResNet50 instead of the smaller ResNet18.
We ran this benchmark on an AWS g4dn.4xlarge instance with a single T4 GPU.
After creating the MAR file as described above we extract it into the model_store so we do not need to upload the file.
```bash
$ unzip -d model_store/resnet-50_mb model_store/resnet-50_mb.mar
```
Subsequently, we can run the benchmark with:
```bash
$ python3 benchmarks/benchmark-ab.py --config benchmarks/config.json
```
The config.json for the benchmark has the following content:
```json
{
"url":"/home/ubuntu/serve/model_store/resnet-50_mb/",
"requests": 50000,
"concurrency": 200,
"input": "/home/ubuntu/serve/examples/image_classifier/kitten.jpg",
"workers": "1",
"batch_size": 64
}
```
This will run the model with a batch size of 64 and a micro batch size of 4 as configured in the config.yaml.
For this section we ran the benchmark with different batch sizes and micro batch sized (marked with "MBS=X") as well as different number of threads to create the following diagrams.
As a baseline we also ran the vanilla ImageClassifier handler without micro batching which is marked as "NO MB".
![](assets/throughput_latency.png)
In the diagrams we see the throughput and P99 latency plotted over the batch size (as configured through TorchServe API).
Each curve represents a different micro batch size as configured through [config.yaml](config.yaml).
We can see that the throughput stays flat for the vanilla ImageClassifier (NO MB) which suggests that the inference is preprocessing bound and the GPU is underutilized which can be confirmed with a look at the nvidia-smi output.
By interleaving the three compute steps and using two threads for pre- and postprocessing we see that the micro batched variants (MBS=4-16) achieve a higher throughput and even a lower batch latency as the GPU is better utilized due to the introduction of micro batches.
For this particular model we can achieve a throughput of up to 250 QPS by increasing the number of preprocessing threads to 4 and choosing 128 and 8 as batch size and micro batch size, respectively.
The actual achieved speedup will depend on the specific model as well as the intensity of the pre- and postprocessing steps.
Image scaling and decompression for example is usually more compute intense than text preprocessing.

## Summary
In summary we can see that micro batching can help to increase the throughput of a model while decreasing its latency.
This is especially true for workloads with compute intense pre- or postprocessing as well as smaller models.
The micro batching approach can also be used to save memory in a CPU use case by scaling the number if inference threads to >1 which allows to run multiple instances of the model which all share the same underlying weights.
This is in contrast to running multiple TorchServe worker which each create their own model instance which can not share their weights as they reside in different processes.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 8 additions & 0 deletions examples/micro_batching/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
batchSize: 32

micro_batching:
micro_batch_size: 4
parallelism:
preprocess: 2
inference: 1
postprocess: 2
30 changes: 30 additions & 0 deletions examples/micro_batching/micro_batching_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import logging

from ts.handler_utils.micro_batching import MicroBatching
from ts.torch_handler.image_classifier import ImageClassifier

logger = logging.getLogger(__name__)


class MicroBatchingHandler(ImageClassifier):
def __init__(self):
mb_handle = MicroBatching(self)
self.handle = mb_handle

def initialize(self, ctx):
super().initialize(ctx)

parallelism = ctx.model_yaml_config.get("micro_batching", {}).get(
"parallelism", None
)
if parallelism:
logger.info(
f"Setting micro batching parallelism from model_config_yaml: {parallelism}"
)
self.handle.parallelism = parallelism

micro_batch_size = ctx.model_yaml_config.get("micro_batching", {}).get(
"micro_batch_size", 1
)
logger.info(f"Setting micro batching size: {micro_batch_size}")
self.handle.micro_batch_size = micro_batch_size
233 changes: 233 additions & 0 deletions test/pytest/test_example_micro_batching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
import asyncio
import json
import random
import shutil
from argparse import Namespace
from pathlib import Path

import pytest
import requests
import test_utils
import yaml
from torchvision.models.resnet import ResNet18_Weights

from ts.torch_handler.unit_tests.test_utils.model_dir import download_model

CURR_FILE_PATH = Path(__file__).parent
REPO_ROOT_DIR = CURR_FILE_PATH.parent.parent

EXAMPLE_ROOT_DIR = REPO_ROOT_DIR.joinpath("examples", "microbatching")


def read_image_bytes(filename):
with open(
filename,
"rb",
) as fin:
image_bytes = fin.read()
return image_bytes


@pytest.fixture(scope="module")
def kitten_image_bytes():
return read_image_bytes(
REPO_ROOT_DIR.joinpath(
"examples/image_classifier/resnet_152_batch/images/kitten.jpg"
).as_posix()
)


@pytest.fixture(scope="module")
def dog_image_bytes():
return read_image_bytes(
REPO_ROOT_DIR.joinpath(
"examples/image_classifier/resnet_152_batch/images/dog.jpg"
).as_posix()
)


@pytest.fixture(scope="module", params=[4, 16])
def mixed_batch(kitten_image_bytes, dog_image_bytes, request):
batch_size = request.param
labels = [
"tiger_cat" if random.random() > 0.5 else "golden_retriever"
for _ in range(batch_size)
]
test_data = []
for l in labels:
test_data.append(kitten_image_bytes if l == "tiger_cat" else dog_image_bytes)
return test_data, labels


@pytest.fixture(scope="module")
def model_name():
yield "image_classifier"


@pytest.fixture(scope="module")
def work_dir(tmp_path_factory, model_name):
return tmp_path_factory.mktemp(model_name)


@pytest.fixture(scope="module")
def serialized_file(work_dir):
model_url = ResNet18_Weights.DEFAULT.url

download_model(model_url, work_dir)

yield Path(work_dir) / "model.pt"


@pytest.fixture(
scope="module", name="mar_file_path", params=["yaml_config", "no_config"]
)
def create_mar_file(
work_dir, session_mocker, serialized_file, model_archiver, model_name, request
):
mar_file_path = Path(work_dir).joinpath(model_name + ".mar")

name_file = REPO_ROOT_DIR.joinpath(
"examples/image_classifier/resnet_18/index_to_name.json"
).as_posix()

config_file = None
if request.param == "yaml_config":
micro_batching_params = {
"micro_batching": {
"micro_batch_size": 2,
"parallelism": {
"preprocess": 2,
"inference": 2,
"postprocess": 2,
},
},
}

config_file = Path(work_dir).joinpath("model_config.yaml")

with open(config_file, "w") as f:
yaml.dump(micro_batching_params, f)
config_file = REPO_ROOT_DIR.joinpath(
"examples", "micro_batching", "config.yaml"
)

extra_files = [name_file]

args = Namespace(
model_name=model_name,
version="1.0",
serialized_file=str(serialized_file),
model_file=REPO_ROOT_DIR.joinpath(
"examples", "image_classifier", "resnet_18", "model.py"
).as_posix(),
handler=REPO_ROOT_DIR.joinpath(
"examples", "micro_batching", "micro_batching_handler.py"
).as_posix(),
extra_files=",".join(extra_files),
export_path=work_dir,
requirements_file=None,
runtime="python",
force=False,
archive_format="default",
config_file=config_file,
)

mock = session_mocker.MagicMock()
mock.parse_args = session_mocker.MagicMock(return_value=args)
session_mocker.patch(
"archiver.ArgParser.export_model_args_parser", return_value=mock
)

# Using ZIP_STORED instead of ZIP_DEFLATED reduces test runtime from 54 secs to 10 secs
from zipfile import ZIP_STORED, ZipFile

session_mocker.patch(
"model_archiver.model_packaging_utils.zipfile.ZipFile",
lambda x, y, _: ZipFile(x, y, ZIP_STORED),
)

model_archiver.generate_model_archive()

assert mar_file_path.exists()

yield mar_file_path.as_posix()

# Clean up files
mar_file_path.unlink(missing_ok=True)


@pytest.fixture(scope="module", name="model_name")
def register_model(mar_file_path, model_store, torchserve):
"""
Register the model in torchserve
"""
shutil.copy(mar_file_path, model_store)

file_name = Path(mar_file_path).name

model_name = Path(file_name).stem

params = (
("model_name", model_name),
("url", file_name),
("initial_workers", "1"),
("synchronous", "true"),
("batch_size", "32"),
)

test_utils.reg_resp = test_utils.register_model_with_params(params)

yield model_name

test_utils.unregister_model(model_name)


def test_single_example_inference(model_name, kitten_image_bytes):
"""
Full circle test with torchserve
"""

response = requests.post(
url=f"http://localhost:8080/predictions/{model_name}", data=kitten_image_bytes
)

import inspect

print(inspect.getmembers(response))

assert response.status_code == 200


async def issue_request(model_name, data):
return requests.post(
url=f"http://localhost:8080/predictions/{model_name}", data=data
)


async def issue_multi_requests(model_name, data):
tasks = []
for d in data:
tasks.append(asyncio.create_task(issue_request(model_name, d)))

ret = []
for t in tasks:
ret.append(await t)

return ret


def test_multi_example_inference(model_name, mixed_batch):
"""
Full circle test with torchserve
"""
test_data, labels = mixed_batch

responses = asyncio.run(issue_multi_requests(model_name, test_data))

status_codes = [r.status_code for r in responses]

assert status_codes == [200] * len(status_codes)

result_entries = [json.loads(r.text) for r in responses]

assert all(l in r.keys() for l, r in zip(labels, result_entries))
Loading