diff --git a/torchbenchmark/models/stable_diffusion/__init__.py b/torchbenchmark/models/stable_diffusion_text_encoder/__init__.py similarity index 100% rename from torchbenchmark/models/stable_diffusion/__init__.py rename to torchbenchmark/models/stable_diffusion_text_encoder/__init__.py diff --git a/torchbenchmark/models/stable_diffusion/install.py b/torchbenchmark/models/stable_diffusion_text_encoder/install.py similarity index 100% rename from torchbenchmark/models/stable_diffusion/install.py rename to torchbenchmark/models/stable_diffusion_text_encoder/install.py diff --git a/torchbenchmark/models/stable_diffusion/metadata.yaml b/torchbenchmark/models/stable_diffusion_text_encoder/metadata.yaml similarity index 100% rename from torchbenchmark/models/stable_diffusion/metadata.yaml rename to torchbenchmark/models/stable_diffusion_text_encoder/metadata.yaml diff --git a/torchbenchmark/models/stable_diffusion_unet/__init__.py b/torchbenchmark/models/stable_diffusion_unet/__init__.py new file mode 100644 index 0000000000..7f65f76601 --- /dev/null +++ b/torchbenchmark/models/stable_diffusion_unet/__init__.py @@ -0,0 +1,49 @@ +""" +HuggingFace Stable Diffusion model. +It requires users to specify "HUGGINGFACE_AUTH_TOKEN" in environment variable +to authorize login and agree HuggingFace terms and conditions. +""" +from torchbenchmark.tasks import COMPUTER_VISION +from torchbenchmark.util.model import BenchmarkModel +from torchbenchmark.util.framework.huggingface.model_factory import HuggingFaceAuthMixin + +import torch +from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler + + +class Model(BenchmarkModel, HuggingFaceAuthMixin): + task = COMPUTER_VISION.GENERATION + + DEFAULT_TRAIN_BSIZE = 1 + DEFAULT_EVAL_BSIZE = 1 + ALLOW_CUSTOMIZE_BSIZE = False + # Default eval precision on CUDA device is fp16 + DEFAULT_EVAL_CUDA_PRECISION = "fp16" + + def __init__(self, test, device, batch_size=None, extra_args=[]): + HuggingFaceAuthMixin.__init__(self) + super().__init__(test=test, device=device, + batch_size=batch_size, extra_args=extra_args) + model_id = "stabilityai/stable-diffusion-2" + scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") + self.pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler) + self.example_inputs = "a photo of an astronaut riding a horse on mars" + self.pipe.to(self.device) + + def enable_fp16_half(self): + pass + + + def get_module(self): + random_input = torch.randn(1, 4, 128, 128).to(self.device) + timestep = torch.tensor([1.0]).to(self.device) + encoder_hidden_states = torch.randn(1, 1, 1024).to(self.device) + return self.pipe.unet, [random_input, timestep, encoder_hidden_states] + + + def train(self): + raise NotImplementedError("Train test is not implemented for the stable diffusion model.") + + def eval(self): + image = self.pipe(self.example_inputs) + return (image, ) diff --git a/torchbenchmark/models/stable_diffusion_unet/install.py b/torchbenchmark/models/stable_diffusion_unet/install.py new file mode 100644 index 0000000000..a9f4576593 --- /dev/null +++ b/torchbenchmark/models/stable_diffusion_unet/install.py @@ -0,0 +1,17 @@ +from torchbenchmark.util.framework.diffusers import install_diffusers +from torchbenchmark.util.framework.huggingface.model_factory import HuggingFaceAuthMixin +import torch +import os +import warnings +MODEL_NAME = "stabilityai/stable-diffusion-2" + +def load_model_checkpoint(): + from diffusers import StableDiffusionPipeline + StableDiffusionPipeline.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, safety_checker=None) + +if __name__ == "__main__": + if not 'HUGGING_FACE_HUB_TOKEN' in os.environ: + warnings.warn("Make sure to set `HUGGINGFACE_HUB_TOKEN` so you can download weights") + else: + install_diffusers() + load_model_checkpoint() diff --git a/torchbenchmark/models/stable_diffusion_unet/metadata.yaml b/torchbenchmark/models/stable_diffusion_unet/metadata.yaml new file mode 100644 index 0000000000..4a03e1edcb --- /dev/null +++ b/torchbenchmark/models/stable_diffusion_unet/metadata.yaml @@ -0,0 +1,10 @@ +devices: + NVIDIA A100-SXM4-40GB: + eval_batch_size: 32 +eval_benchmark: false +eval_deterministic: false +eval_nograd: true +train_benchmark: false +train_deterministic: false +not_implemented: +- device: cpu