-
Notifications
You must be signed in to change notification settings - Fork 863
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
llama2 70b chat accelerate example (#2494)
* llam2 accelerate example * add readme * fmt * fixing the padding and prompt * update steps * Updated readme with more details * changed to inheriting from basehandler * add model_path * change to int8 * add download cmd * update download path * minor edit for model_path --------- Co-authored-by: Geeta Chauhan <4461127+chauhang@users.noreply.github.com> Co-authored-by: Mark Saroufim <marksaroufim@fb.com> Co-authored-by: Ankith Gunapal <agunapal@ischool.Berkeley.edu> Co-authored-by: Hamid Shojanazeri <hamid.nazeri2010@gmail.com>
- Loading branch information
1 parent
683608b
commit 04e0b37
Showing
6 changed files
with
224 additions
and
0 deletions.
There are no files selected for viewing
60 changes: 60 additions & 0 deletions
60
examples/large_models/Huggingface_accelerate/llama2/Readme.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Loading meta-llama/Llama-2-70b-chat-hf on AWS EC2 g5.24xlarge using accelerate | ||
|
||
This document briefs on serving large HG models with limited resource using accelerate. This option can be activated with `low_cpu_mem_usage=True`. The model is first created on the Meta device (with empty weights) and the state dict is then loaded inside it (shard by shard in the case of a sharded checkpoint). | ||
|
||
### Step 1: Download model Permission | ||
|
||
Follow [this instruction](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) to get permission | ||
|
||
Login with a Hugging Face account | ||
``` | ||
huggingface-cli login | ||
# or using an environment variable | ||
huggingface-cli login --token $HUGGINGFACE_TOKEN | ||
``` | ||
|
||
```bash | ||
python ../Download_model.py --model_path model --model_name meta-llama/Llama-2-70b-chat-hf | ||
``` | ||
Model will be saved in the following path, `model/models--meta-llama--Llama-2-70b-chat-hf`. | ||
|
||
### Step 2: Generate MAR file | ||
|
||
Add the downloaded path to " model_path:" in `model-config.yaml` and run the following. | ||
|
||
```bash | ||
torch-model-archiver --model-name llama2-70b-chat --version 1.0 --handler custom_handler.py --config-file model-config.yaml -r requirements.txt --archive-format no-archive | ||
``` | ||
|
||
If you are using conda, and notice issues with mpi4py, you would need to install openmpi-mpicc using the following | ||
|
||
``` | ||
conda install -c conda-forge openmpi-mpicc | ||
``` | ||
|
||
### Step 3: Add the mar file to model store | ||
|
||
```bash | ||
mkdir model_store | ||
mv llama2-70b-chat model_store | ||
mv model model_store/llama2-70b-chat | ||
``` | ||
|
||
### Step 3: Start torchserve | ||
|
||
Update config.properties and start torchserve | ||
|
||
```bash | ||
torchserve --start --ncs --ts-config config.properties --model-store model_store --models llama2-70b-chat | ||
``` | ||
|
||
### Step 4: Run inference | ||
|
||
```bash | ||
curl -v "http://localhost:8080/predictions/llama2-70b-chat" -T sample_text.txt | ||
``` | ||
|
||
results in the following output | ||
``` | ||
Mayonnaise is a thick, creamy condiment made from a mixture of egg yolks, oil, vinegar or lemon juice, and seasonings' | ||
``` |
6 changes: 6 additions & 0 deletions
6
examples/large_models/Huggingface_accelerate/llama2/config.properties
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
inference_address=http://0.0.0.0:8080 | ||
management_address=http://0.0.0.0:8081 | ||
metrics_address=http://0.0.0.0:8082 | ||
enable_envvars_config=true | ||
install_py_dep_per_model=true | ||
|
139 changes: 139 additions & 0 deletions
139
examples/large_models/Huggingface_accelerate/llama2/custom_handler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
import logging | ||
from abc import ABC | ||
|
||
import torch | ||
import transformers | ||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | ||
from accelerate import init_empty_weights | ||
from accelerate import load_checkpoint_and_dispatch | ||
|
||
from ts.context import Context | ||
from ts.torch_handler.base_handler import BaseHandler | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.info("Transformers version %s", transformers.__version__) | ||
|
||
|
||
class LlamaHandler(BaseHandler, ABC): | ||
""" | ||
Transformers handler class for sequence, token classification and question answering. | ||
""" | ||
|
||
def __init__(self): | ||
super(LlamaHandler, self).__init__() | ||
self.max_length = None | ||
self.max_new_tokens = None | ||
self.tokenizer = None | ||
self.initialized = False | ||
|
||
def initialize(self, ctx: Context): | ||
"""In this initialize function, the HF large model is loaded and | ||
partitioned using DeepSpeed. | ||
Args: | ||
ctx (context): It is a JSON Object containing information | ||
pertaining to the model artifacts parameters. | ||
""" | ||
model_dir = ctx.system_properties.get("model_dir") | ||
self.max_length = int(ctx.model_yaml_config["handler"]["max_length"]) | ||
self.max_new_tokens = int(ctx.model_yaml_config["handler"]["max_new_tokens"]) | ||
model_name = ctx.model_yaml_config["handler"]["model_name"] | ||
model_path = f'{model_dir}/{ctx.model_yaml_config["handler"]["model_path"]}' | ||
seed = int(ctx.model_yaml_config["handler"]["manual_seed"]) | ||
torch.manual_seed(seed) | ||
|
||
logger.info("Model %s loading tokenizer", ctx.model_name) | ||
self.model = AutoModelForCausalLM.from_pretrained( | ||
model_path, | ||
device_map="balanced", | ||
low_cpu_mem_usage=True, | ||
torch_dtype=torch.float16, | ||
load_in_8bit=True, | ||
trust_remote_code=True) | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
self.tokenizer.add_special_tokens( | ||
{ | ||
|
||
"pad_token": "<PAD>", | ||
} | ||
) | ||
self.model.resize_token_embeddings(self.model.config.vocab_size + 1) | ||
|
||
logger.info("Model %s loaded successfully", ctx.model_name) | ||
self.initialized = True | ||
|
||
def preprocess(self, requests): | ||
""" | ||
Basic text preprocessing, based on the user's choice of application mode. | ||
Args: | ||
requests (list): A list of dictionaries with a "data" or "body" field, each | ||
containing the input text to be processed. | ||
Returns: | ||
tuple: A tuple with two tensors: the batch of input ids and the batch of | ||
attention masks. | ||
""" | ||
input_texts = [data.get("data") or data.get("body") for data in requests] | ||
input_ids_batch, attention_mask_batch = [], [] | ||
for input_text in input_texts: | ||
input_ids, attention_mask = self.encode_input_text(input_text) | ||
input_ids_batch.append(input_ids) | ||
attention_mask_batch.append(attention_mask) | ||
input_ids_batch = torch.cat(input_ids_batch, dim=0).to(self.model.device) | ||
attention_mask_batch = torch.cat(attention_mask_batch, dim=0).to(self.device) | ||
return input_ids_batch, attention_mask_batch | ||
|
||
def encode_input_text(self, input_text): | ||
""" | ||
Encodes a single input text using the tokenizer. | ||
Args: | ||
input_text (str): The input text to be encoded. | ||
Returns: | ||
tuple: A tuple with two tensors: the encoded input ids and the attention mask. | ||
""" | ||
if isinstance(input_text, (bytes, bytearray)): | ||
input_text = input_text.decode("utf-8") | ||
logger.info("Received text: '%s'", input_text) | ||
inputs = self.tokenizer.encode_plus( | ||
input_text, | ||
max_length=self.max_length, | ||
padding=True, | ||
add_special_tokens=True, | ||
return_tensors="pt", | ||
truncation=True, | ||
) | ||
input_ids = inputs["input_ids"] | ||
attention_mask = inputs["attention_mask"] | ||
return input_ids, attention_mask | ||
|
||
def inference(self, input_batch): | ||
""" | ||
Predicts the class (or classes) of the received text using the serialized transformers | ||
checkpoint. | ||
Args: | ||
input_batch (tuple): A tuple with two tensors: the batch of input ids and the batch | ||
of attention masks, as returned by the preprocess function. | ||
Returns: | ||
list: A list of strings with the predicted values for each input text in the batch. | ||
""" | ||
input_ids_batch, attention_mask_batch = input_batch | ||
input_ids_batch = input_ids_batch.to(self.device) | ||
outputs = self.model.generate( | ||
input_ids_batch, | ||
attention_mask=attention_mask_batch, | ||
max_length=self.max_new_tokens, | ||
) | ||
|
||
inferences = self.tokenizer.batch_decode( | ||
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False | ||
) | ||
|
||
logger.info("Generated text: %s", inferences) | ||
return inferences | ||
|
||
def postprocess(self, inference_output): | ||
"""Post Process Function converts the predicted response into Torchserve readable format. | ||
Args: | ||
inference_output (list): It contains the predicted response of the input text. | ||
Returns: | ||
(list): Returns a list of the Predictions and Explanations. | ||
""" | ||
return inference_output |
13 changes: 13 additions & 0 deletions
13
examples/large_models/Huggingface_accelerate/llama2/model-config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# TorchServe frontend parameters | ||
minWorkers: 1 | ||
maxWorkers: 1 | ||
maxBatchDelay: 100 | ||
responseTimeout: 1200 | ||
deviceType: "gpu" | ||
|
||
handler: | ||
model_name: "meta-llama/Llama-2-70b-chat-hf" | ||
model_path: "model/models--meta-llama--Llama-2-70b-chat-hf/snapshots/9ff8b00464fc439a64bb374769dec3dd627be1c2" | ||
max_length: 50 | ||
max_new_tokens: 50 | ||
manual_seed: 40 |
5 changes: 5 additions & 0 deletions
5
examples/large_models/Huggingface_accelerate/llama2/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
transformers==4.31.0 | ||
accelerate | ||
bitsandbytes | ||
scipy | ||
mpi4py |
1 change: 1 addition & 0 deletions
1
examples/large_models/Huggingface_accelerate/llama2/sample_text.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
what is the recipe of mayonnaise? |