Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
Supported bf16 int8 mixture precision inference during denoising loop…
Browse files Browse the repository at this point in the history
… for SD (#1203)
  • Loading branch information
XinyuYe-Intel authored Jul 28, 2023
1 parent 4c75efe commit bd29731
Show file tree
Hide file tree
Showing 6 changed files with 418 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,36 @@ python text2images.py \
--captions "a photo of an astronaut riding a horse on mars"
```

Below are two results comparison of fp32 model and int8 model. Note int8 model is trained on an Intel® Xeon® Platinum 8480+ Processor.
You can also use BF16 UNet for inference on some steps of denoising loop instead of INT8 UNet to improve output images quality, to do so, just add `--use_bf16` argument in the above command.

Below are two results comparison of fp32 model, int8 model and mixture of bf16 model and int8 model. Note int8 model is trained on an Intel® Xeon® Platinum 8480+ Processor.
<br>
With caption `"a photo of an astronaut riding a horse on mars"`, results of fp32 model and int8 model are listed left and right respectively.
With caption `"a photo of an astronaut riding a horse on mars"`, results of fp32 model, int8 model and mixture of bf16 model and int8 model are listed left, middle and right respectively.
<br>
<img src="./fp32 images/a photo of an astronaut riding a horse on mars fp32.png" width = "300" height = "300" alt="FP32" align=center />
<img src="./int8 images/a photo of an astronaut riding a horse on mars int8.png" width = "300" height = "300" alt="INT8" align=center />
<img src="./int8 bf16 images/a photo of an astronaut riding a horse on mars int8 bf16.png" width = "300" height = "300" alt="INT8 BF16" align=center />

With caption `"The Milky Way lies in the sky, with the golden snow mountain lies below, high definition"`, results of fp32 model and int8 model are listed left and right respectively.
With caption `"The Milky Way lies in the sky, with the golden snow mountain lies below, high definition"`, results of fp32 model, int8 model and mixture of bf16 model and int8 model are listed left, middle and right respectively.
<br>
<img src="./fp32 images/The Milky Way lies in the sky, with the golden snow mountain lies below, high definition fp32.png" width = "300" height = "300" alt="FP32" align=center />
<img src="./int8 images/The Milky Way lies in the sky, with the golden snow mountain lies below, high definition int8.png" width = "300" height = "300" alt="INT8" align=center />
<img src="./int8 bf16 images/The Milky Way lies in the sky, with the golden snow mountain lies below, high definition int8 bf16.png" width = "300" height = "300" alt="INT8 BF16" align=center />

## FID evaluation
We have also evaluated FID scores on COCO2017 validation dataset for FP32 model, BF16 model, INT8 model and mixture of BF16 and INT8 model. FID results are listed below.

| Precision | FP32 | BF16 | INT8 | INT8+BF16 |
|----------------------|-------|-------|-------|-----------|
| FID on COCO2017 val | 30.48 | 30.58 | 35.46 | 30.63 |

To evaluated FID score on COCO2017 validation dataset for mixture of BF16 and INT8 model, you can use below command.

```bash
python evaluate_fid.py \
--model_name_or_path runwayml/stable-diffusion-v1-5 \
--int8_model_path sdv1-5-qat_kd/quant_model.pt \
--dataset_path /path/to/COCO2017 \
--output_dir ./output_images \
--precision int8-bf16
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
#
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import argparse
import copy
import logging
import os
import time
import numpy as np
import pathlib

import torch
from PIL import Image
from diffusers import StableDiffusionPipeline
from torchmetrics.image.fid import FrechetInceptionDistance
import torchvision.datasets as dset
import torchvision.transforms as transforms
from text2images import StableDiffusionPipelineMixedPrecision

logging.getLogger().setLevel(logging.INFO)


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", type=str, default="", help="Model path")
parser.add_argument("--int8_model_path", type=str, default="", help="INT8 model path")
parser.add_argument("--dataset_path", type=str, default="", help="COCO2017 dataset path")
parser.add_argument("--output_dir", type=str, default=None,help="output path")
parser.add_argument("--seed", type=int, default=42, help="random seed")
parser.add_argument('--precision', type=str, default="fp32", help='precision: fp32, bf16, int8, int8-bf16')
parser.add_argument('-i', '--iterations', default=-1, type=int, help='number of total iterations to run')
parser.add_argument('--world-size', default=-1, type=int, help='number of nodes for distributed training')
parser.add_argument('--rank', default=-1, type=int, help='node rank for distributed training')
parser.add_argument('--dist-url', default='env://', type=str, help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='ccl', type=str, help='distributed backend')

args = parser.parse_args()
return args

def main():

args = parse_args()
logging.info(f"Parameters {args}")

# CCL related
os.environ['MASTER_ADDR'] = str(os.environ.get('MASTER_ADDR', '127.0.0.1'))
os.environ['MASTER_PORT'] = '29500'
os.environ['RANK'] = str(os.environ.get('PMI_RANK', 0))
os.environ['WORLD_SIZE'] = str(os.environ.get('PMI_SIZE', 1))

if args.dist_url == "env://" and args.world_size == -1:
args.world_size = int(os.environ["WORLD_SIZE"])
print("World size: ", args.world_size)

args.distributed = args.world_size > 1
if args.distributed:
if args.dist_url == "env://" and args.rank == -1:
args.rank = int(os.environ["RANK"])

# load model
pipe = StableDiffusionPipelineMixedPrecision.from_pretrained(args.model_name_or_path)
pipe.HIGH_PRECISION_STEPS = 5

# data type
if args.precision == "fp32":
print("Running fp32 ...")
dtype=torch.float32
elif args.precision == "bf16":
print("Running bf16 ...")
dtype=torch.bfloat16
elif args.precision == "int8" or args.precision == "int8-bf16":
print(f"Running {args.precision} ...")
if args.precision == "int8-bf16":
unet_bf16 = copy.deepcopy(pipe.unet).to(device=pipe.unet.device, dtype=torch.bfloat16)
pipe.unet_bf16 = unet_bf16
from quantization_modules import load_int8_model
pipe.unet = load_int8_model(pipe.unet, args.int8_model_path, "fake" in args.int8_model_path)
else:
raise ValueError("--precision needs to be the following:: fp32, bf16, fp16, int8, int8-bf16")

# pipe.to(dtype)
if args.distributed:
torch.distributed.init_process_group(backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank)
print("Rank and world size: ", torch.distributed.get_rank()," ", torch.distributed.get_world_size())
# print("Create DistributedDataParallel in CPU")
# pipe = torch.nn.parallel.DistributedDataParallel(pipe)

# prepare dataloader
val_coco = dset.CocoCaptions(root = '{}/val2017'.format(args.dataset_path),
annFile = '{}/annotations/captions_val2017.json'.format(args.dataset_path),
transform=transforms.Compose([transforms.Resize((512, 512)), transforms.PILToTensor(), ]))

if args.distributed:
val_sampler = torch.utils.data.distributed.DistributedSampler(val_coco, shuffle=False)
else:
val_sampler = None

val_dataloader = torch.utils.data.DataLoader(val_coco,
batch_size=1,
shuffle=False,
num_workers=0,
sampler=val_sampler)

print("Running accuracy ...")
# run model
if args.distributed:
torch.distributed.barrier()
fid = FrechetInceptionDistance(normalize=True)
for i, (images, prompts) in enumerate(val_dataloader):
prompt = prompts[0][0]
real_image = images[0]
print("prompt: ", prompt)
if args.precision == "bf16":
context = torch.cpu.amp.autocast(dtype=dtype)
with context, torch.no_grad():
output = pipe(prompt, generator=torch.manual_seed(args.seed), output_type="numpy").images
else:
with torch.no_grad():
output = pipe(prompt, generator=torch.manual_seed(args.seed), output_type="numpy").images

if args.output_dir:
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
image_name = time.strftime("%Y%m%d_%H%M%S")
Image.fromarray((output[0] * 255).round().astype("uint8")).save(f"{args.output_dir}/fake_image_{image_name}.png")
Image.fromarray(real_image.permute(1, 2, 0).numpy()).save(f"{args.output_dir}/real_image_{image_name}.png")

fake_image = torch.tensor(output[0]).unsqueeze(0).permute(0, 3, 1, 2)
real_image = real_image.unsqueeze(0) / 255.0

fid.update(real_image, real=True)
fid.update(fake_image, real=False)

if args.iterations > 0 and i == args.iterations - 1:
break

if args.distributed:
torch.distributed.barrier()
print(f"FID: {float(fid.compute())}")

if __name__ == '__main__':
main()
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ transformers==4.30.2
datasets
torch
torchvision
torchmetrics
torch-fidelity
Pillow
git+https://github.com/intel/neural-compressor.git
Loading

0 comments on commit bd29731

Please sign in to comment.