diff --git a/README.md b/README.md index 34bba8fd9d..5c6bc6ca4e 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,7 @@ Refer to [torchserve docker](docker/README.md) for details. ## 🏆 Highlighted Examples +* [Chatbot with Llama 2 on Mac 🦙💬](examples/LLM/llama2/chat_app) * [🤗 HuggingFace Transformers](examples/Huggingface_Transformers) with a [Better Transformer Integration/ Flash Attention & Xformer Memory Efficient ](examples/Huggingface_Transformers#Speed-up-inference-with-Better-Transformer) * [Model parallel inference](examples/Huggingface_Transformers#model-parallelism) * [MultiModal models with MMF](https://github.com/pytorch/serve/tree/master/examples/MMF-activity-recognition) combining text, audio and video diff --git a/examples/LLM/llama2/chat_app/Readme.md b/examples/LLM/llama2/chat_app/Readme.md new file mode 100644 index 0000000000..4684bd3132 --- /dev/null +++ b/examples/LLM/llama2/chat_app/Readme.md @@ -0,0 +1,142 @@ + +# TorchServe Llama 2 Chatapp + +This is an example showing how to deploy a llama2 chat app using TorchServe. +We use [streamlit](https://github.com/streamlit/streamlit) to create the app + +We are using [llama-cpp-python](https://github.com/abetlen/llama-cpp-python) in this example + +You can run this example on your laptop to understand how to use TorchServe + + +## Architecture + +![Chatbot Architecture](./screenshots/architecture.png) + + +## Pre-requisites + +The following example has been tested on M1 Mac. +Before you install TorchServe, make sure you have the following installed +1) JDK 17 + +Make sure your javac version is `17.x.x` +``` +javac --version +javac 17.0.8 +``` +You can download it from [java](https://www.oracle.com/java/technologies/downloads/#jdk17-mac) +2) Install conda with support for arm64 + +3) Since we are running this example on Mac, we will use the 7B llama2 model. +Download llama2-7b weights by following instructions [here](https://github.com/pytorch/serve/tree/master/examples/large_models/Huggingface_accelerate/llama2#step-1-download-model-permission) + +4) Install streamlit with + +``` +python -m pip install -r requirements.txt +``` + + +### Steps + +#### Install TorchServe +Install TorchServe with the following steps + +``` +python ts_scripts/install_dependencies.py +pip install torchserve torch-model-archiver torch-workflow-archiver +``` + +#### Package model for TorchServe + +Run this script to create `llamacpp.tar.gz` to be loaded in TorchServe + +``` +source package_llama.sh +``` +This creates the quantized weights in `$LLAMA2_WEIGHTS` + +For subsequent runs, we don't need to regenerate these weights. We only need to package the handler, model-config.yaml in the tar file. + +Hence, you can skip the model generation by running the script as follows + +``` +source package_llama.sh false +``` + +You might need to run the below command if the script output indicates it. +``` +sudo xcodebuild -license +``` + +The script is setting an env variable `LLAMA2_Q4_MODEL` and using this in the handler. In an actual use-case, you would set the path to the weights in `model-config.yaml` + +``` +handler: + model_name: "llama-cpp" + model_path: "=1.26.0 \ No newline at end of file diff --git a/examples/LLM/llama2/chat_app/screenshots/Client.png b/examples/LLM/llama2/chat_app/screenshots/Client.png new file mode 100644 index 0000000000..6e876c244a Binary files /dev/null and b/examples/LLM/llama2/chat_app/screenshots/Client.png differ diff --git a/examples/LLM/llama2/chat_app/screenshots/Server.png b/examples/LLM/llama2/chat_app/screenshots/Server.png new file mode 100644 index 0000000000..5e8c1d7b7c Binary files /dev/null and b/examples/LLM/llama2/chat_app/screenshots/Server.png differ diff --git a/examples/LLM/llama2/chat_app/screenshots/Workers.png b/examples/LLM/llama2/chat_app/screenshots/Workers.png new file mode 100644 index 0000000000..a8418a4fd2 Binary files /dev/null and b/examples/LLM/llama2/chat_app/screenshots/Workers.png differ diff --git a/examples/LLM/llama2/chat_app/screenshots/architecture.png b/examples/LLM/llama2/chat_app/screenshots/architecture.png new file mode 100644 index 0000000000..5c158175c3 Binary files /dev/null and b/examples/LLM/llama2/chat_app/screenshots/architecture.png differ diff --git a/examples/LLM/llama2/chat_app/screenshots/batch_size.png b/examples/LLM/llama2/chat_app/screenshots/batch_size.png new file mode 100644 index 0000000000..1cb128ade5 Binary files /dev/null and b/examples/LLM/llama2/chat_app/screenshots/batch_size.png differ diff --git a/examples/LLM/llama2/chat_app/torchserve_server_app.py b/examples/LLM/llama2/chat_app/torchserve_server_app.py new file mode 100644 index 0000000000..74b1b2060d --- /dev/null +++ b/examples/LLM/llama2/chat_app/torchserve_server_app.py @@ -0,0 +1,171 @@ +import json +import os + +import requests +import streamlit as st + +MODEL_NAME = "llamacpp" +# App title +st.set_page_config(page_title="🦙💬 Llama 2 TorchServe Serve") + + +def start_server(): + os.system("torchserve --start --model-store model_store --ncs") + st.session_state.started = True + st.session_state.stopped = False + st.session_state.registered = False + + +def stop_server(): + os.system("torchserve --stop") + st.session_state.stopped = True + st.session_state.started = False + st.session_state.registered = False + + +def _register_model(url): + res = requests.post(url) + if res.status_code != 200: + server_state_container.error("Error registering model", icon="🚫") + st.session_state.started = True + return + st.session_state.registered = True + st.session_state.started = False + st.session_state.stopped = False + server_state_container.caption(res.text) + + +def register_model(): + if not st.session_state.started: + server_state_container.caption("TorchServe is not running. Start it") + return + url = ( + f"http://localhost:8081/models?model_name={MODEL_NAME}&url={MODEL_NAME}" + f".tar.gz&initial_workers=1&synchronous=true" + ) + _register_model(url) + + +def get_status(): + if st.session_state.registered: + url = f"http://localhost:8081/models/{MODEL_NAME}" + res = requests.get(url) + if res.status_code != 200: + model_state_container.error("Error getting model status", icon="🚫") + return + status = json.loads(res.text)[0] + model_state_container.write(status) + + +def scale_workers(workers): + if st.session_state.registered: + num_workers = st.session_state[workers] + url = ( + f"http://localhost:8081/models/{MODEL_NAME}?min_worker=" + f"{str(num_workers)}&synchronous=true" + ) + res = requests.put(url) + server_state_container.caption(res.text) + + +def set_batch_size(batch_size): + if st.session_state.registered: + url = f"http://localhost:8081/models/{MODEL_NAME}/1.0" + res = requests.delete(url) + server_state_container.caption(res.text) + st.session_state.registered = False + + batch_size = st.session_state[batch_size] + url = ( + f"http://localhost:8081/models?model_name={MODEL_NAME}&url={MODEL_NAME}" + f".tar.gz&batch_size={str(batch_size)}&initial_workers={str(workers)}" + f"&synchronous=true&max_batch_delay={str(max_batch_delay)}" + ) + _register_model(url) + + +def set_max_batch_delay(max_batch_delay): + if st.session_state.registered: + url = f"http://localhost:8081/models/{MODEL_NAME}/1.0" + res = requests.delete(url) + server_state_container.caption(res.text) + st.session_state.registered = False + + max_batch_delay = st.session_state[max_batch_delay] + url = ( + f"http://localhost:8081/models?model_name={MODEL_NAME}&url=" + f"{MODEL_NAME}.tar.gz&batch_size={str(batch_size)}&initial_workers=" + f"{str(workers)}&synchronous=true&max_batch_delay={str(max_batch_delay)}" + ) + _register_model(url) + + +if "started" not in st.session_state: + st.session_state.started = False +if "stopped" not in st.session_state: + st.session_state.stopped = False +if "registered" not in st.session_state: + st.session_state.registered = False + +with st.sidebar: + st.title("🦙💬 Llama 2 TorchServe Server ") + + st.button("Start Server", on_click=start_server) + st.button("Stop Server", on_click=stop_server) + st.button("Register Llama2", on_click=register_model) + workers = st.sidebar.slider( + "Num Workers", + key="Num Workers", + min_value=1, + max_value=4, + value=1, + step=1, + on_change=scale_workers, + args=("Num Workers",), + ) + batch_size = st.sidebar.select_slider( + "Batch Size", + key="Batch Size", + options=[2**j for j in range(0, 8)], + on_change=set_batch_size, + args=("Batch Size",), + ) + max_batch_delay = st.sidebar.slider( + "Max Batch Delay", + key="Max Batch Delay", + min_value=100, + max_value=10000, + value=100, + step=100, + on_change=set_max_batch_delay, + args=("Max Batch Delay",), + ) + + if st.session_state.started: + st.success("Started TorchServe", icon="✅") + + if st.session_state.stopped: + st.success("Stopped TorchServe", icon="✅") + + if st.session_state.registered: + st.success("Registered model", icon="✅") + +st.title("TorchServe Status") +server_state_container = st.container() +server_state_container.subheader("Server status:") + +if st.session_state.started: + server_state_container.success("Started TorchServe", icon="✅") + +if st.session_state.stopped: + server_state_container.success("Stopped TorchServe", icon="✅") + +if st.session_state.registered: + server_state_container.success("Registered model", icon="✅") + +model_state_container = st.container() +with model_state_container: + st.subheader("Model Status") + +with model_state_container: + st.button("Model Status", on_click=get_status) diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt index 641873358f..2b1b907552 100644 --- a/ts_scripts/spellcheck_conf/wordlist.txt +++ b/ts_scripts/spellcheck_conf/wordlist.txt @@ -1095,4 +1095,13 @@ PreprocessCallCount AOT microbatches tokenization +Chatapp +autoscaled +cpp +javac +llamacpp +streamlit tp +quantized +Chatbot +LLM