diff --git a/examples/huggingface/pytorch/text-to-image/quantization/qat/README.md b/examples/huggingface/pytorch/text-to-image/quantization/qat/README.md index 67817768647..506809182f1 100644 --- a/examples/huggingface/pytorch/text-to-image/quantization/qat/README.md +++ b/examples/huggingface/pytorch/text-to-image/quantization/qat/README.md @@ -43,14 +43,36 @@ python text2images.py \ --captions "a photo of an astronaut riding a horse on mars" ``` -Below are two results comparison of fp32 model and int8 model. Note int8 model is trained on an Intel® Xeon® Platinum 8480+ Processor. +You can also use BF16 UNet for inference on some steps of denoising loop instead of INT8 UNet to improve output images quality, to do so, just add `--use_bf16` argument in the above command. + +Below are two results comparison of fp32 model, int8 model and mixture of bf16 model and int8 model. Note int8 model is trained on an Intel® Xeon® Platinum 8480+ Processor.
-With caption `"a photo of an astronaut riding a horse on mars"`, results of fp32 model and int8 model are listed left and right respectively. +With caption `"a photo of an astronaut riding a horse on mars"`, results of fp32 model, int8 model and mixture of bf16 model and int8 model are listed left, middle and right respectively.
FP32 INT8 +INT8 BF16 -With caption `"The Milky Way lies in the sky, with the golden snow mountain lies below, high definition"`, results of fp32 model and int8 model are listed left and right respectively. +With caption `"The Milky Way lies in the sky, with the golden snow mountain lies below, high definition"`, results of fp32 model, int8 model and mixture of bf16 model and int8 model are listed left, middle and right respectively.
FP32 INT8 +INT8 BF16 + +## FID evaluation +We have also evaluated FID scores on COCO2017 validation dataset for FP32 model, BF16 model, INT8 model and mixture of BF16 and INT8 model. FID results are listed below. + +| Precision | FP32 | BF16 | INT8 | INT8+BF16 | +|----------------------|-------|-------|-------|-----------| +| FID on COCO2017 val | 30.48 | 30.58 | 35.46 | 30.63 | + +To evaluated FID score on COCO2017 validation dataset for mixture of BF16 and INT8 model, you can use below command. + +```bash +python evaluate_fid.py \ + --model_name_or_path runwayml/stable-diffusion-v1-5 \ + --int8_model_path sdv1-5-qat_kd/quant_model.pt \ + --dataset_path /path/to/COCO2017 \ + --output_dir ./output_images \ + --precision int8-bf16 +``` diff --git a/examples/huggingface/pytorch/text-to-image/quantization/qat/evaluate_fid.py b/examples/huggingface/pytorch/text-to-image/quantization/qat/evaluate_fid.py new file mode 100644 index 00000000000..0a9b03ded21 --- /dev/null +++ b/examples/huggingface/pytorch/text-to-image/quantization/qat/evaluate_fid.py @@ -0,0 +1,160 @@ +# +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +import copy +import logging +import os +import time +import numpy as np +import pathlib + +import torch +from PIL import Image +from diffusers import StableDiffusionPipeline +from torchmetrics.image.fid import FrechetInceptionDistance +import torchvision.datasets as dset +import torchvision.transforms as transforms +from text2images import StableDiffusionPipelineMixedPrecision + +logging.getLogger().setLevel(logging.INFO) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_name_or_path", type=str, default="", help="Model path") + parser.add_argument("--int8_model_path", type=str, default="", help="INT8 model path") + parser.add_argument("--dataset_path", type=str, default="", help="COCO2017 dataset path") + parser.add_argument("--output_dir", type=str, default=None,help="output path") + parser.add_argument("--seed", type=int, default=42, help="random seed") + parser.add_argument('--precision', type=str, default="fp32", help='precision: fp32, bf16, int8, int8-bf16') + parser.add_argument('-i', '--iterations', default=-1, type=int, help='number of total iterations to run') + parser.add_argument('--world-size', default=-1, type=int, help='number of nodes for distributed training') + parser.add_argument('--rank', default=-1, type=int, help='node rank for distributed training') + parser.add_argument('--dist-url', default='env://', type=str, help='url used to set up distributed training') + parser.add_argument('--dist-backend', default='ccl', type=str, help='distributed backend') + + args = parser.parse_args() + return args + +def main(): + + args = parse_args() + logging.info(f"Parameters {args}") + + # CCL related + os.environ['MASTER_ADDR'] = str(os.environ.get('MASTER_ADDR', '127.0.0.1')) + os.environ['MASTER_PORT'] = '29500' + os.environ['RANK'] = str(os.environ.get('PMI_RANK', 0)) + os.environ['WORLD_SIZE'] = str(os.environ.get('PMI_SIZE', 1)) + + if args.dist_url == "env://" and args.world_size == -1: + args.world_size = int(os.environ["WORLD_SIZE"]) + print("World size: ", args.world_size) + + args.distributed = args.world_size > 1 + if args.distributed: + if args.dist_url == "env://" and args.rank == -1: + args.rank = int(os.environ["RANK"]) + + # load model + pipe = StableDiffusionPipelineMixedPrecision.from_pretrained(args.model_name_or_path) + pipe.HIGH_PRECISION_STEPS = 5 + + # data type + if args.precision == "fp32": + print("Running fp32 ...") + dtype=torch.float32 + elif args.precision == "bf16": + print("Running bf16 ...") + dtype=torch.bfloat16 + elif args.precision == "int8" or args.precision == "int8-bf16": + print(f"Running {args.precision} ...") + if args.precision == "int8-bf16": + unet_bf16 = copy.deepcopy(pipe.unet).to(device=pipe.unet.device, dtype=torch.bfloat16) + pipe.unet_bf16 = unet_bf16 + from quantization_modules import load_int8_model + pipe.unet = load_int8_model(pipe.unet, args.int8_model_path, "fake" in args.int8_model_path) + else: + raise ValueError("--precision needs to be the following:: fp32, bf16, fp16, int8, int8-bf16") + + # pipe.to(dtype) + if args.distributed: + torch.distributed.init_process_group(backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank) + print("Rank and world size: ", torch.distributed.get_rank()," ", torch.distributed.get_world_size()) + # print("Create DistributedDataParallel in CPU") + # pipe = torch.nn.parallel.DistributedDataParallel(pipe) + + # prepare dataloader + val_coco = dset.CocoCaptions(root = '{}/val2017'.format(args.dataset_path), + annFile = '{}/annotations/captions_val2017.json'.format(args.dataset_path), + transform=transforms.Compose([transforms.Resize((512, 512)), transforms.PILToTensor(), ])) + + if args.distributed: + val_sampler = torch.utils.data.distributed.DistributedSampler(val_coco, shuffle=False) + else: + val_sampler = None + + val_dataloader = torch.utils.data.DataLoader(val_coco, + batch_size=1, + shuffle=False, + num_workers=0, + sampler=val_sampler) + + print("Running accuracy ...") + # run model + if args.distributed: + torch.distributed.barrier() + fid = FrechetInceptionDistance(normalize=True) + for i, (images, prompts) in enumerate(val_dataloader): + prompt = prompts[0][0] + real_image = images[0] + print("prompt: ", prompt) + if args.precision == "bf16": + context = torch.cpu.amp.autocast(dtype=dtype) + with context, torch.no_grad(): + output = pipe(prompt, generator=torch.manual_seed(args.seed), output_type="numpy").images + else: + with torch.no_grad(): + output = pipe(prompt, generator=torch.manual_seed(args.seed), output_type="numpy").images + + if args.output_dir: + if not os.path.exists(args.output_dir): + os.mkdir(args.output_dir) + image_name = time.strftime("%Y%m%d_%H%M%S") + Image.fromarray((output[0] * 255).round().astype("uint8")).save(f"{args.output_dir}/fake_image_{image_name}.png") + Image.fromarray(real_image.permute(1, 2, 0).numpy()).save(f"{args.output_dir}/real_image_{image_name}.png") + + fake_image = torch.tensor(output[0]).unsqueeze(0).permute(0, 3, 1, 2) + real_image = real_image.unsqueeze(0) / 255.0 + + fid.update(real_image, real=True) + fid.update(fake_image, real=False) + + if args.iterations > 0 and i == args.iterations - 1: + break + + if args.distributed: + torch.distributed.barrier() + print(f"FID: {float(fid.compute())}") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/examples/huggingface/pytorch/text-to-image/quantization/qat/int8 bf16 images/The Milky Way lies in the sky, with the golden snow mountain lies below, high definition int8 bf16.png b/examples/huggingface/pytorch/text-to-image/quantization/qat/int8 bf16 images/The Milky Way lies in the sky, with the golden snow mountain lies below, high definition int8 bf16.png new file mode 100644 index 00000000000..620fddfbcc3 Binary files /dev/null and b/examples/huggingface/pytorch/text-to-image/quantization/qat/int8 bf16 images/The Milky Way lies in the sky, with the golden snow mountain lies below, high definition int8 bf16.png differ diff --git a/examples/huggingface/pytorch/text-to-image/quantization/qat/int8 bf16 images/a photo of an astronaut riding a horse on mars int8 bf16.png b/examples/huggingface/pytorch/text-to-image/quantization/qat/int8 bf16 images/a photo of an astronaut riding a horse on mars int8 bf16.png new file mode 100644 index 00000000000..03e5879dfc7 Binary files /dev/null and b/examples/huggingface/pytorch/text-to-image/quantization/qat/int8 bf16 images/a photo of an astronaut riding a horse on mars int8 bf16.png differ diff --git a/examples/huggingface/pytorch/text-to-image/quantization/qat/requirements.txt b/examples/huggingface/pytorch/text-to-image/quantization/qat/requirements.txt index a0c6c11844d..c20b15c17a2 100644 --- a/examples/huggingface/pytorch/text-to-image/quantization/qat/requirements.txt +++ b/examples/huggingface/pytorch/text-to-image/quantization/qat/requirements.txt @@ -4,5 +4,7 @@ transformers==4.30.2 datasets torch torchvision +torchmetrics +torch-fidelity Pillow git+https://github.com/intel/neural-compressor.git \ No newline at end of file diff --git a/examples/huggingface/pytorch/text-to-image/quantization/qat/text2images.py b/examples/huggingface/pytorch/text-to-image/quantization/qat/text2images.py index da7290284d2..8b46dc11144 100644 --- a/examples/huggingface/pytorch/text-to-image/quantization/qat/text2images.py +++ b/examples/huggingface/pytorch/text-to-image/quantization/qat/text2images.py @@ -1,10 +1,13 @@ import argparse +import copy import math import os import shlex import torch from PIL import Image from diffusers import StableDiffusionPipeline, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from typing import Any, Callable, Dict, List, Optional, Union def parse_args(): parser = argparse.ArgumentParser() @@ -52,6 +55,12 @@ def parse_args(): default=0, help="cuda_id.", ) + parser.add_argument( + "--use_bf16", + action="store_true", + default=False, + help="Whether use bf16 UNet for quality improvement.", + ) args = parser.parse_args() return args @@ -90,20 +99,240 @@ def generate_images( grid = image_grid(images, rows=_rows, cols=num_images_per_prompt // _rows) return grid, images +HIGH_PRECISION_STEPS = 5 + +class StableDiffusionPipelineMixedPrecision(StableDiffusionPipeline): + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + if not hasattr(self, 'HIGH_PRECISION_STEPS'): + self.HIGH_PRECISION_STEPS = HIGH_PRECISION_STEPS + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + if (i < self.HIGH_PRECISION_STEPS or i > len(timesteps)-1 - self.HIGH_PRECISION_STEPS) and hasattr(self, 'unet_bf16'): + context = torch.cpu.amp.autocast(dtype=torch.bfloat16) + with context: + noise_pred = self.unet_bf16( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + else: + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if output_type == "latent": + image = latents + has_nsfw_concept = None + elif output_type == "pil": + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + else: + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + if __name__ == "__main__": args = parse_args() # Load models and create wrapper for stable diffusion unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") - pipeline = StableDiffusionPipeline.from_pretrained( + pipeline = StableDiffusionPipelineMixedPrecision.from_pretrained( args.pretrained_model_name_or_path, unet=unet ) pipeline.safety_checker = lambda images, clip_input: (images, False) results_path = args.pretrained_model_name_or_path if args.quantized_model_name_or_path and os.path.exists(args.quantized_model_name_or_path): + if args.use_bf16: + unet_bf16 = copy.deepcopy(unet).to(device=unet.device, dtype=torch.bfloat16) + setattr(pipeline, "unet_bf16", unet_bf16) from quantization_modules import load_int8_model - unet = load_int8_model(unet, args.quantized_model_name_or_path) + unet = load_int8_model(unet, args.quantized_model_name_or_path, 'fake' in args.quantized_model_name_or_path) unet.eval() setattr(pipeline, "unet", unet) results_path = os.path.dirname(args.quantized_model_name_or_path)