Skip to content

Commit

Permalink
return generator hint as part of ParseResult
Browse files Browse the repository at this point in the history
    - allows for setting the generator dynamically
  • Loading branch information
d3x-at committed Sep 24, 2024
1 parent 14a10ce commit e0c3cc8
Show file tree
Hide file tree
Showing 19 changed files with 58 additions and 74 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def main():
if params is None:
return

image_generator, metadata = params
parser, metadata, parsing_context = params

...
```
Expand Down Expand Up @@ -91,14 +91,14 @@ def main():
parameters, parsing_context = parser.read_parameters(image)

# parse() builds a standardized data structure from the raw parameters
samplers, metadata = parser.parse(parameters, parsing_context)
generator, samplers, metadata = parser.parse(parameters, parsing_context)

except ParserError:
...

# creating a PromptInfo object from the obtained data allows for the use
# of convenience poperties like ".prompts" or ".models"
prompt_info = PromptInfo(parser, samplers, metadata)
prompt_info = PromptInfo(generator, samplers, metadata)
```

### Output
Expand Down
5 changes: 3 additions & 2 deletions examples/fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
usage: uvicorn fast_api:app
test: curl -X "POST" http://127.0.0.1:8000/api/parse -F "image=@/path/to/image.png"
"""

from fastapi import FastAPI, UploadFile
from PIL import Image
from sd_parsers import ParserManager
Expand All @@ -21,5 +22,5 @@ def parse(image: UploadFile):
if params is None:
return {"success": False}

image_generator, metadata = params
return {"success": True, "image_generator": image_generator, "metadata": metadata}
parser, metadata, _ = params
return {"success": True, "parser": type(parser).__name__, "metadata": metadata}
15 changes: 8 additions & 7 deletions src/sd_parsers/_parser_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Provides the ParserManager class."""

from __future__ import annotations

import logging
Expand Down Expand Up @@ -78,12 +79,12 @@ def parse(
with _get_image(image) as image:
for parser, parameters, parsing_context in self._read_parameters(image):
try:
samplers, metadata = parser.parse(parameters, parsing_context)
generator, samplers, metadata = parser.parse(parameters, parsing_context)

return PromptInfo(parser, samplers, metadata)
return PromptInfo(generator, samplers, metadata)

except ParserError as error:
logger.debug("error in %s parser: %s", parser.generator.value, error)
logger.debug("error in parser: %s", error)

return None

Expand All @@ -103,13 +104,13 @@ def read_parameters(self, image: Union[str, bytes, Path, SupportsRead[bytes], Im
- PIL.UnidentifiedImageError: If the image cannot be opened and identified.
- ValueError: If a StringIO instance is used for `image`.
"""

with _get_image(image) as image:
try:
parser, parameters, _ = next(
parser, parameters, parsing_context = next(
iter(self._read_parameters(image, lambda x: x._COMPLEXITY_INDEX))
)
return parser.generator, parameters
return parser, parameters, parsing_context

except StopIteration:
return None
Expand All @@ -129,4 +130,4 @@ def _read_parameters(
yield parser, parameters, parsing_context

except ParserError as error:
logger.debug("error in %s parser: %s", parser.generator.value, error)
logger.debug("error in parser: %s", error)
18 changes: 7 additions & 11 deletions src/sd_parsers/data.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""Data classes representing a subset of image generation parameters."""

from __future__ import annotations

import itertools
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional

if TYPE_CHECKING:
from .parser import Parser
from typing import Any, Dict, Iterable, List, Optional


class Generators(str, Enum):
Expand Down Expand Up @@ -107,6 +105,9 @@ def __hash__(self) -> int:
class PromptInfo:
"""Contains structured image generation parameters."""

generator: Generators
"""Image generator which might have produced the parsed image."""

samplers: List[Sampler]
"""Samplers used in generating the parsed image."""

Expand All @@ -119,7 +120,7 @@ class PromptInfo:

def __init__(
self,
parser: Parser,
generator: Generators,
samplers: List[Sampler],
metadata: Dict[Any, Any],
):
Expand All @@ -131,15 +132,10 @@ def __init__(
samplers: The samplers used in generating the parsed image.
metadata: Any additional parameters which are found in the image metadata.
"""
self._parser = parser
self.generator = generator
self.samplers = samplers
self.metadata = metadata

@property
def generator(self) -> Generators:
"""Image generater which might have produced the parsed image."""
return self._parser.generator

@property
def full_prompt(self) -> str:
"""
Expand Down
8 changes: 2 additions & 6 deletions src/sd_parsers/parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Provides the Parser base class & other useful utility."""

from __future__ import annotations

from abc import ABC, abstractmethod
Expand All @@ -20,7 +21,7 @@
or to create a new key using the given formatting instruction (see `FormatField`).
"""

ParseResult = Tuple[List[_data.Sampler], Dict[Any, Any]]
ParseResult = Tuple[_data.Generators, List[_data.Sampler], Dict[Any, Any]]
"""The result of Parser.parse() is a tuple of encountered samplers and remaining metadata."""

_EXIF_TAGS = {v: k for k, v in ExifTags.TAGS.items()}
Expand All @@ -35,11 +36,6 @@ class Parser(ABC):
def __init__(self, normalize_parameters: bool = True):
self.do_normalization_pass = normalize_parameters

@property
@abstractmethod
def generator(self) -> _data.Generators:
"""Identifier for the inferred image generator."""

@abstractmethod
def read_parameters(
self,
Expand Down
7 changes: 2 additions & 5 deletions src/sd_parsers/parsers/_automatic1111.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Parser for images generated by AUTOMATIC1111's stable-diffusion-webui or similar."""

import json
import re
from contextlib import suppress
Expand All @@ -20,10 +21,6 @@ class AUTOMATIC1111Parser(Parser):

_COMPLEXITY_INDEX = 100

@property
def generator(self):
return Generators.AUTOMATIC1111

def read_parameters(self, image: Image, use_text: bool = True):
try:
if image.format == "PNG":
Expand Down Expand Up @@ -71,7 +68,7 @@ def parse(self, parameters: Dict[str, Any], _) -> ParseResult:
if negative_prompt:
sampler["negative_prompts"] = [Prompt(1, negative_prompt)]

return [Sampler(**sampler)], metadata
return Generators.AUTOMATIC1111, [Sampler(**sampler)], metadata


def get_sampler_info(lines):
Expand Down
7 changes: 2 additions & 5 deletions src/sd_parsers/parsers/_comfyui.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Parser for images generated by ComfyUI or similar."""

import json
import logging
from collections import defaultdict
Expand All @@ -24,10 +25,6 @@
class ComfyUIParser(Parser):
"""Parser for images generated by ComfyUI"""

@property
def generator(self):
return Generators.COMFYUI

def read_parameters(self, image: Image, use_text: bool = True):
if image.format != "PNG":
raise MetadataError("unsupported image format", image.format)
Expand All @@ -50,7 +47,7 @@ def parse(self, parameters: Dict[str, Any], _) -> ParseResult:
raise ParserError("error reading parameters") from error

samplers, metadata = ImageContext.extract(self, prompt, workflow)
return samplers, metadata
return Generators.COMFYUI, samplers, metadata


class ImageContext:
Expand Down
7 changes: 2 additions & 5 deletions src/sd_parsers/parsers/_dummy_parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Example stub for additional parsers"""

import json
from typing import Any, Dict

Expand All @@ -14,10 +15,6 @@ class DummyParser(Parser):
Example stub for additional parsers
"""

@property
def generator(self):
return Generators.UNKNOWN

def read_parameters(self, image: Image, use_text: bool = True):
"""
Read the relevant generation parameters from the given image.
Expand Down Expand Up @@ -95,4 +92,4 @@ def parse(self, parameters: Dict[str, Any], parsing_context: Any) -> ParseResult
raise ParserError("something happened here") from error

# return list of samplers and unused working parameters
return [Sampler(**sampler)], working_parameters
return Generators.UNKNOWN, [Sampler(**sampler)], working_parameters
7 changes: 2 additions & 5 deletions src/sd_parsers/parsers/_fooocus.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Parser for images generated by AUTOMATIC1111's stable-diffusion-webui or similar."""

import copy
import json
from typing import Any, Dict
Expand All @@ -19,10 +20,6 @@ class FooocusParser(Parser):

_COMPLEXITY_INDEX = 90

@property
def generator(self):
return Generators.FOOOCUS

def read_parameters(self, image: Image, use_text: bool = True):
try:
if image.format == "PNG":
Expand Down Expand Up @@ -71,4 +68,4 @@ def parse(self, _parameters: Dict[str, Any], _) -> ParseResult:
except KeyError as error:
raise ParserError("error reading parameter value") from error

return [Sampler(**sampler)], parameters
return Generators.FOOOCUS, [Sampler(**sampler)], parameters
5 changes: 3 additions & 2 deletions src/sd_parsers/parsers/_invokeai/_variant_dream.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Read generation parameters from images generated by legacy InvokeAI."""

from __future__ import annotations

import re
from contextlib import suppress
from typing import Any, Dict

from sd_parsers.data import Prompt, Sampler
from sd_parsers.data import Generators, Prompt, Sampler
from sd_parsers.exceptions import ParserError
from sd_parsers.parser import Parser, ParseResult, pop_keys

Expand Down Expand Up @@ -55,7 +56,7 @@ def _parse_dream(parser: Parser, parameters: dict) -> ParseResult:
# prompts
_add_prompts(sampler, prompts, {})

return [Sampler(**sampler)], metadata
return Generators.INVOKEAI, [Sampler(**sampler)], metadata


def _get_sampler(parser: Parser, metadata: Dict[str, Any], key: str):
Expand Down
5 changes: 3 additions & 2 deletions src/sd_parsers/parsers/_invokeai/_variant_invokeai_meta.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Read generation parameters from the newest InvokeAI metadata format."""

from __future__ import annotations

import json
from contextlib import suppress
from typing import Any, Dict

from sd_parsers.data import Model, Prompt, Sampler
from sd_parsers.data import Generators, Model, Prompt, Sampler
from sd_parsers.exceptions import ParserError
from sd_parsers.parser import Parser, ParseResult

Expand Down Expand Up @@ -44,7 +45,7 @@ def _parse_invokeai_meta(parser: Parser, parameters: Dict[str, Any]) -> ParseRes
metadata=model_info,
)

return [Sampler(**sampler)], metadata
return Generators.INVOKEAI, [Sampler(**sampler)], metadata


__all__ = ["_parse_invokeai_meta"]
5 changes: 3 additions & 2 deletions src/sd_parsers/parsers/_invokeai/_variant_sd_metadata.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Read generation parameters for an image containing a `sd-metadata` field."""

from __future__ import annotations

import json
from typing import Any, Dict

from sd_parsers.data import Model, Sampler
from sd_parsers.data import Generators, Model, Sampler
from sd_parsers.exceptions import ParserError
from sd_parsers.parser import Parser, ParseResult

Expand Down Expand Up @@ -50,7 +51,7 @@ def _parse_sd_metadata(parser: Parser, parameters: Dict[str, Any]) -> ParseResul
if model_name or model_hash:
sampler["model"] = Model(name=model_name, hash=model_hash)

return [Sampler(**sampler)], {**metadata, **metadata_image}
return Generators.INVOKEAI, [Sampler(**sampler)], {**metadata, **metadata_image}


__all__ = ["_parse_sd_metadata"]
6 changes: 1 addition & 5 deletions src/sd_parsers/parsers/_invokeai/parser.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Parser for images generated by InvokeAI."""

from __future__ import annotations

from typing import Any, Callable, Dict, NamedTuple

from PIL.Image import Image

from sd_parsers.data import Generators
from sd_parsers.exceptions import MetadataError
from sd_parsers.parser import Parser, ParseResult

Expand Down Expand Up @@ -43,10 +43,6 @@ class VariantParser(NamedTuple):
class InvokeAIParser(Parser):
"""parser for images generated by invokeai"""

@property
def generator(self):
return Generators.INVOKEAI

def read_parameters(self, image: Image, use_text: bool = True):
if image.format != "PNG":
raise MetadataError("unsupported image format", image.format)
Expand Down
7 changes: 2 additions & 5 deletions src/sd_parsers/parsers/_novelai.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Parser for images generated by NovelAI or similar."""

import copy
import json
import re
Expand All @@ -18,10 +19,6 @@
class NovelAIParser(Parser):
"""parser for images generated by NovelAI"""

@property
def generator(self):
return Generators.NOVELAI

def read_parameters(self, image: Image, use_text: bool = True):
if image.format != "PNG":
raise MetadataError("unsupported image format", image.format)
Expand Down Expand Up @@ -76,4 +73,4 @@ def parse(self, parameters: Dict[str, Any], _) -> ParseResult:
model_name, model_hash = match.groups()
sampler["model"] = Model(name=model_name, hash=model_hash)

return [Sampler(**sampler)], metadata
return Generators.NOVELAI, [Sampler(**sampler)], metadata
5 changes: 3 additions & 2 deletions tests/test_automatic1111.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from PIL import Image
from sd_parsers.data import Model, Prompt, Sampler
from sd_parsers.data import Model, Prompt, Sampler, Generators
from sd_parsers.parsers import AUTOMATIC1111Parser, _automatic1111

from tests.tools import RESOURCE_PATH
Expand Down Expand Up @@ -49,8 +49,9 @@ def test_parse(filename: str, expected):
with Image.open(RESOURCE_PATH / "parsers/AUTOMATIC1111" / filename) as image:
params = parser.read_parameters(image)

samplers, metadata = parser.parse(*params)
generator, samplers, metadata = parser.parse(*params)

assert generator == Generators.AUTOMATIC1111
assert samplers == expected_samplers
assert metadata == expected_metadata

Expand Down
Loading

0 comments on commit e0c3cc8

Please sign in to comment.