From 5f6e1e3b1654e4aa1576f16ee0aa542c96d46540 Mon Sep 17 00:00:00 2001 From: Ankush Patel Date: Sun, 27 Aug 2023 18:20:21 -0700 Subject: [PATCH 1/5] split progress update and image_preview_update into two threads --- spice_agent/inference/actions.py | 91 ++++++++++++++++++++++++-------- 1 file changed, 68 insertions(+), 23 deletions(-) diff --git a/spice_agent/inference/actions.py b/spice_agent/inference/actions.py index b3615ae..56b257e 100644 --- a/spice_agent/inference/actions.py +++ b/spice_agent/inference/actions.py @@ -1,5 +1,6 @@ import json import logging +import threading import os from compel import Compel, ReturnedEmbeddingsType from pathlib import Path @@ -198,6 +199,11 @@ def __init__(self, spice) -> None: ] ] = None + # Threading + self.progress_thread = None + self.image_preview_thread = None + self.update_inference_job_lock = threading.Lock() + # logging.basicConfig(level=logging.INFO) if self.spice.DEBUG: transformers.logging.set_verbosity_debug() # type: ignore @@ -352,18 +358,25 @@ def callback_for_stable_diffusion( }, ) - def callback_for_stable_diffusion_xl( - self, step: int, timestep: int, latents: torch.FloatTensor - ) -> None: - """ - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - """ # noqa + def update_progress(self, step: int): + if self.pipe and self.inference_job_id and self.pipe_input: + progress = ( + (step + 1) / self.pipe_input.inference_options.num_inference_steps + ) * 100 - # Need access to vae to decode here: - if self.pipe and self.inference_job_id and self.pipe_input and step > 0: + print(f"Sending progress update: {progress}") + with self.update_inference_job_lock: + self._update_inference_job( + inference_job_id=self.inference_job_id, + status_details={"progress": progress}, + ) + print("Sending progress complete!") + + def update_image_preview(self, step: int, latents: torch.FloatTensor): + if self.pipe and self.inference_job_id and self.pipe_input: # Send preview images with torch.no_grad(): + print(f"Sending image preview update! {step}") # make sure the VAE is in float32 mode, as it overflows in float16 self.pipe.vae.to(dtype=torch.float32) @@ -399,19 +412,42 @@ def callback_for_stable_diffusion_xl( path=save_at ) file_id = upload_file_response.json()["data"]["uploadFile"]["id"] - self._update_inference_job( - inference_job_id=self.inference_job_id, - status="COMPLETE", - file_outputs_ids=file_id, - status_details={ - "progress": ( - (step + 1) - / self.pipe_input.inference_options.num_inference_steps - ) - * 100, - "current_image_file_name": file_name, - }, + + with self.update_inference_job_lock: + self._update_inference_job( + inference_job_id=self.inference_job_id, + status="COMPLETE", + file_outputs_ids=file_id, + status_details={ + "current_image_file_name": file_name, + }, + ) + print("Sending image preview update complete!") + + def callback_for_stable_diffusion_xl( + self, step: int, timestep: int, latents: torch.FloatTensor + ) -> None: + """ + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + """ # noqa + + # Need access to vae to decode here: + if self.pipe and self.inference_job_id and self.pipe_input and step > 0: + if not self.progress_thread or not self.progress_thread.is_alive(): + self.progress_thread = threading.Thread( + target=self.update_progress, args=(step,) + ) + self.progress_thread.start() + + if ( + not self.image_preview_thread + or not self.image_preview_thread.is_alive() + ): + self.image_preview_thread = threading.Thread( + target=self.update_image_preview, args=(step, latents) ) + self.image_preview_thread.start() def run_pipeline( self, @@ -473,8 +509,6 @@ def run_pipeline( negative_prompt = options.get("negative_prompt", "") generator = self._get_generator(int(options.get("seed", -1))) - options["callback_steps"] = 10 - was_guarded = False if not save_at.exists(): pipe = DiffusionPipeline.from_pretrained( @@ -587,6 +621,13 @@ def run_pipeline( callback=self.callback_for_stable_diffusion_xl, ).images # type: ignore + # Cleanup threads + if self.progress_thread: + self.progress_thread.join() + + if self.image_preview_thread: + self.image_preview_thread.join() + refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-refiner-1.0", text_encoder_2=pipe.text_encoder_2, @@ -667,6 +708,10 @@ def run_pipeline( response = self._update_inference_job( inference_job_id=inference_job_id, status="COMPLETE", + status_details={ + "progress": 100, + "current_image_file_name": f"{inference_job_id}.png", + }, file_outputs_ids=file_id, was_guarded=was_guarded, ) From 2fd43de4ab4fd47242d1e1ccb9ec356c238e740f Mon Sep 17 00:00:00 2001 From: Ankush Patel Date: Tue, 29 Aug 2023 00:11:40 -0700 Subject: [PATCH 2/5] add sd2.1 diffusion callback --- spice_agent/inference/actions.py | 118 +++++++++++++++++++++++-------- 1 file changed, 90 insertions(+), 28 deletions(-) diff --git a/spice_agent/inference/actions.py b/spice_agent/inference/actions.py index 56b257e..4c0f37b 100644 --- a/spice_agent/inference/actions.py +++ b/spice_agent/inference/actions.py @@ -13,6 +13,7 @@ StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, ) +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput from diffusers.models.attention_processor import ( AttnProcessor2_0, @@ -336,47 +337,24 @@ def _get_generator(self, seed: int) -> torch.Generator: # PyTorch releases. return torch.manual_seed(seed) - def callback_for_stable_diffusion( - self, step: int, timestep: int, latents: torch.FloatTensor - ) -> None: - """ - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - """ # noqa - - # Need access to vae to decode here: - if self.pipe and self.inference_job_id and self.pipe_input and step > 0: - # Update progress on backend - self._update_inference_job( - inference_job_id=self.inference_job_id, - status_details={ - "progress": ( - (step + 1) - / self.pipe_input.inference_options.num_inference_steps - ) - * 100 - }, - ) - def update_progress(self, step: int): if self.pipe and self.inference_job_id and self.pipe_input: progress = ( (step + 1) / self.pipe_input.inference_options.num_inference_steps ) * 100 - print(f"Sending progress update: {progress}") with self.update_inference_job_lock: self._update_inference_job( inference_job_id=self.inference_job_id, status_details={"progress": progress}, ) - print("Sending progress complete!") - def update_image_preview(self, step: int, latents: torch.FloatTensor): + def update_image_preview_for_stable_diffusion_xl( + self, step: int, latents: torch.FloatTensor + ): if self.pipe and self.inference_job_id and self.pipe_input: # Send preview images with torch.no_grad(): - print(f"Sending image preview update! {step}") # make sure the VAE is in float32 mode, as it overflows in float16 self.pipe.vae.to(dtype=torch.float32) @@ -422,7 +400,82 @@ def update_image_preview(self, step: int, latents: torch.FloatTensor): "current_image_file_name": file_name, }, ) - print("Sending image preview update complete!") + + def update_image_preview_for_stable_diffusion( + self, step: int, latents: torch.FloatTensor + ): + if self.pipe and self.inference_job_id and self.pipe_input: + # Send preview images + with torch.no_grad(): + image = self.pipe.vae.decode( + latents / self.pipe.vae.config.scaling_factor, return_dict=False + )[0] + image, has_nsfw_concept = self.pipe.run_safety_checker( + image, self.device, torch.FloatTensor + ) + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.pipe.image_processor.postprocess( + image, do_denormalize=do_denormalize + ) + + result = StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) + + file_name = f"{self.inference_job_id}-preview-step-{step}.png" + save_at = Path(SPICE_INFERENCE_DIRECTORY / file_name) + + image = result[0][0] + image.save(save_at) + if len(result) > 1 and result[1]: # type: ignore + was_guarded = result[1][0] + else: + was_guarded = False + + upload_file_response = self.spice.uploader.upload_file_via_api( + path=save_at + ) + file_id = upload_file_response.json()["data"]["uploadFile"]["id"] + + with self.update_inference_job_lock: + self._update_inference_job( + inference_job_id=self.inference_job_id, + status="COMPLETE", + file_outputs_ids=file_id, + was_guarded=was_guarded, + status_details={ + "current_image_file_name": file_name, + }, + ) + + def callback_for_stable_diffusion( + self, step: int, timestep: int, latents: torch.FloatTensor + ) -> None: + """ + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + """ # noqa + if self.pipe and self.inference_job_id and self.pipe_input and step > 0: + if not self.progress_thread or not self.progress_thread.is_alive(): + self.progress_thread = threading.Thread( + target=self.update_progress, args=(step,) + ) + self.progress_thread.start() + + if ( + not self.image_preview_thread + or not self.image_preview_thread.is_alive() + ): + self.image_preview_thread = threading.Thread( + target=self.update_image_preview_for_stable_diffusion, + args=(step, latents), + ) + self.image_preview_thread.start() def callback_for_stable_diffusion_xl( self, step: int, timestep: int, latents: torch.FloatTensor @@ -445,7 +498,8 @@ def callback_for_stable_diffusion_xl( or not self.image_preview_thread.is_alive() ): self.image_preview_thread = threading.Thread( - target=self.update_image_preview, args=(step, latents) + target=self.update_image_preview_for_stable_diffusion_xl, + args=(step, latents), ) self.image_preview_thread.start() @@ -573,6 +627,14 @@ def run_pipeline( generator=generator, callback=self.callback_for_stable_diffusion, ) # type:ignore + + # Cleanup threads + if self.progress_thread: + self.progress_thread.join() + + if self.image_preview_thread: + self.image_preview_thread.join() + # Configure MOE for xl diffusion base + refinement TASK elif isinstance(pipe, StableDiffusionXLPipeline): # Configure input for stable diffusion xl pipeline From c7c7d1390a0dd6f5caca33ab248a9e7f21c909c7 Mon Sep 17 00:00:00 2001 From: Ankush Patel Date: Wed, 30 Aug 2023 12:08:22 -0700 Subject: [PATCH 3/5] add file version history in status details --- spice_agent/inference/actions.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/spice_agent/inference/actions.py b/spice_agent/inference/actions.py index 4c0f37b..66fb988 100644 --- a/spice_agent/inference/actions.py +++ b/spice_agent/inference/actions.py @@ -374,7 +374,7 @@ def update_image_preview_for_stable_diffusion_xl( else: latents = latents.float() # type: ignore - file_name = f"{self.inference_job_id}-preview-step-{step}.png" + file_name = f"{self.inference_job_id}-{step}.png" save_at = Path(SPICE_INFERENCE_DIRECTORY / file_name) image = self.pipe.vae.decode( @@ -396,9 +396,7 @@ def update_image_preview_for_stable_diffusion_xl( inference_job_id=self.inference_job_id, status="COMPLETE", file_outputs_ids=file_id, - status_details={ - "current_image_file_name": file_name, - }, + status_details={"file_version_history": {step: file_name}}, ) def update_image_preview_for_stable_diffusion( @@ -427,7 +425,7 @@ def update_image_preview_for_stable_diffusion( images=image, nsfw_content_detected=has_nsfw_concept ) - file_name = f"{self.inference_job_id}-preview-step-{step}.png" + file_name = f"{self.inference_job_id}-{step}.png" save_at = Path(SPICE_INFERENCE_DIRECTORY / file_name) image = result[0][0] @@ -448,9 +446,7 @@ def update_image_preview_for_stable_diffusion( status="COMPLETE", file_outputs_ids=file_id, was_guarded=was_guarded, - status_details={ - "current_image_file_name": file_name, - }, + status_details={"file_version_history": {step: file_name}}, ) def callback_for_stable_diffusion( @@ -470,7 +466,7 @@ def callback_for_stable_diffusion( if ( not self.image_preview_thread or not self.image_preview_thread.is_alive() - ): + ) and step > self.pipe_input.inference_options.num_inference_steps // 2: self.image_preview_thread = threading.Thread( target=self.update_image_preview_for_stable_diffusion, args=(step, latents), @@ -496,7 +492,7 @@ def callback_for_stable_diffusion_xl( if ( not self.image_preview_thread or not self.image_preview_thread.is_alive() - ): + ) and step > self.pipe_input.inference_options.num_inference_steps // 2: self.image_preview_thread = threading.Thread( target=self.update_image_preview_for_stable_diffusion_xl, args=(step, latents), @@ -556,13 +552,15 @@ def run_pipeline( text_output=json.dumps(result), ) elif is_text_input and is_file_output: - SPICE_INFERENCE_DIRECTORY.mkdir(parents=True, exist_ok=True) - save_at = Path(SPICE_INFERENCE_DIRECTORY / f"{inference_job_id}.png") - prompt = text_input negative_prompt = options.get("negative_prompt", "") generator = self._get_generator(int(options.get("seed", -1))) + # TODO: remove max_step once a file version system is in place + max_step = options.get("num_inference_steps", 999) + SPICE_INFERENCE_DIRECTORY.mkdir(parents=True, exist_ok=True) + file_name = f"{inference_job_id}-{max_step}.png" + save_at = Path(SPICE_INFERENCE_DIRECTORY / file_name) was_guarded = False if not save_at.exists(): pipe = DiffusionPipeline.from_pretrained( @@ -767,12 +765,13 @@ def run_pipeline( path=save_at ) file_id = upload_file_response.json()["data"]["uploadFile"]["id"] + response = self._update_inference_job( inference_job_id=inference_job_id, status="COMPLETE", status_details={ "progress": 100, - "current_image_file_name": f"{inference_job_id}.png", + "file_version_history": {max_step: file_name}, }, file_outputs_ids=file_id, was_guarded=was_guarded, From 7012fc9d75da36b5d9c5b8e7c564abb842acf7d6 Mon Sep 17 00:00:00 2001 From: Ankush Patel Date: Wed, 30 Aug 2023 12:17:31 -0700 Subject: [PATCH 4/5] specify callbacks --- spice_agent/inference/actions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/spice_agent/inference/actions.py b/spice_agent/inference/actions.py index 46147b9..9740439 100644 --- a/spice_agent/inference/actions.py +++ b/spice_agent/inference/actions.py @@ -562,8 +562,8 @@ def run_pipeline( SPICE_INFERENCE_DIRECTORY.mkdir(parents=True, exist_ok=True) file_name = f"{inference_job_id}-{max_step}.png" save_at = Path(SPICE_INFERENCE_DIRECTORY / file_name) - was_guarded = False + if not save_at.exists(): pipe = DiffusionPipeline.from_pretrained( model_repo_id, @@ -625,7 +625,7 @@ def run_pipeline( **asdict(stable_diffusion_pipeline_input.inference_options), **asdict(stable_diffusion_pipeline_input.output), generator=generator, - callback=callback, + callback=self.callback_for_stable_diffusion, ) # type:ignore # Cleanup threads @@ -680,7 +680,7 @@ def run_pipeline( ), **asdict(stable_diffusion_pipeline_xl_input.output), generator=generator, - callback=callback, + callback=self.callback_for_stable_diffusion_xl, ).images # type: ignore # Cleanup threads From ec057a1e4d6d522f81e3df9cfc0b4774bee3ccff Mon Sep 17 00:00:00 2001 From: Ankush Patel Date: Wed, 30 Aug 2023 16:02:55 -0700 Subject: [PATCH 5/5] remove file_version_history --- spice_agent/inference/actions.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/spice_agent/inference/actions.py b/spice_agent/inference/actions.py index 9740439..aaa9633 100644 --- a/spice_agent/inference/actions.py +++ b/spice_agent/inference/actions.py @@ -49,6 +49,11 @@ LOGGER = logging.getLogger(__name__) +# TODO: Currently, we only support single image outputs for +# Stable Diffusion pipelines. This value is a placeholder for when +# multi-image outputs are added +IMAGE_GROUP_VALUE = 0 + def _init_compel( device, pipeline: DiffusionPipeline, prompt: str, negative_prompt: str = "" @@ -375,7 +380,7 @@ def update_image_preview_for_stable_diffusion_xl( else: latents = latents.float() # type: ignore - file_name = f"{self.inference_job_id}-{step}.png" + file_name = f"{self.inference_job_id}-{step}-{IMAGE_GROUP_VALUE}.png" save_at = Path(SPICE_INFERENCE_DIRECTORY / file_name) image = self.pipe.vae.decode( @@ -397,7 +402,6 @@ def update_image_preview_for_stable_diffusion_xl( inference_job_id=self.inference_job_id, status="COMPLETE", file_outputs_ids=file_id, - status_details={"file_version_history": {step: file_name}}, ) def update_image_preview_for_stable_diffusion( @@ -426,7 +430,7 @@ def update_image_preview_for_stable_diffusion( images=image, nsfw_content_detected=has_nsfw_concept ) - file_name = f"{self.inference_job_id}-{step}.png" + file_name = f"{self.inference_job_id}-{step}-{IMAGE_GROUP_VALUE}.png" save_at = Path(SPICE_INFERENCE_DIRECTORY / file_name) image = result[0][0] @@ -447,7 +451,6 @@ def update_image_preview_for_stable_diffusion( status="COMPLETE", file_outputs_ids=file_id, was_guarded=was_guarded, - status_details={"file_version_history": {step: file_name}}, ) def callback_for_stable_diffusion( @@ -557,10 +560,9 @@ def run_pipeline( negative_prompt = options.get("negative_prompt", "") generator = self._get_generator(int(options.get("seed", -1))) - # TODO: remove max_step once a file version system is in place max_step = options.get("num_inference_steps", 999) SPICE_INFERENCE_DIRECTORY.mkdir(parents=True, exist_ok=True) - file_name = f"{inference_job_id}-{max_step}.png" + file_name = f"{inference_job_id}-{max_step}-{IMAGE_GROUP_VALUE}.png" save_at = Path(SPICE_INFERENCE_DIRECTORY / file_name) was_guarded = False @@ -773,7 +775,6 @@ def run_pipeline( status="COMPLETE", status_details={ "progress": 100, - "file_version_history": {max_step: file_name}, }, file_outputs_ids=file_id, was_guarded=was_guarded,