Skip to content

Commit

Permalink
add fp variant support for macos
Browse files Browse the repository at this point in the history
  • Loading branch information
Ankush Patel authored and Ankush Patel committed Aug 2, 2023
1 parent 1fb38c5 commit 6e4296e
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions spice_agent/inference/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ def run_pipeline(
LOGGER.info(f""" [*] Model: {model_repo_id}.""")
LOGGER.info(f""" [*] Text Input: '{text_input}'""")

variant = "fp16"
torch_dtype = torch.float16
if torch.backends.mps.is_available():
variant = "fp32"
torch_dtype = torch.float32
mps_empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
Expand All @@ -152,8 +156,8 @@ def run_pipeline(
if not save_at.exists():
pipe = DiffusionPipeline.from_pretrained(
model_repo_id,
torch_dtype=torch.float16,
variant="fp16",
torch_dtype=torch_dtype,
variant=variant,
use_safetensors=True,
)
pipe = pipe.to(self.device) # type: ignore
Expand All @@ -168,8 +172,8 @@ def run_pipeline(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=pipe.text_encoder_2,
vae=pipe.vae,
torch_dtype=torch.float16,
variant="fp16",
torch_dtype=torch_dtype,
variant=variant,
use_safetensors=True,
)

Expand Down Expand Up @@ -199,7 +203,7 @@ def run_pipeline(
# represents "not-safe-for-work" (nsfw) content, according to the
# `safety_checker`.
result = pipe_result[0][0] # type: ignore
if len(pipe_result) > 1:
if len(pipe_result) > 1 and pipe_result[1]: # type: ignore
was_guarded = pipe_result[1][0] # type: ignore
result.save(save_at) # type: ignore
else:
Expand Down

0 comments on commit 6e4296e

Please sign in to comment.