Skip to content

Commit

Permalink
Merge branch 'main' of github.com:TrackingLaboratory/tracklab
Browse files Browse the repository at this point in the history
  • Loading branch information
victorjoos committed Jun 26, 2024
2 parents 8c328f0 + b342af6 commit bec590d
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 103 deletions.
Empty file removed hydra_plugins/__init__.py
Empty file.
130 changes: 35 additions & 95 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ lightning = "^2.0"
pytorch_lightning = "^2.0"
numpy = "^1.23.5"
openmim = "^0.3.9"
ultralytics = "8.0.61"
ultralytics = "8.0.100"
sphinx = "^7.2"
sphinx_rtd_theme = "^2.0"
myst-parser = "^2.0"
Expand Down
10 changes: 10 additions & 0 deletions tracklab/configs/modules/pose_bottomup/yolov8-pose.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

_target_: tracklab.wrappers.YOLOv8Pose
batch_size: 4
cfg:
# models available :
# yolov8n-pose.pt, yolov8s-pose.pt, yolov8m-pose.pt, yolov8l-pose.pt, yolov8x-pose.pt, yolov8x-pose-p6.pt
# those models will be downloaded automatically if not found in the path
path_to_checkpoint: "${model_dir}/yolo/yolov8m-pose.pt"

min_confidence: 0.4
16 changes: 9 additions & 7 deletions tracklab/wrappers/detect_multiple/yolov8_pose_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import numpy as np
import pandas as pd

from typing import Any
from tracklab.pipeline.imagelevel_module import ImageLevelModule

os.environ["YOLO_VERBOSE"] = "False"
from ultralytics import YOLO

from tracklab.utils.cv2 import cv2_load_image
from tracklab.utils.coordinates import ltrb_to_ltwh

import logging
Expand All @@ -25,6 +25,7 @@ def collate_fn(batch):

class YOLOv8Pose(ImageLevelModule):
collate_fn = collate_fn
input_columns = []
output_columns = [
"image_id",
"video_id",
Expand All @@ -35,22 +36,23 @@ class YOLOv8Pose(ImageLevelModule):
"keypoints_conf",
]

def __init__(self, cfg, device, batch_size):
def __init__(self, cfg, device, batch_size, **kwargs):
super().__init__(batch_size)
self.cfg = cfg
self.device = device
self.model = YOLO(cfg.path_to_checkpoint)
self.model.to(device)
self.id = 0

@torch.no_grad()
def preprocess(self, metadata: pd.Series):
image = cv2_load_image(metadata.file_path)
def preprocess(self, image, detections, metadata: pd.Series):
return {
"image": image,
"shape": (image.shape[1], image.shape[0]),
}

@torch.no_grad()
def process(self, batch, metadatas: pd.DataFrame):
def process(self, batch: Any, detections: pd.DataFrame, metadatas: pd.DataFrame):
images, shapes = batch
results_by_image = self.model(images)
detections = []
Expand All @@ -69,8 +71,8 @@ def process(self, batch, metadatas: pd.DataFrame):
bbox_conf=bbox.conf[0],
video_id=metadata.video_id,
category_id=1, # `person` class in posetrack
keypoints_xyc=keypoints.data[0],
keypoints_conf=np.mean(keypoints.conf[0], axis=0),
keypoints_xyc=keypoints,
keypoints_conf=np.mean(keypoints[:, 2], axis=0),
),
name=self.id,
)
Expand Down

0 comments on commit bec590d

Please sign in to comment.