-
Notifications
You must be signed in to change notification settings - Fork 863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Diffusion Fast Example #2902
Diffusion Fast Example #2902
Changes from 13 commits
5610c3c
c5f3d05
ecc6a98
375f2ce
2a0f8e5
241f31f
fcb008b
ecdc4b2
93d9118
5bf4012
6f3c91f
e7bc2b4
1b3b891
7f25077
440d5ab
cc79ed0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import torch | ||
from diffusers import DiffusionPipeline | ||
|
||
pipeline = DiffusionPipeline.from_pretrained( | ||
"stabilityai/stable-diffusion-xl-base-1.0", | ||
torch_dtype=torch.float32, | ||
use_safetensors=True, | ||
) | ||
pipeline.save_pretrained("./Base_Diffusion_model") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
|
||
## Diffusion-Fast | ||
|
||
[Diffusion fast](https://github.com/huggingface/diffusion-fast) is a simple and efficient pytorch-native way of optimizing Stable Diffusion XL (SDXL). | ||
|
||
It features: | ||
* Running with the bfloat16 precision | ||
* scaled_dot_product_attention (SDPA) | ||
* torch.compile | ||
* Combining q,k,v projections for attention computation | ||
* Dynamic int8 quantization | ||
|
||
Details about the optimizations and various results can be found in this [blog](https://pytorch.org/blog/accelerating-generative-ai-3/). | ||
The example has been tested on A10, A100 as well as H100. | ||
|
||
|
||
#### Pre-requisites | ||
|
||
`cd` to the example folder `examples/image_generation/diffusion_fast` | ||
|
||
Install dependencies and upgrade torch to nightly build (currently required) | ||
``` | ||
git clone https://github.com/huggingface/diffusion-fast.git | ||
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --ignore-installed -y | ||
pip install accelerate transformers peft | ||
pip install --no-cache-dir git+https://github.com/pytorch-labs/ao@54bcd5a10d0abbe7b0c045052029257099f83fd9 | ||
pip install pandas matplotlib seaborn | ||
``` | ||
### Step 1: Download the Stable diffusion model | ||
|
||
```bash | ||
python Download_model.py | ||
``` | ||
This saves the model in `Base_Diffusion_model` | ||
|
||
### Step 1: Generate model archive | ||
At this stage we're creating the model archive which includes the configuration of our model in [model_config.yaml](./model_config.yaml). | ||
It's also the point where we need to decide if we want to deploy our model on a single or multiple GPUs. | ||
For the single GPU case we can use the default configuration that can be found in [model_config.yaml](./model_config.yaml). | ||
|
||
``` | ||
torch-model-archiver --model-name diffusion_fast --version 1.0 --handler diffusion_fast_handler.py --config-file model_config.yaml --extra-files "diffusion-fast/utils/pipeline_utils.py" --archive-format no-archive | ||
mv Base_Diffusion_model diffusion_fast/ | ||
``` | ||
|
||
### Step 2: Add the model archive to model store | ||
|
||
``` | ||
mkdir model_store | ||
mv diffusion_fast model_store | ||
``` | ||
|
||
### Step 3: Start torchserve | ||
|
||
``` | ||
torchserve --start --ts-config config.properties --model-store model_store --models diffusion_fast | ||
``` | ||
|
||
### Step 4: Run inference | ||
|
||
``` | ||
python query.py --url "http://localhost:8080/predictions/diffusion_fast" --prompt "a photo of an astronaut riding a horse on mars" | ||
``` | ||
The image generated will be written to a file `output-<>.jpg` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
inference_address=http://127.0.0.1:8080 | ||
management_address=http://127.0.0.1:8081 | ||
metrics_address=http://127.0.0.1:8082 | ||
max_response_size=655350000 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import logging | ||
import os | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import torch | ||
from pipeline_utils import load_pipeline | ||
|
||
from ts.handler_utils.timer import timed | ||
from ts.torch_handler.base_handler import BaseHandler | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class DiffusionFastHandler(BaseHandler): | ||
""" | ||
Diffusion-Fast handler class for text to image generation. | ||
""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.initialized = False | ||
|
||
def initialize(self, ctx): | ||
"""In this initialize function, the Diffusion Fast model is loaded and | ||
initialized here. | ||
Args: | ||
ctx (context): It is a JSON Object containing information | ||
pertaining to the model artifacts parameters. | ||
""" | ||
self.context = ctx | ||
self.manifest = ctx.manifest | ||
properties = ctx.system_properties | ||
model_dir = properties.get("model_dir") | ||
|
||
self.device = torch.device( | ||
"cuda:" + str(properties.get("gpu_id")) | ||
if torch.cuda.is_available() and properties.get("gpu_id") is not None | ||
else "cpu" | ||
) | ||
|
||
self.num_inference_steps = ctx.model_yaml_config["handler"][ | ||
"num_inference_steps" | ||
] | ||
|
||
# Parameters for the model | ||
compile_unet = ctx.model_yaml_config["handler"]["compile_unet"] | ||
compile_vae = ctx.model_yaml_config["handler"]["compile_vae"] | ||
compile_mode = ctx.model_yaml_config["handler"]["compile_mode"] | ||
enable_fused_projections = ctx.model_yaml_config["handler"][ | ||
"enable_fused_projections" | ||
] | ||
do_quant = ctx.model_yaml_config["handler"]["do_quant"] | ||
change_comp_config = ctx.model_yaml_config["handler"]["change_comp_config"] | ||
no_sdpa = ctx.model_yaml_config["handler"]["no_sdpa"] | ||
no_bf16 = ctx.model_yaml_config["handler"]["no_bf16"] | ||
upcast_vae = ctx.model_yaml_config["handler"]["upcast_vae"] | ||
|
||
# Load model weights | ||
model_weights = Path(ctx.model_yaml_config["handler"]["model_weights"]) | ||
|
||
ckpt = os.path.join(model_dir, model_weights) | ||
self.pipeline = load_pipeline( | ||
ckpt=ckpt, | ||
compile_unet=compile_unet, | ||
compile_vae=compile_vae, | ||
compile_mode=compile_mode, | ||
enable_fused_projections=enable_fused_projections, | ||
do_quant=do_quant, | ||
change_comp_config=change_comp_config, | ||
no_bf16=no_bf16, | ||
no_sdpa=no_sdpa, | ||
upcast_vae=upcast_vae, | ||
) | ||
|
||
logger.info("Diffusion Fast model loaded successfully") | ||
|
||
self.initialized = True | ||
|
||
@timed | ||
def preprocess(self, requests): | ||
"""Basic text preprocessing, of the user's prompt. | ||
Args: | ||
requests (str): The Input data in the form of text is passed on to the preprocess | ||
function. | ||
Returns: | ||
list : The preprocess function returns a list of prompts. | ||
""" | ||
|
||
assert ( | ||
len(requests) == 1 | ||
), "Diffusion Fast is currently only supported with batch_size=1" | ||
|
||
inputs = [] | ||
for _, data in enumerate(requests): | ||
input_text = data.get("data") | ||
if input_text is None: | ||
input_text = data.get("body") | ||
if isinstance(input_text, (bytes, bytearray)): | ||
input_text = input_text.decode("utf-8") | ||
logger.info("Received text: '%s'", input_text) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this can be removed for prod There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
inputs.append(input_text) | ||
return inputs | ||
|
||
@timed | ||
def inference(self, inputs): | ||
"""Generates the image relevant to the received text. | ||
Args: | ||
input_batch (list): List of Text from the pre-process function is passed here | ||
Returns: | ||
list : It returns a list of the generate images for the input text | ||
""" | ||
# Handling inference for sequence_classification. | ||
inferences = self.pipeline( | ||
inputs, num_inference_steps=self.num_inference_steps, height=768, width=768 | ||
).images | ||
|
||
logger.info("Generated image: '%s'", inferences) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
return inferences | ||
|
||
@timed | ||
def postprocess(self, inference_output): | ||
"""Post Process Function converts the generated image into Torchserve readable format. | ||
Args: | ||
inference_output (list): It contains the generated image of the input text. | ||
Returns: | ||
(list): Returns a list of the images. | ||
""" | ||
images = [] | ||
for image in inference_output: | ||
images.append(np.array(image).tolist()) | ||
return images |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
minWorkers: 1 | ||
maxWorkers: 1 | ||
maxBatchDelay: 200 | ||
responseTimeout: 3600 | ||
deviceType: "gpu" | ||
handler: | ||
model_weights: "./Base_Diffusion_model" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we change to model_path to align with the other LMI example style? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
num_inference_steps: 30 | ||
compile_unet: true | ||
compile_mode: "max-autotune" | ||
compile_vae: true | ||
enable_fused_projections: true | ||
do_quant: "int8dynamic" | ||
change_comp_config: true | ||
no_sdpa: false | ||
no_bf16: false | ||
upcast_vae: false | ||
profile: true |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import argparse | ||
import json | ||
from datetime import datetime | ||
|
||
import numpy as np | ||
import requests | ||
from PIL import Image | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--url", type=str, required=True, help="Torchserve inference endpoint" | ||
) | ||
parser.add_argument( | ||
"--prompt", type=str, required=True, help="Prompt for image generation" | ||
) | ||
parser.add_argument( | ||
"--filename", | ||
type=str, | ||
default="output-{}.jpg".format(str(datetime.now().strftime("%Y%m%d%H%M%S"))), | ||
help="Filename of output image", | ||
) | ||
args = parser.parse_args() | ||
|
||
response = requests.post(args.url, data=args.prompt) | ||
# Contruct image from response | ||
image = Image.fromarray(np.array(json.loads(response.text), dtype="uint8")) | ||
image.save(args.filename) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the order needs to be changed b/c the input gpu_id is < 0 for cpu case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated the logic