Skip to content

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch

License

Notifications You must be signed in to change notification settings

dryadsystems/hazy-diffusers

 
 

Repository files navigation

Diffusers + FlashAttention

This is a branch of HuggingFace Diffusers to incorporate FlashAttention, optimized for high throughput.

Installation

The FlashAttention implementation in this repo depends on the experimental cutlass branch. FlashAttention requires CUDA 11, NVCC, and a Turing or Ampere GPU.

To install FlashAttention:

git clone https://github.com/HazyResearch/flash-attention.git
cd flash-attention
git checkout cutlass
git submodule init
git submodule update
python setup.py install
cd ..

To install diffusers:

git clone https://github.com/HazyResearch/diffusers.git
cd diffusers
pip install -e .

Running

A sample benchmark, following HuggingFace's benchmark of diffusers:

import time
import torch
from diffusers import StableDiffusionPipeline
import functools

# torch disable grad
torch.set_grad_enabled(False)

torch.manual_seed(1231)
torch.cuda.manual_seed(1231)

prompt = "a photo of an astronaut riding a horse on mars"

# cudnn benchmarking
torch.backends.cudnn.benchmark = True

# make sure you're logged in with `huggingface-cli login`
pipe = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4", 
    use_auth_token=True,
    revision="fp16",
    torch_dtype=torch.float16
).to("cuda")

batch_size = 10

# warmup
with torch.inference_mode():
    image = pipe([prompt] * batch_size, num_inference_steps=5).images[0]

for _ in range(3):
    torch.cuda.synchronize()
    start_time = time.time()
    with torch.inference_mode():
        image = pipe([prompt] * batch_size, num_inference_steps=50).images[0]
    torch.cuda.synchronize()
    print(f"Pipeline inference took {time.time() - start_time:.2f} seconds")

About

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch

Resources

License

Code of conduct

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.8%
  • Makefile 0.2%