Skip to content

Commit

Permalink
feat(onnx): ViT zero-shot image classification (#863)
Browse files Browse the repository at this point in the history
* feat(onnx): ViT zero-shot image classification

* update(pyproject): numpy dependency

* update(pyproject): regex dependency

* Update test_check_examples.py
  • Loading branch information
QIN2DIM authored Oct 24, 2023
1 parent 34f7cbb commit db775a4
Show file tree
Hide file tree
Showing 52 changed files with 739 additions and 8 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,5 @@ examples/*.png
docs/*.xmind
tests_private/
.run/
assets/image_label_binary/off_road_vehicle/yes
assets/image_label_binary/off_road_vehicle/bad
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
92 changes: 92 additions & 0 deletions examples/demo_classifier_clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-
# Time : 2023/10/24 5:39
# Author : QIN2DIM
# GitHub : https://github.com/QIN2DIM
# Description:
import os
import shutil
import sys
from pathlib import Path

from PIL import Image
from tqdm import tqdm

from hcaptcha_challenger import (
DataLake,
install,
ModelHub,
ZeroShotImageClassifier,
register_pipline,
)

install(upgrade=True)

assets_dir = Path(__file__).parent.parent.joinpath("assets")
images_dir = assets_dir.joinpath("image_label_binary/off_road_vehicle")


def auto_labeling():
"""
Example:
---
1. Roughly observe the distribution of the dataset and design a DataLake for the challenge prompt.
- ChallengePrompt: "Please click each image containing an off-road vehicle"
- positive_labels --> ["off-road vehicle"]
- negative_labels --> ["bicycle", "car"]
2. You can design them in batches and save them as YAML files,
which the classifier can read and automatically DataLake
3. Note that positive_labels is a list, and you can specify multiple labels for this variable
if the label pointed to by the prompt contains ambiguity。
:return:
"""
# Refresh experiment environment
modelhub = ModelHub.from_github_repo()
modelhub.parse_objects()

yes_dir = images_dir.joinpath("yes")
bad_dir = images_dir.joinpath("bad")
for cd in [yes_dir, bad_dir]:
shutil.rmtree(cd, ignore_errors=True)
cd.mkdir(parents=True, exist_ok=True)

# !! IMPORT !!
# Prompt: "Please click each image containing an off-road vehicle"
data_lake = DataLake.from_prompts(
positive_labels=["off-road vehicle"], negative_labels=["bicycle", "car"]
)

# Parse DataLake and build the model pipline
tool = ZeroShotImageClassifier.from_datalake(data_lake)
model = register_pipline(modelhub)

total = len(os.listdir(images_dir))
with tqdm(total=total, desc=f"Labeling | {images_dir.name}") as progress:
for image_name in os.listdir(images_dir):
image_path = images_dir.joinpath(image_name)
if not image_path.is_file():
progress.total -= 1
continue

# The label at position 0 is the highest scoring target
image = Image.open(image_path)
results = tool(model, image)

# we're only dealing with binary classification tasks here
if results[0]["label"] in data_lake.positive_labels:
output_path = yes_dir.joinpath(image_name)
else:
output_path = bad_dir.joinpath(image_name)
shutil.copyfile(image_path, output_path)

progress.update(1)

if "win32" in sys.platform:
os.startfile(images_dir)


if __name__ == "__main__":
auto_labeling()
File renamed without changes.
File renamed without changes.
8 changes: 8 additions & 0 deletions hcaptcha_challenger/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@

from hcaptcha_challenger.components.image_classifier import Classifier as BinaryClassifier
from hcaptcha_challenger.components.image_classifier import LocalBinaryClassifier
from hcaptcha_challenger.components.zero_shot_image_classifier import (
ZeroShotImageClassifier,
DataLake,
register_pipline,
)
from hcaptcha_challenger.components.image_label_area_select import AreaSelector
from hcaptcha_challenger.components.prompt_handler import (
label_cleaning,
Expand All @@ -28,6 +33,9 @@
__all__ = [
"BinaryClassifier",
"LocalBinaryClassifier",
"ZeroShotImageClassifier",
"register_pipline",
"DataLake",
"AreaSelector",
"label_cleaning",
"diagnose_task",
Expand Down
154 changes: 154 additions & 0 deletions hcaptcha_challenger/components/zero_shot_image_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# -*- coding: utf-8 -*-
# Time : 2023/10/20 17:28
# Author : QIN2DIM
# GitHub : https://github.com/QIN2DIM
# Description: zero-shot image classification
from __future__ import annotations

from dataclasses import dataclass
from dataclasses import field
from pathlib import Path
from typing import List, Dict, Any, Literal, Iterable

import onnxruntime
from PIL.Image import Image

from hcaptcha_challenger.components.prompt_handler import split_prompt_message, label_cleaning
from hcaptcha_challenger.onnx.clip import MossCLIP
from hcaptcha_challenger.onnx.modelhub import ModelHub
from hcaptcha_challenger.onnx.utils import is_cuda_pipline_available
from hcaptcha_challenger.utils import from_dict_to_model


def register_pipline(modelhub: ModelHub, *, fmt: Literal["onnx", "transformers"] = None, **kwargs):
"""
Ace Model:
- laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K --> ONNX 1.7GB
- QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336 --> ONNX
:param modelhub:
:param fmt:
:param kwargs:
:return:
"""
if fmt in ["transformers", None]:
fmt = "transformers" if is_cuda_pipline_available else "onnx"

if fmt in ["onnx"]:
v_net, t_net = None, None

if not modelhub.label_alias:
modelhub.parse_objects()

if visual_path := kwargs.get("visual_path"):
if not isinstance(visual_path, Path):
raise ValueError("visual_path should be a pathlib.Path")
if not visual_path.is_file():
raise FileNotFoundError(
f"Select to use visual ONNX model, but the specified model does not exist - {visual_path=}"
)
v_net = onnxruntime.InferenceSession(
visual_path, providers=onnxruntime.get_available_providers()
)
if textual_path := kwargs.get("textual_path"):
if not isinstance(textual_path, Path):
raise ValueError("textual_path should be a pathlib.Path")
if not textual_path.is_file():
raise FileNotFoundError(
f"Select to use textual ONNX model, but the specified model does not exist - {textual_path=}"
)
t_net = onnxruntime.InferenceSession(
textual_path, providers=onnxruntime.get_available_providers()
)

if not v_net:
visual_model = kwargs.get("visual_model", modelhub.DEFAULT_CLIP_VISUAL_MODEL)
v_net = modelhub.match_net(visual_model)
if not t_net:
textual_model = kwargs.get("textual_model", modelhub.DEFAULT_CLIP_TEXTUAL_MODEL)
t_net = modelhub.match_net(textual_model)

_pipeline = MossCLIP.from_pluggable_model(v_net, t_net)
return _pipeline

if fmt in ["transformers"]:
from transformers import pipeline # type:ignore

import torch # type:ignore

device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = kwargs.get("checkpoint", "laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K")
task = kwargs.get("task", "zero-shot-image-classification")
batch_size = kwargs.get("batch_size", 16)
_pipeline = pipeline(task=task, device=device, model=checkpoint, batch_size=batch_size)
return _pipeline


@dataclass
class ZeroShotImageClassifier:
positive_labels: List[str] = field(default_factory=list)
candidate_labels: List[str] = field(default_factory=list)

@classmethod
def from_prompt(cls, positive_labels: List[str], candidate_labels: List[str]):
return cls(positive_labels=positive_labels, candidate_labels=candidate_labels)

@classmethod
def from_datalake(cls, datalake: DataLake):
positive_labels, candidate_labels = datalake.format_inputs()
return cls(positive_labels=positive_labels, candidate_labels=candidate_labels)

def __call__(self, detector: MossCLIP, image: Image, *args, **kwargs):
if isinstance(detector, MossCLIP) and not isinstance(image, Iterable):
image = [image]
predictions = detector(image, candidate_labels=self.candidate_labels)
return predictions


@dataclass
class DataLake:
positive_labels: List[str] | str
"""
Indicate the label with the meaning "True",
preferably an independent noun or clause
"""

negative_labels: List[str]
"""
Indicate the label with the meaning "False",
preferably an independent noun or clause
"""

joined_dirs: List[str] | None = None
"""
Attributes reserved for AutoLabeling
Used to indicate the directory where the dataset is located
input_dir = db_dir.joinpath(*joined_dirs).absolute()
"""

@classmethod
def from_serialized(cls, fields: Dict[str, Any]):
return from_dict_to_model(cls, fields)

@classmethod
def from_prompts(cls, positive_labels: List[str], negative_labels: List[str]):
return cls(positive_labels=positive_labels, negative_labels=negative_labels)

def format_inputs(self):
if isinstance(self.positive_labels, str):
self.positive_labels = [self.positive_labels]

# When the input is a challenge prompt, cut it into phrases
positive_labels = []
for prompt in self.positive_labels:
prompt = prompt.replace("_", " ")
label = split_prompt_message(label_cleaning(prompt), "en")
positive_labels.append(label)
self.positive_labels = positive_labels

candidate_labels = positive_labels.copy()
if isinstance(self.negative_labels, list) and len(self.negative_labels) != 0:
candidate_labels.extend(self.negative_labels)

return positive_labels, candidate_labels
Loading

0 comments on commit db775a4

Please sign in to comment.