diff --git a/examples/huggingface/pytorch/text-to-image/quantization/qat/README.md b/examples/huggingface/pytorch/text-to-image/quantization/qat/README.md
index 67817768647..506809182f1 100644
--- a/examples/huggingface/pytorch/text-to-image/quantization/qat/README.md
+++ b/examples/huggingface/pytorch/text-to-image/quantization/qat/README.md
@@ -43,14 +43,36 @@ python text2images.py \
--captions "a photo of an astronaut riding a horse on mars"
```
-Below are two results comparison of fp32 model and int8 model. Note int8 model is trained on an Intel® Xeon® Platinum 8480+ Processor.
+You can also use BF16 UNet for inference on some steps of denoising loop instead of INT8 UNet to improve output images quality, to do so, just add `--use_bf16` argument in the above command.
+
+Below are two results comparison of fp32 model, int8 model and mixture of bf16 model and int8 model. Note int8 model is trained on an Intel® Xeon® Platinum 8480+ Processor.
-With caption `"a photo of an astronaut riding a horse on mars"`, results of fp32 model and int8 model are listed left and right respectively.
+With caption `"a photo of an astronaut riding a horse on mars"`, results of fp32 model, int8 model and mixture of bf16 model and int8 model are listed left, middle and right respectively.
+
-With caption `"The Milky Way lies in the sky, with the golden snow mountain lies below, high definition"`, results of fp32 model and int8 model are listed left and right respectively.
+With caption `"The Milky Way lies in the sky, with the golden snow mountain lies below, high definition"`, results of fp32 model, int8 model and mixture of bf16 model and int8 model are listed left, middle and right respectively.
+
+
+## FID evaluation
+We have also evaluated FID scores on COCO2017 validation dataset for FP32 model, BF16 model, INT8 model and mixture of BF16 and INT8 model. FID results are listed below.
+
+| Precision | FP32 | BF16 | INT8 | INT8+BF16 |
+|----------------------|-------|-------|-------|-----------|
+| FID on COCO2017 val | 30.48 | 30.58 | 35.46 | 30.63 |
+
+To evaluated FID score on COCO2017 validation dataset for mixture of BF16 and INT8 model, you can use below command.
+
+```bash
+python evaluate_fid.py \
+ --model_name_or_path runwayml/stable-diffusion-v1-5 \
+ --int8_model_path sdv1-5-qat_kd/quant_model.pt \
+ --dataset_path /path/to/COCO2017 \
+ --output_dir ./output_images \
+ --precision int8-bf16
+```
diff --git a/examples/huggingface/pytorch/text-to-image/quantization/qat/evaluate_fid.py b/examples/huggingface/pytorch/text-to-image/quantization/qat/evaluate_fid.py
new file mode 100644
index 00000000000..0a9b03ded21
--- /dev/null
+++ b/examples/huggingface/pytorch/text-to-image/quantization/qat/evaluate_fid.py
@@ -0,0 +1,160 @@
+#
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2023 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import argparse
+import copy
+import logging
+import os
+import time
+import numpy as np
+import pathlib
+
+import torch
+from PIL import Image
+from diffusers import StableDiffusionPipeline
+from torchmetrics.image.fid import FrechetInceptionDistance
+import torchvision.datasets as dset
+import torchvision.transforms as transforms
+from text2images import StableDiffusionPipelineMixedPrecision
+
+logging.getLogger().setLevel(logging.INFO)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_name_or_path", type=str, default="", help="Model path")
+ parser.add_argument("--int8_model_path", type=str, default="", help="INT8 model path")
+ parser.add_argument("--dataset_path", type=str, default="", help="COCO2017 dataset path")
+ parser.add_argument("--output_dir", type=str, default=None,help="output path")
+ parser.add_argument("--seed", type=int, default=42, help="random seed")
+ parser.add_argument('--precision', type=str, default="fp32", help='precision: fp32, bf16, int8, int8-bf16')
+ parser.add_argument('-i', '--iterations', default=-1, type=int, help='number of total iterations to run')
+ parser.add_argument('--world-size', default=-1, type=int, help='number of nodes for distributed training')
+ parser.add_argument('--rank', default=-1, type=int, help='node rank for distributed training')
+ parser.add_argument('--dist-url', default='env://', type=str, help='url used to set up distributed training')
+ parser.add_argument('--dist-backend', default='ccl', type=str, help='distributed backend')
+
+ args = parser.parse_args()
+ return args
+
+def main():
+
+ args = parse_args()
+ logging.info(f"Parameters {args}")
+
+ # CCL related
+ os.environ['MASTER_ADDR'] = str(os.environ.get('MASTER_ADDR', '127.0.0.1'))
+ os.environ['MASTER_PORT'] = '29500'
+ os.environ['RANK'] = str(os.environ.get('PMI_RANK', 0))
+ os.environ['WORLD_SIZE'] = str(os.environ.get('PMI_SIZE', 1))
+
+ if args.dist_url == "env://" and args.world_size == -1:
+ args.world_size = int(os.environ["WORLD_SIZE"])
+ print("World size: ", args.world_size)
+
+ args.distributed = args.world_size > 1
+ if args.distributed:
+ if args.dist_url == "env://" and args.rank == -1:
+ args.rank = int(os.environ["RANK"])
+
+ # load model
+ pipe = StableDiffusionPipelineMixedPrecision.from_pretrained(args.model_name_or_path)
+ pipe.HIGH_PRECISION_STEPS = 5
+
+ # data type
+ if args.precision == "fp32":
+ print("Running fp32 ...")
+ dtype=torch.float32
+ elif args.precision == "bf16":
+ print("Running bf16 ...")
+ dtype=torch.bfloat16
+ elif args.precision == "int8" or args.precision == "int8-bf16":
+ print(f"Running {args.precision} ...")
+ if args.precision == "int8-bf16":
+ unet_bf16 = copy.deepcopy(pipe.unet).to(device=pipe.unet.device, dtype=torch.bfloat16)
+ pipe.unet_bf16 = unet_bf16
+ from quantization_modules import load_int8_model
+ pipe.unet = load_int8_model(pipe.unet, args.int8_model_path, "fake" in args.int8_model_path)
+ else:
+ raise ValueError("--precision needs to be the following:: fp32, bf16, fp16, int8, int8-bf16")
+
+ # pipe.to(dtype)
+ if args.distributed:
+ torch.distributed.init_process_group(backend=args.dist_backend,
+ init_method=args.dist_url,
+ world_size=args.world_size,
+ rank=args.rank)
+ print("Rank and world size: ", torch.distributed.get_rank()," ", torch.distributed.get_world_size())
+ # print("Create DistributedDataParallel in CPU")
+ # pipe = torch.nn.parallel.DistributedDataParallel(pipe)
+
+ # prepare dataloader
+ val_coco = dset.CocoCaptions(root = '{}/val2017'.format(args.dataset_path),
+ annFile = '{}/annotations/captions_val2017.json'.format(args.dataset_path),
+ transform=transforms.Compose([transforms.Resize((512, 512)), transforms.PILToTensor(), ]))
+
+ if args.distributed:
+ val_sampler = torch.utils.data.distributed.DistributedSampler(val_coco, shuffle=False)
+ else:
+ val_sampler = None
+
+ val_dataloader = torch.utils.data.DataLoader(val_coco,
+ batch_size=1,
+ shuffle=False,
+ num_workers=0,
+ sampler=val_sampler)
+
+ print("Running accuracy ...")
+ # run model
+ if args.distributed:
+ torch.distributed.barrier()
+ fid = FrechetInceptionDistance(normalize=True)
+ for i, (images, prompts) in enumerate(val_dataloader):
+ prompt = prompts[0][0]
+ real_image = images[0]
+ print("prompt: ", prompt)
+ if args.precision == "bf16":
+ context = torch.cpu.amp.autocast(dtype=dtype)
+ with context, torch.no_grad():
+ output = pipe(prompt, generator=torch.manual_seed(args.seed), output_type="numpy").images
+ else:
+ with torch.no_grad():
+ output = pipe(prompt, generator=torch.manual_seed(args.seed), output_type="numpy").images
+
+ if args.output_dir:
+ if not os.path.exists(args.output_dir):
+ os.mkdir(args.output_dir)
+ image_name = time.strftime("%Y%m%d_%H%M%S")
+ Image.fromarray((output[0] * 255).round().astype("uint8")).save(f"{args.output_dir}/fake_image_{image_name}.png")
+ Image.fromarray(real_image.permute(1, 2, 0).numpy()).save(f"{args.output_dir}/real_image_{image_name}.png")
+
+ fake_image = torch.tensor(output[0]).unsqueeze(0).permute(0, 3, 1, 2)
+ real_image = real_image.unsqueeze(0) / 255.0
+
+ fid.update(real_image, real=True)
+ fid.update(fake_image, real=False)
+
+ if args.iterations > 0 and i == args.iterations - 1:
+ break
+
+ if args.distributed:
+ torch.distributed.barrier()
+ print(f"FID: {float(fid.compute())}")
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/examples/huggingface/pytorch/text-to-image/quantization/qat/int8 bf16 images/The Milky Way lies in the sky, with the golden snow mountain lies below, high definition int8 bf16.png b/examples/huggingface/pytorch/text-to-image/quantization/qat/int8 bf16 images/The Milky Way lies in the sky, with the golden snow mountain lies below, high definition int8 bf16.png
new file mode 100644
index 00000000000..620fddfbcc3
Binary files /dev/null and b/examples/huggingface/pytorch/text-to-image/quantization/qat/int8 bf16 images/The Milky Way lies in the sky, with the golden snow mountain lies below, high definition int8 bf16.png differ
diff --git a/examples/huggingface/pytorch/text-to-image/quantization/qat/int8 bf16 images/a photo of an astronaut riding a horse on mars int8 bf16.png b/examples/huggingface/pytorch/text-to-image/quantization/qat/int8 bf16 images/a photo of an astronaut riding a horse on mars int8 bf16.png
new file mode 100644
index 00000000000..03e5879dfc7
Binary files /dev/null and b/examples/huggingface/pytorch/text-to-image/quantization/qat/int8 bf16 images/a photo of an astronaut riding a horse on mars int8 bf16.png differ
diff --git a/examples/huggingface/pytorch/text-to-image/quantization/qat/requirements.txt b/examples/huggingface/pytorch/text-to-image/quantization/qat/requirements.txt
index a0c6c11844d..c20b15c17a2 100644
--- a/examples/huggingface/pytorch/text-to-image/quantization/qat/requirements.txt
+++ b/examples/huggingface/pytorch/text-to-image/quantization/qat/requirements.txt
@@ -4,5 +4,7 @@ transformers==4.30.2
datasets
torch
torchvision
+torchmetrics
+torch-fidelity
Pillow
git+https://github.com/intel/neural-compressor.git
\ No newline at end of file
diff --git a/examples/huggingface/pytorch/text-to-image/quantization/qat/text2images.py b/examples/huggingface/pytorch/text-to-image/quantization/qat/text2images.py
index da7290284d2..8b46dc11144 100644
--- a/examples/huggingface/pytorch/text-to-image/quantization/qat/text2images.py
+++ b/examples/huggingface/pytorch/text-to-image/quantization/qat/text2images.py
@@ -1,10 +1,13 @@
import argparse
+import copy
import math
import os
import shlex
import torch
from PIL import Image
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from typing import Any, Callable, Dict, List, Optional, Union
def parse_args():
parser = argparse.ArgumentParser()
@@ -52,6 +55,12 @@ def parse_args():
default=0,
help="cuda_id.",
)
+ parser.add_argument(
+ "--use_bf16",
+ action="store_true",
+ default=False,
+ help="Whether use bf16 UNet for quality improvement.",
+ )
args = parser.parse_args()
return args
@@ -90,20 +99,240 @@ def generate_images(
grid = image_grid(images, rows=_rows, cols=num_images_per_prompt // _rows)
return grid, images
+HIGH_PRECISION_STEPS = 5
+
+class StableDiffusionPipelineMixedPrecision(StableDiffusionPipeline):
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ if not hasattr(self, 'HIGH_PRECISION_STEPS'):
+ self.HIGH_PRECISION_STEPS = HIGH_PRECISION_STEPS
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ if (i < self.HIGH_PRECISION_STEPS or i > len(timesteps)-1 - self.HIGH_PRECISION_STEPS) and hasattr(self, 'unet_bf16'):
+ context = torch.cpu.amp.autocast(dtype=torch.bfloat16)
+ with context:
+ noise_pred = self.unet_bf16(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ ).sample
+ else:
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if output_type == "latent":
+ image = latents
+ has_nsfw_concept = None
+ elif output_type == "pil":
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # 10. Convert to PIL
+ image = self.numpy_to_pil(image)
+ else:
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
if __name__ == "__main__":
args = parse_args()
# Load models and create wrapper for stable diffusion
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
- pipeline = StableDiffusionPipeline.from_pretrained(
+ pipeline = StableDiffusionPipelineMixedPrecision.from_pretrained(
args.pretrained_model_name_or_path, unet=unet
)
pipeline.safety_checker = lambda images, clip_input: (images, False)
results_path = args.pretrained_model_name_or_path
if args.quantized_model_name_or_path and os.path.exists(args.quantized_model_name_or_path):
+ if args.use_bf16:
+ unet_bf16 = copy.deepcopy(unet).to(device=unet.device, dtype=torch.bfloat16)
+ setattr(pipeline, "unet_bf16", unet_bf16)
from quantization_modules import load_int8_model
- unet = load_int8_model(unet, args.quantized_model_name_or_path)
+ unet = load_int8_model(unet, args.quantized_model_name_or_path, 'fake' in args.quantized_model_name_or_path)
unet.eval()
setattr(pipeline, "unet", unet)
results_path = os.path.dirname(args.quantized_model_name_or_path)