diff --git a/spice_agent/inference/actions.py b/spice_agent/inference/actions.py index 31fff4c..aaa9633 100644 --- a/spice_agent/inference/actions.py +++ b/spice_agent/inference/actions.py @@ -1,5 +1,6 @@ import json import logging +import threading import os from compel import Compel, ReturnedEmbeddingsType from pathlib import Path @@ -12,6 +13,15 @@ StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, ) + +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput +from diffusers.models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) from gql import gql from gql.transport.exceptions import TransportQueryError @@ -39,6 +49,11 @@ LOGGER = logging.getLogger(__name__) +# TODO: Currently, we only support single image outputs for +# Stable Diffusion pipelines. This value is a placeholder for when +# multi-image outputs are added +IMAGE_GROUP_VALUE = 0 + def _init_compel( device, pipeline: DiffusionPipeline, prompt: str, negative_prompt: str = "" @@ -191,6 +206,11 @@ def __init__(self, spice) -> None: ] ] = None + # Threading + self.progress_thread = None + self.image_preview_thread = None + self.update_inference_job_lock = threading.Lock() + # logging.basicConfig(level=logging.INFO) if self.spice.DEBUG: transformers.logging.set_verbosity_debug() # type: ignore @@ -323,6 +343,116 @@ def _get_generator(self, seed: int) -> torch.Generator: # PyTorch releases. return torch.manual_seed(seed) + def update_progress(self, step: int): + if self.pipe and self.inference_job_id and self.pipe_input: + progress = ( + (step + 1) / self.pipe_input.inference_options.num_inference_steps + ) * 100 + + with self.update_inference_job_lock: + self._update_inference_job( + inference_job_id=self.inference_job_id, + status_details={"progress": progress}, + ) + + def update_image_preview_for_stable_diffusion_xl( + self, step: int, latents: torch.FloatTensor + ): + if self.pipe and self.inference_job_id and self.pipe_input: + # Send preview images + with torch.no_grad(): + # make sure the VAE is in float32 mode, as it overflows in float16 + self.pipe.vae.to(dtype=torch.float32) + + use_torch_2_0_or_xformers = isinstance( + self.pipe.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + if use_torch_2_0_or_xformers: + self.pipe.vae.post_quant_conv.to(latents.dtype) + self.pipe.vae.decoder.conv_in.to(latents.dtype) + self.pipe.vae.decoder.mid_block.to(latents.dtype) + else: + latents = latents.float() # type: ignore + + file_name = f"{self.inference_job_id}-{step}-{IMAGE_GROUP_VALUE}.png" + save_at = Path(SPICE_INFERENCE_DIRECTORY / file_name) + + image = self.pipe.vae.decode( + latents / self.pipe.vae.config.scaling_factor, return_dict=False + )[0] + image = self.pipe.watermark.apply_watermark(image) + image = self.pipe.image_processor.postprocess(image, output_type="pil") + + image = StableDiffusionXLPipelineOutput(images=image)[0][0] + image.save(save_at) + + upload_file_response = self.spice.uploader.upload_file_via_api( + path=save_at + ) + file_id = upload_file_response.json()["data"]["uploadFile"]["id"] + + with self.update_inference_job_lock: + self._update_inference_job( + inference_job_id=self.inference_job_id, + status="COMPLETE", + file_outputs_ids=file_id, + ) + + def update_image_preview_for_stable_diffusion( + self, step: int, latents: torch.FloatTensor + ): + if self.pipe and self.inference_job_id and self.pipe_input: + # Send preview images + with torch.no_grad(): + image = self.pipe.vae.decode( + latents / self.pipe.vae.config.scaling_factor, return_dict=False + )[0] + image, has_nsfw_concept = self.pipe.run_safety_checker( + image, self.device, torch.FloatTensor + ) + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.pipe.image_processor.postprocess( + image, do_denormalize=do_denormalize + ) + + result = StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) + + file_name = f"{self.inference_job_id}-{step}-{IMAGE_GROUP_VALUE}.png" + save_at = Path(SPICE_INFERENCE_DIRECTORY / file_name) + + image = result[0][0] + image.save(save_at) + if len(result) > 1 and result[1]: # type: ignore + was_guarded = result[1][0] + else: + was_guarded = False + + upload_file_response = self.spice.uploader.upload_file_via_api( + path=save_at + ) + file_id = upload_file_response.json()["data"]["uploadFile"]["id"] + + with self.update_inference_job_lock: + self._update_inference_job( + inference_job_id=self.inference_job_id, + status="COMPLETE", + file_outputs_ids=file_id, + was_guarded=was_guarded, + ) + def callback_for_stable_diffusion( self, step: int, timestep: int, latents: torch.FloatTensor ) -> None: @@ -330,20 +460,48 @@ def callback_for_stable_diffusion( A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. """ # noqa + if self.pipe and self.inference_job_id and self.pipe_input and step > 0: + if not self.progress_thread or not self.progress_thread.is_alive(): + self.progress_thread = threading.Thread( + target=self.update_progress, args=(step,) + ) + self.progress_thread.start() + + if ( + not self.image_preview_thread + or not self.image_preview_thread.is_alive() + ) and step > self.pipe_input.inference_options.num_inference_steps // 2: + self.image_preview_thread = threading.Thread( + target=self.update_image_preview_for_stable_diffusion, + args=(step, latents), + ) + self.image_preview_thread.start() + + def callback_for_stable_diffusion_xl( + self, step: int, timestep: int, latents: torch.FloatTensor + ) -> None: + """ + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + """ # noqa # Need access to vae to decode here: - if self.pipe and self.inference_job_id and self.pipe_input: - # Update progress on backend - self._update_inference_job( - inference_job_id=self.inference_job_id, - status_details={ - "progress": ( - (step + 1) - / self.pipe_input.inference_options.num_inference_steps - ) - * 100 - }, - ) + if self.pipe and self.inference_job_id and self.pipe_input and step > 0: + if not self.progress_thread or not self.progress_thread.is_alive(): + self.progress_thread = threading.Thread( + target=self.update_progress, args=(step,) + ) + self.progress_thread.start() + + if ( + not self.image_preview_thread + or not self.image_preview_thread.is_alive() + ) and step > self.pipe_input.inference_options.num_inference_steps // 2: + self.image_preview_thread = threading.Thread( + target=self.update_image_preview_for_stable_diffusion_xl, + args=(step, latents), + ) + self.image_preview_thread.start() def run_pipeline( self, @@ -398,15 +556,16 @@ def run_pipeline( text_output=json.dumps(result), ) elif is_text_input and is_file_output: - SPICE_INFERENCE_DIRECTORY.mkdir(parents=True, exist_ok=True) - save_at = Path(SPICE_INFERENCE_DIRECTORY / f"{inference_job_id}.png") - prompt = text_input negative_prompt = options.get("negative_prompt", "") generator = self._get_generator(int(options.get("seed", -1))) - callback = self.callback_for_stable_diffusion + max_step = options.get("num_inference_steps", 999) + SPICE_INFERENCE_DIRECTORY.mkdir(parents=True, exist_ok=True) + file_name = f"{inference_job_id}-{max_step}-{IMAGE_GROUP_VALUE}.png" + save_at = Path(SPICE_INFERENCE_DIRECTORY / file_name) was_guarded = False + if not save_at.exists(): pipe = DiffusionPipeline.from_pretrained( model_repo_id, @@ -468,8 +627,16 @@ def run_pipeline( **asdict(stable_diffusion_pipeline_input.inference_options), **asdict(stable_diffusion_pipeline_input.output), generator=generator, - callback=callback, + callback=self.callback_for_stable_diffusion, ) # type:ignore + + # Cleanup threads + if self.progress_thread: + self.progress_thread.join() + + if self.image_preview_thread: + self.image_preview_thread.join() + # Configure MOE for xl diffusion base + refinement TASK elif isinstance(pipe, StableDiffusionXLPipeline): # Configure input for stable diffusion xl pipeline @@ -515,9 +682,16 @@ def run_pipeline( ), **asdict(stable_diffusion_pipeline_xl_input.output), generator=generator, - callback=callback, + callback=self.callback_for_stable_diffusion_xl, ).images # type: ignore + # Cleanup threads + if self.progress_thread: + self.progress_thread.join() + + if self.image_preview_thread: + self.image_preview_thread.join() + refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-refiner-1.0", text_encoder_2=pipe.text_encoder_2, @@ -595,9 +769,13 @@ def run_pipeline( path=save_at ) file_id = upload_file_response.json()["data"]["uploadFile"]["id"] + response = self._update_inference_job( inference_job_id=inference_job_id, status="COMPLETE", + status_details={ + "progress": 100, + }, file_outputs_ids=file_id, was_guarded=was_guarded, )