Skip to content

Commit

Permalink
Merge 07bd3cc into bb4eb8b
Browse files Browse the repository at this point in the history
  • Loading branch information
agunapal authored Aug 24, 2023
2 parents bb4eb8b + 07bd3cc commit 951f83c
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 0 deletions.
64 changes: 64 additions & 0 deletions examples/object_detector/yolo/yolov8/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Object Detection using Ultralytics's pretrained YOLOv8(yolov8n) model.


Install `ultralytics` using
```
python -m pip install -r requirements.txt
```

In this example, we are using the YOLOv8 Nano model from ultralytics.Downlaod the pretrained weights from [Ultralytics](https://docs.ultralytics.com/models/yolov8/#supported-modes)

```
wget https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt
```

We need a custom handler to load the YOLOv8n model. The default `initialize` function loads `.pt` file using `torch.jit.load`. This doesn't work for YOLOv8n model. Hence, we need a custom handler with an `initialize` method where we load the model using ultralytics.

## Create a model archive file for Yolov8n model

```
torch-model-archiver --model-name yolov8n --version 1.0 --serialized-file yolov8n.pt --handler custom_handler.py
```

```
mkdir model_store
mv yolov8n.mar model_store/.
```

## Start TorchServe and register the model


```
torchserve --start --model-store model_store --ncs
curl -X POST "localhost:8081/models?model_name=yolov8n&url=yolov8n.mar&initial_workers=4&batch_size=2"
```

results in

```
{
"status": "Model \"yolov8n\" Version: 1.0 registered with 4 initial workers"
}
```

## Run Inference

Here we are counting the number of detected objects in the image. You can change the post-process method in the handler to return the bounding box coordinates

```
curl http://127.0.0.1:8080/predictions/yolov8n -T persons.jpg & curl http://127.0.0.1:8080/predictions/yolov8n -T bus.jpg
```

gives the output

```
{
"person": 4,
"handbag": 3,
"bench": 3
}{
"person": 4,
"bus": 1,
"stop sign": 1
}
```
Binary file added examples/object_detector/yolo/yolov8/bus.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
79 changes: 79 additions & 0 deletions examples/object_detector/yolo/yolov8/custom_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import logging
import os
from collections import Counter

import torch
from torchvision import transforms
from ultralytics import YOLO

from ts.torch_handler.object_detector import ObjectDetector

logger = logging.getLogger(__name__)

try:
import torch_xla.core.xla_model as xm

XLA_AVAILABLE = True
except ImportError as error:
XLA_AVAILABLE = False


class Yolov8Handler(ObjectDetector):
image_processing = transforms.Compose(
[transforms.Resize(640), transforms.CenterCrop(640), transforms.ToTensor()]
)

def __init__(self):
super(Yolov8Handler, self).__init__()

def initialize(self, context):
# Set device type
if torch.cuda.is_available():
self.device = torch.device("cuda")
elif XLA_AVAILABLE:
self.device = xm.xla_device()
else:
self.device = torch.device("cpu")

# Load the model
properties = context.system_properties
self.manifest = context.manifest
model_dir = properties.get("model_dir")
self.model_pt_path = None
if "serializedFile" in self.manifest["model"]:
serialized_file = self.manifest["model"]["serializedFile"]
self.model_pt_path = os.path.join(model_dir, serialized_file)
self.model = self._load_torchscript_model(self.model_pt_path)
logger.debug("Model file %s loaded successfully", self.model_pt_path)

self.initialized = True

def _load_torchscript_model(self, model_pt_path):
"""Loads the PyTorch model and returns the NN model object.
Args:
model_pt_path (str): denotes the path of the model file.
Returns:
(NN Model Object) : Loads the model object.
"""
# TODO: remove this method if https://github.com/pytorch/text/issues/1793 gets resolved

model = YOLO(model_pt_path)
model.to(self.device)
return model

def postprocess(self, res):
output = []
for data in res:
classes = data.boxes.cls.tolist()
names = data.names

# Map to class names
classes = map(lambda cls: names[int(cls)], classes)

# Get a count of objects detected
result = Counter(classes)
output.append(dict(result))

return output
Binary file added examples/object_detector/yolo/yolov8/persons.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions examples/object_detector/yolo/yolov8/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ultralytics>=8.0.144

0 comments on commit 951f83c

Please sign in to comment.