-
-
Notifications
You must be signed in to change notification settings - Fork 256
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(onnx): ViT zero-shot image classification (#863)
* feat(onnx): ViT zero-shot image classification * update(pyproject): numpy dependency * update(pyproject): regex dependency * Update test_check_examples.py
- Loading branch information
Showing
52 changed files
with
739 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file added
BIN
+3.48 KB
assets/image_label_binary/off_road_vehicle/0007e214f726261736e90a813e7d31c8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+3.74 KB
assets/image_label_binary/off_road_vehicle/00b36c88e0c0582068265d8cbc809a82.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+3.82 KB
assets/image_label_binary/off_road_vehicle/04c92c524190c528670667585a33236a.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+3.72 KB
assets/image_label_binary/off_road_vehicle/07a7e16586dd98688c2a3eda5c82c514.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+2.46 KB
assets/image_label_binary/off_road_vehicle/0b2b8747c637fb7e12c1fd0a46ee4f24.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+4 KB
assets/image_label_binary/off_road_vehicle/0d7a677f283561475a493711f626ad51.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+3.77 KB
assets/image_label_binary/off_road_vehicle/0ef9d278727c10dd71bd534eece42aa1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+3.55 KB
assets/image_label_binary/off_road_vehicle/0f191260b3d1760e8f0c0a9d8b0a4e81.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+3.99 KB
assets/image_label_binary/off_road_vehicle/0fcdaf0cb257440fb14983b1bbc3a557.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+5.46 KB
assets/image_label_binary/off_road_vehicle/1a0ae5178a0ccd4b39b7fad5729981ac.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+2.27 KB
assets/image_label_binary/off_road_vehicle/1a1d94d1bc009f3dc067cdf6e53ecb24.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+3.96 KB
assets/image_label_binary/off_road_vehicle/1ad19e90cb26000f041fd7c8796a80ab.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+5.23 KB
assets/image_label_binary/off_road_vehicle/1b066d90ad334236d1d4d82cd09f9fbd.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+4.18 KB
assets/image_label_binary/off_road_vehicle/1ce468b575d97c4da40c1127d8a1762e.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+3.84 KB
assets/image_label_binary/off_road_vehicle/1d4ca676da2cac85de82310a3262f053.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+3.96 KB
assets/image_label_binary/off_road_vehicle/1f82cd88d47ddffe42bc3de2f23b68bf.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+3.53 KB
assets/image_label_binary/off_road_vehicle/2a0c175f26a70f705148cfe3c9b155c4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+3.41 KB
assets/image_label_binary/off_road_vehicle/2af57a97625592076e65467ee6508c88.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+3.35 KB
assets/image_label_binary/off_road_vehicle/2bf2ed21ed421381086bf704a64f5b5f.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+3.67 KB
assets/image_label_binary/off_road_vehicle/2d517193517551b21ee50decff1590a0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+2.87 KB
assets/image_label_binary/off_road_vehicle/2d5c576bd90f2d1c19a6933529237488.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+3.7 KB
assets/image_label_binary/off_road_vehicle/3a1b86ade8ec9e10d9526788d751e8c2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+3.85 KB
assets/image_label_binary/off_road_vehicle/3a67fb45061eea367490e90768d68bdd.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+3.38 KB
assets/image_label_binary/off_road_vehicle/3ae160d7de29831c9fa06a039ba66776.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+4.07 KB
assets/image_label_binary/off_road_vehicle/3b98adec326f963cec5b2b109124689e.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added
BIN
+5.48 KB
assets/image_label_binary/off_road_vehicle/3fca65e030849348541fca91e8ae91e6.png
Oops, something went wrong.
Binary file added
BIN
+4.86 KB
assets/image_label_binary/off_road_vehicle/4b8c10c12837115178a9075192c1f1c5.png
Oops, something went wrong.
Binary file added
BIN
+3.02 KB
assets/image_label_binary/off_road_vehicle/4bf9b3ce1ed7cecce47061c3f9f8314f.png
Oops, something went wrong.
Binary file added
BIN
+3.54 KB
assets/image_label_binary/off_road_vehicle/4da1a018239f7684c9628c48074a56b0.png
Oops, something went wrong.
Binary file added
BIN
+3.49 KB
assets/image_label_binary/off_road_vehicle/4ee60f4567d4d903ff042df598bd7317.png
Oops, something went wrong.
Binary file added
BIN
+3.21 KB
assets/image_label_binary/off_road_vehicle/4f1c36d943cb6f79ede955d36281abf7.png
Oops, something went wrong.
Binary file added
BIN
+4.35 KB
assets/image_label_binary/off_road_vehicle/5a8431a404ea5683e074075af0ba344e.png
Oops, something went wrong.
Binary file added
BIN
+3.83 KB
assets/image_label_binary/off_road_vehicle/5ef3e8da4d99b89ff2f4e38ce0465b5e.png
Oops, something went wrong.
Binary file added
BIN
+4.88 KB
assets/image_label_binary/off_road_vehicle/5f19b671a48aa37fe9457889d7e59b6c.png
Oops, something went wrong.
Binary file added
BIN
+4.91 KB
assets/image_label_binary/off_road_vehicle/6a2f122c4b878fd4b73a1209bb4674ff.png
Oops, something went wrong.
Binary file added
BIN
+3.58 KB
assets/image_label_binary/off_road_vehicle/6b7db18e8980602a5766738a206c6c8b.png
Oops, something went wrong.
Binary file added
BIN
+3.46 KB
assets/image_label_binary/off_road_vehicle/6c60fe1e327d972638459f1e4a0f9434.png
Oops, something went wrong.
Binary file added
BIN
+3.98 KB
assets/image_label_binary/off_road_vehicle/6d9a40f3825c0fc39f22540adf5646cf.png
Oops, something went wrong.
Binary file added
BIN
+4.45 KB
assets/image_label_binary/off_road_vehicle/7bb8d34f415127caf90ee074d8de50ab.png
Oops, something went wrong.
Binary file added
BIN
+3.66 KB
assets/image_label_binary/off_road_vehicle/7cbc7eb99ee4a9a8aa914fe1e8828bb7.png
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# -*- coding: utf-8 -*- | ||
# Time : 2023/10/24 5:39 | ||
# Author : QIN2DIM | ||
# GitHub : https://github.com/QIN2DIM | ||
# Description: | ||
import os | ||
import shutil | ||
import sys | ||
from pathlib import Path | ||
|
||
from PIL import Image | ||
from tqdm import tqdm | ||
|
||
from hcaptcha_challenger import ( | ||
DataLake, | ||
install, | ||
ModelHub, | ||
ZeroShotImageClassifier, | ||
register_pipline, | ||
) | ||
|
||
install(upgrade=True) | ||
|
||
assets_dir = Path(__file__).parent.parent.joinpath("assets") | ||
images_dir = assets_dir.joinpath("image_label_binary/off_road_vehicle") | ||
|
||
|
||
def auto_labeling(): | ||
""" | ||
Example: | ||
--- | ||
1. Roughly observe the distribution of the dataset and design a DataLake for the challenge prompt. | ||
- ChallengePrompt: "Please click each image containing an off-road vehicle" | ||
- positive_labels --> ["off-road vehicle"] | ||
- negative_labels --> ["bicycle", "car"] | ||
2. You can design them in batches and save them as YAML files, | ||
which the classifier can read and automatically DataLake | ||
3. Note that positive_labels is a list, and you can specify multiple labels for this variable | ||
if the label pointed to by the prompt contains ambiguity。 | ||
:return: | ||
""" | ||
# Refresh experiment environment | ||
modelhub = ModelHub.from_github_repo() | ||
modelhub.parse_objects() | ||
|
||
yes_dir = images_dir.joinpath("yes") | ||
bad_dir = images_dir.joinpath("bad") | ||
for cd in [yes_dir, bad_dir]: | ||
shutil.rmtree(cd, ignore_errors=True) | ||
cd.mkdir(parents=True, exist_ok=True) | ||
|
||
# !! IMPORT !! | ||
# Prompt: "Please click each image containing an off-road vehicle" | ||
data_lake = DataLake.from_prompts( | ||
positive_labels=["off-road vehicle"], negative_labels=["bicycle", "car"] | ||
) | ||
|
||
# Parse DataLake and build the model pipline | ||
tool = ZeroShotImageClassifier.from_datalake(data_lake) | ||
model = register_pipline(modelhub) | ||
|
||
total = len(os.listdir(images_dir)) | ||
with tqdm(total=total, desc=f"Labeling | {images_dir.name}") as progress: | ||
for image_name in os.listdir(images_dir): | ||
image_path = images_dir.joinpath(image_name) | ||
if not image_path.is_file(): | ||
progress.total -= 1 | ||
continue | ||
|
||
# The label at position 0 is the highest scoring target | ||
image = Image.open(image_path) | ||
results = tool(model, image) | ||
|
||
# we're only dealing with binary classification tasks here | ||
if results[0]["label"] in data_lake.positive_labels: | ||
output_path = yes_dir.joinpath(image_name) | ||
else: | ||
output_path = bad_dir.joinpath(image_name) | ||
shutil.copyfile(image_path, output_path) | ||
|
||
progress.update(1) | ||
|
||
if "win32" in sys.platform: | ||
os.startfile(images_dir) | ||
|
||
|
||
if __name__ == "__main__": | ||
auto_labeling() |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
154 changes: 154 additions & 0 deletions
154
hcaptcha_challenger/components/zero_shot_image_classifier.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# -*- coding: utf-8 -*- | ||
# Time : 2023/10/20 17:28 | ||
# Author : QIN2DIM | ||
# GitHub : https://github.com/QIN2DIM | ||
# Description: zero-shot image classification | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from dataclasses import field | ||
from pathlib import Path | ||
from typing import List, Dict, Any, Literal, Iterable | ||
|
||
import onnxruntime | ||
from PIL.Image import Image | ||
|
||
from hcaptcha_challenger.components.prompt_handler import split_prompt_message, label_cleaning | ||
from hcaptcha_challenger.onnx.clip import MossCLIP | ||
from hcaptcha_challenger.onnx.modelhub import ModelHub | ||
from hcaptcha_challenger.onnx.utils import is_cuda_pipline_available | ||
from hcaptcha_challenger.utils import from_dict_to_model | ||
|
||
|
||
def register_pipline(modelhub: ModelHub, *, fmt: Literal["onnx", "transformers"] = None, **kwargs): | ||
""" | ||
Ace Model: | ||
- laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K --> ONNX 1.7GB | ||
- QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336 --> ONNX | ||
:param modelhub: | ||
:param fmt: | ||
:param kwargs: | ||
:return: | ||
""" | ||
if fmt in ["transformers", None]: | ||
fmt = "transformers" if is_cuda_pipline_available else "onnx" | ||
|
||
if fmt in ["onnx"]: | ||
v_net, t_net = None, None | ||
|
||
if not modelhub.label_alias: | ||
modelhub.parse_objects() | ||
|
||
if visual_path := kwargs.get("visual_path"): | ||
if not isinstance(visual_path, Path): | ||
raise ValueError("visual_path should be a pathlib.Path") | ||
if not visual_path.is_file(): | ||
raise FileNotFoundError( | ||
f"Select to use visual ONNX model, but the specified model does not exist - {visual_path=}" | ||
) | ||
v_net = onnxruntime.InferenceSession( | ||
visual_path, providers=onnxruntime.get_available_providers() | ||
) | ||
if textual_path := kwargs.get("textual_path"): | ||
if not isinstance(textual_path, Path): | ||
raise ValueError("textual_path should be a pathlib.Path") | ||
if not textual_path.is_file(): | ||
raise FileNotFoundError( | ||
f"Select to use textual ONNX model, but the specified model does not exist - {textual_path=}" | ||
) | ||
t_net = onnxruntime.InferenceSession( | ||
textual_path, providers=onnxruntime.get_available_providers() | ||
) | ||
|
||
if not v_net: | ||
visual_model = kwargs.get("visual_model", modelhub.DEFAULT_CLIP_VISUAL_MODEL) | ||
v_net = modelhub.match_net(visual_model) | ||
if not t_net: | ||
textual_model = kwargs.get("textual_model", modelhub.DEFAULT_CLIP_TEXTUAL_MODEL) | ||
t_net = modelhub.match_net(textual_model) | ||
|
||
_pipeline = MossCLIP.from_pluggable_model(v_net, t_net) | ||
return _pipeline | ||
|
||
if fmt in ["transformers"]: | ||
from transformers import pipeline # type:ignore | ||
|
||
import torch # type:ignore | ||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
checkpoint = kwargs.get("checkpoint", "laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K") | ||
task = kwargs.get("task", "zero-shot-image-classification") | ||
batch_size = kwargs.get("batch_size", 16) | ||
_pipeline = pipeline(task=task, device=device, model=checkpoint, batch_size=batch_size) | ||
return _pipeline | ||
|
||
|
||
@dataclass | ||
class ZeroShotImageClassifier: | ||
positive_labels: List[str] = field(default_factory=list) | ||
candidate_labels: List[str] = field(default_factory=list) | ||
|
||
@classmethod | ||
def from_prompt(cls, positive_labels: List[str], candidate_labels: List[str]): | ||
return cls(positive_labels=positive_labels, candidate_labels=candidate_labels) | ||
|
||
@classmethod | ||
def from_datalake(cls, datalake: DataLake): | ||
positive_labels, candidate_labels = datalake.format_inputs() | ||
return cls(positive_labels=positive_labels, candidate_labels=candidate_labels) | ||
|
||
def __call__(self, detector: MossCLIP, image: Image, *args, **kwargs): | ||
if isinstance(detector, MossCLIP) and not isinstance(image, Iterable): | ||
image = [image] | ||
predictions = detector(image, candidate_labels=self.candidate_labels) | ||
return predictions | ||
|
||
|
||
@dataclass | ||
class DataLake: | ||
positive_labels: List[str] | str | ||
""" | ||
Indicate the label with the meaning "True", | ||
preferably an independent noun or clause | ||
""" | ||
|
||
negative_labels: List[str] | ||
""" | ||
Indicate the label with the meaning "False", | ||
preferably an independent noun or clause | ||
""" | ||
|
||
joined_dirs: List[str] | None = None | ||
""" | ||
Attributes reserved for AutoLabeling | ||
Used to indicate the directory where the dataset is located | ||
input_dir = db_dir.joinpath(*joined_dirs).absolute() | ||
""" | ||
|
||
@classmethod | ||
def from_serialized(cls, fields: Dict[str, Any]): | ||
return from_dict_to_model(cls, fields) | ||
|
||
@classmethod | ||
def from_prompts(cls, positive_labels: List[str], negative_labels: List[str]): | ||
return cls(positive_labels=positive_labels, negative_labels=negative_labels) | ||
|
||
def format_inputs(self): | ||
if isinstance(self.positive_labels, str): | ||
self.positive_labels = [self.positive_labels] | ||
|
||
# When the input is a challenge prompt, cut it into phrases | ||
positive_labels = [] | ||
for prompt in self.positive_labels: | ||
prompt = prompt.replace("_", " ") | ||
label = split_prompt_message(label_cleaning(prompt), "en") | ||
positive_labels.append(label) | ||
self.positive_labels = positive_labels | ||
|
||
candidate_labels = positive_labels.copy() | ||
if isinstance(self.negative_labels, list) and len(self.negative_labels) != 0: | ||
candidate_labels.extend(self.negative_labels) | ||
|
||
return positive_labels, candidate_labels |
Oops, something went wrong.