Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
bstandaert committed Jun 26, 2024
2 parents 6a4b32a + 92b57a4 commit b342af6
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 108 deletions.
130 changes: 35 additions & 95 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ lightning = "^2.0"
pytorch_lightning = "^2.0"
numpy = "^1.23.5"
openmim = "^0.3.9"
ultralytics = "8.0.61"
ultralytics = "8.0.100"
sphinx = "^7.2"
sphinx_rtd_theme = "^2.0"
myst-parser = "^2.0"
Expand Down
10 changes: 10 additions & 0 deletions tracklab/configs/modules/pose_bottomup/yolov8-pose.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

_target_: tracklab.wrappers.YOLOv8Pose
batch_size: 4
cfg:
# models available :
# yolov8n-pose.pt, yolov8s-pose.pt, yolov8m-pose.pt, yolov8l-pose.pt, yolov8x-pose.pt, yolov8x-pose-p6.pt
# those models will be downloaded automatically if not found in the path
path_to_checkpoint: "${model_dir}/yolo/yolov8m-pose.pt"

min_confidence: 0.4
2 changes: 1 addition & 1 deletion tracklab/utils/cv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def draw_ignore_region(patch, image_metadata):
def print_count_frame(patch, frame, nframes):
draw_text(
patch,
f"{frame}/{nframes}",
f"{frame+1}/{nframes}",
(6, 15),
fontFace=1,
fontScale=2.0,
Expand Down
7 changes: 3 additions & 4 deletions tracklab/wrappers/datasets/external_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def __init__(self, dataset_path: str, video_path: str, *args, **kwargs):
if not mimetypes.guess_type(video_path)[0].startswith('video'):
continue
nframes = self.get_frame_count(video_path)
video_id = i
video_name = video_path.stem
video_id = video_name
image_metadata.extend(
[
{
Expand Down Expand Up @@ -115,11 +115,10 @@ def __init__(self, dataset_path: str, video_path: str, *args, **kwargs):
video_metadata,
image_metadata,
None,
image_metadata
)

sets = {"val": val_set}

super().__init__(dataset_path, sets, *args, **kwargs)
super().__init__(dataset_path, dict(val=val_set), *args, **kwargs)

@staticmethod
def get_frame_count(video_path):
Expand Down
16 changes: 9 additions & 7 deletions tracklab/wrappers/detect_multiple/yolov8_pose_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import numpy as np
import pandas as pd

from typing import Any
from tracklab.pipeline.imagelevel_module import ImageLevelModule

os.environ["YOLO_VERBOSE"] = "False"
from ultralytics import YOLO

from tracklab.utils.cv2 import cv2_load_image
from tracklab.utils.coordinates import ltrb_to_ltwh

import logging
Expand All @@ -25,6 +25,7 @@ def collate_fn(batch):

class YOLOv8Pose(ImageLevelModule):
collate_fn = collate_fn
input_columns = []
output_columns = [
"image_id",
"video_id",
Expand All @@ -35,22 +36,23 @@ class YOLOv8Pose(ImageLevelModule):
"keypoints_conf",
]

def __init__(self, cfg, device, batch_size):
def __init__(self, cfg, device, batch_size, **kwargs):
super().__init__(batch_size)
self.cfg = cfg
self.device = device
self.model = YOLO(cfg.path_to_checkpoint)
self.model.to(device)
self.id = 0

@torch.no_grad()
def preprocess(self, metadata: pd.Series):
image = cv2_load_image(metadata.file_path)
def preprocess(self, image, detections, metadata: pd.Series):
return {
"image": image,
"shape": (image.shape[1], image.shape[0]),
}

@torch.no_grad()
def process(self, batch, metadatas: pd.DataFrame):
def process(self, batch: Any, detections: pd.DataFrame, metadatas: pd.DataFrame):
images, shapes = batch
results_by_image = self.model(images)
detections = []
Expand All @@ -69,8 +71,8 @@ def process(self, batch, metadatas: pd.DataFrame):
bbox_conf=bbox.conf[0],
video_id=metadata.video_id,
category_id=1, # `person` class in posetrack
keypoints_xyc=keypoints.data[0],
keypoints_conf=np.mean(keypoints.conf[0], axis=0),
keypoints_xyc=keypoints,
keypoints_conf=np.mean(keypoints[:, 2], axis=0),
),
name=self.id,
)
Expand Down

0 comments on commit b342af6

Please sign in to comment.