Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support cover images embedded in safetensors metadata #15319

Merged
merged 1 commit into from
Mar 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion extensions-builtin/Lora/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def __init__(self, name, filename):

def read_metadata():
metadata = sd_models.read_metadata_from_safetensors(filename)
metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text

return metadata

Expand Down
2 changes: 1 addition & 1 deletion extensions-builtin/Lora/ui_extra_networks_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def create_item(self, name, index=None, enable_filter=True):
"name": name,
"filename": lora_on_disk.filename,
"shorthash": lora_on_disk.shorthash,
"preview": self.find_preview(path),
"preview": self.find_preview(path) or self.find_embedded_preview(path, name, lora_on_disk.metadata),
"description": self.find_description(path),
"search_terms": search_terms,
"local_preview": f"{path}.{shared.opts.samples_format}",
Expand Down
42 changes: 42 additions & 0 deletions modules/ui_extra_networks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import functools
import os.path
import urllib.parse
from base64 import b64decode
from io import BytesIO
from pathlib import Path
from typing import Optional, Union
from dataclasses import dataclass
Expand All @@ -11,6 +13,7 @@
import json
import html
from fastapi.exceptions import HTTPException
from PIL import Image

from modules.infotext_utils import image_from_url_text

Expand Down Expand Up @@ -108,6 +111,31 @@ def fetch_file(filename: str = ""):
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})


def fetch_cover_images(page: str = "", item: str = "", index: int = 0):
from starlette.responses import Response

page = next(iter([x for x in extra_pages if x.name == page]), None)
if page is None:
raise HTTPException(status_code=404, detail="File not found")

metadata = page.metadata.get(item)
if metadata is None:
raise HTTPException(status_code=404, detail="File not found")

cover_images = json.loads(metadata.get('ssmd_cover_images', {}))
image = cover_images[index] if index < len(cover_images) else None
if not image:
raise HTTPException(status_code=404, detail="File not found")

try:
image = Image.open(BytesIO(b64decode(image)))
buffer = BytesIO()
image.save(buffer, format=image.format)
return Response(content=buffer.getvalue(), media_type=image.get_format_mimetype())
except Exception as err:
raise ValueError(f"File cannot be fetched: {item}. Failed to load cover image.") from err


def get_metadata(page: str = "", item: str = ""):
from starlette.responses import JSONResponse

Expand All @@ -119,6 +147,8 @@ def get_metadata(page: str = "", item: str = ""):
if metadata is None:
return JSONResponse({})

metadata = {i:metadata[i] for i in metadata if i != 'ssmd_cover_images'} # those are cover images, and they are too big to display in UI as text

return JSONResponse({"metadata": json.dumps(metadata, indent=4, ensure_ascii=False)})


Expand All @@ -142,6 +172,7 @@ def get_single_card(page: str = "", tabname: str = "", name: str = ""):

def add_pages_to_demo(app):
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
app.add_api_route("/sd_extra_networks/cover-images", fetch_cover_images, methods=["GET"])
app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"])

Expand Down Expand Up @@ -627,6 +658,17 @@ def find_preview(self, path):

return None

def find_embedded_preview(self, path, name, metadata):
"""
Find if embedded preview exists in safetensors metadata and return endpoint for it.
"""

file = f"{path}.safetensors"
if self.lister.exists(file) and 'ssmd_cover_images' in metadata and len(list(filter(None, json.loads(metadata['ssmd_cover_images'])))) > 0:
return f"./sd_extra_networks/cover-images?page={self.extra_networks_tabname}&item={name}"

return None

def find_description(self, path):
"""
Find and read a description file for a given path (without extension).
Expand Down