Skip to content

Commit

Permalink
Merge pull request #70 from spicecloud/inference-callback
Browse files Browse the repository at this point in the history
Inference callback
  • Loading branch information
djstein authored Aug 29, 2023
2 parents 2175611 + ca07a3f commit f9d5d95
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 42 deletions.
69 changes: 51 additions & 18 deletions spice_agent/inference/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@
from diffusers import (
DiffusionPipeline,
StableDiffusionPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLPipeline,
StableDiffusionXLImg2ImgPipeline,
)
from gql import gql
from gql.transport.exceptions import TransportQueryError

from spice_agent.inference.types import (
CallbackOptionsBase,
OutputForStableDiffusionPipeline,
InputForStableDiffusionPipeline,
InputForStableDiffusionXLPipeline,
Expand Down Expand Up @@ -182,6 +181,15 @@ class Inference:
def __init__(self, spice) -> None:
self.spice = spice
self.device = self.spice.get_device()
self.pipe = None
self.inference_job_id = None
self.pipe_input: Optional[
Union[
StableDiffusionPipelineInput,
StableDiffusionXLPipelineInput,
StableDiffusionXLImg2ImgPipelineInput,
]
] = None

# logging.basicConfig(level=logging.INFO)
if self.spice.DEBUG:
Expand All @@ -192,7 +200,8 @@ def __init__(self, spice) -> None:
def _update_inference_job(
self,
inference_job_id: str,
status: str,
status: Optional[str] = None,
status_details: Optional[Dict[str, Any]] = None,
was_guarded: Optional[bool] = None,
text_output: Optional[str] = None,
file_outputs_ids: list[str] = [],
Expand All @@ -208,6 +217,7 @@ def _update_inference_job(
updatedAt
completedAt
status
statusDetails
model {
name
slug
Expand All @@ -227,9 +237,11 @@ def _update_inference_job(
"""
)

input: Dict[str, str | float | list[str]] = {"inferenceJobId": inference_job_id}
input: Dict[str, Any] = {"inferenceJobId": inference_job_id}
if status is not None:
input["status"] = status
if status_details is not None:
input["statusDetails"] = status_details
if was_guarded is not None:
input["wasGuarded"] = was_guarded
if text_output is not None:
Expand Down Expand Up @@ -311,6 +323,28 @@ def _get_generator(self, seed: int) -> torch.Generator:
# PyTorch releases.
return torch.manual_seed(seed)

def callback_for_stable_diffusion(
self, step: int, timestep: int, latents: torch.FloatTensor
) -> None:
"""
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
""" # noqa

# Need access to vae to decode here:
if self.pipe and self.inference_job_id and self.pipe_input:
# Update progress on backend
self._update_inference_job(
inference_job_id=self.inference_job_id,
status_details={
"progress": (
(step + 1)
/ self.pipe_input.inference_options.num_inference_steps
)
* 100
},
)

def run_pipeline(
self,
inference_job_id: str,
Expand All @@ -329,6 +363,8 @@ def run_pipeline(
return

inference_job_id = result["updateInferenceJob"]["id"]
self.inference_job_id = inference_job_id

model_repo_id = result["updateInferenceJob"]["model"]["repoId"]
text_input = result["updateInferenceJob"]["textInput"]
is_text_input = result["updateInferenceJob"]["model"]["isTextInput"]
Expand Down Expand Up @@ -368,6 +404,7 @@ def run_pipeline(
prompt = text_input
negative_prompt = options.get("negative_prompt", "")
generator = self._get_generator(int(options.get("seed", -1)))
callback = self.callback_for_stable_diffusion

was_guarded = False
if not save_at.exists():
Expand All @@ -378,6 +415,7 @@ def run_pipeline(
use_safetensors=True,
)
pipe = pipe.to(self.device) # type: ignore
self.pipe = pipe

# Get input for Stable Diffusion Pipelines
input_for_stable_diffusion_pipeline = (
Expand All @@ -386,9 +424,6 @@ def run_pipeline(
)
)

# Get callback options
callback_options = CallbackOptionsBase()

# Configure Stable Diffusion TASK
if isinstance(pipe, StableDiffusionPipeline):
# Configure inference options for stable diffusion pipeline
Expand Down Expand Up @@ -417,23 +452,23 @@ def run_pipeline(
InferenceOptionsForStableDiffusionPipeline,
):
raise ValueError(
"Infernece options for stable diffusion pipeline not configured!" # noqa
"Inference options for stable diffusion pipeline not configured!" # noqa
)

# Specify input for stable diffusion pipeline
stable_diffusion_pipeline_input = StableDiffusionPipelineInput(
input=input_for_stable_diffusion_pipeline,
inference_options=inference_options_for_stable_diffusion,
callback_options=callback_options,
output=output_for_stable_diffusion_pipeline,
)
self.pipe_input = stable_diffusion_pipeline_input

pipe_result = pipe(
**asdict(stable_diffusion_pipeline_input.input),
**asdict(stable_diffusion_pipeline_input.inference_options),
**asdict(stable_diffusion_pipeline_input.callback_options),
**asdict(stable_diffusion_pipeline_input.output),
generator=generator,
callback=callback,
) # type:ignore
# Configure MOE for xl diffusion base + refinement TASK
elif isinstance(pipe, StableDiffusionXLPipeline):
Expand Down Expand Up @@ -469,20 +504,18 @@ def run_pipeline(
stable_diffusion_pipeline_xl_input = StableDiffusionXLPipelineInput( # noqa
input=input_for_stable_diffusion_xl,
inference_options=inference_options_for_stable_diffusion_xl,
callback_options=callback_options,
output=output_for_stable_diffusion_xl_pipeline,
)
self.pipe_input = stable_diffusion_pipeline_xl_input

latents = pipe(
**asdict(stable_diffusion_pipeline_xl_input.input),
**asdict(
stable_diffusion_pipeline_xl_input.inference_options
),
**asdict(
stable_diffusion_pipeline_xl_input.callback_options
),
**asdict(stable_diffusion_pipeline_xl_input.output),
generator=generator,
callback=callback,
).images # type: ignore

refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
Expand Down Expand Up @@ -526,18 +559,17 @@ def run_pipeline(
stable_diffusion_pipeline_xl_img2img_input = StableDiffusionXLImg2ImgPipelineInput( # noqa
input=input_for_stable_diffusion_xl_img2img,
inference_options=inference_options_for_stable_diffusion_xl_img2img,
callback_options=callback_options,
output=output_for_stable_diffusion_xl_img2img_pipeline,
)
self.pipe_input = stable_diffusion_pipeline_xl_img2img_input

# Note, we do not attach a callback here since refinement
# of the image is not as time consuming
pipe_result = refiner(
**asdict(stable_diffusion_pipeline_xl_img2img_input.input),
**asdict(
stable_diffusion_pipeline_xl_img2img_input.inference_options
),
**asdict(
stable_diffusion_pipeline_xl_img2img_input.callback_options
),
**asdict(stable_diffusion_pipeline_xl_img2img_input.output),
generator=generator,
) # type: ignore
Expand All @@ -551,6 +583,7 @@ def run_pipeline(
# denoting whether the corresponding generated image likely
# represents "not-safe-for-work" (nsfw) content, according to the
# `safety_checker`.
# TODO decode output based on output of actual pipeline
result = pipe_result[0][0] # type: ignore
if len(pipe_result) > 1 and pipe_result[1]: # type: ignore
was_guarded = pipe_result[1][0] # type: ignore
Expand Down
37 changes: 13 additions & 24 deletions spice_agent/inference/types.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,10 @@
import torch
from typing import Optional, Dict, Union, List, Any, Tuple, Callable
from typing import Optional, Dict, Union, List, Any, Tuple
from dataclasses import dataclass
from PIL import Image as PILImage
import numpy as np


# Base Data Classes --------------------------------------------------------------------
@dataclass
class CallbackOptionsBase:
"""
Callback options
Args:
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
""" # noqa

callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None
callback_steps: int = 1


# Data Classes for Stable Diffusion Pipeline -------------------------------------------
@dataclass
class InputForStableDiffusionPipeline:
Expand Down Expand Up @@ -77,6 +58,9 @@ class InferenceOptionsForStableDiffusionPipeline:
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
Expand All @@ -94,6 +78,7 @@ class InferenceOptionsForStableDiffusionPipeline:
guidance_scale: float = 7.5
num_images_per_prompt: Optional[int] = 1
eta: float = 0.0
callback_steps: int = 1
cross_attention_kwargs: Optional[Dict[str, Any]] = None
guidance_rescale: float = 0.7

Expand Down Expand Up @@ -187,6 +172,9 @@ class InferenceOptionsForStableDiffusionXLPipeline:
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
Expand All @@ -210,6 +198,7 @@ class InferenceOptionsForStableDiffusionXLPipeline:
guidance_scale: float = 7.5
num_images_per_prompt: Optional[int] = 1
eta: float = 0.0
callback_steps: int = 1
cross_attention_kwargs: Optional[Dict[str, Any]] = None
guidance_rescale: float = 0.7
original_size: Optional[Tuple[int, int]] = None
Expand Down Expand Up @@ -294,6 +283,9 @@ class InferenceOptionsForStableDiffusionXLImg2ImgPipeline:
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
Expand Down Expand Up @@ -322,6 +314,7 @@ class InferenceOptionsForStableDiffusionXLImg2ImgPipeline:
guidance_scale: float = 7.5
num_images_per_prompt: Optional[int] = 1
eta: float = 0.0
callback_steps: int = 1
cross_attention_kwargs: Optional[Dict[str, Any]] = None
guidance_rescale: float = 0.7
original_size: Optional[Tuple[int, int]] = None
Expand All @@ -336,29 +329,25 @@ class InferenceOptionsForStableDiffusionXLImg2ImgPipeline:
# These have the following 4 attributes:
# - input: these are inputs sent to the inference task
# - inference_options: these are options that affect the run
# - callback_options: these are options that observe the run
# - output: these are additional arguments that specify output behavior


@dataclass
class StableDiffusionPipelineInput:
input: InputForStableDiffusionPipeline
inference_options: InferenceOptionsForStableDiffusionPipeline
callback_options: CallbackOptionsBase
output: OutputForStableDiffusionPipeline


@dataclass
class StableDiffusionXLPipelineInput:
input: InputForStableDiffusionXLPipeline
inference_options: InferenceOptionsForStableDiffusionXLPipeline
callback_options: CallbackOptionsBase
output: OutputForStableDiffusionPipeline


@dataclass
class StableDiffusionXLImg2ImgPipelineInput:
input: InputForStableDiffusionXLImg2ImgPipeline
inference_options: InferenceOptionsForStableDiffusionXLImg2ImgPipeline
callback_options: CallbackOptionsBase
output: OutputForStableDiffusionPipeline

0 comments on commit f9d5d95

Please sign in to comment.