Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference callback with image preview #73

Merged
merged 6 commits into from
Aug 31, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 196 additions & 18 deletions spice_agent/inference/actions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
import threading
import os
from compel import Compel, ReturnedEmbeddingsType
from pathlib import Path
Expand All @@ -12,6 +13,15 @@
StableDiffusionXLPipeline,
StableDiffusionXLImg2ImgPipeline,
)

from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from gql import gql
from gql.transport.exceptions import TransportQueryError

Expand Down Expand Up @@ -39,6 +49,11 @@

LOGGER = logging.getLogger(__name__)

# TODO: Currently, we only support single image outputs for
# Stable Diffusion pipelines. This value is a placeholder for when
# multi-image outputs are added
IMAGE_GROUP_VALUE = 0


def _init_compel(
device, pipeline: DiffusionPipeline, prompt: str, negative_prompt: str = ""
Expand Down Expand Up @@ -191,6 +206,11 @@ def __init__(self, spice) -> None:
]
] = None

# Threading
self.progress_thread = None
self.image_preview_thread = None
self.update_inference_job_lock = threading.Lock()

# logging.basicConfig(level=logging.INFO)
if self.spice.DEBUG:
transformers.logging.set_verbosity_debug() # type: ignore
Expand Down Expand Up @@ -323,27 +343,165 @@ def _get_generator(self, seed: int) -> torch.Generator:
# PyTorch releases.
return torch.manual_seed(seed)

def update_progress(self, step: int):
if self.pipe and self.inference_job_id and self.pipe_input:
progress = (
(step + 1) / self.pipe_input.inference_options.num_inference_steps
) * 100

with self.update_inference_job_lock:
self._update_inference_job(
inference_job_id=self.inference_job_id,
status_details={"progress": progress},
)

def update_image_preview_for_stable_diffusion_xl(
self, step: int, latents: torch.FloatTensor
):
if self.pipe and self.inference_job_id and self.pipe_input:
# Send preview images
with torch.no_grad():
# make sure the VAE is in float32 mode, as it overflows in float16
self.pipe.vae.to(dtype=torch.float32)

use_torch_2_0_or_xformers = isinstance(
self.pipe.vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
if use_torch_2_0_or_xformers:
self.pipe.vae.post_quant_conv.to(latents.dtype)
self.pipe.vae.decoder.conv_in.to(latents.dtype)
self.pipe.vae.decoder.mid_block.to(latents.dtype)
else:
latents = latents.float() # type: ignore

file_name = f"{self.inference_job_id}-{step}-{IMAGE_GROUP_VALUE}.png"
save_at = Path(SPICE_INFERENCE_DIRECTORY / file_name)

image = self.pipe.vae.decode(
latents / self.pipe.vae.config.scaling_factor, return_dict=False
)[0]
image = self.pipe.watermark.apply_watermark(image)
image = self.pipe.image_processor.postprocess(image, output_type="pil")

image = StableDiffusionXLPipelineOutput(images=image)[0][0]
image.save(save_at)

upload_file_response = self.spice.uploader.upload_file_via_api(
path=save_at
)
file_id = upload_file_response.json()["data"]["uploadFile"]["id"]

with self.update_inference_job_lock:
self._update_inference_job(
inference_job_id=self.inference_job_id,
status="COMPLETE",
file_outputs_ids=file_id,
)

def update_image_preview_for_stable_diffusion(
self, step: int, latents: torch.FloatTensor
):
if self.pipe and self.inference_job_id and self.pipe_input:
# Send preview images
with torch.no_grad():
image = self.pipe.vae.decode(
latents / self.pipe.vae.config.scaling_factor, return_dict=False
)[0]
image, has_nsfw_concept = self.pipe.run_safety_checker(
image, self.device, torch.FloatTensor
)

if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

image = self.pipe.image_processor.postprocess(
image, do_denormalize=do_denormalize
)

result = StableDiffusionPipelineOutput(
images=image, nsfw_content_detected=has_nsfw_concept
)

file_name = f"{self.inference_job_id}-{step}-{IMAGE_GROUP_VALUE}.png"
save_at = Path(SPICE_INFERENCE_DIRECTORY / file_name)

image = result[0][0]
image.save(save_at)
if len(result) > 1 and result[1]: # type: ignore
was_guarded = result[1][0]
else:
was_guarded = False

upload_file_response = self.spice.uploader.upload_file_via_api(
path=save_at
)
file_id = upload_file_response.json()["data"]["uploadFile"]["id"]

with self.update_inference_job_lock:
self._update_inference_job(
inference_job_id=self.inference_job_id,
status="COMPLETE",
file_outputs_ids=file_id,
was_guarded=was_guarded,
)

def callback_for_stable_diffusion(
self, step: int, timestep: int, latents: torch.FloatTensor
) -> None:
"""
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
""" # noqa
if self.pipe and self.inference_job_id and self.pipe_input and step > 0:
if not self.progress_thread or not self.progress_thread.is_alive():
self.progress_thread = threading.Thread(
target=self.update_progress, args=(step,)
)
self.progress_thread.start()

if (
not self.image_preview_thread
or not self.image_preview_thread.is_alive()
) and step > self.pipe_input.inference_options.num_inference_steps // 2:
self.image_preview_thread = threading.Thread(
target=self.update_image_preview_for_stable_diffusion,
args=(step, latents),
)
self.image_preview_thread.start()

def callback_for_stable_diffusion_xl(
self, step: int, timestep: int, latents: torch.FloatTensor
) -> None:
"""
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
""" # noqa

# Need access to vae to decode here:
if self.pipe and self.inference_job_id and self.pipe_input:
# Update progress on backend
self._update_inference_job(
inference_job_id=self.inference_job_id,
status_details={
"progress": (
(step + 1)
/ self.pipe_input.inference_options.num_inference_steps
)
* 100
},
)
if self.pipe and self.inference_job_id and self.pipe_input and step > 0:
if not self.progress_thread or not self.progress_thread.is_alive():
self.progress_thread = threading.Thread(
target=self.update_progress, args=(step,)
)
self.progress_thread.start()

if (
not self.image_preview_thread
or not self.image_preview_thread.is_alive()
) and step > self.pipe_input.inference_options.num_inference_steps // 2:
self.image_preview_thread = threading.Thread(
target=self.update_image_preview_for_stable_diffusion_xl,
args=(step, latents),
)
self.image_preview_thread.start()

def run_pipeline(
self,
Expand Down Expand Up @@ -398,15 +556,16 @@ def run_pipeline(
text_output=json.dumps(result),
)
elif is_text_input and is_file_output:
SPICE_INFERENCE_DIRECTORY.mkdir(parents=True, exist_ok=True)
save_at = Path(SPICE_INFERENCE_DIRECTORY / f"{inference_job_id}.png")

prompt = text_input
negative_prompt = options.get("negative_prompt", "")
generator = self._get_generator(int(options.get("seed", -1)))
callback = self.callback_for_stable_diffusion

max_step = options.get("num_inference_steps", 999)
SPICE_INFERENCE_DIRECTORY.mkdir(parents=True, exist_ok=True)
file_name = f"{inference_job_id}-{max_step}-{IMAGE_GROUP_VALUE}.png"
save_at = Path(SPICE_INFERENCE_DIRECTORY / file_name)
was_guarded = False

if not save_at.exists():
pipe = DiffusionPipeline.from_pretrained(
model_repo_id,
Expand Down Expand Up @@ -468,8 +627,16 @@ def run_pipeline(
**asdict(stable_diffusion_pipeline_input.inference_options),
**asdict(stable_diffusion_pipeline_input.output),
generator=generator,
callback=callback,
callback=self.callback_for_stable_diffusion,
) # type:ignore

# Cleanup threads
if self.progress_thread:
self.progress_thread.join()

if self.image_preview_thread:
self.image_preview_thread.join()

# Configure MOE for xl diffusion base + refinement TASK
elif isinstance(pipe, StableDiffusionXLPipeline):
# Configure input for stable diffusion xl pipeline
Expand Down Expand Up @@ -515,9 +682,16 @@ def run_pipeline(
),
**asdict(stable_diffusion_pipeline_xl_input.output),
generator=generator,
callback=callback,
callback=self.callback_for_stable_diffusion_xl,
).images # type: ignore

# Cleanup threads
if self.progress_thread:
self.progress_thread.join()

if self.image_preview_thread:
self.image_preview_thread.join()

refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=pipe.text_encoder_2,
Expand Down Expand Up @@ -595,9 +769,13 @@ def run_pipeline(
path=save_at
)
file_id = upload_file_response.json()["data"]["uploadFile"]["id"]

response = self._update_inference_job(
inference_job_id=inference_job_id,
status="COMPLETE",
status_details={
"progress": 100,
},
file_outputs_ids=file_id,
was_guarded=was_guarded,
)
Expand Down