Skip to content

Commit

Permalink
Add joytag automation
Browse files Browse the repository at this point in the history
  • Loading branch information
gingerchicken committed Sep 16, 2024
1 parent 10a5b9f commit 20d16ee
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 0 deletions.
8 changes: 8 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ RUN pip install -r requirements.txt
# Download opennsfw model
RUN python3 -c "import opennsfw2; opennsfw2.make_open_nsfw_model()"

# Download joytag model
RUN apt update && apt install -y git-lfs
RUN git clone https://huggingface.co/fancyfeast/joytag ~/.joytag/model
RUN git clone https://github.com/fpgaminer/joytag ~/.joytag/joytag

# Install the joytag requirements
RUN pip install -r ~/.joytag/joytag/requirements.txt

# Copy all of the current code to the /app/ directory
COPY . /app/

Expand Down
2 changes: 2 additions & 0 deletions booru/automation/tag/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .tag_automation import *
from .metadata import *
from .tags import *
from .joytag import JoytagAutomation

__INITIALISED__ = False

Expand All @@ -17,6 +18,7 @@ def perform_setup():
__INITIALISED__ = True

# Register the automations
TagAutomationRegistry().register(JoytagAutomation())
TagAutomationRegistry().register(AnimatedContentTagAutomation())
TagAutomationRegistry().register(LargeFileSizeTagAutomation())

Expand Down
113 changes: 113 additions & 0 deletions booru/automation/tag/joytag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import sys
import os

# Add the directory containing Models.py to the Python path
sys.path.append(os.path.expanduser('~/.joytag/joytag'))

from Models import VisionModel # ~/.joytag/joytag/Models.py
from PIL import Image
import torch.amp.autocast_mode
from pathlib import Path
import torch
import torchvision.transforms.functional as TVF

from .tag_automation import TagAutomation
from booru.models.tags import Tag
from booru.models.posts import Post

class JoytagAutomation(TagAutomation):
"""
An automation for tagging images with Joytag.
"""

model_path = Path(os.path.expanduser('~/.joytag/model'))
threshold = 0.7

def __init__(self):
super().__init__()

def __prepare_model(self):
"""
Creates a VisionModel object from the model at the given path.
"""

self.__model = None
self.__top_tags = None

if self.__model is not None and self.__top_tags is not None:
return

model = VisionModel.load_model(self.model_path, device='cpu')
model.eval()
model = model.to('cpu')

# Get the top tags
# Load the tags from the file
with open(self.model_path / 'top_tags.txt', 'r') as f:
top_tags = [line.strip() for line in f.readlines() if line.strip()]

self.__model = model
self.__top_tags = top_tags

# Wrappers
def __prepare_image(self, image: Image.Image, target_size: int) -> torch.Tensor:
# Pad image to square
image_shape = image.size
max_dim = max(image_shape)
pad_left = (max_dim - image_shape[0]) // 2
pad_top = (max_dim - image_shape[1]) // 2

padded_image = Image.new('RGB', (max_dim, max_dim), (255, 255, 255))
padded_image.paste(image, (pad_left, pad_top))

# Resize image
if max_dim != target_size:
padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC)

# Convert to tensor
image_tensor = TVF.pil_to_tensor(padded_image) / 255.0

# Normalize
image_tensor = TVF.normalize(image_tensor, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])

return image_tensor

@torch.no_grad()
def __predict(self, image: Image.Image):
self.__prepare_model() # Ensure the model is loaded

image_tensor = self.__prepare_image(image, self.__model.image_size)
batch = {
'image': image_tensor.unsqueeze(0).to('cpu'),
}

with torch.amp.autocast_mode.autocast('cpu', enabled=True):
preds = self.__model(batch)
tag_preds = preds['tags'].sigmoid().cpu()

scores = {self.__top_tags[i]: tag_preds[0][i] for i in range(len(self.__top_tags))}
predicted_tags = [tag for tag, score in scores.items() if score > self.threshold]
tag_string = ', '.join(predicted_tags)

return tag_string, scores

# Override
def get_tags(self, post : Post) -> list[Tag]:
# Check if the post is a video or gif
if post.is_video:
return []

image = Image.open(post.get_media_path())

tag_string, scores = self.__predict(image)

# Select the tags over the threshold
selected_tags = []
predicted_tags = scores.items()
for tag, score in predicted_tags:
if score <= self.threshold:
continue

selected_tags.append(Tag.create_or_get(tag))

return selected_tags

0 comments on commit 20d16ee

Please sign in to comment.