diff --git a/examples/micro_batching/README.md b/examples/micro_batching/README.md new file mode 100644 index 0000000000..373a708520 --- /dev/null +++ b/examples/micro_batching/README.md @@ -0,0 +1,111 @@ +# Micro Batching +Accelerators like GPUs can be used most cost efficiently for inference if they are steadily fed with incoming data. +TorchServe currently allows a single batch to be processed per backend worker. +In each worker the three computation steps (preprocess, inference, postprocess) are executed sequentially. +Because pre- and postprocessing are often carried out on the CPU the GPU sits idle until the two CPU bound steps are executed and the worker receives a new batch. +The following example will show how to make better use of an accelerator in high load scenarios. + +For this we are going to assume that there are a lot of incoming client requests and we can potentially fill a bigger batch size within the batch delay time frame where the frontend collects requests for the next batch. +Given this precondition we are going to increase the batch size which the backend worker receives and subsequently split the big batch up into smaller *micro* batches to perform the processing. +We can then perform the computation on the micro batches in parallel as more than one batches are available to the worker. +This way we can already process a micro batch on the GPU while the preprocessing is applied to remaining micro batches. +The pros and cons of this approach are as follow: + +Pros: + +* Higher throughput by better utilizing the available accelerator +* Lower overall latency when enough requests are available for computation + +Cons: + +* Potentially higher latency and throughput if not enough requests are available + +## Implementation +This example implements micro batching using a custom handler which overwrites the *handle* method with a MicroBatching object defined in __ts.utils.micro_batching__. +```python +class MicroBatchingHandler(ImageClassifier): + def __init__(self): + mb_handle = MicroBatching(self) + self.handle = mb_handle +``` +The MicroBatching object takes the custom handler as an input and spins up a number of threads. +Each thread will work on one of the processing steps (preprocess, inference, postprocess) of the custom handler while multiple threads can be assigned to process the same step in parallel. +The number of threads as well as the size of the micro batch size is configurable through the [model yaml config](config.yaml): +```yaml +batchSize: 32 +micro_batching: + micro_batch_size: 4 + parallelism: + preprocess: 2 + inference: 1 + postprocess: 2 +``` +Each number in the *parallelism* dictionary represents the number of threads created for the respective step on initialization. +The micro_batch_size parameter should be chosen much smaller than the batch size configured through the TorchServe API (e.g. 64 vs 4) + +## Example +The following example will take a ResNet18 image classification model and run the pre- and postprocessing in parallel which includes resizing and cropping the image. + +First, we need to download the model weights: +```bash +$ cd +$ wget https://download.pytorch.org/models/resnet18-f37072fd.pth +``` +Second, we create the MAR file while including the necessary source and config files as additional files: +```bash +$ torch-model-archiver --model-name resnet-18_mb --version 1.0 --model-file ./examples/image_classifier/resnet_18/model.py --serialized-file resnet18-f37072fd.pth --handler examples/micro_batching/micro_batching_handler.py --extra-files ./examples/image_classifier/index_to_name.json --config-file examples/micro_batching/config.yaml +``` +Our MicroBatchingHandler defined in [micro_batching_handler.py](micro_batching_handler.py) inherits from ImageClassifier which already defines the necessary pre- and postprocessing. + +Third, we move the MAR file to our model_store and start TorchServe. +```bash +$ mkdir model_store +$ mv resnet-18_mb.mar model_store/ +$ torchserve --start --ncs --model-store model_store --models resnet-18_mb.mar +``` + +Finally, we test the registered model with a request: +```bash +$ curl http://127.0.0.1:8080/predictions/resnet-18_mb -T ./examples/image_classifier/kitten.jpg +``` +In the next section we will have a look at how the throughput and latency of the model behave by benchmarking it with TorchServe's benchmark tool. + +## Results +For the following benchmark we use [benchmark-ab.py](../../benchmarks/benchmark-ab.py) and a ResNet50 instead of the smaller ResNet18. +We ran this benchmark on an AWS g4dn.4xlarge instance with a single T4 GPU. +After creating the MAR file as described above we extract it into the model_store so we do not need to upload the file. +```bash +$ unzip -d model_store/resnet-50_mb model_store/resnet-50_mb.mar +``` +Subsequently, we can run the benchmark with: +```bash +$ python3 benchmarks/benchmark-ab.py --config benchmarks/config.json +``` +The config.json for the benchmark has the following content: +```json +{ + "url":"/home/ubuntu/serve/model_store/resnet-50_mb/", + "requests": 50000, + "concurrency": 200, + "input": "/home/ubuntu/serve/examples/image_classifier/kitten.jpg", + "workers": "1", + "batch_size": 64 +} +``` +This will run the model with a batch size of 64 and a micro batch size of 4 as configured in the config.yaml. +For this section we ran the benchmark with different batch sizes and micro batch sized (marked with "MBS=X") as well as different number of threads to create the following diagrams. +As a baseline we also ran the vanilla ImageClassifier handler without micro batching which is marked as "NO MB". +![](assets/throughput_latency.png) +In the diagrams we see the throughput and P99 latency plotted over the batch size (as configured through TorchServe API). +Each curve represents a different micro batch size as configured through [config.yaml](config.yaml). +We can see that the throughput stays flat for the vanilla ImageClassifier (NO MB) which suggests that the inference is preprocessing bound and the GPU is underutilized which can be confirmed with a look at the nvidia-smi output. +By interleaving the three compute steps and using two threads for pre- and postprocessing we see that the micro batched variants (MBS=4-16) achieve a higher throughput and even a lower batch latency as the GPU is better utilized due to the introduction of micro batches. +For this particular model we can achieve a throughput of up to 250 QPS by increasing the number of preprocessing threads to 4 and choosing 128 and 8 as batch size and micro batch size, respectively. +The actual achieved speedup will depend on the specific model as well as the intensity of the pre- and postprocessing steps. +Image scaling and decompression for example is usually more compute intense than text preprocessing. + +## Summary +In summary we can see that micro batching can help to increase the throughput of a model while decreasing its latency. +This is especially true for workloads with compute intense pre- or postprocessing as well as smaller models. +The micro batching approach can also be used to save memory in a CPU use case by scaling the number if inference threads to >1 which allows to run multiple instances of the model which all share the same underlying weights. +This is in contrast to running multiple TorchServe worker which each create their own model instance which can not share their weights as they reside in different processes. diff --git a/examples/micro_batching/assets/throughput_latency.png b/examples/micro_batching/assets/throughput_latency.png new file mode 100644 index 0000000000..8c00c575da Binary files /dev/null and b/examples/micro_batching/assets/throughput_latency.png differ diff --git a/examples/micro_batching/config.yaml b/examples/micro_batching/config.yaml new file mode 100644 index 0000000000..e97e58cc60 --- /dev/null +++ b/examples/micro_batching/config.yaml @@ -0,0 +1,8 @@ +batchSize: 32 + +micro_batching: + micro_batch_size: 4 + parallelism: + preprocess: 2 + inference: 1 + postprocess: 2 diff --git a/examples/micro_batching/micro_batching_handler.py b/examples/micro_batching/micro_batching_handler.py new file mode 100644 index 0000000000..bef34ed513 --- /dev/null +++ b/examples/micro_batching/micro_batching_handler.py @@ -0,0 +1,30 @@ +import logging + +from ts.handler_utils.micro_batching import MicroBatching +from ts.torch_handler.image_classifier import ImageClassifier + +logger = logging.getLogger(__name__) + + +class MicroBatchingHandler(ImageClassifier): + def __init__(self): + mb_handle = MicroBatching(self) + self.handle = mb_handle + + def initialize(self, ctx): + super().initialize(ctx) + + parallelism = ctx.model_yaml_config.get("micro_batching", {}).get( + "parallelism", None + ) + if parallelism: + logger.info( + f"Setting micro batching parallelism from model_config_yaml: {parallelism}" + ) + self.handle.parallelism = parallelism + + micro_batch_size = ctx.model_yaml_config.get("micro_batching", {}).get( + "micro_batch_size", 1 + ) + logger.info(f"Setting micro batching size: {micro_batch_size}") + self.handle.micro_batch_size = micro_batch_size diff --git a/test/pytest/test_example_micro_batching.py b/test/pytest/test_example_micro_batching.py new file mode 100644 index 0000000000..b07c99ea7d --- /dev/null +++ b/test/pytest/test_example_micro_batching.py @@ -0,0 +1,233 @@ +import asyncio +import json +import random +import shutil +from argparse import Namespace +from pathlib import Path + +import pytest +import requests +import test_utils +import yaml +from torchvision.models.resnet import ResNet18_Weights + +from ts.torch_handler.unit_tests.test_utils.model_dir import download_model + +CURR_FILE_PATH = Path(__file__).parent +REPO_ROOT_DIR = CURR_FILE_PATH.parent.parent + +EXAMPLE_ROOT_DIR = REPO_ROOT_DIR.joinpath("examples", "microbatching") + + +def read_image_bytes(filename): + with open( + filename, + "rb", + ) as fin: + image_bytes = fin.read() + return image_bytes + + +@pytest.fixture(scope="module") +def kitten_image_bytes(): + return read_image_bytes( + REPO_ROOT_DIR.joinpath( + "examples/image_classifier/resnet_152_batch/images/kitten.jpg" + ).as_posix() + ) + + +@pytest.fixture(scope="module") +def dog_image_bytes(): + return read_image_bytes( + REPO_ROOT_DIR.joinpath( + "examples/image_classifier/resnet_152_batch/images/dog.jpg" + ).as_posix() + ) + + +@pytest.fixture(scope="module", params=[4, 16]) +def mixed_batch(kitten_image_bytes, dog_image_bytes, request): + batch_size = request.param + labels = [ + "tiger_cat" if random.random() > 0.5 else "golden_retriever" + for _ in range(batch_size) + ] + test_data = [] + for l in labels: + test_data.append(kitten_image_bytes if l == "tiger_cat" else dog_image_bytes) + return test_data, labels + + +@pytest.fixture(scope="module") +def model_name(): + yield "image_classifier" + + +@pytest.fixture(scope="module") +def work_dir(tmp_path_factory, model_name): + return tmp_path_factory.mktemp(model_name) + + +@pytest.fixture(scope="module") +def serialized_file(work_dir): + model_url = ResNet18_Weights.DEFAULT.url + + download_model(model_url, work_dir) + + yield Path(work_dir) / "model.pt" + + +@pytest.fixture( + scope="module", name="mar_file_path", params=["yaml_config", "no_config"] +) +def create_mar_file( + work_dir, session_mocker, serialized_file, model_archiver, model_name, request +): + mar_file_path = Path(work_dir).joinpath(model_name + ".mar") + + name_file = REPO_ROOT_DIR.joinpath( + "examples/image_classifier/resnet_18/index_to_name.json" + ).as_posix() + + config_file = None + if request.param == "yaml_config": + micro_batching_params = { + "micro_batching": { + "micro_batch_size": 2, + "parallelism": { + "preprocess": 2, + "inference": 2, + "postprocess": 2, + }, + }, + } + + config_file = Path(work_dir).joinpath("model_config.yaml") + + with open(config_file, "w") as f: + yaml.dump(micro_batching_params, f) + config_file = REPO_ROOT_DIR.joinpath( + "examples", "micro_batching", "config.yaml" + ) + + extra_files = [name_file] + + args = Namespace( + model_name=model_name, + version="1.0", + serialized_file=str(serialized_file), + model_file=REPO_ROOT_DIR.joinpath( + "examples", "image_classifier", "resnet_18", "model.py" + ).as_posix(), + handler=REPO_ROOT_DIR.joinpath( + "examples", "micro_batching", "micro_batching_handler.py" + ).as_posix(), + extra_files=",".join(extra_files), + export_path=work_dir, + requirements_file=None, + runtime="python", + force=False, + archive_format="default", + config_file=config_file, + ) + + mock = session_mocker.MagicMock() + mock.parse_args = session_mocker.MagicMock(return_value=args) + session_mocker.patch( + "archiver.ArgParser.export_model_args_parser", return_value=mock + ) + + # Using ZIP_STORED instead of ZIP_DEFLATED reduces test runtime from 54 secs to 10 secs + from zipfile import ZIP_STORED, ZipFile + + session_mocker.patch( + "model_archiver.model_packaging_utils.zipfile.ZipFile", + lambda x, y, _: ZipFile(x, y, ZIP_STORED), + ) + + model_archiver.generate_model_archive() + + assert mar_file_path.exists() + + yield mar_file_path.as_posix() + + # Clean up files + mar_file_path.unlink(missing_ok=True) + + +@pytest.fixture(scope="module", name="model_name") +def register_model(mar_file_path, model_store, torchserve): + """ + Register the model in torchserve + """ + shutil.copy(mar_file_path, model_store) + + file_name = Path(mar_file_path).name + + model_name = Path(file_name).stem + + params = ( + ("model_name", model_name), + ("url", file_name), + ("initial_workers", "1"), + ("synchronous", "true"), + ("batch_size", "32"), + ) + + test_utils.reg_resp = test_utils.register_model_with_params(params) + + yield model_name + + test_utils.unregister_model(model_name) + + +def test_single_example_inference(model_name, kitten_image_bytes): + """ + Full circle test with torchserve + """ + + response = requests.post( + url=f"http://localhost:8080/predictions/{model_name}", data=kitten_image_bytes + ) + + import inspect + + print(inspect.getmembers(response)) + + assert response.status_code == 200 + + +async def issue_request(model_name, data): + return requests.post( + url=f"http://localhost:8080/predictions/{model_name}", data=data + ) + + +async def issue_multi_requests(model_name, data): + tasks = [] + for d in data: + tasks.append(asyncio.create_task(issue_request(model_name, d))) + + ret = [] + for t in tasks: + ret.append(await t) + + return ret + + +def test_multi_example_inference(model_name, mixed_batch): + """ + Full circle test with torchserve + """ + test_data, labels = mixed_batch + + responses = asyncio.run(issue_multi_requests(model_name, test_data)) + + status_codes = [r.status_code for r in responses] + + assert status_codes == [200] * len(status_codes) + + result_entries = [json.loads(r.text) for r in responses] + + assert all(l in r.keys() for l, r in zip(labels, result_entries)) diff --git a/ts/handler_utils/micro_batching.py b/ts/handler_utils/micro_batching.py new file mode 100644 index 0000000000..ae9671f707 --- /dev/null +++ b/ts/handler_utils/micro_batching.py @@ -0,0 +1,177 @@ +import os +import queue +import threading +import time +from copy import copy +from dataclasses import dataclass +from typing import Dict + +try: + PROFILER_AVAILABLE = True +except ImportError: + PROFILER_AVAILABLE = False + + +HANDLER_METHODS = ["preprocess", "inference", "postprocess"] + + +def execute_call(in_queue, out_queue, handle, event): + while not event.is_set(): + try: + idx, in_data = in_queue.get(timeout=0.5) + except queue.Empty: + continue + out_data = handle(in_data) + out_queue.put((idx, out_data)) + + +@dataclass +class WorkerThread: + event: threading.Event + thread: threading.Thread + + +class MicroBatching(object): + def __init__( + self, parent_handler, micro_batch_size: int = 1, parallelism: Dict = None + ): + self.handler = parent_handler + self.micro_batch_size = micro_batch_size + self._parallelism = parallelism if parallelism is not None else {} + self.thread_groups = {c: [] for c in HANDLER_METHODS} + self.queues = {} + self.terminate = threading.Event() + self._create_queues() + self._update_threads() + + def __del__(self): + self.shutdown() + + @property + def parallelism(self) -> Dict: + return copy(self._parallelism) + + @parallelism.setter + def parallelism(self, parallelism: Dict): + """Set number of threads for each of the processing steps. + + Args: + parallelism (Dict): New number of threads per processing step + + Returns: + None + """ + assert all(k in HANDLER_METHODS for k in parallelism.keys()) + + self._parallelism.update(parallelism) + self._update_threads() + + def shutdown(self): + """Shuts down all running threads. + + Args: + None + + Returns: + None + """ + for _, tg in self.thread_groups.items(): + for t in tg: + t.event.set() + t.thread.join() + + def _create_queues(self): + # Set up processing queues + self.queues[HANDLER_METHODS[0] + "_in"] = queue.Queue() + for i in range(len(HANDLER_METHODS) - 1): + # Each "out" queue is the "in" queue of the next processing step + self.queues[HANDLER_METHODS[i] + "_out"] = queue.Queue() + self.queues[HANDLER_METHODS[i + 1] + "_in"] = self.queues[ + HANDLER_METHODS[i] + "_out" + ] + self.queues[HANDLER_METHODS[-1] + "_out"] = queue.Queue() + + def _update_threads(self): + for c in HANDLER_METHODS: + tgt_parallelism = self._parallelism.get(c, 1) + assert tgt_parallelism >= 0 + cur_parallelism = lambda: len(self.thread_groups[c]) + + # Scale up threads if necessary + while tgt_parallelism > cur_parallelism(): + in_queue = self.queues[c + "_in"] + out_queue = self.queues[c + "_out"] + call = getattr(self.handler, c) + event = threading.Event() + + t = threading.Thread( + target=execute_call, + args=(in_queue, out_queue, call, event), + ) + t.start() + self.thread_groups[c].append(WorkerThread(event, t)) + + # Scale down threads if necessary + while tgt_parallelism < cur_parallelism(): + self.thread_groups[c][-1].event.set() + self.thread_groups[c][-1].thread.join() + self.thread_groups[c].pop() + + def handle(self, data): + num_batches = 0 + for idx, i in enumerate(range(0, len(data), self.micro_batch_size)): + self.queues[HANDLER_METHODS[0] + "_in"].put_nowait( + (idx, data[i : i + self.micro_batch_size]) + ) + num_batches += 1 + + output = [] + while len(output) != num_batches: + output.append(self.queues[HANDLER_METHODS[-1] + "_out"].get()) + + return [item for batch in sorted(output) for item in batch[1]] + + def __call__(self, data, context): + """Entry point for default handler. It takes the data from the input request and returns + the predicted outcome for the input. This method is a modified variant from the BaseHandler. + It calls the MicroBatching handle method instead of running the single processing steps. + + Args: + data (list): The input data that needs to be made a prediction request on. + context (Context): It is a JSON Object containing information pertaining to + the model artefacts parameters. + + Returns: + list : Returns a list of dictionary with the predicted response. + """ + + # It can be used for pre or post processing if needed as additional request + # information is available in context + start_time = time.time() + + self.handler.context = context + metrics = self.handler.context.metrics + + is_profiler_enabled = os.environ.get("ENABLE_TORCH_PROFILER", None) + if is_profiler_enabled: + if PROFILER_AVAILABLE: + output, _ = self.handler._infer_with_profiler(data=data) + else: + raise RuntimeError( + "Profiler is enabled but current version of torch does not support." + "Install torch>=1.8.1 to use profiler." + ) + else: + if self.handler._is_describe(): + output = [self.handler.describe_handle()] + elif self.handler._is_explain(): + data_preprocess = self.handler.preprocess(data) + output = self.handler.explain_handle(data_preprocess, data) + else: + output = self.handle(data) + + stop_time = time.time() + metrics.add_time( + "HandlerTime", round((stop_time - start_time) * 1000, 2), None, "ms" + ) + return output diff --git a/ts/tests/unit_tests/test_micro_batching.py b/ts/tests/unit_tests/test_micro_batching.py new file mode 100644 index 0000000000..001de6411b --- /dev/null +++ b/ts/tests/unit_tests/test_micro_batching.py @@ -0,0 +1,181 @@ +""" +Unit test for MicroBatchHandler class. +""" +import json +import random +import sys +from pathlib import Path + +import pytest +from torchvision.models.resnet import ResNet18_Weights + +from ts.torch_handler.image_classifier import ImageClassifier +from ts.torch_handler.unit_tests.test_utils.mock_context import MockContext +from ts.torch_handler.unit_tests.test_utils.model_dir import copy_files, download_model + +REPO_DIR = Path(__file__).parents[3] + + +def read_image_bytes(filename): + with open( + filename, + "rb", + ) as fin: + image_bytes = fin.read() + return image_bytes + + +@pytest.fixture(scope="module") +def kitten_image_bytes(): + return read_image_bytes( + REPO_DIR.joinpath( + "examples/image_classifier/resnet_152_batch/images/kitten.jpg" + ).as_posix() + ) + + +@pytest.fixture(scope="module") +def dog_image_bytes(): + return read_image_bytes( + REPO_DIR.joinpath( + "examples/image_classifier/resnet_152_batch/images/dog.jpg" + ).as_posix() + ) + + +@pytest.fixture(scope="module") +def model_name(): + return "image_classifier" + + +@pytest.fixture(scope="module") +def model_dir(tmp_path_factory, model_name): + model_dir = tmp_path_factory.mktemp("image_classifier_model_dir") + + src_dir = REPO_DIR.joinpath("examples/image_classifier/resnet_18/") + + model_url = ResNet18_Weights.DEFAULT.url + + download_model(model_url, model_dir) + + files = { + "model.py": model_name + ".py", + "index_to_name.json": "index_to_name.json", + } + + copy_files(src_dir, model_dir, files) + + sys.path.append(model_dir.as_posix()) + yield model_dir + sys.path.pop() + + +@pytest.fixture(scope="module") +def context(model_dir, model_name): + micro_batching_params = { + "mb_size": 2, + "mb_parallelism": { + "preprocess": 1, + "inference": 2, + "postprocess": 3, + }, + } + + config_file = Path(model_dir).joinpath("micro_batching.json") + + with open(config_file, "w") as f: + json.dump(micro_batching_params, f) + + context = MockContext( + model_name="mnist", + model_dir=model_dir.as_posix(), + model_file=model_name + ".py", + ) + context.model_yaml_config = micro_batching_params + yield context + + +@pytest.fixture(scope="module", params=[1, 8]) +def handler(context, request): + handler = ImageClassifier() + + from ts.handler_utils.micro_batching import MicroBatching + + mb_handle = MicroBatching(handler, micro_batch_size=request.param) + handler.initialize(context) + + handler.handle = mb_handle + handler.handle.parallelism = context.model_yaml_config["mb_parallelism"] + + yield handler + + mb_handle.shutdown() + + +@pytest.fixture(scope="module", params=[1, 16]) +def mixed_batch(kitten_image_bytes, dog_image_bytes, request): + batch_size = request.param + labels = [ + "tiger_cat" if random.random() > 0.5 else "golden_retriever" + for _ in range(batch_size) + ] + test_data = [] + for l in labels: + test_data.append( + {"data": kitten_image_bytes} + if l == "tiger_cat" + else {"data": dog_image_bytes} + ) + return test_data, labels + + +def test_handle(context, mixed_batch, handler): + test_data, labels = mixed_batch + results = handler.handle(test_data, context) + assert len(results) == len(labels) + for l, r in zip(labels, results): + assert l in r + + +def test_handle_explain(context, kitten_image_bytes, handler): + context.explain = True + test_data = [{"data": kitten_image_bytes, "target": 0}] * 2 + results = handler.handle(test_data, context) + assert len(results) == 2 + assert results[0] + + +def test_micro_batching_handler_threads(handler): + assert len(handler.handle.thread_groups["preprocess"]) == 1 + assert len(handler.handle.thread_groups["inference"]) == 2 + assert len(handler.handle.thread_groups["postprocess"]) == 3 + + +def test_spin_up_down_threads(handler): + assert len(handler.handle.thread_groups["preprocess"]) == 1 + assert len(handler.handle.thread_groups["inference"]) == 2 + assert len(handler.handle.thread_groups["postprocess"]) == 3 + + new_parallelism = { + "preprocess": 2, + "inference": 3, + "postprocess": 4, + } + + handler.handle.parallelism = new_parallelism + + assert len(handler.handle.thread_groups["preprocess"]) == 2 + assert len(handler.handle.thread_groups["inference"]) == 3 + assert len(handler.handle.thread_groups["postprocess"]) == 4 + + new_parallelism = { + "preprocess": 1, + "inference": 2, + "postprocess": 3, + } + + handler.handle.parallelism = new_parallelism + + assert len(handler.handle.thread_groups["preprocess"]) == 1 + assert len(handler.handle.thread_groups["inference"]) == 2 + assert len(handler.handle.thread_groups["postprocess"]) == 3 diff --git a/ts/utils/__init__.py b/ts/utils/__init__.py index fb83331a7e..01d3db1a0d 100644 --- a/ts/utils/__init__.py +++ b/ts/utils/__init__.py @@ -1,7 +1,3 @@ - - """ Util files for TorchServe """ - -from . import timeit_decorator diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt index 60f1a8acd2..04ed1aa80d 100644 --- a/ts_scripts/spellcheck_conf/wordlist.txt +++ b/ts_scripts/spellcheck_conf/wordlist.txt @@ -1032,6 +1032,10 @@ mps deviceIds rpc pippy +MBS +MicroBatching +MicroBatchingHandler +QPS PiPPy Microbatching Micro-batching