Skip to content

Commit

Permalink
Update llm_diffusion_serving_app
Browse files Browse the repository at this point in the history
  • Loading branch information
ravi9 committed Nov 4, 2024
1 parent c346a1f commit 0d965e6
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 31 deletions.
24 changes: 14 additions & 10 deletions examples/usecases/llm_diffusion_serving_app/docker/client_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@

st.subheader("Number of images to generate")
num_images = st.sidebar.number_input(
"num_of_img", min_value=1, max_value=6, value=2, step=1
"num_of_img", min_value=1, max_value=8, value=2, step=1
)

st.subheader("LLM Model parameters")
Expand All @@ -87,7 +87,7 @@
)


st.subheader("SD Model parameters")
st.subheader("Stable Diffusion Parameters")
num_inference_steps = st.sidebar.number_input(
"steps", min_value=1, max_value=100, value=5, step=1
)
Expand Down Expand Up @@ -207,19 +207,22 @@ def generate_llm_model_response(prompt_template_with_user_input, user_prompt):
intro_container = st.container()
with intro_container:
st.markdown("""
The multi-image generation app generate similar image generation prompts using **LLaMA3** and
The multi-image generation app generates similar image generation prompts using **LLaMA-3.2** and
these prompts are then processed in parallel by **Stable Diffusion**, which is optimized
using the **latent-consistency/lcm-sdxl** model and accelerated with **Torch.compile** using the
**OpenVINO** backend. This approach enables efficient and high-quality image generation,
offering users a selection of interpretations to choose from.
""")
st.image("workflow-2.png")
st.markdown("""
**NOTE:** The initial image generations might take longer due to model initialization and warm-up.
Subsequent generations should be faster !
""")

user_prompt = st.text_input("Enter an Image Generation Prompt :")

prompt_container = st.container()
status_container = st.container()
# image_container = st.container()

if 'gen_images' not in st.session_state:
st.session_state.gen_images = []
Expand All @@ -233,9 +236,11 @@ def display_images_in_grid(images, captions):
col.image(img, caption=caption, use_column_width=True)

def display_prompts():
prompt_container.write(f"Generated prompts:")
for pr in st.session_state.llm_prompts:
prompt_container.write(pr)
prompt_container.write(f"Generated Prompts:")
prompt_list = ""
for i, pr in enumerate(st.session_state.llm_prompts, 1):
prompt_list += f"{i}. {pr}\n"
prompt_container.markdown(prompt_list)

if 'llm_prompts' not in st.session_state:
st.session_state.llm_prompts = None
Expand All @@ -258,14 +263,13 @@ def display_prompts():
display_prompts()
prompt_container.write(f"LLM time: {st.session_state.llm_time:.2f} sec.")
else:
st.warning("Start TorchServe and Register models in the Server/Control Center App running at port 8084 ...", icon="⚠️")
st.warning("Start TorchServe and Register models in the Server App running at port 8084.", icon="⚠️")


if not st.session_state.llm_prompts:
prompt_container.write(f"Enter Image Generation Prompt and Click Generate Prompts !")
pass
elif len(st.session_state.llm_prompts) != num_images:
#prompt_container.write(f"Generate the prompts again!")
st.warning("Generate the prompts again !", icon="⚠️")
pass
else:
Expand All @@ -276,7 +280,7 @@ def display_prompts():

sd_res = asyncio.run(generate_sd_response_v2(st.session_state.llm_prompts))

if sd_res is not None: # Only proceed if no errors
if sd_res is not None:
images = sd_response_postprocess(sd_res)
st.session_state.gen_images[:0] = images

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ pt2:
handler:
profile: true
model_store_dir: "/home/model-server/model-store/"
max_new_tokens: 50
max_new_tokens: 40
compile: true
fx_graph_cache: true
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
minWorkers: 1
minWorkers: 2
maxWorkers: 4
maxBatchDelay: 200
responseTimeout: 3600
Expand All @@ -7,8 +7,6 @@ pt2:
backend: "openvino"
options:
device: "CPU"
model_caching: true
cache_dir: "./sd_model_cache"
config:
PERFORMANCE_HINT: "LATENCY"
handler:
Expand Down
49 changes: 32 additions & 17 deletions examples/usecases/llm_diffusion_serving_app/docker/server_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import requests
import streamlit as st
import logging

logger = logging.getLogger(__name__)

MODEL_NAME_LLM = os.environ["MODEL_NAME_LLM"]
MODEL_NAME_LLM = MODEL_NAME_LLM.replace("/", "---")
Expand Down Expand Up @@ -60,7 +63,7 @@ def _register_model(url, MODEL_NAME):
st.session_state.started = True
return False

print(f"registering {MODEL_NAME}")
print(f"Registering {MODEL_NAME}")
st.session_state.registered[MODEL_NAME] = True
st.session_state.stopped = False
server_state_container.caption(res.text)
Expand All @@ -72,19 +75,32 @@ def register_model(MODEL_NAME):
if not st.session_state.started:
server_state_container.caption("TorchServe is not running. Start it")
return
url = f"http://127.0.0.1:8081/models?model_name={MODEL_NAME}&url={MODEL_NAME}&batch_size=1&max_batch_delay=3000&initial_workers=1&synchronous=true&disable_token_authorization=true"

url = (f"http://127.0.0.1:8081/models"
f"?model_name={MODEL_NAME}"
f"&url={MODEL_NAME}"
f"&batch_size=1"
f"&max_batch_delay=3000"
f"&initial_workers=1"
f"&synchronous=true"
f"&disable_token_authorization=true")

return _register_model(url, MODEL_NAME)

def register_models(models: list):
for model in models:
if not register_model(model):
for model in models:
success = register_model(model)
# If registration fails, exit the function early
if not success:
logger.error(f"Failed to register model: {model}")
return
# Call scale_sd_workers after model registration, which overrides min_workers in model-config.yaml
scale_sd_workers()
logger.info("All models registered successfully.")

def get_model_status():
for MODEL_NAME in [MODEL_NAME_LLM, MODEL_NAME_SD]:
print(
f"registered state for {MODEL_NAME} is {st.session_state.registered[MODEL_NAME]}"
)
print(f"Registered state for {MODEL_NAME} is {st.session_state.registered[MODEL_NAME]}")
if st.session_state.registered[MODEL_NAME]:
url = f"http://localhost:8081/models/{MODEL_NAME}"
res = requests.get(url)
Expand All @@ -99,10 +115,9 @@ def get_model_status():
else:
model_state_container.write(f"{MODEL_NAME} is not registered ! ")


def scale_sd_workers(workers):
def scale_sd_workers(workers_key="sd_workers"):
if st.session_state.registered[MODEL_NAME_SD]:
num_workers = st.session_state[workers]
num_workers = st.session_state.get(workers_key, 2)
url = (
f"http://localhost:8081/models/{MODEL_NAME_SD}?min_worker="
f"{str(num_workers)}&synchronous=true"
Expand Down Expand Up @@ -172,13 +187,13 @@ def get_sw_versions():
st.subheader("SD Model parameters")

workers_sd = st.sidebar.number_input(
"Num Workers SD",
key="Num Workers SD",
"Num Workers for Stable Diffusion",
key="sd_workers",
min_value=1,
max_value=4,
value=4,
value=2,
on_change=scale_sd_workers,
args=("Num Workers SD",),
args=("sd_workers",),
)

if st.session_state.started:
Expand All @@ -194,17 +209,17 @@ def get_sw_versions():
st.success(f"Registered model {MODEL_NAME_SD}", icon="✅")



# Server Page UI

st.title("Multi-Image Generation App Control Center")
image_container = st.container()
with image_container:
st.markdown("""
This Streamlit app is designed to generate multiple images based on a provided text prompt.
It leverages **TorchServe** for efficient model serving and management, and utilizes **LLaMA3**
It leverages **TorchServe** for efficient model serving and management, and utilizes **LLaMA3.2**
for prompt generation, and **Stable Diffusion**
with **latent-consistency/lcm-sdxl** and **Torch.compile** using **OpenVINO backend** for image generation.
After Starting TorchServe and Registering models, go to Client App running at port 8085.
""")
st.image("workflow-1.png")

Expand Down

0 comments on commit 0d965e6

Please sign in to comment.