You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, @JaMe76 As mentioned in the discussion #302 ( to wrap the ocr detector in a deepdoctection ObjectDetector ) you said to create a wrapper class for new or custom ocr.
I want to use YOLOv10 Model for layout detection with DOCTR so with the use of Matching service and page parsing I can filter the categories to take output from the DOCTR.
as for Layout parser is based on detectron2 we have dd.D2FrcnnDetector and that we use in ImageLayoutService() so we can use this layout service in the Doctectionpipe with doctr
Is there anything specific for Yolo models or I need to create for that?
Here is the code for layout parser and doctr instead of this I want to use YOLOv110 model, as well as can you tell me if I want to use different model instead of detectron2 how I can do that?
I have written a wrapper for the YOLO class I will share the code below I tried to filter using page_parsing
YOLO Wrapper class:
# -*- coding: utf-8 -*-
# File: yolov10detector.py
# Copyright 2024 Your Name. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
YOLOv10 document detection engine for layout extraction
"""
from __future__ import annotations
import os
from pathlib import Path
import traceback
from typing import List, Mapping, Union
from ultralytics import YOLOv10 # type: ignore
from ..utils.types import PathLikeOrStr, PixelValues, Requirement
import cv2
import numpy as np
from lazy_imports import try_import
from ..datapoint.convert import convert_np_array_to_b64_b
from ..utils.logger import LoggingRecord, logger
from ..utils.settings import LayoutType, ObjectTypes, DefaultType, TypeOrStr, get_type
from ..utils.types import JsonDict, PixelValues, Requirement
from .base import DetectionResult, ModelCategories, ObjectDetector
from typing import List, Union
from deepdoctection.utils.settings import LayoutType
from pathlib import Path
def _yolo10_to_detectresult(results, width: int, height: int) -> list[DetectionResult]:
"""
Converts YOLOv10 detection results into DetectionResult objects using inference speed as confidence.
:param results: YOLOv10 detection results
:param width: image width
:param height: image height
:return: A list of DetectionResult objects
"""
all_results: list[DetectionResult] = []
# Retrieve class names from results.names
names = results.names
name_keys = list(names.keys()) # Get list of class IDs (keys of names) to iterate through
# Use inference speed as the confidence score (e.g., using 'inference' time as a proxy)
confidence = results.speed.get('inference', 0) / 100 # Normalize by 100 if you want a scale between 0-1
# Loop through each detected box
for i, box in enumerate(results.boxes.xyxy):
# Extract and normalize bounding box coordinates
x1, y1, x2, y2 = box.tolist()
# Assign class_id based on the index in name_keys
class_id = name_keys[i % len(name_keys)] # Use modulo to cycle through class IDs if necessary
class_name = names[class_id] # Get the class name directly using class_id
# Create a DetectionResult object with inferred confidence
detection = DetectionResult(
box=[x1, y1, x2, y2],
score=confidence, # Set the normalized speed as confidence
class_id=class_id,
class_name=class_name,
)
# Append the DetectionResult to the list
all_results.append(detection)
print("all_results: ", all_results)
return all_results
def predict_yolo10(np_img: PixelValues, model, conf_threshold: float = 0.2, iou_threshold: float = 0.8) -> list[DetectionResult]:
"""
Run inference using the YOLOv10 model.
:param np_img: Input image as numpy array (BGR format)
:param model: YOLOv10 model instance
:param conf_threshold: Confidence threshold for detections
:param iou_threshold: Intersection-over-Union threshold for non-max suppression
:return: A list of detection results
"""
height, width = np_img.shape[:2]
# Run the model
results = model(source=np_img, conf=conf_threshold, iou=iou_threshold)[0]
# Convert results to DetectionResult format
all_results = _yolo10_to_detectresult(results, width, height)
return all_results
class YoloDetector(ObjectDetector):
"""
Document detector using YOLOv10 engine for layout analysis.
Model weights must be placed at `.cache/deepdoctection/weights/yolo10/yolov10x_best.pt`.
The detector predicts different categories of document elements such as text, tables, figures, headers, etc.
"""
def __init__(self,
conf_threshold: float = 0.2,
iou_threshold: float = 0.8,
model_weights: PathLikeOrStr = None,
categories: Mapping[int, TypeOrStr] = None) -> None:
"""
:param conf_threshold: Confidence threshold for YOLOv10 detections.
:param iou_threshold: IoU threshold for YOLOv10 detections.
:param model_weights: Path to the YOLOv10 model weights file.
:param categories: List of category names or LayoutType enums for YOLOv10 classes.
"""
self.name = "yolo_detector"
self.model_id = self.get_model_id()
self.conf_threshold = conf_threshold
self.iou_threshold = iou_threshold
# Load YOLOv10 model with specified weights
self.model = YOLOv10(model_weights)
if categories is None:
raise ValueError("A dictionary of category mappings must be provided.")
self.categories = ModelCategories(init_categories=categories)
def predict(self, np_img: PixelValues) -> list[DetectionResult]:
"""
Perform inference on a document image using YOLOv10 and return detection results.
:param np_img: Input image as numpy array (BGR format)
:return: A list of DetectionResult objects.
"""
return predict_yolo10(np_img, self.model, self.conf_threshold, self.iou_threshold)
@classmethod
def get_requirements(cls) -> list[Requirement]:
# No additional installation requirements as YOLO model is expected to be locally stored.
return []
def clone(self) -> YoloDetector:
"""
Clone the current detector instance.
"""
return self.__class__(conf_threshold=self.conf_threshold,
iou_threshold=self.iou_threshold,
model_weights=self.model.model_path)
def get_category_names(self) -> tuple[ObjectTypes, ...]:
"""
Get the category names used by YOLOv10 for document detection.
"""
return self.categories.get_categories(as_dict=False)
Example or implementation code:
import time
from deepdoctection.extern.doctrocr import DoctrTextRecognizer, DoctrTextlineDetector
from deepdoctection.extern.yolodetector import YoloDetector
from deepdoctection.extern.model import ModelCatalog, ModelDownloadManager
from deepdoctection.pipe.doctectionpipe import DoctectionPipe
from deepdoctection.pipe.layout import ImageLayoutService
from deepdoctection.pipe.text import TextExtractionService
import deepdoctection as dd
import matplotlib.pyplot as plt
# Define model paths and categories
model_name = "yolo/yolo10x/yolov10x_best.pt"
path = "/home/adinat/Python Ml/img/"
@dd.object_types_registry.register("YOLOTYPE")
class YOLOTYPE(dd.ObjectTypes):
"""Additional Newspaper labels not registered yet"""
CAPTION ="Caption",
FOOTNOTE = "Footnote",
FORMULA = "Formula",
LIST_ITEM = "List-item",
PAGE_FOOTER = "Page-footer",
PAGE_HEADER = "Page-header",
PICTURE = "Picture",
SECTION_HEADER = "Section-header",
TABLE = "Table",
TEXT = "Text",
TITLE = "Title"
dd.ModelCatalog.register(model_name, dd.ModelProfile(
name=model_name,
description="YOLOv10 model for layout analysis",
tp_model=False,
size=[],
categories={
1: YOLOTYPE.CAPTION,
2: YOLOTYPE.FOOTNOTE,
3: YOLOTYPE.FORMULA,
4: YOLOTYPE.LIST_ITEM,
5: YOLOTYPE.PAGE_FOOTER,
6: YOLOTYPE.PAGE_HEADER,
7: YOLOTYPE.PICTURE,
8: YOLOTYPE.SECTION_HEADER,
9: YOLOTYPE.TABLE,
10: YOLOTYPE.TEXT,
11: YOLOTYPE.TITLE,
},
model_wrapper="YoloDetector"
))
yolo_weights_path = dd.ModelCatalog.get_full_path_weights(model_name)
categories = dd.ModelCatalog.get_profile(model_name).categories
cats = [
YOLOTYPE.CAPTION,
YOLOTYPE.FOOTNOTE,
YOLOTYPE.FORMULA,
YOLOTYPE.LIST_ITEM,
YOLOTYPE.PAGE_FOOTER,
YOLOTYPE.PAGE_HEADER,
YOLOTYPE.PICTURE,
YOLOTYPE.SECTION_HEADER,
YOLOTYPE.TABLE,
YOLOTYPE.TEXT,
YOLOTYPE.TITLE,
]
yolo_detector = YoloDetector(model_weights=yolo_weights_path, categories=categories)
layout_service = ImageLayoutService(yolo_detector)
path_weights_tl = ModelDownloadManager.maybe_download_weights_and_configs("doctr/db_resnet50/pt/db_resnet50-ac60cadc.pt")
categorie = ModelCatalog.get_profile("doctr/db_resnet50/pt/db_resnet50-ac60cadc.pt").categories
det = DoctrTextlineDetector("db_resnet50",path_weights_tl,categorie,"cpu")
doctrdet = ImageLayoutService(det, to_image=True, crop_image=True)
path_weights_tr = dd.ModelDownloadManager.maybe_download_weights_and_configs("doctr/crnn_vgg16_bn/pt/crnn_vgg16_bn-9762b0b0.pt")
rec = DoctrTextRecognizer("crnn_vgg16_bn", path_weights_tr, "cpu")
text = TextExtractionService(rec, extract_from_roi=dd.LayoutType.WORD)
map_comp = dd.MatchingService(parent_categories=cats,
child_categories=[dd.LayoutType.WORD],
matching_rule='ioa', threshold=0.5,
max_parent_only=True)
text_order_comp = dd.TextOrderService(text_container=dd.LayoutType.WORD,
text_block_categories= cats,
floating_text_block_categories=cats,
include_residual_text_container=True)
page_parsing = dd.PageParsingService(text_container=dd.LayoutType.WORD,
floating_text_block_categories=[YOLOTYPE.TEXT],
include_residual_text_container=True)
pipe_comp_list=[layout_service,doctrdet,text,map_comp,text_order_comp]
analyzer = DoctectionPipe(pipeline_component_list=pipe_comp_list,page_parsing_service=page_parsing)
df=analyzer.analyze(path=path)
df.reset_state()
for dp in df:
print(dp.text)
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hello, @JaMe76 As mentioned in the discussion #302 ( to wrap the ocr detector in a deepdoctection ObjectDetector ) you said to create a wrapper class for new or custom ocr.
I want to use YOLOv10 Model for layout detection with DOCTR so with the use of Matching service and page parsing I can filter the categories to take output from the DOCTR.
as for Layout parser is based on detectron2 we have dd.D2FrcnnDetector and that we use in ImageLayoutService() so we can use this layout service in the Doctectionpipe with doctr
Is there anything specific for Yolo models or I need to create for that?
Here is the code for layout parser and doctr instead of this I want to use YOLOv110 model, as well as can you tell me if I want to use different model instead of detectron2 how I can do that?
I have written a wrapper for the YOLO class I will share the code below I tried to filter using page_parsing
YOLO Wrapper class:
Example or implementation code:
I am unable to filter based on the categories
Beta Was this translation helpful? Give feedback.
All reactions