Skip to content

Commit

Permalink
Merge pull request #60 from spicecloud/inference-job-options
Browse files Browse the repository at this point in the history
Inference job options
  • Loading branch information
djstein authored Aug 8, 2023
2 parents 9f5abcc + e6e6504 commit 20081a0
Showing 1 changed file with 37 additions and 8 deletions.
45 changes: 37 additions & 8 deletions spice_agent/inference/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
from pathlib import Path
from typing import Dict, Optional
from typing import Optional, Dict, Any

from diffusers import (
DiffusionPipeline,
Expand Down Expand Up @@ -69,6 +69,7 @@ def _update_inference_job(
textInput
textOutput
wasGuarded
options
}
}
}
Expand Down Expand Up @@ -103,6 +104,34 @@ def _update_inference_job(
else:
raise exception

def _get_stable_diffusion_options(self, options: Dict[str, Any]) -> dict:
"""
Parses any inference options that may be defined for
StableDiffusionPipeline
"""

stable_diffusion_options: dict = {}

if "negative_prompt" in options:
stable_diffusion_options["negative_prompt"] = options["negative_prompt"]

if "guidance_scale" in options:
stable_diffusion_options["guidance_scale"] = options["guidance_scale"]

if "num_inference_steps" in options:
stable_diffusion_options["num_inference_steps"] = options[
"num_inference_steps"
]

if "seed" in options:
# Note, completely reproducible results are not guaranteed across
# PyTorch releases.
stable_diffusion_options["generator"] = torch.manual_seed(
int(options["seed"])
)

return stable_diffusion_options

def run_pipeline(
self,
inference_job_id: str,
Expand All @@ -127,6 +156,7 @@ def run_pipeline(
is_text_output = result["updateInferenceJob"]["model"]["isTextOutput"]
result["updateInferenceJob"]["model"]["isFileInput"]
is_file_output = result["updateInferenceJob"]["model"]["isFileOutput"]
options = result["updateInferenceJob"]["options"]

LOGGER.info(f""" [*] Model: {model_repo_id}.""")
LOGGER.info(f""" [*] Text Input: '{text_input}'""")
Expand Down Expand Up @@ -155,6 +185,7 @@ def run_pipeline(
elif is_text_input and is_file_output:
SPICE_INFERENCE_DIRECTORY.mkdir(parents=True, exist_ok=True)
save_at = Path(SPICE_INFERENCE_DIRECTORY / f"{inference_job_id}.png")
stable_diffusion_options = self._get_stable_diffusion_options(options)
was_guarded = False
if not save_at.exists():
pipe = DiffusionPipeline.from_pretrained(
Expand Down Expand Up @@ -184,21 +215,19 @@ def run_pipeline(

latents = pipe(
prompt=text_input,
# negative_prompt=negative_prompt,
# num_images_per_prompt=num_images_per_prompt,
# num_inference_steps=n_steps,
output_type="latent",
**stable_diffusion_options,
).images # type: ignore

pipe_result = refiner(
prompt=text_input,
# negative_prompt=negative_prompt,
# num_images_per_prompt=num_images_per_prompt,
# num_inference_steps=n_steps,
image=latents, # type: ignore
**stable_diffusion_options,
) # type: ignore
else:
pipe_result = pipe(text_input, return_dict=False) # type:ignore
pipe_result = pipe(
text_input, return_dict=False, **stable_diffusion_options
) # type:ignore

# pipe returns a tuple in the form the first element is a list with
# the generated images, and the second element is a list of `bool`s
Expand Down

0 comments on commit 20081a0

Please sign in to comment.