Skip to content

Commit

Permalink
SD: image encoder bench (#1881)
Browse files Browse the repository at this point in the history
Summary:
Problem was in our benchmarks we were only benchmarking the text encoder that's not actually the bottleneck so instead now benchmarking in addition the unet

The dynamo runners assume the inputs must be nn modules we might be able to relax this to a HF pipeline object but in the meantime this should be fine

Pull Request resolved: #1881

Reviewed By: xuzhao9

Differential Revision: D49152514

Pulled By: msaroufim

fbshipit-source-id: ab38963d2faabab5f3e3351ec0bca0e5e2d418ef
  • Loading branch information
msaroufim authored and facebook-github-bot committed Sep 11, 2023
1 parent feb15ff commit 9cbeee3
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 0 deletions.
49 changes: 49 additions & 0 deletions torchbenchmark/models/stable_diffusion_unet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
HuggingFace Stable Diffusion model.
It requires users to specify "HUGGINGFACE_AUTH_TOKEN" in environment variable
to authorize login and agree HuggingFace terms and conditions.
"""
from torchbenchmark.tasks import COMPUTER_VISION
from torchbenchmark.util.model import BenchmarkModel
from torchbenchmark.util.framework.huggingface.model_factory import HuggingFaceAuthMixin

import torch
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler


class Model(BenchmarkModel, HuggingFaceAuthMixin):
task = COMPUTER_VISION.GENERATION

DEFAULT_TRAIN_BSIZE = 1
DEFAULT_EVAL_BSIZE = 1
ALLOW_CUSTOMIZE_BSIZE = False
# Default eval precision on CUDA device is fp16
DEFAULT_EVAL_CUDA_PRECISION = "fp16"

def __init__(self, test, device, batch_size=None, extra_args=[]):
HuggingFaceAuthMixin.__init__(self)
super().__init__(test=test, device=device,
batch_size=batch_size, extra_args=extra_args)
model_id = "stabilityai/stable-diffusion-2"
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
self.pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler)
self.example_inputs = "a photo of an astronaut riding a horse on mars"
self.pipe.to(self.device)

def enable_fp16_half(self):
pass


def get_module(self):
random_input = torch.randn(1, 4, 128, 128).to(self.device)
timestep = torch.tensor([1.0]).to(self.device)
encoder_hidden_states = torch.randn(1, 1, 1024).to(self.device)
return self.pipe.unet, [random_input, timestep, encoder_hidden_states]


def train(self):
raise NotImplementedError("Train test is not implemented for the stable diffusion model.")

def eval(self):
image = self.pipe(self.example_inputs)
return (image, )
17 changes: 17 additions & 0 deletions torchbenchmark/models/stable_diffusion_unet/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from torchbenchmark.util.framework.diffusers import install_diffusers
from torchbenchmark.util.framework.huggingface.model_factory import HuggingFaceAuthMixin
import torch
import os
import warnings
MODEL_NAME = "stabilityai/stable-diffusion-2"

def load_model_checkpoint():
from diffusers import StableDiffusionPipeline
StableDiffusionPipeline.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, safety_checker=None)

if __name__ == "__main__":
if not 'HUGGING_FACE_HUB_TOKEN' in os.environ:
warnings.warn("Make sure to set `HUGGINGFACE_HUB_TOKEN` so you can download weights")
else:
install_diffusers()
load_model_checkpoint()
10 changes: 10 additions & 0 deletions torchbenchmark/models/stable_diffusion_unet/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
devices:
NVIDIA A100-SXM4-40GB:
eval_batch_size: 32
eval_benchmark: false
eval_deterministic: false
eval_nograd: true
train_benchmark: false
train_deterministic: false
not_implemented:
- device: cpu

0 comments on commit 9cbeee3

Please sign in to comment.