From 6b2a07c871ef27f9fddbaaf6240c0c0337ad7412 Mon Sep 17 00:00:00 2001 From: lxning Date: Thu, 16 Nov 2023 12:06:05 -0800 Subject: [PATCH 01/49] fmt --- .../inferentia2/llama2/inf2_cb_handler.py | 433 ++++++++++++++++++ .../inferentia2/llama2/model-config.yaml | 1 + 2 files changed, 434 insertions(+) create mode 100644 examples/large_models/inferentia2/llama2/inf2_cb_handler.py diff --git a/examples/large_models/inferentia2/llama2/inf2_cb_handler.py b/examples/large_models/inferentia2/llama2/inf2_cb_handler.py new file mode 100644 index 0000000000..2d41057ba2 --- /dev/null +++ b/examples/large_models/inferentia2/llama2/inf2_cb_handler.py @@ -0,0 +1,433 @@ +import logging +import os +import types +from abc import ABC + +import torch +import transformers +from transformers import AutoConfig + +from ts.context import Context +from ts.torch_handler.base_handler import BaseHandler + +logger = logging.getLogger(__name__) +logger.info("Transformers version %s", transformers.__version__) + + +class LlamaHandler(BaseHandler, ABC): + """ + Transformers handler class for sequence, token classification and question answering. + """ + + def __init__(self): + super(LlamaHandler, self).__init__() + self.max_length = None + self.max_new_tokens = None + self.tokenizer = None + self.micro_batch_size = 1 + self.encoded_empty_padding = None + self.prefilled_ts_inf2_encoded_padding = False + self.initialized = False + + def initialize(self, ctx: Context): + """In this initialize function, the HF large model is loaded and + partitioned using DeepSpeed. + Args: + ctx (context): It is a JSON Object containing information + pertaining to the model artifacts parameters. + """ + model_dir = ctx.system_properties.get("model_dir") + model_checkpoint_dir = ctx.model_yaml_config.get("handler", {}).get( + "model_checkpoint_dir", "" + ) + model_checkpoint_path = f"{model_dir}/{model_checkpoint_dir}" + os.environ["NEURONX_CACHE"] = "on" + os.environ["NEURONX_DUMP_TO"] = f"{model_dir}/neuron_cache" + os.environ["NEURON_CC_FLAGS"] = "--model-type=transformer-inference" + + # settings for model compiliation and loading + amp = ctx.model_yaml_config.get("handler", {}).get("amp", "fp32") + tp_degree = ctx.model_yaml_config.get("handler", {}).get("tp_degree", 6) + self.max_length = int(ctx.model_yaml_config["handler"]["max_length"]) + self.max_new_tokens = int(ctx.model_yaml_config["handler"]["max_new_tokens"]) + self.micro_batch_size = int( + ctx.model_yaml_config.get("micro_batching", {}).get("micro_batch_size", 1) + ) + + # allocate "tp_degree" number of neuron cores to the worker process + os.environ["NEURON_RT_NUM_CORES"] = str(tp_degree) + try: + num_neuron_cores_available = ( + torch_neuronx.xla_impl.data_parallel.device_count() + ) + assert num_neuron_cores_available >= int(tp_degree) + except (RuntimeError, AssertionError) as error: + logger.error( + "Required number of neuron cores for tp_degree " + + str(tp_degree) + + " are not available: " + + str(error) + ) + + raise error + + self.tokenizer = LlamaTokenizer.from_pretrained(model_checkpoint_path) + self.tokenizer.pad_token = self.tokenizer.eos_token + self.model = LlamaForSampling.from_pretrained( + model_checkpoint_path, + batch_size=ctx.system_properties.get("batch_size"), + amp=amp, + tp_degree=tp_degree, + ) + logger.info("Starting to compile the model") + self.model.to_neuron() + logger.info("Model has been successfully compiled") + model_config = AutoConfig.from_pretrained(model_checkpoint_path) + self.model = HuggingFaceGenerationModelAdapter(model_config, self.model) + + self.model.resize_token_embeddings(self.model.config.vocab_size + 1) + + # Replace _update_model_kwargs_for_generation of model with a method that extracts the kv cache for us + old_update = self.model._update_model_kwargs_for_generation + ctx.cache = {} + ctx.kv_cache = {} + encoded = self.tokenizer( + "", return_tensors="pt", padding=True, return_token_type_ids=False + ) + encoded["past_key_values"] = None + self.context.cache["ts_inf2_encoded_padding"] = { + # "stopping_criteria": self._create_stopping_criteria(req_id, max_new_tokens=data["max_new_tokens"]), + "stopping_criteria": self._create_stopping_criteria( + "ts_inf2_encoded_padding", max_new_tokens=self.max_new_tokens + ), + "init_encoded": encoded, + "prompt_length": len(encoded["input_ids"]), + } + + def extract_past_key_values_func(self, *args, **kwargs): + ctx.kv_cache["past_key_values"] = args[0]["past_key_values"][0] + if self.prefilled_ts_inf2_encoded_padding is False: + ctx.kv_cache["ts_inf2_empty_padding_past_key_values"] = args[0][ + "past_key_values" + ][1] + return old_update(*args, **kwargs) + + self.model._update_model_kwargs_for_generation = types.MethodType( + extract_past_key_values_func, self.model + ) + + logger.info("Model %s loaded successfully", ctx.model_name) + self.initialized = True + + def preprocess(self, requests): + """ + Basic text preprocessing, based on the user's choice of application mode. + Args: + requests (list): A list of dictionaries with a "data" or "body" field, each + containing the input text to be processed. + Returns: + tuple: A tuple with two tensors: the batch of input ids and the batch of + attention masks. + """ + self._clean_cache() + + prefill, decode = [], [] + for req_id, req_data in zip(self.context.request_ids.values(), requests): + # Tokenizer requests which are not prefilled yet + if not req_id in self.context.cache: + data = req_data["body"] or req_data["data"] + if isinstance(data, (bytes, bytearray)): + data = data.decode("utf-8") + logger.info("Received text: '%s'", data) + encoded = self.tokenizer( + data, return_tensors="pt", padding=True, return_token_type_ids=False + ) + encoded["past_key_values"] = None + self.context.cache[req_id] = { + # "stopping_criteria": self._create_stopping_criteria(req_id, max_new_tokens=data["max_new_tokens"]), + "stopping_criteria": self._create_stopping_criteria( + req_id, max_new_tokens=self.max_new_tokens + ), + "encoded": encoded, + "prompt_length": len(encoded["input_ids"]), + } + prefill.append(req_id) + else: + decode.append(req_id) + + return prefill, decode + + def inference(self, input_batch): + """ + Predicts the class (or classes) of the received text using the serialized transformers + checkpoint. + Args: + input_batch (tuple): A tuple with two tensors: the batch of input ids and the batch + of attention masks, as returned by the preprocess function. + Returns: + list: A list of strings with the predicted values for each input text in the batch. + """ + + prefill, decode_ids = input_batch + + # Prefill requests + results = {} + for req_id in prefill: + results[req_id] = self._run_prefill(req_id) + + # Decode the rest + if decode_ids: + decode_ids.extend( + ["ts_inf2_encoded_padding"] * (self.micro_batch_size - len(decode_ids)) + ) + decode_result = self._run_decode(decode_ids) if decode_ids else {} + results.update(decode_result) + return [results[i] for i in self.context.request_ids.values()] + + def postprocess(self, inference_output): + """Post Process Function converts the predicted response into Torchserve readable format. + Args: + inference_output (list): It contains the predicted response of the input text. + Returns: + (list): Returns a list of the Predictions and Explanations. + """ + + self.context.stopping_criteria = [ + self.context.cache[i]["stopping_criteria"] + for i in self.context.request_ids.values() + ] + + return inference_output + + @torch.no_grad() + def _run_prefill(self, req_id): + assert ( + self.context.cache[req_id]["encoded"]["past_key_values"] is None + ), "There should be no cached values" + # Pad input to match compiled model batch size + input_ids_batch, attention_mask_batch = [], [] + input_ids_batch.append(self.context.cache[req_id]["encoded"]["input_ids"]) + attention_mask_batch.append( + self.context.cache[req_id]["encoded"]["attention_mask"] + ) + input_ids_batch.extend( + [self.context.cache["ts_inf2_encoded_padding"]["init_encoded"]["input_ids"]] + * (self.micro_batch_size - 1) + ) + attention_mask_batch.extend( + [ + self.context.cache["ts_inf2_encoded_padding"]["init_encoded"][ + "attention_mask" + ] + ] + * (self.micro_batch_size - 1) + ) + input_ids_batch = torch.cat(input_ids_batch, dim=0) + attention_mask_batch = torch.cat(attention_mask_batch, dim=0) + output = self.model.generate( + input_ids_batch, + attention_mask=attention_mask_batch, + max_new_tokens=1, + return_dict_in_generate=True, + use_cache=True, + ) + + # Save empty padding output + if self.prefilled_ts_inf2_encoded_padding is False: + attention_mask = self.context.cache["ts_inf2_encoded_padding"]["encoded"][ + "attention_mask" + ] + attention_mask = torch.concat( + (attention_mask, torch.ones((1, 1), dtype=torch.int64)), dim=1 + ) + self.context.cache["ts_inf2_encoded_padding"]["encoded"] = { + "input_ids": output.sequences[1], + "attention_mask": attention_mask, + } + self.prefilled_ts_inf2_encoded_padding = True + + # Save extracted kv cache values and adjust attention mask for next call + self.context.cache[req_id]["encoded"][ + "past_key_values" + ] = self.context.kv_cache["past_key_values"] + del self.context.kv_cache["past_key_values"] + self.context.cache[req_id]["encoded"]["input_ids"] = output.sequences[0] + + attention_mask = self.context.cache[req_id]["encoded"]["attention_mask"] + attention_mask = torch.concat( + (attention_mask, torch.ones((1, 1), dtype=torch.int64)), dim=1 + ) + self.context.cache[req_id]["encoded"]["attention_mask"] = attention_mask + + result = { + "text": self.tokenizer.decode( + output.sequences[0], skip_special_tokens=True + ), + "ids": output.sequences[0].tolist(), + } + logger.info(f"_run_prefill result: {0}".format(result)) + return result["text"] + + def _run_decode(self, ids): + assert len(ids) + + encoded = self._prepare_model_inputs(ids) + + outputs = self.model.generate( + **encoded, max_new_tokens=1, return_dict_in_generate=True, use_cache=True + ) + + results = {} + for idx, req_id in enumerate(ids): + self.context.cache[req_id]["encoded"][ + "past_key_values" + ] = self._collect_kv_cache_of_idx_in_batch(idx) + self.context.cache[req_id]["encoded"]["input_ids"] = outputs.sequences[ + idx + ].unsqueeze(0) + attention_mask = encoded["attention_mask"][idx].unsqueeze(0) + attention_mask = torch.concat( + (attention_mask, torch.ones((1, 1), dtype=torch.int64)), dim=1 + ) + self.context.cache[req_id]["encoded"]["attention_mask"] = attention_mask + results[req_id] = { + "text": self.tokenizer.decode( + outputs.sequences[idx][-1], skip_special_tokens=True + ), + "ids": [outputs.sequences[idx][-1].item()], + } + del self.context.kv_cache["past_key_values"] + return results + + def _prepare_model_inputs(self, ids): + lengths = list( + torch.sum(self.context.cache[i]["encoded"]["attention_mask"], dim=1).item() + for i in ids + ) + max_len = max(lengths) + + input_ids = [] + attention_mask = [] + kv_cache = {} + for req_id, seq_len in zip(ids, lengths): + input_ids.append(self.context.cache[req_id]["encoded"]["input_ids"]) + attention_mask.append( + self.context.cache[req_id]["encoded"]["attention_mask"] + ) + + for layer_idx, layer_kv in enumerate( + self.context.cache[req_id]["encoded"]["past_key_values"] + ): + k, v = layer_kv + kv_cache[layer_idx] = kv_cache.get(layer_idx, {}) + kv_cache[layer_idx][0] = kv_cache.get(layer_idx, {}).get(0, []) + [k] + kv_cache[layer_idx][1] = kv_cache.get(layer_idx, {}).get(1, []) + [v] + padded_len = input_ids[-1].size()[-1] + if padded_len < max_len: + # Apply padding to input_ids, attention_mask and past_key_values + n = max_len - seq_len + input_ids[-1] = torch.concat( + ( + self.tokenizer.pad_token_id + + torch.zeros((1, n), dtype=torch.int64), + input_ids[-1], + ), + dim=1, + ) + attention_mask[-1] = torch.concat( + (torch.zeros((1, n), dtype=torch.int64), attention_mask[-1]), dim=1 + ) + + size_delta = list(kv_cache[0][0][-1].size()) + size_delta[2] = n + dtype = kv_cache[0][0][-1].dtype + for layer_idx in range(len(kv_cache)): + kv_cache[layer_idx][0][-1] = torch.concat( + ( + torch.zeros(size_delta, dtype=dtype), + kv_cache[layer_idx][0][-1], + ), + dim=2, + ) + kv_cache[layer_idx][1][-1] = torch.concat( + ( + torch.zeros(size_delta, dtype=dtype), + kv_cache[layer_idx][1][-1], + ), + dim=2, + ) + + elif padded_len > max_len: + # Truncate padding from input_ids, attention_mask and past_key_values + input_ids[-1] = input_ids[-1][:, -max_len:] + attention_mask[-1] = attention_mask[-1][:, -max_len:] + + for layer_idx in range(len(kv_cache)): + kv_cache[layer_idx][0][-1] = kv_cache[layer_idx][0][-1][ + :, :, (-max_len + 1) :, : + ] + kv_cache[layer_idx][1][-1] = kv_cache[layer_idx][1][-1][ + :, :, (-max_len + 1) :, : + ] + del self.context.cache[req_id]["encoded"]["past_key_values"] + + for layer_idx in range(len(kv_cache)): + kv_cache[layer_idx][0] = torch.concat(kv_cache[layer_idx][0], dim=0) + kv_cache[layer_idx][1] = torch.concat(kv_cache[layer_idx][1], dim=0) + + kv_cache = tuple( + (kv_cache[layer_idx][0], kv_cache[layer_idx][1]) + for layer_idx in range(len(kv_cache)) + ) + + encoded = { + "input_ids": torch.concat(input_ids, dim=0), + "attention_mask": torch.concat(attention_mask, dim=0), + "past_key_values": kv_cache, + } + return encoded + + def _collect_kv_cache_of_idx_in_batch(self, idx): + # The materialization of the tuple here is important for some reason (TODO: figure out why); Otherwise prediction differ + return tuple( + tuple(kv[idx, ...].unsqueeze(0) for kv in layers) + for layers in self.context.kv_cache["past_key_values"] + ) + + def _create_stopping_criteria(self, req_id, max_new_tokens=25): + class StoppingCriteria(object): + def __init__( + self, + cache, + req_id, + stop_token, + max_new_tokens, + ): + self.req_id = req_id + self.cache = cache + self.max_new_tokens = max_new_tokens + self.stop_token = stop_token + + def __call__(self, res): + self.max_new_tokens -= 1 + + if self.max_new_tokens == 0 or res["ids"][-1] == self.stop_token: + self.clean_up() + return True + return False + + def clean_up(self): + del self.cache[self.req_id] + + return StoppingCriteria( + self.context.cache, + req_id, + self.tokenizer.eos_token_id, + max_new_tokens, + ) + + def _clean_cache(self): + new_ids = set(self.context.request_ids.keys()) + for idx in self.context.kv_cache.keys(): + if idx not in new_ids: + del self.context.kv_cache[idx] diff --git a/examples/large_models/inferentia2/llama2/model-config.yaml b/examples/large_models/inferentia2/llama2/model-config.yaml index 031a409903..6913af49f7 100644 --- a/examples/large_models/inferentia2/llama2/model-config.yaml +++ b/examples/large_models/inferentia2/llama2/model-config.yaml @@ -9,6 +9,7 @@ handler: amp: "bf16" tp_degree: 6 max_length: 100 + max_new_tokens: 25 micro_batching: micro_batch_size: 4 From 6e27cbeda85076252200cd1911067745a47114fc Mon Sep 17 00:00:00 2001 From: lxning Date: Sun, 19 Nov 2023 15:52:03 -0800 Subject: [PATCH 02/49] fmt --- .../inferentia2/llama2/inf2_cb_handler.py | 252 +++++++++--------- 1 file changed, 124 insertions(+), 128 deletions(-) diff --git a/examples/large_models/inferentia2/llama2/inf2_cb_handler.py b/examples/large_models/inferentia2/llama2/inf2_cb_handler.py index 2d41057ba2..72c3ceae49 100644 --- a/examples/large_models/inferentia2/llama2/inf2_cb_handler.py +++ b/examples/large_models/inferentia2/llama2/inf2_cb_handler.py @@ -4,14 +4,15 @@ from abc import ABC import torch -import transformers -from transformers import AutoConfig +import torch_neuronx +from transformers import AutoConfig, LlamaTokenizer +from transformers_neuronx.generation_utils import HuggingFaceGenerationModelAdapter +from transformers_neuronx.llama.model import LlamaForSampling from ts.context import Context from ts.torch_handler.base_handler import BaseHandler logger = logging.getLogger(__name__) -logger.info("Transformers version %s", transformers.__version__) class LlamaHandler(BaseHandler, ABC): @@ -25,8 +26,6 @@ def __init__(self): self.max_new_tokens = None self.tokenizer = None self.micro_batch_size = 1 - self.encoded_empty_padding = None - self.prefilled_ts_inf2_encoded_padding = False self.initialized = False def initialize(self, ctx: Context): @@ -42,7 +41,7 @@ def initialize(self, ctx: Context): ) model_checkpoint_path = f"{model_dir}/{model_checkpoint_dir}" os.environ["NEURONX_CACHE"] = "on" - os.environ["NEURONX_DUMP_TO"] = f"{model_dir}/neuron_cache" + os.environ["NEURON_COMPILE_CACHE_URL"] = f"{model_dir}/neuron_cache" os.environ["NEURON_CC_FLAGS"] = "--model-type=transformer-inference" # settings for model compiliation and loading @@ -75,7 +74,7 @@ def initialize(self, ctx: Context): self.tokenizer.pad_token = self.tokenizer.eos_token self.model = LlamaForSampling.from_pretrained( model_checkpoint_path, - batch_size=ctx.system_properties.get("batch_size"), + batch_size=self.micro_batch_size, amp=amp, tp_degree=tp_degree, ) @@ -85,31 +84,13 @@ def initialize(self, ctx: Context): model_config = AutoConfig.from_pretrained(model_checkpoint_path) self.model = HuggingFaceGenerationModelAdapter(model_config, self.model) - self.model.resize_token_embeddings(self.model.config.vocab_size + 1) - # Replace _update_model_kwargs_for_generation of model with a method that extracts the kv cache for us old_update = self.model._update_model_kwargs_for_generation ctx.cache = {} ctx.kv_cache = {} - encoded = self.tokenizer( - "", return_tensors="pt", padding=True, return_token_type_ids=False - ) - encoded["past_key_values"] = None - self.context.cache["ts_inf2_encoded_padding"] = { - # "stopping_criteria": self._create_stopping_criteria(req_id, max_new_tokens=data["max_new_tokens"]), - "stopping_criteria": self._create_stopping_criteria( - "ts_inf2_encoded_padding", max_new_tokens=self.max_new_tokens - ), - "init_encoded": encoded, - "prompt_length": len(encoded["input_ids"]), - } def extract_past_key_values_func(self, *args, **kwargs): - ctx.kv_cache["past_key_values"] = args[0]["past_key_values"][0] - if self.prefilled_ts_inf2_encoded_padding is False: - ctx.kv_cache["ts_inf2_empty_padding_past_key_values"] = args[0][ - "past_key_values" - ][1] + ctx.kv_cache["past_key_values"] = args[0]["past_key_values"] return old_update(*args, **kwargs) self.model._update_model_kwargs_for_generation = types.MethodType( @@ -131,7 +112,7 @@ def preprocess(self, requests): """ self._clean_cache() - prefill, decode = [], [] + prefill_req_ids, decode, prefill_input_text = [], [], [] for req_id, req_data in zip(self.context.request_ids.values(), requests): # Tokenizer requests which are not prefilled yet if not req_id in self.context.cache: @@ -139,23 +120,17 @@ def preprocess(self, requests): if isinstance(data, (bytes, bytearray)): data = data.decode("utf-8") logger.info("Received text: '%s'", data) - encoded = self.tokenizer( - data, return_tensors="pt", padding=True, return_token_type_ids=False - ) - encoded["past_key_values"] = None - self.context.cache[req_id] = { - # "stopping_criteria": self._create_stopping_criteria(req_id, max_new_tokens=data["max_new_tokens"]), - "stopping_criteria": self._create_stopping_criteria( - req_id, max_new_tokens=self.max_new_tokens - ), - "encoded": encoded, - "prompt_length": len(encoded["input_ids"]), - } - prefill.append(req_id) + prefill_input_text.append(data.strip()) + prefill_req_ids.append(req_id) else: decode.append(req_id) - return prefill, decode + prefill_encoded = None + if len(prefill_input_text) > 0: + prefill_encoded = self._run_tokenizer_batch( + self, prefill_input_text, prefill + ) + return prefill_req_ids, prefill_encoded, decode def inference(self, input_batch): """ @@ -168,18 +143,12 @@ def inference(self, input_batch): list: A list of strings with the predicted values for each input text in the batch. """ - prefill, decode_ids = input_batch + prefill, prefill_encoded, decode_ids = input_batch # Prefill requests - results = {} - for req_id in prefill: - results[req_id] = self._run_prefill(req_id) + results = self._run_prefill_batch(prefill, prefill_encoded) # Decode the rest - if decode_ids: - decode_ids.extend( - ["ts_inf2_encoded_padding"] * (self.micro_batch_size - len(decode_ids)) - ) decode_result = self._run_decode(decode_ids) if decode_ids else {} results.update(decode_result) return [results[i] for i in self.context.request_ids.values()] @@ -191,7 +160,6 @@ def postprocess(self, inference_output): Returns: (list): Returns a list of the Predictions and Explanations. """ - self.context.stopping_criteria = [ self.context.cache[i]["stopping_criteria"] for i in self.context.request_ids.values() @@ -199,84 +167,86 @@ def postprocess(self, inference_output): return inference_output - @torch.no_grad() - def _run_prefill(self, req_id): - assert ( - self.context.cache[req_id]["encoded"]["past_key_values"] is None - ), "There should be no cached values" + def _run_tokenizer_batch(self, prefill_input_text, prefill_req_ids): # Pad input to match compiled model batch size - input_ids_batch, attention_mask_batch = [], [] - input_ids_batch.append(self.context.cache[req_id]["encoded"]["input_ids"]) - attention_mask_batch.append( - self.context.cache[req_id]["encoded"]["attention_mask"] - ) - input_ids_batch.extend( - [self.context.cache["ts_inf2_encoded_padding"]["init_encoded"]["input_ids"]] - * (self.micro_batch_size - 1) - ) - attention_mask_batch.extend( - [ - self.context.cache["ts_inf2_encoded_padding"]["init_encoded"][ - "attention_mask" - ] - ] - * (self.micro_batch_size - 1) + if self.micro_batch_size > len(prefill_req_ids): + prefill_input_text.extend( + [""] * (self.micro_batch_size - len(prefill_req_ids)) + ) + else: + return None + + batch_encoded = self.tokenizer( + prefill_input_text, + return_tensors="pt", + padding=True, + return_token_type_ids=False, ) - input_ids_batch = torch.cat(input_ids_batch, dim=0) - attention_mask_batch = torch.cat(attention_mask_batch, dim=0) - output = self.model.generate( - input_ids_batch, - attention_mask=attention_mask_batch, + for idx, req_id in enumerate(prefill_req_ids): + encoded = { + "input_ids": batch_encoded["input_ids"][idx], + "attention_mask": batch_encoded["attention_mask"][idx], + "past_key_values": None, + } + self.context.cache[req_id] = { + "stopping_criteria": self._create_stopping_criteria( + req_id, max_new_tokens=self.max_new_tokens + ), + "encoded": encoded, + "prompt_length": len(encoded["input_ids"][idx]), + } + + return batch_encoded + + @torch.no_grad() + def _run_prefill_batch(self, prefill_req_ids, prefill_encoded): + outputs = self.model.generata( + input_ids=prefill_encoded["input_ids"], + attention_mask=prefill_encoded["attention_mask"], max_new_tokens=1, return_dict_in_generate=True, use_cache=True, ) - # Save empty padding output - if self.prefilled_ts_inf2_encoded_padding is False: - attention_mask = self.context.cache["ts_inf2_encoded_padding"]["encoded"][ - "attention_mask" - ] + outputs_decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + + # Prefill requests + results = {} + for idx, req_id in enumerate(prefill_req_ids): + # Save extracted kv cache values and adjust attention mask for next call + self.context.cache[req_id]["encoded"][ + "past_key_values" + ] = self._collect_kv_cache_of_idx_in_batch(idx) + self.context.cache[req_id]["encoded"]["input_ids"] = outputs.sequences[idx] + + device = next(iter(self.model.parameters())).device + dtype = torch.int64 + config = {"device": device, "dtype": dtype} + attention_mask = self.context.cache[req_id]["encoded"]["attention_mask"] attention_mask = torch.concat( - (attention_mask, torch.ones((1, 1), dtype=torch.int64)), dim=1 + (attention_mask, torch.ones((1, 1), **config)), dim=1 ) - self.context.cache["ts_inf2_encoded_padding"]["encoded"] = { - "input_ids": output.sequences[1], - "attention_mask": attention_mask, + self.context.cache[req_id]["encoded"]["attention_mask"] = attention_mask + + results[req_id] = { + "text": outputs_decoded[idx], + "ids": outputs.sequences[idx].tolist(), } - self.prefilled_ts_inf2_encoded_padding = True - # Save extracted kv cache values and adjust attention mask for next call - self.context.cache[req_id]["encoded"][ - "past_key_values" - ] = self.context.kv_cache["past_key_values"] del self.context.kv_cache["past_key_values"] - self.context.cache[req_id]["encoded"]["input_ids"] = output.sequences[0] - - attention_mask = self.context.cache[req_id]["encoded"]["attention_mask"] - attention_mask = torch.concat( - (attention_mask, torch.ones((1, 1), dtype=torch.int64)), dim=1 - ) - self.context.cache[req_id]["encoded"]["attention_mask"] = attention_mask - - result = { - "text": self.tokenizer.decode( - output.sequences[0], skip_special_tokens=True - ), - "ids": output.sequences[0].tolist(), - } - logger.info(f"_run_prefill result: {0}".format(result)) - return result["text"] + return results def _run_decode(self, ids): - assert len(ids) - encoded = self._prepare_model_inputs(ids) outputs = self.model.generate( **encoded, max_new_tokens=1, return_dict_in_generate=True, use_cache=True ) + device = next(iter(self.model.parameters())).device + dtype = torch.int64 + config = {"device": device, "dtype": dtype} + results = {} for idx, req_id in enumerate(ids): self.context.cache[req_id]["encoded"][ @@ -287,7 +257,7 @@ def _run_decode(self, ids): ].unsqueeze(0) attention_mask = encoded["attention_mask"][idx].unsqueeze(0) attention_mask = torch.concat( - (attention_mask, torch.ones((1, 1), dtype=torch.int64)), dim=1 + (attention_mask, torch.ones((1, 1), **config)), dim=1 ) self.context.cache[req_id]["encoded"]["attention_mask"] = attention_mask results[req_id] = { @@ -306,57 +276,82 @@ def _prepare_model_inputs(self, ids): ) max_len = max(lengths) + for idx in range(self.micro_batch_size - len(ids)): + ids.append("batch_padding") + lengths.append(0) + + device = next(iter(self.model.parameters())).device + dtype = torch.int64 + config = {"device": device, "dtype": dtype} + input_ids = [] attention_mask = [] kv_cache = {} + for req_id, seq_len in zip(ids, lengths): - input_ids.append(self.context.cache[req_id]["encoded"]["input_ids"]) - attention_mask.append( - self.context.cache[req_id]["encoded"]["attention_mask"] - ) + if req_id != "batch_padding": + input_ids.append(self.context.cache[req_id]["encoded"]["input_ids"]) + attention_mask.append( + self.context.cache[req_id]["encoded"]["attention_mask"] + ) + + for layer_idx, layer_kv in enumerate( + self.context.cache[req_id]["encoded"]["past_key_values"] + ): + k, v = layer_kv + kv_cache[layer_idx] = kv_cache.get(layer_idx, {}) + kv_cache[layer_idx][0] = kv_cache.get(layer_idx, {}).get(0, []) + [ + k + ] + kv_cache[layer_idx][1] = kv_cache.get(layer_idx, {}).get(1, []) + [ + v + ] + else: + config = {"device": device, "dtype": dtype} + input_ids.append( + self.tokenizer.pad_token_id + torch.zeros((1, max_len), **config) + ) + attention_mask.append(torch.zeros((1, max_len), **config)) + for layer_idx in range(len(kv_cache)): + kv_cache[layer_idx][0] = kv_cache.get(layer_idx, {}).get( + 0, [] + ) + torch.zeros((max_len), **config) + kv_cache[layer_idx][1] = kv_cache.get(layer_idx, {}).get( + 1, [] + ) + torch.zeros((max_len), **config) - for layer_idx, layer_kv in enumerate( - self.context.cache[req_id]["encoded"]["past_key_values"] - ): - k, v = layer_kv - kv_cache[layer_idx] = kv_cache.get(layer_idx, {}) - kv_cache[layer_idx][0] = kv_cache.get(layer_idx, {}).get(0, []) + [k] - kv_cache[layer_idx][1] = kv_cache.get(layer_idx, {}).get(1, []) + [v] padded_len = input_ids[-1].size()[-1] if padded_len < max_len: # Apply padding to input_ids, attention_mask and past_key_values n = max_len - seq_len input_ids[-1] = torch.concat( ( - self.tokenizer.pad_token_id - + torch.zeros((1, n), dtype=torch.int64), + self.tokenizer.pad_token_id + torch.zeros((1, n), **config), input_ids[-1], ), dim=1, ) attention_mask[-1] = torch.concat( - (torch.zeros((1, n), dtype=torch.int64), attention_mask[-1]), dim=1 + (torch.zeros((1, n), **config), attention_mask[-1]), dim=1 ) size_delta = list(kv_cache[0][0][-1].size()) size_delta[2] = n - dtype = kv_cache[0][0][-1].dtype for layer_idx in range(len(kv_cache)): kv_cache[layer_idx][0][-1] = torch.concat( ( - torch.zeros(size_delta, dtype=dtype), + torch.zeros(size_delta, **config), kv_cache[layer_idx][0][-1], ), dim=2, ) kv_cache[layer_idx][1][-1] = torch.concat( ( - torch.zeros(size_delta, dtype=dtype), + torch.zeros(size_delta, **config), kv_cache[layer_idx][1][-1], ), dim=2, ) - elif padded_len > max_len: # Truncate padding from input_ids, attention_mask and past_key_values input_ids[-1] = input_ids[-1][:, -max_len:] @@ -369,7 +364,8 @@ def _prepare_model_inputs(self, ids): kv_cache[layer_idx][1][-1] = kv_cache[layer_idx][1][-1][ :, :, (-max_len + 1) :, : ] - del self.context.cache[req_id]["encoded"]["past_key_values"] + if req_id != "batch_padding": + del self.context.cache[req_id]["encoded"]["past_key_values"] for layer_idx in range(len(kv_cache)): kv_cache[layer_idx][0] = torch.concat(kv_cache[layer_idx][0], dim=0) From 7dcfffdff6ce984ed8f4161dd302839b4de4f70f Mon Sep 17 00:00:00 2001 From: lxning Date: Tue, 21 Nov 2023 10:09:11 -0800 Subject: [PATCH 03/49] fmt --- .../inferentia2/llama2/inf2_cb_handler.py | 57 +++++++++++-------- ts/protocol/otf_message_handler.py | 5 +- 2 files changed, 36 insertions(+), 26 deletions(-) diff --git a/examples/large_models/inferentia2/llama2/inf2_cb_handler.py b/examples/large_models/inferentia2/llama2/inf2_cb_handler.py index 72c3ceae49..600850f703 100644 --- a/examples/large_models/inferentia2/llama2/inf2_cb_handler.py +++ b/examples/large_models/inferentia2/llama2/inf2_cb_handler.py @@ -72,6 +72,7 @@ def initialize(self, ctx: Context): self.tokenizer = LlamaTokenizer.from_pretrained(model_checkpoint_path) self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.model = LlamaForSampling.from_pretrained( model_checkpoint_path, batch_size=self.micro_batch_size, @@ -126,9 +127,9 @@ def preprocess(self, requests): decode.append(req_id) prefill_encoded = None - if len(prefill_input_text) > 0: + if len(prefill_req_ids) > 0: prefill_encoded = self._run_tokenizer_batch( - self, prefill_input_text, prefill + prefill_input_text, prefill_req_ids ) return prefill_req_ids, prefill_encoded, decode @@ -143,10 +144,14 @@ def inference(self, input_batch): list: A list of strings with the predicted values for each input text in the batch. """ - prefill, prefill_encoded, decode_ids = input_batch + prefill_req_ids, prefill_encoded, decode_ids = input_batch # Prefill requests - results = self._run_prefill_batch(prefill, prefill_encoded) + results = ( + self._run_prefill_batch(prefill_req_ids, prefill_encoded) + if prefill_req_ids + else {} + ) # Decode the rest decode_result = self._run_decode(decode_ids) if decode_ids else {} @@ -178,9 +183,12 @@ def _run_tokenizer_batch(self, prefill_input_text, prefill_req_ids): batch_encoded = self.tokenizer( prefill_input_text, + # max_length=self.max_length, return_tensors="pt", padding=True, + add_special_tokens=True, return_token_type_ids=False, + truncation=True, ) for idx, req_id in enumerate(prefill_req_ids): encoded = { @@ -193,22 +201,23 @@ def _run_tokenizer_batch(self, prefill_input_text, prefill_req_ids): req_id, max_new_tokens=self.max_new_tokens ), "encoded": encoded, - "prompt_length": len(encoded["input_ids"][idx]), + "prompt_length": encoded["input_ids"].shape[0], } return batch_encoded @torch.no_grad() def _run_prefill_batch(self, prefill_req_ids, prefill_encoded): - outputs = self.model.generata( - input_ids=prefill_encoded["input_ids"], - attention_mask=prefill_encoded["attention_mask"], + outputs = self.model.generate( + **prefill_encoded, max_new_tokens=1, return_dict_in_generate=True, use_cache=True, ) - outputs_decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + outputs_decoded = self.tokenizer.batch_decode( + outputs.sequences, skip_special_tokens=True + ) # Prefill requests results = {} @@ -218,13 +227,12 @@ def _run_prefill_batch(self, prefill_req_ids, prefill_encoded): "past_key_values" ] = self._collect_kv_cache_of_idx_in_batch(idx) self.context.cache[req_id]["encoded"]["input_ids"] = outputs.sequences[idx] - device = next(iter(self.model.parameters())).device dtype = torch.int64 config = {"device": device, "dtype": dtype} attention_mask = self.context.cache[req_id]["encoded"]["attention_mask"] attention_mask = torch.concat( - (attention_mask, torch.ones((1, 1), **config)), dim=1 + (attention_mask, torch.ones((1), **config)), dim=0 ) self.context.cache[req_id]["encoded"]["attention_mask"] = attention_mask @@ -249,15 +257,15 @@ def _run_decode(self, ids): results = {} for idx, req_id in enumerate(ids): + if req_id == "batch_padding": + continue self.context.cache[req_id]["encoded"][ "past_key_values" ] = self._collect_kv_cache_of_idx_in_batch(idx) - self.context.cache[req_id]["encoded"]["input_ids"] = outputs.sequences[ - idx - ].unsqueeze(0) - attention_mask = encoded["attention_mask"][idx].unsqueeze(0) + self.context.cache[req_id]["encoded"]["input_ids"] = outputs.sequences[idx] + attention_mask = encoded["attention_mask"][idx] attention_mask = torch.concat( - (attention_mask, torch.ones((1, 1), **config)), dim=1 + (attention_mask, torch.ones((1), **config)), dim=0 ) self.context.cache[req_id]["encoded"]["attention_mask"] = attention_mask results[req_id] = { @@ -271,7 +279,7 @@ def _run_decode(self, ids): def _prepare_model_inputs(self, ids): lengths = list( - torch.sum(self.context.cache[i]["encoded"]["attention_mask"], dim=1).item() + torch.sum(self.context.cache[i]["encoded"]["attention_mask"]).item() for i in ids ) max_len = max(lengths) @@ -309,9 +317,9 @@ def _prepare_model_inputs(self, ids): else: config = {"device": device, "dtype": dtype} input_ids.append( - self.tokenizer.pad_token_id + torch.zeros((1, max_len), **config) + self.tokenizer.pad_token_id + torch.zeros((max_len), **config) ) - attention_mask.append(torch.zeros((1, max_len), **config)) + attention_mask.append(torch.zeros((max_len), **config)) for layer_idx in range(len(kv_cache)): kv_cache[layer_idx][0] = kv_cache.get(layer_idx, {}).get( 0, [] @@ -326,13 +334,12 @@ def _prepare_model_inputs(self, ids): n = max_len - seq_len input_ids[-1] = torch.concat( ( - self.tokenizer.pad_token_id + torch.zeros((1, n), **config), + self.tokenizer.pad_token_id + torch.zeros((n), **config), input_ids[-1], - ), - dim=1, + ) ) attention_mask[-1] = torch.concat( - (torch.zeros((1, n), **config), attention_mask[-1]), dim=1 + (torch.zeros((n), **config), attention_mask[-1]) ) size_delta = list(kv_cache[0][0][-1].size()) @@ -377,8 +384,8 @@ def _prepare_model_inputs(self, ids): ) encoded = { - "input_ids": torch.concat(input_ids, dim=0), - "attention_mask": torch.concat(attention_mask, dim=0), + "input_ids": torch.stack(input_ids), + "attention_mask": torch.stack(attention_mask), "past_key_values": kv_cache, } return encoded diff --git a/ts/protocol/otf_message_handler.py b/ts/protocol/otf_message_handler.py index 91eb206f80..37d222e0e9 100644 --- a/ts/protocol/otf_message_handler.py +++ b/ts/protocol/otf_message_handler.py @@ -119,7 +119,10 @@ def create_predict_response( msg += struct.pack("!i", len(buf)) msg += buf else: - val = ret[idx] + if context.stopping_criteria: + val = ret[idx]["result"] + else: + val = ret[idx] # NOTE: Process bytes/bytearray case before processing the string case. if isinstance(val, (bytes, bytearray)): msg += struct.pack("!i", len(val)) From 5eed61e2cf5e6be3d99b416a22da81b8986b78c5 Mon Sep 17 00:00:00 2001 From: lxning Date: Tue, 21 Nov 2023 13:12:10 -0800 Subject: [PATCH 04/49] add space --- .../large_models/inferentia2/llama2/test_stream_response.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/large_models/inferentia2/llama2/test_stream_response.py b/examples/large_models/inferentia2/llama2/test_stream_response.py index 2c205dd3de..6f20249e30 100644 --- a/examples/large_models/inferentia2/llama2/test_stream_response.py +++ b/examples/large_models/inferentia2/llama2/test_stream_response.py @@ -9,6 +9,6 @@ for chunk in response.iter_content(chunk_size=None): if chunk: data = chunk.decode("utf-8") - print(data, end="", flush=True) + print(data, end=" ", flush=True) print("") From c81320aecde51ee4c179a815dde8a337a497e1c5 Mon Sep 17 00:00:00 2001 From: lxning Date: Tue, 21 Nov 2023 15:05:33 -0800 Subject: [PATCH 05/49] fmt --- .../large_models/inferentia2/llama2/Readme.md | 17 +++++++---------- .../llama2/continuous_batching/Readme.md | 8 ++++++++ .../inf2_handler.py} | 16 +++++++--------- .../continuous_batching/model-config.yaml | 14 ++++++++++++++ .../inferentia2/llama2/streamer/Readme.md | 9 +++++++++ .../llama2/{ => streamer}/inf2_handler.py | 0 .../llama2/{ => streamer}/model-config.yaml | 1 - 7 files changed, 45 insertions(+), 20 deletions(-) create mode 100644 examples/large_models/inferentia2/llama2/continuous_batching/Readme.md rename examples/large_models/inferentia2/llama2/{inf2_cb_handler.py => continuous_batching/inf2_handler.py} (97%) create mode 100644 examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml create mode 100644 examples/large_models/inferentia2/llama2/streamer/Readme.md rename examples/large_models/inferentia2/llama2/{ => streamer}/inf2_handler.py (100%) rename examples/large_models/inferentia2/llama2/{ => streamer}/model-config.yaml (93%) diff --git a/examples/large_models/inferentia2/llama2/Readme.md b/examples/large_models/inferentia2/llama2/Readme.md index f882688a5e..201ed37b39 100644 --- a/examples/large_models/inferentia2/llama2/Readme.md +++ b/examples/large_models/inferentia2/llama2/Readme.md @@ -1,16 +1,13 @@ # Large model inference on Inferentia2 -This document briefs on serving the [Llama 2](https://huggingface.co/meta-llama) model on [AWS Inferentia2](https://aws.amazon.com/ec2/instance-types/inf2/) for text completion with [micro batching](https://github.com/pytorch/serve/tree/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/examples/micro_batching) and [streaming response](https://github.com/pytorch/serve/blob/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/docs/inference_api.md#curl-example-1) support. +This document briefs on serving the [Llama 2](https://huggingface.co/meta-llama) model on [AWS Inferentia2](https://aws.amazon.com/ec2/instance-types/inf2/) for text completion with TorchServe's features: -Inferentia2 uses [Neuron SDK](https://aws.amazon.com/machine-learning/neuron/) which is built on top of PyTorch XLA stack. For large model inference [`transformers-neuronx`](https://github.com/aws-neuron/transformers-neuronx) package is used that takes care of model partitioning and running inference. - -**Note**: To run the model on an Inf2 instance, the model gets compiled as a preprocessing step. As part of the compilation process, to generate the model graph, a specific batch size is used. Following this, when running inference, we need to pass input which matches the batch size that was used during compilation. Model compilation and input padding to match compiled model batch size is taken care of by the [custom handler](inf2_handler.py) in this example. +* demo1: [micro batching](https://github.com/pytorch/serve/tree/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/examples/micro_batching) and [streaming response](https://github.com/pytorch/serve/blob/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/docs/inference_api.md#curl-example-1) support in folder streamer. +* demo2: continuous batching support in folder continuous_batching -The batch size and micro batch size configurations are present in [model-config.yaml](model-config.yaml). The batch size indicates the maximum number of requests torchserve will aggregate and send to the custom handler within the batch delay. -The batch size is chosen to be a relatively large value, say 16 since micro batching enables running the preprocess(tokenization) and inference steps in parallel on the micro batches. The micro batch size is the batch size used for the Inf2 model compilation. -Since compilation batch size can influence compile time and also constrained by the Inf2 instance type, this is chosen to be a relatively smaller value, say 4. +Inferentia2 uses [Neuron SDK](https://aws.amazon.com/machine-learning/neuron/) which is built on top of PyTorch XLA stack. For large model inference [`transformers-neuronx`](https://github.com/aws-neuron/transformers-neuronx) package is used that takes care of model partitioning and running inference. -This example also demonstrates the utilization of neuronx cache to store inf2 model compilation artifacts using the `NEURONX_CACHE` and `NEURONX_DUMP_TO` environment variables in the custom handler. +This example folder demonstrates the utilization of neuronx cache to store inf2 model compilation artifacts using the `NEURONX_CACHE` and `NEURON_COMPILE_CACHE_URL` environment variables in the custom handler. When the model is loaded for the first time, the model is compiled for the configured micro batch size and the compilation artifacts are saved to the neuronx cache. On subsequent model load, the compilation artifacts in the neuronx cache serves as `Ahead of Time(AOT)` compilation artifacts and significantly reduces the model load time. For convenience, the compiled model artifacts for this example are made available on the Torchserve model zoo: `s3://torchserve/mar_files/llama-2-13b-neuronx-b4`\ @@ -22,7 +19,7 @@ Get an Inf2 instance(Note: This example was tested on instance type:`inf2.24xlar DLAMI Name: ` Deep Learning AMI Neuron PyTorch 1.13 (Ubuntu 20.04) 20230720 Amazon Machine Image (AMI)` or higher. **Note**: The `inf2.24xlarge` instance consists of 6 neuron chips with 2 neuron cores each. The total accelerator memory is 192GB. -Based on the configuration used in [model-config.yaml](model-config.yaml), with `tp_degree` set to 6, 3 of the 6 neuron chips are used, i.e 6 neuron cores. +Based on the configuration used in [model-config.yaml](streamer/model-config.yaml), with `tp_degree` set to 6, 3 of the 6 neuron chips are used, i.e 6 neuron cores. On loading the model, the accelerator memory consumed is 38.1GB (12.7GB per chip). ### Step 2: Package Installations @@ -85,7 +82,7 @@ python ../util/inf2_save_split_checkpoints.py --model_name meta-llama/Llama-2-13 ### Step 4: Package model artifacts ```bash -torch-model-archiver --model-name llama-2-13b --version 1.0 --handler inf2_handler.py -r requirements.txt --config-file model-config.yaml --archive-format no-archive +torch-model-archiver --model-name llama-2-13b --version 1.0 --handler /PATH/TO/inf2_handler.py -r requirements.txt --config-file /PATH/TO/model-config.yaml --archive-format no-archive mv llama-2-13b-split llama-2-13b ``` diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md b/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md new file mode 100644 index 0000000000..0295e13a58 --- /dev/null +++ b/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md @@ -0,0 +1,8 @@ +# Demo2: Llama-2 Using TorchServe continuous batching on inf2 + +This document briefs on serving the [Llama 2](https://huggingface.co/meta-llama) model on [AWS Inferentia2](https://aws.amazon.com/ec2/instance-types/inf2/) for text completion with TorchServe [micro batching](https://github.com/pytorch/serve/tree/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/examples/micro_batching) and [streaming response](https://github.com/pytorch/serve/blob/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/docs/inference_api.md#curl-example-1) support. + +**Note**: To run the model on an Inf2 instance, the model gets compiled as a preprocessing step. As part of the compilation process, to generate the model graph, a specific batch size is used. Following this, when running inference, we need to pass input which matches the batch size that was used during compilation. Model compilation and input padding to match compiled model batch size is taken care of by the [custom handler](inf2_handler.py) in this example. + +The batch size [model-config.yaml](model-config.yaml). The batch size indicates the maximum number of requests torchserve will aggregate and send to the custom handler within the batch delay. It is the batch size used for the Inf2 model compilation. +Since compilation batch size can influence compile time and also constrained by the Inf2 instance type, this is chosen to be a relatively smaller value, say 4. diff --git a/examples/large_models/inferentia2/llama2/inf2_cb_handler.py b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py similarity index 97% rename from examples/large_models/inferentia2/llama2/inf2_cb_handler.py rename to examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py index 600850f703..0b3d51a997 100644 --- a/examples/large_models/inferentia2/llama2/inf2_cb_handler.py +++ b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py @@ -25,7 +25,7 @@ def __init__(self): self.max_length = None self.max_new_tokens = None self.tokenizer = None - self.micro_batch_size = 1 + self.batch_size = 1 self.initialized = False def initialize(self, ctx: Context): @@ -49,8 +49,8 @@ def initialize(self, ctx: Context): tp_degree = ctx.model_yaml_config.get("handler", {}).get("tp_degree", 6) self.max_length = int(ctx.model_yaml_config["handler"]["max_length"]) self.max_new_tokens = int(ctx.model_yaml_config["handler"]["max_new_tokens"]) - self.micro_batch_size = int( - ctx.model_yaml_config.get("micro_batching", {}).get("micro_batch_size", 1) + self.batch_size = int( + ctx.model_yaml_config.get("handler", {}).get("batch_size", 1) ) # allocate "tp_degree" number of neuron cores to the worker process @@ -75,7 +75,7 @@ def initialize(self, ctx: Context): self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.model = LlamaForSampling.from_pretrained( model_checkpoint_path, - batch_size=self.micro_batch_size, + batch_size=self.batch_size, amp=amp, tp_degree=tp_degree, ) @@ -174,10 +174,8 @@ def postprocess(self, inference_output): def _run_tokenizer_batch(self, prefill_input_text, prefill_req_ids): # Pad input to match compiled model batch size - if self.micro_batch_size > len(prefill_req_ids): - prefill_input_text.extend( - [""] * (self.micro_batch_size - len(prefill_req_ids)) - ) + if self.batch_size > len(prefill_req_ids): + prefill_input_text.extend([""] * (self.batch_size - len(prefill_req_ids))) else: return None @@ -284,7 +282,7 @@ def _prepare_model_inputs(self, ids): ) max_len = max(lengths) - for idx in range(self.micro_batch_size - len(ids)): + for idx in range(self.batch_size - len(ids)): ids.append("batch_padding") lengths.append(0) diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml b/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml new file mode 100644 index 0000000000..8078cc122c --- /dev/null +++ b/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml @@ -0,0 +1,14 @@ +minWorkers: 1 +maxWorkers: 1 +maxBatchDelay: 100 +responseTimeout: 10800 +batchSize: 4 +continuousBatching: true + +handler: + model_checkpoint_dir: "llama-2-13b-split" + amp: "bf16" + tp_degree: 6 + max_length: 100 + max_new_tokens: 25 + batch_size: 4 diff --git a/examples/large_models/inferentia2/llama2/streamer/Readme.md b/examples/large_models/inferentia2/llama2/streamer/Readme.md new file mode 100644 index 0000000000..684b418e8b --- /dev/null +++ b/examples/large_models/inferentia2/llama2/streamer/Readme.md @@ -0,0 +1,9 @@ +# Demo1: Llama-2 Using TorchServe micro-batching and Streamer on inf2 + +This document briefs on serving the [Llama 2](https://huggingface.co/meta-llama) model on [AWS Inferentia2](https://aws.amazon.com/ec2/instance-types/inf2/) for text completion with TorchServe [micro batching](https://github.com/pytorch/serve/tree/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/examples/micro_batching) and [streaming response](https://github.com/pytorch/serve/blob/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/docs/inference_api.md#curl-example-1) support. + +**Note**: To run the model on an Inf2 instance, the model gets compiled as a preprocessing step. As part of the compilation process, to generate the model graph, a specific batch size is used. Following this, when running inference, we need to pass input which matches the batch size that was used during compilation. Model compilation and input padding to match compiled model batch size is taken care of by the [custom handler](inf2_handler.py) in this example. + +The batch size and micro batch size configurations are present in [model-config.yaml](model-config.yaml). The batch size indicates the maximum number of requests torchserve will aggregate and send to the custom handler within the batch delay. +The batch size is chosen to be a relatively large value, say 16 since micro batching enables running the preprocess(tokenization) and inference steps in parallel on the micro batches. The micro batch size is the batch size used for the Inf2 model compilation. +Since compilation batch size can influence compile time and also constrained by the Inf2 instance type, this is chosen to be a relatively smaller value, say 4. diff --git a/examples/large_models/inferentia2/llama2/inf2_handler.py b/examples/large_models/inferentia2/llama2/streamer/inf2_handler.py similarity index 100% rename from examples/large_models/inferentia2/llama2/inf2_handler.py rename to examples/large_models/inferentia2/llama2/streamer/inf2_handler.py diff --git a/examples/large_models/inferentia2/llama2/model-config.yaml b/examples/large_models/inferentia2/llama2/streamer/model-config.yaml similarity index 93% rename from examples/large_models/inferentia2/llama2/model-config.yaml rename to examples/large_models/inferentia2/llama2/streamer/model-config.yaml index 6913af49f7..031a409903 100644 --- a/examples/large_models/inferentia2/llama2/model-config.yaml +++ b/examples/large_models/inferentia2/llama2/streamer/model-config.yaml @@ -9,7 +9,6 @@ handler: amp: "bf16" tp_degree: 6 max_length: 100 - max_new_tokens: 25 micro_batching: micro_batch_size: 4 From 426e930087f41f3b8444c3d3808c42eed0822a13 Mon Sep 17 00:00:00 2001 From: lxning Date: Tue, 21 Nov 2023 15:09:40 -0800 Subject: [PATCH 06/49] fmt --- .../inferentia2/llama2/continuous_batching/inf2_handler.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py index 0b3d51a997..b15a4e4ba2 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py +++ b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py @@ -181,7 +181,6 @@ def _run_tokenizer_batch(self, prefill_input_text, prefill_req_ids): batch_encoded = self.tokenizer( prefill_input_text, - # max_length=self.max_length, return_tensors="pt", padding=True, add_special_tokens=True, @@ -235,7 +234,7 @@ def _run_prefill_batch(self, prefill_req_ids, prefill_encoded): self.context.cache[req_id]["encoded"]["attention_mask"] = attention_mask results[req_id] = { - "text": outputs_decoded[idx], + "result": outputs_decoded[idx], "ids": outputs.sequences[idx].tolist(), } @@ -267,7 +266,7 @@ def _run_decode(self, ids): ) self.context.cache[req_id]["encoded"]["attention_mask"] = attention_mask results[req_id] = { - "text": self.tokenizer.decode( + "result": self.tokenizer.decode( outputs.sequences[idx][-1], skip_special_tokens=True ), "ids": [outputs.sequences[idx][-1].item()], From eb9881615450b16e1ff7c648c67bafc049ea6b37 Mon Sep 17 00:00:00 2001 From: lxning Date: Tue, 21 Nov 2023 16:27:36 -0800 Subject: [PATCH 07/49] fmt --- .../inferentia2/llama2/continuous_batching/inf2_handler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py index b15a4e4ba2..52a109b7ff 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py +++ b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py @@ -176,8 +176,6 @@ def _run_tokenizer_batch(self, prefill_input_text, prefill_req_ids): # Pad input to match compiled model batch size if self.batch_size > len(prefill_req_ids): prefill_input_text.extend([""] * (self.batch_size - len(prefill_req_ids))) - else: - return None batch_encoded = self.tokenizer( prefill_input_text, From 8d1251f1e85462cc99597cce3e782fec50c5a269 Mon Sep 17 00:00:00 2001 From: lxning Date: Tue, 21 Nov 2023 23:30:26 -0800 Subject: [PATCH 08/49] fmt --- .../llama2/continuous_batching/Dockerfile | 10 + .../llama2/continuous_batching/Readme.md | 2 + .../inf2-llama-2-continuous-batching.ipynb | 179 ++++++++++++++++++ .../continuous_batching/model-config.yaml | 2 +- .../llama2/continuous_batching/test.sh | 5 + examples/large_models/utils/Download_model.py | 9 +- 6 files changed, 205 insertions(+), 2 deletions(-) create mode 100644 examples/large_models/inferentia2/llama2/continuous_batching/Dockerfile create mode 100644 examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb create mode 100755 examples/large_models/inferentia2/llama2/continuous_batching/test.sh diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/Dockerfile b/examples/large_models/inferentia2/llama2/continuous_batching/Dockerfile new file mode 100644 index 0000000000..2d08cc08cd --- /dev/null +++ b/examples/large_models/inferentia2/llama2/continuous_batching/Dockerfile @@ -0,0 +1,10 @@ +FROM 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference-neuronx:1.13.1-neuronx-py310-sdk2.15.0-ubuntu20.04 +# workaround transformers 4.35 failure of downloading HF model bin files +RUN pip install transformers==4.34 +WORKDIR /home/model-server +RUN git clone https://github.com/pytorch/serve.git \ + && cd serve +WORKDIR /home/model-server/serve +RUN git checkout feat/inf2_cb +RUN pip install pygit2 +RUN python ts_scripts/install_from_src.py diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md b/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md index 0295e13a58..3aa764c0b5 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md +++ b/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md @@ -6,3 +6,5 @@ This document briefs on serving the [Llama 2](https://huggingface.co/meta-llama) The batch size [model-config.yaml](model-config.yaml). The batch size indicates the maximum number of requests torchserve will aggregate and send to the custom handler within the batch delay. It is the batch size used for the Inf2 model compilation. Since compilation batch size can influence compile time and also constrained by the Inf2 instance type, this is chosen to be a relatively smaller value, say 4. + +`inf2-llama-2-continuous-batching.ipynb` is the notebook example. diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb b/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb new file mode 100644 index 0000000000..756bd33d76 --- /dev/null +++ b/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb @@ -0,0 +1,179 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "## TorchServe Continuous Batching Serve Llama-2 on Inferentia-2\n", + "This notebook demonstrates TorchServe continuous batching serving Llama-2-13b on Inferentia-2 `inf2.24xlarge`." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "### Build a customized docker container to install the code changes from this [PR](https://github.com/pytorch/serve/pull/2803).\n", + "This section can be skipped once [Neuron DLC](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#neuron-containers) release TorchServe latest version." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "!aws ecr get-login-password --region us-east-1 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-east-1.amazonaws.com\n", + "!docker pull 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference-neuronx:1.13.1-neuronx-py310-sdk2.15.0-ubuntu20.04" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "!cat Dockerfile" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2023-11-21T22:05:30.551799Z", + "end_time": "2023-11-21T22:05:30.698105Z" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "!docker build -t neuron-sdk-215:torchserve-cb ." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "### Serialize Llama-2-13b-hf model, save checkpoints and precompile the model with batch size 4 inside the docker container" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# Enter into docker container\n", + "!mkdir model_store\n", + "\n", + "!docker exec -it -v model_store/:/home/model-server/model_store --device /dev/neuron0:/dev/neuron0 --device /dev/neuron1:/dev/neuron1 --device /dev/neuron2:/dev/neuron2 --device /dev/neuron3:/dev/neuron3 --device /dev/neuron4:/dev/neuron4 --device /dev/neuron5:/dev/neuron5 neuron-sdk-215:torchserve-cb bash" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# Inside the container\n", + "\n", + "# login in Hugginface hub\n", + "huggingface-cli login --token $HUGGINGFACE_TOKEN\n", + "\n", + "# Save checkpoints\n", + "cd /home/model-server/serve/examples/large_models/inferentia2/llama2/continuous_batching\n", + "python ../../util/inf2_save_split_checkpoints.py --model_name meta-llama/Llama-2-13b-hf --save_path './llama-2-13b-split'\n", + "\n", + "# Create TorchServe model artifacts\n", + "torch-model-archiver --model-name llama-2-13b --version 1.0 --handler inf2_handler.py -r ../requirements.txt --config-file model-config.yaml --archive-format no-archive\n", + "mv llama-2-13b-split llama-2-13b\n", + "mv llama-2-13b /home/model-server/model_store\n", + "\n", + "# Precompile complete once the log \"Model llama-2-13b loaded successfully\"\n", + "torchserve --ncs --start --model-store /home/model-server/model_store --models llama-2-13b --ts-config ../config.properties\n", + "\n", + "# Exit the container" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "### Run inference" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# Start the container\n", + "!docker run -it -v model_store:/opt/ml/model --device /dev/neuron0:/dev/neuron0 --device /dev/neuron1:/dev/neuron1 --device /dev/neuron2:/dev/neuron2 --device /dev/neuron3:/dev/neuron3 --device /dev/neuron4:/dev/neuron4 --device /dev/neuron5:/dev/neuron5 -p 8080:8080 -p 8081:8081 -p 8082:8082 neuron-sdk-215:torchserve-cb" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# Run single inference request\n", + "!python test_stream_response.py" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "# Run multiple inference requests concurrently\n", + "!./tesh.sh" + ], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml b/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml index 8078cc122c..18a3ab92f6 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml +++ b/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml @@ -8,7 +8,7 @@ continuousBatching: true handler: model_checkpoint_dir: "llama-2-13b-split" amp: "bf16" - tp_degree: 6 + tp_degree: 12 max_length: 100 max_new_tokens: 25 batch_size: 4 diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/test.sh b/examples/large_models/inferentia2/llama2/continuous_batching/test.sh new file mode 100755 index 0000000000..14de9eb993 --- /dev/null +++ b/examples/large_models/inferentia2/llama2/continuous_batching/test.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +for i in {1..64}; do + python ../test_stream_response.py > t_${i} & +done diff --git a/examples/large_models/utils/Download_model.py b/examples/large_models/utils/Download_model.py index 2e8f6c9579..d67a0526e9 100644 --- a/examples/large_models/utils/Download_model.py +++ b/examples/large_models/utils/Download_model.py @@ -39,6 +39,13 @@ def hf_model(model_str): parser.add_argument( "--model_name", "-m", type=hf_model, required=True, help="HuggingFace model name" ) +parser.add_argument( + "--use_auth_token", + "-", + type=bool, + default=False, + help="Use HF authentication token", +) parser.add_argument("--revision", "-r", type=str, default="main", help="Revision") args = parser.parse_args() # Only download pytorch checkpoint files @@ -49,6 +56,6 @@ def hf_model(model_str): revision=args.revision, allow_patterns=allow_patterns, cache_dir=args.model_path, - use_auth_token=False, + use_auth_token=args.use_auth_token, ) print(f"Files for '{args.model_name}' is downloaded to '{snapshot_path}'") From 81c453259d71a33f78acdf6afe3a7b2331fbc9a9 Mon Sep 17 00:00:00 2001 From: lxning Date: Wed, 22 Nov 2023 12:45:49 -0800 Subject: [PATCH 09/49] fix regression test --- ts/protocol/otf_message_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ts/protocol/otf_message_handler.py b/ts/protocol/otf_message_handler.py index 37d222e0e9..ed09ed3057 100644 --- a/ts/protocol/otf_message_handler.py +++ b/ts/protocol/otf_message_handler.py @@ -119,7 +119,7 @@ def create_predict_response( msg += struct.pack("!i", len(buf)) msg += buf else: - if context.stopping_criteria: + if context.stopping_criteria and ret[idx]["result"]: val = ret[idx]["result"] else: val = ret[idx] From 9f2e450a0de25bcf1b55cdfa766a323db016ae91 Mon Sep 17 00:00:00 2001 From: lxning Date: Wed, 22 Nov 2023 16:22:33 -0800 Subject: [PATCH 10/49] check key result --- ts/protocol/otf_message_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ts/protocol/otf_message_handler.py b/ts/protocol/otf_message_handler.py index ed09ed3057..9439eecf30 100644 --- a/ts/protocol/otf_message_handler.py +++ b/ts/protocol/otf_message_handler.py @@ -119,7 +119,7 @@ def create_predict_response( msg += struct.pack("!i", len(buf)) msg += buf else: - if context.stopping_criteria and ret[idx]["result"]: + if context.stopping_criteria and "result" in ret[idx]: val = ret[idx]["result"] else: val = ret[idx] From 687a1f574ccc73077fef0228db6fa5bb4e3b9ddf Mon Sep 17 00:00:00 2001 From: lxning Date: Wed, 22 Nov 2023 17:28:09 -0800 Subject: [PATCH 11/49] fmt --- ts/protocol/otf_message_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ts/protocol/otf_message_handler.py b/ts/protocol/otf_message_handler.py index 9439eecf30..b8dc61a414 100644 --- a/ts/protocol/otf_message_handler.py +++ b/ts/protocol/otf_message_handler.py @@ -119,7 +119,7 @@ def create_predict_response( msg += struct.pack("!i", len(buf)) msg += buf else: - if context.stopping_criteria and "result" in ret[idx]: + if context and context.stopping_criteria and "result" in ret[idx]: val = ret[idx]["result"] else: val = ret[idx] From 632e89639bdcb8aa3b4ccf34adedbd8e31b22012 Mon Sep 17 00:00:00 2001 From: lxning Date: Wed, 22 Nov 2023 18:24:53 -0800 Subject: [PATCH 12/49] update folder --- .../inf2-llama-2-continuous-batching.ipynb | 4 ++-- .../inferentia2/llama2/continuous_batching/requirements.txt | 1 + examples/large_models/utils/Download_model.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 examples/large_models/inferentia2/llama2/continuous_batching/requirements.txt diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb b/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb index 756bd33d76..6b199395ea 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb +++ b/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb @@ -75,7 +75,7 @@ "# Enter into docker container\n", "!mkdir model_store\n", "\n", - "!docker exec -it -v model_store/:/home/model-server/model_store --device /dev/neuron0:/dev/neuron0 --device /dev/neuron1:/dev/neuron1 --device /dev/neuron2:/dev/neuron2 --device /dev/neuron3:/dev/neuron3 --device /dev/neuron4:/dev/neuron4 --device /dev/neuron5:/dev/neuron5 neuron-sdk-215:torchserve-cb bash" + "!docker run -it -v model_store:/home/model-server/model_store --device /dev/neuron0:/dev/neuron0 --device /dev/neuron1:/dev/neuron1 --device /dev/neuron2:/dev/neuron2 --device /dev/neuron3:/dev/neuron3 --device /dev/neuron4:/dev/neuron4 --device /dev/neuron5:/dev/neuron5 neuron-sdk-215:torchserve-cb bash" ], "metadata": { "collapsed": false @@ -96,7 +96,7 @@ "python ../../util/inf2_save_split_checkpoints.py --model_name meta-llama/Llama-2-13b-hf --save_path './llama-2-13b-split'\n", "\n", "# Create TorchServe model artifacts\n", - "torch-model-archiver --model-name llama-2-13b --version 1.0 --handler inf2_handler.py -r ../requirements.txt --config-file model-config.yaml --archive-format no-archive\n", + "torch-model-archiver --model-name llama-2-13b --version 1.0 --handler inf2_handler.py -r requirements.txt --config-file model-config.yaml --archive-format no-archive\n", "mv llama-2-13b-split llama-2-13b\n", "mv llama-2-13b /home/model-server/model_store\n", "\n", diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/requirements.txt b/examples/large_models/inferentia2/llama2/continuous_batching/requirements.txt new file mode 100644 index 0000000000..080a0d3acc --- /dev/null +++ b/examples/large_models/inferentia2/llama2/continuous_batching/requirements.txt @@ -0,0 +1 @@ +sentencepiece diff --git a/examples/large_models/utils/Download_model.py b/examples/large_models/utils/Download_model.py index d67a0526e9..3b3367a2e7 100644 --- a/examples/large_models/utils/Download_model.py +++ b/examples/large_models/utils/Download_model.py @@ -41,7 +41,7 @@ def hf_model(model_str): ) parser.add_argument( "--use_auth_token", - "-", + "-t", type=bool, default=False, help="Use HF authentication token", From f6f6df169b1d8fd5225eba885088c7a995c266b5 Mon Sep 17 00:00:00 2001 From: lxning Date: Sat, 25 Nov 2023 11:47:52 -0800 Subject: [PATCH 13/49] fmt --- .../inferentia2/llama2/continuous_batching/inf2_handler.py | 4 ++-- .../large_models/inferentia2/llama2/streamer/inf2_handler.py | 4 +++- .../large_models/inferentia2/llama2/test_stream_response.py | 4 +++- ts/protocol/otf_message_handler.py | 5 +---- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py index 52a109b7ff..b60cc9824f 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py +++ b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py @@ -232,7 +232,7 @@ def _run_prefill_batch(self, prefill_req_ids, prefill_encoded): self.context.cache[req_id]["encoded"]["attention_mask"] = attention_mask results[req_id] = { - "result": outputs_decoded[idx], + "text": outputs_decoded[idx], "ids": outputs.sequences[idx].tolist(), } @@ -264,7 +264,7 @@ def _run_decode(self, ids): ) self.context.cache[req_id]["encoded"]["attention_mask"] = attention_mask results[req_id] = { - "result": self.tokenizer.decode( + "text": self.tokenizer.decode( outputs.sequences[idx][-1], skip_special_tokens=True ), "ids": [outputs.sequences[idx][-1].item()], diff --git a/examples/large_models/inferentia2/llama2/streamer/inf2_handler.py b/examples/large_models/inferentia2/llama2/streamer/inf2_handler.py index f4e3e43a37..a56db7ca9c 100644 --- a/examples/large_models/inferentia2/llama2/streamer/inf2_handler.py +++ b/examples/large_models/inferentia2/llama2/streamer/inf2_handler.py @@ -136,7 +136,9 @@ def inference(self, tokenized_input): for new_text in self.output_streamer: logger.debug("send response stream") send_intermediate_predict_response( - new_text[: len(micro_batch_req_id_map)], + { + "text": new_text[: len(micro_batch_req_id_map)], + }, micro_batch_req_id_map, "Intermediate Prediction success", 200, diff --git a/examples/large_models/inferentia2/llama2/test_stream_response.py b/examples/large_models/inferentia2/llama2/test_stream_response.py index 6f20249e30..018841320c 100644 --- a/examples/large_models/inferentia2/llama2/test_stream_response.py +++ b/examples/large_models/inferentia2/llama2/test_stream_response.py @@ -1,3 +1,4 @@ +import orjson import requests response = requests.post( @@ -9,6 +10,7 @@ for chunk in response.iter_content(chunk_size=None): if chunk: data = chunk.decode("utf-8") - print(data, end=" ", flush=True) + data = orjson.loads(data) + print(data["text"], end=" ", flush=True) print("") diff --git a/ts/protocol/otf_message_handler.py b/ts/protocol/otf_message_handler.py index b8dc61a414..91eb206f80 100644 --- a/ts/protocol/otf_message_handler.py +++ b/ts/protocol/otf_message_handler.py @@ -119,10 +119,7 @@ def create_predict_response( msg += struct.pack("!i", len(buf)) msg += buf else: - if context and context.stopping_criteria and "result" in ret[idx]: - val = ret[idx]["result"] - else: - val = ret[idx] + val = ret[idx] # NOTE: Process bytes/bytearray case before processing the string case. if isinstance(val, (bytes, bytearray)): msg += struct.pack("!i", len(val)) From 31446cfedb97441c892fb42c5a1d7d6e2154e7cc Mon Sep 17 00:00:00 2001 From: lxning Date: Sat, 25 Nov 2023 20:27:30 -0800 Subject: [PATCH 14/49] update key name --- .../inferentia2/llama2/continuous_batching/inf2_handler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py index b60cc9824f..ae0b86fbad 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py +++ b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py @@ -233,7 +233,7 @@ def _run_prefill_batch(self, prefill_req_ids, prefill_encoded): results[req_id] = { "text": outputs_decoded[idx], - "ids": outputs.sequences[idx].tolist(), + "tokens": outputs.sequences[idx].tolist(), } del self.context.kv_cache["past_key_values"] @@ -267,7 +267,7 @@ def _run_decode(self, ids): "text": self.tokenizer.decode( outputs.sequences[idx][-1], skip_special_tokens=True ), - "ids": [outputs.sequences[idx][-1].item()], + "tokens": [outputs.sequences[idx][-1].item()], } del self.context.kv_cache["past_key_values"] return results @@ -409,7 +409,7 @@ def __init__( def __call__(self, res): self.max_new_tokens -= 1 - if self.max_new_tokens == 0 or res["ids"][-1] == self.stop_token: + if self.max_new_tokens == 0 or res["tokens"][-1] == self.stop_token: self.clean_up() return True return False From 60f8a4c660e8b6cc829a38481385846d222e4f42 Mon Sep 17 00:00:00 2001 From: lxning Date: Sat, 25 Nov 2023 21:00:46 -0800 Subject: [PATCH 15/49] add orjson --- .../inferentia2/llama2/continuous_batching/Dockerfile | 1 + requirements/developer.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/Dockerfile b/examples/large_models/inferentia2/llama2/continuous_batching/Dockerfile index 2d08cc08cd..31c7648eec 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/Dockerfile +++ b/examples/large_models/inferentia2/llama2/continuous_batching/Dockerfile @@ -8,3 +8,4 @@ WORKDIR /home/model-server/serve RUN git checkout feat/inf2_cb RUN pip install pygit2 RUN python ts_scripts/install_from_src.py +ENV TS_INSTALL_PY_DEP_PER_MODEL true diff --git a/requirements/developer.txt b/requirements/developer.txt index 7b8fa63e31..cfb141037c 100644 --- a/requirements/developer.txt +++ b/requirements/developer.txt @@ -18,3 +18,4 @@ intel_extension_for_pytorch==2.1.0; sys_platform != 'win32' and sys_platform != onnxruntime==1.15.0 googleapis-common-protos onnx==1.14.1 +orjson From 7cee16737b405ef1868ba25595a8f3973f24b83e Mon Sep 17 00:00:00 2001 From: lxning Date: Sun, 26 Nov 2023 19:47:16 -0800 Subject: [PATCH 16/49] update streamer --- .../large_models/inferentia2/llama2/streamer/inf2_handler.py | 4 +--- ts/handler_utils/hf_batch_streamer.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/large_models/inferentia2/llama2/streamer/inf2_handler.py b/examples/large_models/inferentia2/llama2/streamer/inf2_handler.py index a56db7ca9c..f4e3e43a37 100644 --- a/examples/large_models/inferentia2/llama2/streamer/inf2_handler.py +++ b/examples/large_models/inferentia2/llama2/streamer/inf2_handler.py @@ -136,9 +136,7 @@ def inference(self, tokenized_input): for new_text in self.output_streamer: logger.debug("send response stream") send_intermediate_predict_response( - { - "text": new_text[: len(micro_batch_req_id_map)], - }, + new_text[: len(micro_batch_req_id_map)], micro_batch_req_id_map, "Intermediate Prediction success", 200, diff --git a/ts/handler_utils/hf_batch_streamer.py b/ts/handler_utils/hf_batch_streamer.py index 5f89dbcef0..2f8c029ce5 100644 --- a/ts/handler_utils/hf_batch_streamer.py +++ b/ts/handler_utils/hf_batch_streamer.py @@ -27,7 +27,7 @@ def put(self, value): ) for index in range(self.batch_size): - self.streamers[index].put(value[index : index + 1]) + self.streamers[index].put({"text": value[index : index + 1]}) def end(self): for streamer in self.streamers: From 540115df1d46b7c6116b6bdcb37eeb5354dda6c2 Mon Sep 17 00:00:00 2001 From: lxning Date: Sun, 26 Nov 2023 20:12:35 -0800 Subject: [PATCH 17/49] add key text for streamer iterator --- ts/handler_utils/hf_batch_streamer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ts/handler_utils/hf_batch_streamer.py b/ts/handler_utils/hf_batch_streamer.py index 2f8c029ce5..aa76ebf3c9 100644 --- a/ts/handler_utils/hf_batch_streamer.py +++ b/ts/handler_utils/hf_batch_streamer.py @@ -27,7 +27,7 @@ def put(self, value): ) for index in range(self.batch_size): - self.streamers[index].put({"text": value[index : index + 1]}) + self.streamers[index].put(value[index : index + 1]) def end(self): for streamer in self.streamers: @@ -40,7 +40,7 @@ def __next__(self): values = [] for iterator in self.streamer_iterators: try: - values.append(next(iterator)) + values.append({"text": next(iterator)}) except StopIteration: values.append(None) From 63f42b55c93a2acf8e566ae33939681591eb4acf Mon Sep 17 00:00:00 2001 From: lxning Date: Tue, 28 Nov 2023 10:18:28 -0800 Subject: [PATCH 18/49] update test_hf_batch_streamer output --- ts/tests/unit_tests/test_hf_batch_streamer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ts/tests/unit_tests/test_hf_batch_streamer.py b/ts/tests/unit_tests/test_hf_batch_streamer.py index 199bc3e0ed..e0d3b69885 100644 --- a/ts/tests/unit_tests/test_hf_batch_streamer.py +++ b/ts/tests/unit_tests/test_hf_batch_streamer.py @@ -23,8 +23,8 @@ def test_hf_batch_streamer(): for data in streamer: assert len(data) == 2 - output1 += data[0] - output2 += data[1] + output1 += data[0]["text"] + output2 += data[1]["text"] assert output1 == input1 assert output2 == input2 From 42d4719ce03f65e48e6e2f500af99ba9f6115ee7 Mon Sep 17 00:00:00 2001 From: lxning Date: Fri, 1 Dec 2023 19:35:33 -0800 Subject: [PATCH 19/49] integrate split checkpoint in handler --- .../continuous_batching/inf2_handler.py | 35 ++++++++++++++++--- .../continuous_batching/model-config.yaml | 1 + 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py index ae0b86fbad..4a7dfda3c5 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py +++ b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py @@ -5,9 +5,10 @@ import torch import torch_neuronx -from transformers import AutoConfig, LlamaTokenizer +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer from transformers_neuronx.generation_utils import HuggingFaceGenerationModelAdapter from transformers_neuronx.llama.model import LlamaForSampling +from transformers_neuronx.module import save_pretrained_split from ts.context import Context from ts.torch_handler.base_handler import BaseHandler @@ -40,13 +41,29 @@ def initialize(self, ctx: Context): "model_checkpoint_dir", "" ) model_checkpoint_path = f"{model_dir}/{model_checkpoint_dir}" + model_path = f'{model_dir}/{ctx.model_yaml_config["handler"]["model_path"]}' + + if not os.path.exists(model_checkpoint_path): + # Load and save the CPU model + model_cpu = AutoModelForCausalLM.from_pretrained( + model_path, low_cpu_mem_usage=True + ) + save_pretrained_split(model_cpu, model_checkpoint_path) + # Load and save tokenizer for the model + tokenizer = AutoTokenizer.from_pretrained(model_path, return_tensors="pt") + tokenizer.save_pretrained(model_checkpoint_path) + os.environ["NEURONX_CACHE"] = "on" os.environ["NEURON_COMPILE_CACHE_URL"] = f"{model_dir}/neuron_cache" - os.environ["NEURON_CC_FLAGS"] = "--model-type=transformer-inference" + os.environ["NEURON_CC_FLAGS"] = "-O1 --model-type=transformer" # settings for model compiliation and loading amp = ctx.model_yaml_config.get("handler", {}).get("amp", "fp32") tp_degree = ctx.model_yaml_config.get("handler", {}).get("tp_degree", 6) + context_length_estimate = ctx.model_yaml_config.get("handler", {}).get( + "context_length_estimate", None + ) + self.max_length = int(ctx.model_yaml_config["handler"]["max_length"]) self.max_new_tokens = int(ctx.model_yaml_config["handler"]["max_new_tokens"]) self.batch_size = int( @@ -73,11 +90,14 @@ def initialize(self, ctx: Context): self.tokenizer = LlamaTokenizer.from_pretrained(model_checkpoint_path) self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + self.tokenizer.padding_side = "left" self.model = LlamaForSampling.from_pretrained( model_checkpoint_path, batch_size=self.batch_size, amp=amp, tp_degree=tp_degree, + n_positions=self.max_length if self.max_length > 2048 else 2048, + context_length_estimate=context_length_estimate, ) logger.info("Starting to compile the model") self.model.to_neuron() @@ -112,6 +132,7 @@ def preprocess(self, requests): attention masks. """ self._clean_cache() + logger.info(f"requests size={len(requests)}") prefill_req_ids, decode, prefill_input_text = [], [], [] for req_id, req_data in zip(self.context.request_ids.values(), requests): @@ -302,6 +323,7 @@ def _prepare_model_inputs(self, ids): self.context.cache[req_id]["encoded"]["past_key_values"] ): k, v = layer_kv + logger.info(f"layer_idx={layer_idx}, past_key_values, k={k}, v={v}") kv_cache[layer_idx] = kv_cache.get(layer_idx, {}) kv_cache[layer_idx][0] = kv_cache.get(layer_idx, {}).get(0, []) + [ k @@ -324,6 +346,7 @@ def _prepare_model_inputs(self, ids): ) + torch.zeros((max_len), **config) padded_len = input_ids[-1].size()[-1] + logger.info(f"req_id={req_id}, padded_len={padded_len}, max_len={max_len}") if padded_len < max_len: # Apply padding to input_ids, attention_mask and past_key_values n = max_len - seq_len @@ -336,7 +359,7 @@ def _prepare_model_inputs(self, ids): attention_mask[-1] = torch.concat( (torch.zeros((n), **config), attention_mask[-1]) ) - + continue size_delta = list(kv_cache[0][0][-1].size()) size_delta[2] = n for layer_idx in range(len(kv_cache)): @@ -356,8 +379,10 @@ def _prepare_model_inputs(self, ids): ) elif padded_len > max_len: # Truncate padding from input_ids, attention_mask and past_key_values - input_ids[-1] = input_ids[-1][:, -max_len:] - attention_mask[-1] = attention_mask[-1][:, -max_len:] + logger.info(f"padded_len shape={input_ids[-1].size()}") + input_ids[-1] = input_ids[-1][-max_len:] + attention_mask[-1] = attention_mask[-1][-max_len:] + continue for layer_idx in range(len(kv_cache)): kv_cache[layer_idx][0][-1] = kv_cache[layer_idx][0][-1][ diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml b/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml index 18a3ab92f6..f384635529 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml +++ b/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml @@ -6,6 +6,7 @@ batchSize: 4 continuousBatching: true handler: + model_path: "" model_checkpoint_dir: "llama-2-13b-split" amp: "bf16" tp_degree: 12 From 5a5252e35686774261bd63a7d5ac646ab0e9995b Mon Sep 17 00:00:00 2001 From: lxning Date: Sat, 2 Dec 2023 17:57:56 -0800 Subject: [PATCH 20/49] fmt --- .../inf2-llama-2-continuous-batching.ipynb | 23 +-- .../continuous_batching/inf2_handler.py | 167 ++++++------------ .../continuous_batching/model-config.yaml | 8 +- 3 files changed, 62 insertions(+), 136 deletions(-) diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb b/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb index 6b199395ea..c262b58d57 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb +++ b/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb @@ -58,15 +58,6 @@ "collapsed": false } }, - { - "cell_type": "markdown", - "source": [ - "### Serialize Llama-2-13b-hf model, save checkpoints and precompile the model with batch size 4 inside the docker container" - ], - "metadata": { - "collapsed": false - } - }, { "cell_type": "code", "execution_count": null, @@ -86,19 +77,13 @@ "execution_count": null, "outputs": [], "source": [ - "# Inside the container\n", - "\n", "# login in Hugginface hub\n", - "huggingface-cli login --token $HUGGINGFACE_TOKEN\n", - "\n", - "# Save checkpoints\n", - "cd /home/model-server/serve/examples/large_models/inferentia2/llama2/continuous_batching\n", - "python ../../util/inf2_save_split_checkpoints.py --model_name meta-llama/Llama-2-13b-hf --save_path './llama-2-13b-split'\n", + "!huggingface-cli login --token $HUGGINGFACE_TOKEN\n", "\n", "# Create TorchServe model artifacts\n", - "torch-model-archiver --model-name llama-2-13b --version 1.0 --handler inf2_handler.py -r requirements.txt --config-file model-config.yaml --archive-format no-archive\n", - "mv llama-2-13b-split llama-2-13b\n", - "mv llama-2-13b /home/model-server/model_store\n", + "!torch-model-archiver --model-name llama-2-13b --version 1.0 --handler inf2_handler.py -r requirements.txt --config-file model-config.yaml --archive-format no-archive\n", + "!mkdir -p /home/model-server/model_store\n", + "!mv llama-2-13b /home/model-server/model_store\n", "\n", "# Precompile complete once the log \"Model llama-2-13b loaded successfully\"\n", "torchserve --ncs --start --model-store /home/model-server/model_store --models llama-2-13b --ts-config ../config.properties\n", diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py index 4a7dfda3c5..335b65bf2b 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py +++ b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py @@ -28,6 +28,7 @@ def __init__(self): self.tokenizer = None self.batch_size = 1 self.initialized = False + self.batch_ids = [] def initialize(self, ctx: Context): """In this initialize function, the HF large model is loaded and @@ -65,12 +66,15 @@ def initialize(self, ctx: Context): ) self.max_length = int(ctx.model_yaml_config["handler"]["max_length"]) - self.max_new_tokens = int(ctx.model_yaml_config["handler"]["max_new_tokens"]) + self.max_new_tokens = ctx.model_yaml_config["handler"]["max_new_tokens"] self.batch_size = int( ctx.model_yaml_config.get("handler", {}).get("batch_size", 1) ) - # allocate "tp_degree" number of neuron cores to the worker process + for i in range(self.batch_size): + self.batch_ids.append(i) + + # allocate "tp_degree" number of neuron cores to the worker process os.environ["NEURON_RT_NUM_CORES"] = str(tp_degree) try: num_neuron_cores_available = ( @@ -131,28 +135,33 @@ def preprocess(self, requests): tuple: A tuple with two tensors: the batch of input ids and the batch of attention masks. """ - self._clean_cache() logger.info(f"requests size={len(requests)}") - prefill_req_ids, decode, prefill_input_text = [], [], [] + prefill_req_ids, decode_req_ids, prefill_input = [], [], [] for req_id, req_data in zip(self.context.request_ids.values(), requests): # Tokenizer requests which are not prefilled yet if not req_id in self.context.cache: - data = req_data["body"] or req_data["data"] + data = req_data["data"] if isinstance(data, (bytes, bytearray)): data = data.decode("utf-8") + max_new_tokens = int( + req_data.get("max_new_tokens", str(self.max_new_tokens)) + ) + prefill_input.append( + {"data": data.strip(), "max_new_token": max_new_tokens} + ) logger.info("Received text: '%s'", data) - prefill_input_text.append(data.strip()) prefill_req_ids.append(req_id) else: - decode.append(req_id) + decode_req_ids.append(req_id) prefill_encoded = None if len(prefill_req_ids) > 0: prefill_encoded = self._run_tokenizer_batch( - prefill_input_text, prefill_req_ids + prefill_input, prefill_req_ids, decode_req_ids ) - return prefill_req_ids, prefill_encoded, decode + + return prefill_req_ids, prefill_encoded, decode_req_ids def inference(self, input_batch): """ @@ -165,18 +174,16 @@ def inference(self, input_batch): list: A list of strings with the predicted values for each input text in the batch. """ - prefill_req_ids, prefill_encoded, decode_ids = input_batch + prefill_req_ids, prefill_encoded, decode_req_ids = input_batch - # Prefill requests - results = ( - self._run_prefill_batch(prefill_req_ids, prefill_encoded) - if prefill_req_ids - else {} - ) + results = {} + if prefill_req_ids: + # Prefill requests + results.update(self._run_prefill_batch(prefill_req_ids, prefill_encoded)) + else: + # Decode the rest + results.update(self._run_decode(decode_req_ids)) - # Decode the rest - decode_result = self._run_decode(decode_ids) if decode_ids else {} - results.update(decode_result) return [results[i] for i in self.context.request_ids.values()] def postprocess(self, inference_output): @@ -193,10 +200,13 @@ def postprocess(self, inference_output): return inference_output - def _run_tokenizer_batch(self, prefill_input_text, prefill_req_ids): + def _run_tokenizer_batch(self, prefill_input, prefill_req_ids, decode_req_ids): # Pad input to match compiled model batch size - if self.batch_size > len(prefill_req_ids): - prefill_input_text.extend([""] * (self.batch_size - len(prefill_req_ids))) + prefill_input_text = [""] * self.batch_size + for req_id, input_data in zip(prefill_req_ids, prefill_input): + idx = self.batch_ids.pop() + prefill_input_text[idx] = input_data["data"] + self.context.cache[req_id] = {"batch_idx": idx} batch_encoded = self.tokenizer( prefill_input_text, @@ -210,22 +220,34 @@ def _run_tokenizer_batch(self, prefill_input_text, prefill_req_ids): encoded = { "input_ids": batch_encoded["input_ids"][idx], "attention_mask": batch_encoded["attention_mask"][idx], - "past_key_values": None, - } - self.context.cache[req_id] = { - "stopping_criteria": self._create_stopping_criteria( - req_id, max_new_tokens=self.max_new_tokens - ), - "encoded": encoded, - "prompt_length": encoded["input_ids"].shape[0], } + self.context.cache[req_id].update( + { + "stopping_criteria": self._create_stopping_criteria( + req_id, max_new_tokens=prefill_input[idx]["max_new_tokens"] + ), + "encoded": encoded, + "prompt_length": encoded["input_ids"].shape[0], + } + ) + + for req_id in decode_req_ids: + idx = self.context.cache[req_id]["batch_idx"] + batch_encoded["input_ids"][idx] = self.context.cache[req_id]["encoded"][ + "input_ids" + ] + batch_encoded["attention_mask"][idx] = self.context.cache[req_id][ + "encoded" + ]["attention_mask"] return batch_encoded @torch.no_grad() def _run_prefill_batch(self, prefill_req_ids, prefill_encoded): + self.model.reset_generation() outputs = self.model.generate( - **prefill_encoded, + prefill_encoded["input_ids"], + prefill_encoded["attention_mask"], max_new_tokens=1, return_dict_in_generate=True, use_cache=True, @@ -238,10 +260,6 @@ def _run_prefill_batch(self, prefill_req_ids, prefill_encoded): # Prefill requests results = {} for idx, req_id in enumerate(prefill_req_ids): - # Save extracted kv cache values and adjust attention mask for next call - self.context.cache[req_id]["encoded"][ - "past_key_values" - ] = self._collect_kv_cache_of_idx_in_batch(idx) self.context.cache[req_id]["encoded"]["input_ids"] = outputs.sequences[idx] device = next(iter(self.model.parameters())).device dtype = torch.int64 @@ -256,8 +274,6 @@ def _run_prefill_batch(self, prefill_req_ids, prefill_encoded): "text": outputs_decoded[idx], "tokens": outputs.sequences[idx].tolist(), } - - del self.context.kv_cache["past_key_values"] return results def _run_decode(self, ids): @@ -275,9 +291,7 @@ def _run_decode(self, ids): for idx, req_id in enumerate(ids): if req_id == "batch_padding": continue - self.context.cache[req_id]["encoded"][ - "past_key_values" - ] = self._collect_kv_cache_of_idx_in_batch(idx) + self.context.cache[req_id]["encoded"]["input_ids"] = outputs.sequences[idx] attention_mask = encoded["attention_mask"][idx] attention_mask = torch.concat( @@ -290,7 +304,6 @@ def _run_decode(self, ids): ), "tokens": [outputs.sequences[idx][-1].item()], } - del self.context.kv_cache["past_key_values"] return results def _prepare_model_inputs(self, ids): @@ -310,7 +323,6 @@ def _prepare_model_inputs(self, ids): input_ids = [] attention_mask = [] - kv_cache = {} for req_id, seq_len in zip(ids, lengths): if req_id != "batch_padding": @@ -319,31 +331,12 @@ def _prepare_model_inputs(self, ids): self.context.cache[req_id]["encoded"]["attention_mask"] ) - for layer_idx, layer_kv in enumerate( - self.context.cache[req_id]["encoded"]["past_key_values"] - ): - k, v = layer_kv - logger.info(f"layer_idx={layer_idx}, past_key_values, k={k}, v={v}") - kv_cache[layer_idx] = kv_cache.get(layer_idx, {}) - kv_cache[layer_idx][0] = kv_cache.get(layer_idx, {}).get(0, []) + [ - k - ] - kv_cache[layer_idx][1] = kv_cache.get(layer_idx, {}).get(1, []) + [ - v - ] else: config = {"device": device, "dtype": dtype} input_ids.append( self.tokenizer.pad_token_id + torch.zeros((max_len), **config) ) attention_mask.append(torch.zeros((max_len), **config)) - for layer_idx in range(len(kv_cache)): - kv_cache[layer_idx][0] = kv_cache.get(layer_idx, {}).get( - 0, [] - ) + torch.zeros((max_len), **config) - kv_cache[layer_idx][1] = kv_cache.get(layer_idx, {}).get( - 1, [] - ) + torch.zeros((max_len), **config) padded_len = input_ids[-1].size()[-1] logger.info(f"req_id={req_id}, padded_len={padded_len}, max_len={max_len}") @@ -359,64 +352,18 @@ def _prepare_model_inputs(self, ids): attention_mask[-1] = torch.concat( (torch.zeros((n), **config), attention_mask[-1]) ) - continue - size_delta = list(kv_cache[0][0][-1].size()) - size_delta[2] = n - for layer_idx in range(len(kv_cache)): - kv_cache[layer_idx][0][-1] = torch.concat( - ( - torch.zeros(size_delta, **config), - kv_cache[layer_idx][0][-1], - ), - dim=2, - ) - kv_cache[layer_idx][1][-1] = torch.concat( - ( - torch.zeros(size_delta, **config), - kv_cache[layer_idx][1][-1], - ), - dim=2, - ) elif padded_len > max_len: # Truncate padding from input_ids, attention_mask and past_key_values logger.info(f"padded_len shape={input_ids[-1].size()}") input_ids[-1] = input_ids[-1][-max_len:] attention_mask[-1] = attention_mask[-1][-max_len:] - continue - - for layer_idx in range(len(kv_cache)): - kv_cache[layer_idx][0][-1] = kv_cache[layer_idx][0][-1][ - :, :, (-max_len + 1) :, : - ] - kv_cache[layer_idx][1][-1] = kv_cache[layer_idx][1][-1][ - :, :, (-max_len + 1) :, : - ] - if req_id != "batch_padding": - del self.context.cache[req_id]["encoded"]["past_key_values"] - - for layer_idx in range(len(kv_cache)): - kv_cache[layer_idx][0] = torch.concat(kv_cache[layer_idx][0], dim=0) - kv_cache[layer_idx][1] = torch.concat(kv_cache[layer_idx][1], dim=0) - - kv_cache = tuple( - (kv_cache[layer_idx][0], kv_cache[layer_idx][1]) - for layer_idx in range(len(kv_cache)) - ) encoded = { "input_ids": torch.stack(input_ids), "attention_mask": torch.stack(attention_mask), - "past_key_values": kv_cache, } return encoded - def _collect_kv_cache_of_idx_in_batch(self, idx): - # The materialization of the tuple here is important for some reason (TODO: figure out why); Otherwise prediction differ - return tuple( - tuple(kv[idx, ...].unsqueeze(0) for kv in layers) - for layers in self.context.kv_cache["past_key_values"] - ) - def _create_stopping_criteria(self, req_id, max_new_tokens=25): class StoppingCriteria(object): def __init__( @@ -448,9 +395,3 @@ def clean_up(self): self.tokenizer.eos_token_id, max_new_tokens, ) - - def _clean_cache(self): - new_ids = set(self.context.request_ids.keys()) - for idx in self.context.kv_cache.keys(): - if idx not in new_ids: - del self.context.kv_cache[idx] diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml b/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml index f384635529..897d505c89 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml +++ b/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml @@ -2,14 +2,14 @@ minWorkers: 1 maxWorkers: 1 maxBatchDelay: 100 responseTimeout: 10800 -batchSize: 4 +batchSize: 8 continuousBatching: true handler: - model_path: "" + model_path: "model/models--meta-llama--Llama-2-13b-hf/snapshots/dc1d3b3bfdb69df26f8fc966c16353274b138c55" model_checkpoint_dir: "llama-2-13b-split" amp: "bf16" tp_degree: 12 max_length: 100 - max_new_tokens: 25 - batch_size: 4 + max_new_tokens: 50 + batch_size: 8 From dd42d7c3d4433c0e25c2b74efaa5e81830f89db2 Mon Sep 17 00:00:00 2001 From: lxning Date: Sat, 2 Dec 2023 19:34:27 -0800 Subject: [PATCH 21/49] fmt --- .../continuous_batching/inf2_handler.py | 96 +++++++------------ 1 file changed, 32 insertions(+), 64 deletions(-) diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py index 335b65bf2b..300622d9dd 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py +++ b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py @@ -28,7 +28,7 @@ def __init__(self): self.tokenizer = None self.batch_size = 1 self.initialized = False - self.batch_ids = [] + self.batch_empty_ids = [] def initialize(self, ctx: Context): """In this initialize function, the HF large model is loaded and @@ -72,7 +72,7 @@ def initialize(self, ctx: Context): ) for i in range(self.batch_size): - self.batch_ids.append(i) + self.batch_empty_ids.append(i) # allocate "tp_degree" number of neuron cores to the worker process os.environ["NEURON_RT_NUM_CORES"] = str(tp_degree) @@ -155,13 +155,13 @@ def preprocess(self, requests): else: decode_req_ids.append(req_id) - prefill_encoded = None + encoded = None if len(prefill_req_ids) > 0: - prefill_encoded = self._run_tokenizer_batch( + encoded = self._run_tokenizer_batch( prefill_input, prefill_req_ids, decode_req_ids ) - return prefill_req_ids, prefill_encoded, decode_req_ids + return prefill_req_ids, decode_req_ids, encoded def inference(self, input_batch): """ @@ -174,12 +174,12 @@ def inference(self, input_batch): list: A list of strings with the predicted values for each input text in the batch. """ - prefill_req_ids, prefill_encoded, decode_req_ids = input_batch + prefill_req_ids, decode_req_ids, encoded = input_batch results = {} if prefill_req_ids: # Prefill requests - results.update(self._run_prefill_batch(prefill_req_ids, prefill_encoded)) + results.update(self._run_prefill(prefill_req_ids + decode_req_ids, encoded)) else: # Decode the rest results.update(self._run_decode(decode_req_ids)) @@ -204,7 +204,7 @@ def _run_tokenizer_batch(self, prefill_input, prefill_req_ids, decode_req_ids): # Pad input to match compiled model batch size prefill_input_text = [""] * self.batch_size for req_id, input_data in zip(prefill_req_ids, prefill_input): - idx = self.batch_ids.pop() + idx = self.batch_empty_ids.pop() prefill_input_text[idx] = input_data["data"] self.context.cache[req_id] = {"batch_idx": idx} @@ -231,56 +231,26 @@ def _run_tokenizer_batch(self, prefill_input, prefill_req_ids, decode_req_ids): } ) - for req_id in decode_req_ids: - idx = self.context.cache[req_id]["batch_idx"] - batch_encoded["input_ids"][idx] = self.context.cache[req_id]["encoded"][ - "input_ids" - ] - batch_encoded["attention_mask"][idx] = self.context.cache[req_id][ - "encoded" - ]["attention_mask"] - - return batch_encoded + ids = prefill_req_ids + decode_req_ids + return self._prepare_model_inputs(ids) @torch.no_grad() - def _run_prefill_batch(self, prefill_req_ids, prefill_encoded): + def _run_prefill(self, prefill_req_ids, prefill_encoded): self.model.reset_generation() - outputs = self.model.generate( - prefill_encoded["input_ids"], - prefill_encoded["attention_mask"], - max_new_tokens=1, - return_dict_in_generate=True, - use_cache=True, - ) - - outputs_decoded = self.tokenizer.batch_decode( - outputs.sequences, skip_special_tokens=True - ) - - # Prefill requests - results = {} - for idx, req_id in enumerate(prefill_req_ids): - self.context.cache[req_id]["encoded"]["input_ids"] = outputs.sequences[idx] - device = next(iter(self.model.parameters())).device - dtype = torch.int64 - config = {"device": device, "dtype": dtype} - attention_mask = self.context.cache[req_id]["encoded"]["attention_mask"] - attention_mask = torch.concat( - (attention_mask, torch.ones((1), **config)), dim=0 - ) - self.context.cache[req_id]["encoded"]["attention_mask"] = attention_mask - - results[req_id] = { - "text": outputs_decoded[idx], - "tokens": outputs.sequences[idx].tolist(), - } - return results + return self._generate_token(prefill_req_ids, prefill_encoded) + @torch.no_grad() def _run_decode(self, ids): encoded = self._prepare_model_inputs(ids) + return self._generate_token(ids, encoded) + def _generate_token(self, req_ids, encoded): outputs = self.model.generate( - **encoded, max_new_tokens=1, return_dict_in_generate=True, use_cache=True + encoded["input_ids"], + encoded["attention_mask"], + max_new_tokens=1, + return_dict_in_generate=True, + use_cache=True, ) device = next(iter(self.model.parameters())).device @@ -288,10 +258,7 @@ def _run_decode(self, ids): config = {"device": device, "dtype": dtype} results = {} - for idx, req_id in enumerate(ids): - if req_id == "batch_padding": - continue - + for idx, req_id in enumerate(req_ids): self.context.cache[req_id]["encoded"]["input_ids"] = outputs.sequences[idx] attention_mask = encoded["attention_mask"][idx] attention_mask = torch.concat( @@ -307,15 +274,13 @@ def _run_decode(self, ids): return results def _prepare_model_inputs(self, ids): - lengths = list( - torch.sum(self.context.cache[i]["encoded"]["attention_mask"]).item() - for i in ids - ) - max_len = max(lengths) + lengths = [0] * self.batch_size + for i in ids: + lengths[self.context.cache[i]["batch_idx"]] = torch.sum( + self.context.cache[i]["encoded"]["attention_mask"] + ).item() - for idx in range(self.batch_size - len(ids)): - ids.append("batch_padding") - lengths.append(0) + max_len = max(lengths) device = next(iter(self.model.parameters())).device dtype = torch.int64 @@ -325,12 +290,11 @@ def _prepare_model_inputs(self, ids): attention_mask = [] for req_id, seq_len in zip(ids, lengths): - if req_id != "batch_padding": + if seq_len > 0: input_ids.append(self.context.cache[req_id]["encoded"]["input_ids"]) attention_mask.append( self.context.cache[req_id]["encoded"]["attention_mask"] ) - else: config = {"device": device, "dtype": dtype} input_ids.append( @@ -369,12 +333,14 @@ class StoppingCriteria(object): def __init__( self, cache, + batch_empty_ids, req_id, stop_token, max_new_tokens, ): self.req_id = req_id self.cache = cache + self.batch_empty_ids = batch_empty_ids self.max_new_tokens = max_new_tokens self.stop_token = stop_token @@ -387,10 +353,12 @@ def __call__(self, res): return False def clean_up(self): + self.batch_empty_ids.append(self.cache[self.req_id]["batch_idx"]) del self.cache[self.req_id] return StoppingCriteria( self.context.cache, + self.batch_empty_ids, req_id, self.tokenizer.eos_token_id, max_new_tokens, From 100927ed0d5223e91483d5a7c7fb1c2d692c5154 Mon Sep 17 00:00:00 2001 From: lxning Date: Mon, 11 Dec 2023 09:54:31 -0800 Subject: [PATCH 22/49] fmt --- .../continuous_batching/inf2_handler.py | 100 ++++++++++-------- 1 file changed, 57 insertions(+), 43 deletions(-) diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py index 300622d9dd..a0cbb8c98b 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py +++ b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py @@ -1,6 +1,5 @@ import logging import os -import types from abc import ABC import torch @@ -51,14 +50,16 @@ def initialize(self, ctx: Context): ) save_pretrained_split(model_cpu, model_checkpoint_path) # Load and save tokenizer for the model - tokenizer = AutoTokenizer.from_pretrained(model_path, return_tensors="pt") + tokenizer = AutoTokenizer.from_pretrained( + model_path, return_tensors="pt", padding_side="left" + ) tokenizer.save_pretrained(model_checkpoint_path) os.environ["NEURONX_CACHE"] = "on" os.environ["NEURON_COMPILE_CACHE_URL"] = f"{model_dir}/neuron_cache" os.environ["NEURON_CC_FLAGS"] = "-O1 --model-type=transformer" - # settings for model compiliation and loading + # settings for model compilation and loading amp = ctx.model_yaml_config.get("handler", {}).get("amp", "fp32") tp_degree = ctx.model_yaml_config.get("handler", {}).get("tp_degree", 6) context_length_estimate = ctx.model_yaml_config.get("handler", {}).get( @@ -74,7 +75,7 @@ def initialize(self, ctx: Context): for i in range(self.batch_size): self.batch_empty_ids.append(i) - # allocate "tp_degree" number of neuron cores to the worker process + # allocate "tp_degree" number of neuron cores to the worker process os.environ["NEURON_RT_NUM_CORES"] = str(tp_degree) try: num_neuron_cores_available = ( @@ -91,7 +92,9 @@ def initialize(self, ctx: Context): raise error - self.tokenizer = LlamaTokenizer.from_pretrained(model_checkpoint_path) + self.tokenizer = LlamaTokenizer.from_pretrained( + model_checkpoint_path, return_tensors="pt", padding_side="left" + ) self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.tokenizer.padding_side = "left" @@ -109,19 +112,6 @@ def initialize(self, ctx: Context): model_config = AutoConfig.from_pretrained(model_checkpoint_path) self.model = HuggingFaceGenerationModelAdapter(model_config, self.model) - # Replace _update_model_kwargs_for_generation of model with a method that extracts the kv cache for us - old_update = self.model._update_model_kwargs_for_generation - ctx.cache = {} - ctx.kv_cache = {} - - def extract_past_key_values_func(self, *args, **kwargs): - ctx.kv_cache["past_key_values"] = args[0]["past_key_values"] - return old_update(*args, **kwargs) - - self.model._update_model_kwargs_for_generation = types.MethodType( - extract_past_key_values_func, self.model - ) - logger.info("Model %s loaded successfully", ctx.model_name) self.initialized = True @@ -148,20 +138,23 @@ def preprocess(self, requests): req_data.get("max_new_tokens", str(self.max_new_tokens)) ) prefill_input.append( - {"data": data.strip(), "max_new_token": max_new_tokens} + {"data": data.strip(), "max_new_tokens": max_new_tokens} ) - logger.info("Received text: '%s'", data) + logger.info( + "Received text: '%s', max_new_token=%d", data, max_new_tokens + ) prefill_req_ids.append(req_id) else: decode_req_ids.append(req_id) - encoded = None if len(prefill_req_ids) > 0: - encoded = self._run_tokenizer_batch( + encoded = self._run_tokenize_prefill( prefill_input, prefill_req_ids, decode_req_ids ) + else: + encoded = self._prepare_model_inputs() - return prefill_req_ids, decode_req_ids, encoded + return encoded def inference(self, input_batch): """ @@ -200,13 +193,20 @@ def postprocess(self, inference_output): return inference_output - def _run_tokenizer_batch(self, prefill_input, prefill_req_ids, decode_req_ids): + def _run_tokenize_prefill(self, prefill_input, prefill_req_ids, decode_req_ids): # Pad input to match compiled model batch size prefill_input_text = [""] * self.batch_size - for req_id, input_data in zip(prefill_req_ids, prefill_input): + for i, req_id in enumerate(prefill_req_ids): idx = self.batch_empty_ids.pop() + input_data = prefill_input[i] prefill_input_text[idx] = input_data["data"] - self.context.cache[req_id] = {"batch_idx": idx} + self.context.cache[req_id] = { + "batch_idx": idx, + "stopping_criteria": self._create_stopping_criteria( + req_id, max_new_tokens=input_data["max_new_tokens"] + ), + } + logger.info(f"_run_tokenizer_batch prefill_input_text={prefill_input_text}") batch_encoded = self.tokenizer( prefill_input_text, @@ -216,16 +216,16 @@ def _run_tokenizer_batch(self, prefill_input, prefill_req_ids, decode_req_ids): return_token_type_ids=False, truncation=True, ) - for idx, req_id in enumerate(prefill_req_ids): + seq_length = min(batch_encoded.input_ids.shape[-1]) + for req_id in prefill_req_ids: + idx = self.context.cache[req_id]["batch_idx"] encoded = { - "input_ids": batch_encoded["input_ids"][idx], - "attention_mask": batch_encoded["attention_mask"][idx], + "input_ids": batch_encoded.input_ids[idx, :seq_length], + "attention_mask": batch_encoded.attention_mask[idx, :seq_length], } + logger.info(f"encoded={encoded}") self.context.cache[req_id].update( { - "stopping_criteria": self._create_stopping_criteria( - req_id, max_new_tokens=prefill_input[idx]["max_new_tokens"] - ), "encoded": encoded, "prompt_length": encoded["input_ids"].shape[0], } @@ -247,18 +247,22 @@ def _run_decode(self, ids): def _generate_token(self, req_ids, encoded): outputs = self.model.generate( encoded["input_ids"], - encoded["attention_mask"], + attention_mask=encoded["attention_mask"], max_new_tokens=1, return_dict_in_generate=True, use_cache=True, ) + outputs_decoded = self.tokenizer.batch_decode( + outputs.sequences, skip_special_tokens=True + ) device = next(iter(self.model.parameters())).device dtype = torch.int64 config = {"device": device, "dtype": dtype} results = {} - for idx, req_id in enumerate(req_ids): + for req_id in req_ids: + idx = self.context.cache[req_id]["batch_idx"] self.context.cache[req_id]["encoded"]["input_ids"] = outputs.sequences[idx] attention_mask = encoded["attention_mask"][idx] attention_mask = torch.concat( @@ -266,21 +270,23 @@ def _generate_token(self, req_ids, encoded): ) self.context.cache[req_id]["encoded"]["attention_mask"] = attention_mask results[req_id] = { - "text": self.tokenizer.decode( - outputs.sequences[idx][-1], skip_special_tokens=True - ), - "tokens": [outputs.sequences[idx][-1].item()], + "text": outputs_decoded[idx], + "tokens": outputs.sequences[idx].tolist(), } return results - def _prepare_model_inputs(self, ids): + def _prepare_model_inputs(self, req_ids): lengths = [0] * self.batch_size - for i in ids: - lengths[self.context.cache[i]["batch_idx"]] = torch.sum( - self.context.cache[i]["encoded"]["attention_mask"] + idx_to_req_id = {} + for req_id in req_ids: + idx = self.context.cache[req_id]["batch_idx"] + lengths[idx] = torch.sum( + self.context.cache[req_id]["encoded"]["attention_mask"] ).item() + idx_to_req_id[idx] = req_id max_len = max(lengths) + logger.info(f"_prepare_model_inputs lengths={lengths}") device = next(iter(self.model.parameters())).device dtype = torch.int64 @@ -289,8 +295,11 @@ def _prepare_model_inputs(self, ids): input_ids = [] attention_mask = [] - for req_id, seq_len in zip(ids, lengths): + for idx in range(self.batch_size): + seq_len = lengths[idx] if seq_len > 0: + req_id = idx_to_req_id[idx] + logger.info(f"idx={idx}, seq_len={seq_len}") input_ids.append(self.context.cache[req_id]["encoded"]["input_ids"]) attention_mask.append( self.context.cache[req_id]["encoded"]["attention_mask"] @@ -322,10 +331,15 @@ def _prepare_model_inputs(self, ids): input_ids[-1] = input_ids[-1][-max_len:] attention_mask[-1] = attention_mask[-1][-max_len:] + logger.info(f"input_ids={input_ids}, attention_mask={attention_mask}") + encoded = { "input_ids": torch.stack(input_ids), "attention_mask": torch.stack(attention_mask), + # "input_ids": input_ids, + # "attention_mask": attention_mask, } + logger.info(f"_prepare_model_inputs encoded={encoded}") return encoded def _create_stopping_criteria(self, req_id, max_new_tokens=25): From 06e84171ab02acd21daee42707dfee88174242b8 Mon Sep 17 00:00:00 2001 From: lxning Date: Thu, 14 Dec 2023 10:05:30 -0800 Subject: [PATCH 23/49] fmt --- .../continuous_batching/inf2_handler.py | 166 ++++++++++++++++++ 1 file changed, 166 insertions(+) diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py index a0cbb8c98b..c5a584f136 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py +++ b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py @@ -15,6 +15,172 @@ logger = logging.getLogger(__name__) +class NeuronxContinuousBatching: + def __init__(self): + self.empty_idx_queue = [] + self.idx_to_req_id = {} + self.decoded_req_id = {} + self.batch_store = [] + self.batch_size = 1 + for idx, batch_id in enumerate(reversed(range(self.batch_size))): + self.empty_idx_queue.append(batch_id) + self.batch_store[idx] = {} + + def _get_empty_idx(self): + assert len(self.empty_idx_queue) > 0 + return self.empty_idx_queue.pop() + + def _add_empty_idx(self, idx): + self.empty_idx_queue.append(idx) + + def _get_idx(self, req_id): + idx = self.decoded_req_id.get(req_id, -1) + assert idx > 0 + return idx + + def _get_req_id(self, idx): + req_id = self.idx_to_req_id.get(idx, None) + assert req_id is not None + return req_id + + def _set_req_id(self, idx, req_id): + self.idx_to_req_id[idx] = req_id + + def _init_batch_store(self, batch_encoded): + for i in range(batch_encoded.input_ids.size(dim=0)): + self.batch_store[i] = { + "input_ids": batch_encoded.input_ids[i, :], + "attention_mask": batch_encoded.attention_mask[i, :], + } + + def _update_batch_store(self, idx, input_ids, attention_mask, length): + self.batch_store[idx] = { + "input_ids": input_ids[-length:], + "attention_mask": attention_mask[-length:], + } + + def preprocess(self, batch_requests, ctx): + prefill_input_text = [] + if len(self.decoded_req_id) == 0: + prefill_input_text = [""] * self.batch_size + + prefill_req_ids, decode_req_ids, prefill_input = [], [], [] + for req_id, req_data in zip(ctx.request_ids.values(), batch_requests): + # Tokenizer requests which are not prefilled yet + if not req_id in self.decoded_req_id: + idx = self._get_empty_idx() + self._set_req_id(idx, req_id) + + data = req_data["data"] + if isinstance(data, (bytes, bytearray)): + data = data.decode("utf-8") + max_new_tokens = int( + req_data.get("max_new_tokens", int(self.max_new_tokens)) + ) + if len(prefill_input_text) == self.batch_size: + prefill_input_text[idx] = data.strip() + else: + prefill_input_text.append(data.strip()) + + self.batch_store[idx] = { + "req_id": req_id, + "stopping_criteria": self._create_stopping_criteria( + req_id, max_new_tokens + ), + } + logger.info( + "Received text: '%s', max_new_token=%d", data, max_new_tokens + ) + + prefill_req_ids.append(req_id) + else: + decode_req_ids.append(req_id) + + def prefill(self, prefill_input, prefill_req_ids, decoded_req_ids): + if len(self.decoded_req_id) == 0: + prefill_input_text = [""] * self.batch_size + + for i, req_id in enumerate(prefill_req_ids): + idx = self._get_idx(req_id) + input_data = prefill_input[i] + prefill_input_text[idx] = input_data["data"] + self.context.cache[req_id] = { + "batch_idx": idx, + "stopping_criteria": self._create_stopping_criteria( + req_id, max_new_tokens=input_data["max_new_tokens"] + ), + } + logger.info(f"_run_tokenizer_batch prefill_input_text={prefill_input_text}") + + batch_encoded = self.tokenizer( + prefill_input_text, + return_tensors="pt", + padding=True, + add_special_tokens=True, + return_token_type_ids=False, + truncation=True, + ) + seq_length = min(batch_encoded.input_ids.shape[-1]) + for req_id in prefill_req_ids: + idx = self.context.cache[req_id]["batch_idx"] + encoded = { + "input_ids": batch_encoded.input_ids[idx, :seq_length], + "attention_mask": batch_encoded.attention_mask[idx, :seq_length], + } + logger.info(f"encoded={encoded}") + self.context.cache[req_id].update( + { + "encoded": encoded, + "prompt_length": encoded["input_ids"].shape[0], + } + ) + + ids = prefill_req_ids + decode_req_ids + return self._prepare_model_inputs(ids) + + def decode(self): + pass + + def clean(self, req_id): + pass + + def _create_stopping_criteria(self, req_id, max_new_tokens=25): + class StoppingCriteria(object): + def __init__( + self, + cache, + batch_empty_ids, + req_id, + stop_token, + max_new_tokens, + ): + self.req_id = req_id + self.cache = cache + self.batch_empty_ids = batch_empty_ids + self.max_new_tokens = max_new_tokens + self.stop_token = stop_token + + def __call__(self, res): + self.max_new_tokens -= 1 + + if self.max_new_tokens == 0 or res["tokens"][-1] == self.stop_token: + self.clean_up() + return True + return False + + def clean_up(self): + self.batch_empty_ids.append(self.cache[self.req_id]["batch_idx"]) + del self.cache[self.req_id] + + return StoppingCriteria( + self.context.cache, + self.batch_empty_ids, + req_id, + self.tokenizer.eos_token_id, + max_new_tokens, + ) + + class LlamaHandler(BaseHandler, ABC): """ Transformers handler class for sequence, token classification and question answering. From 607a349df205170f5e2de5bb1be77b8841452a8b Mon Sep 17 00:00:00 2001 From: lxning Date: Mon, 18 Dec 2023 16:02:46 -0800 Subject: [PATCH 24/49] fmt --- .../continuous_batching/inf2_handler.py | 609 ++++++------------ 1 file changed, 198 insertions(+), 411 deletions(-) diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py index c5a584f136..ac0233a2d0 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py +++ b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py @@ -1,13 +1,12 @@ import logging import os -from abc import ABC import torch import torch_neuronx -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer -from transformers_neuronx.generation_utils import HuggingFaceGenerationModelAdapter +from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer from transformers_neuronx.llama.model import LlamaForSampling from transformers_neuronx.module import save_pretrained_split +from transformers_neuronx.sampling import select_tokens from ts.context import Context from ts.torch_handler.base_handler import BaseHandler @@ -15,193 +14,25 @@ logger = logging.getLogger(__name__) -class NeuronxContinuousBatching: +class LlamaContinuousBatchingHandler(BaseHandler): def __init__(self): - self.empty_idx_queue = [] - self.idx_to_req_id = {} - self.decoded_req_id = {} - self.batch_store = [] - self.batch_size = 1 - for idx, batch_id in enumerate(reversed(range(self.batch_size))): - self.empty_idx_queue.append(batch_id) - self.batch_store[idx] = {} - - def _get_empty_idx(self): - assert len(self.empty_idx_queue) > 0 - return self.empty_idx_queue.pop() - - def _add_empty_idx(self, idx): - self.empty_idx_queue.append(idx) - - def _get_idx(self, req_id): - idx = self.decoded_req_id.get(req_id, -1) - assert idx > 0 - return idx - - def _get_req_id(self, idx): - req_id = self.idx_to_req_id.get(idx, None) - assert req_id is not None - return req_id - - def _set_req_id(self, idx, req_id): - self.idx_to_req_id[idx] = req_id - - def _init_batch_store(self, batch_encoded): - for i in range(batch_encoded.input_ids.size(dim=0)): - self.batch_store[i] = { - "input_ids": batch_encoded.input_ids[i, :], - "attention_mask": batch_encoded.attention_mask[i, :], - } - - def _update_batch_store(self, idx, input_ids, attention_mask, length): - self.batch_store[idx] = { - "input_ids": input_ids[-length:], - "attention_mask": attention_mask[-length:], - } - - def preprocess(self, batch_requests, ctx): - prefill_input_text = [] - if len(self.decoded_req_id) == 0: - prefill_input_text = [""] * self.batch_size - - prefill_req_ids, decode_req_ids, prefill_input = [], [], [] - for req_id, req_data in zip(ctx.request_ids.values(), batch_requests): - # Tokenizer requests which are not prefilled yet - if not req_id in self.decoded_req_id: - idx = self._get_empty_idx() - self._set_req_id(idx, req_id) - - data = req_data["data"] - if isinstance(data, (bytes, bytearray)): - data = data.decode("utf-8") - max_new_tokens = int( - req_data.get("max_new_tokens", int(self.max_new_tokens)) - ) - if len(prefill_input_text) == self.batch_size: - prefill_input_text[idx] = data.strip() - else: - prefill_input_text.append(data.strip()) - - self.batch_store[idx] = { - "req_id": req_id, - "stopping_criteria": self._create_stopping_criteria( - req_id, max_new_tokens - ), - } - logger.info( - "Received text: '%s', max_new_token=%d", data, max_new_tokens - ) - - prefill_req_ids.append(req_id) - else: - decode_req_ids.append(req_id) - - def prefill(self, prefill_input, prefill_req_ids, decoded_req_ids): - if len(self.decoded_req_id) == 0: - prefill_input_text = [""] * self.batch_size - - for i, req_id in enumerate(prefill_req_ids): - idx = self._get_idx(req_id) - input_data = prefill_input[i] - prefill_input_text[idx] = input_data["data"] - self.context.cache[req_id] = { - "batch_idx": idx, - "stopping_criteria": self._create_stopping_criteria( - req_id, max_new_tokens=input_data["max_new_tokens"] - ), - } - logger.info(f"_run_tokenizer_batch prefill_input_text={prefill_input_text}") - - batch_encoded = self.tokenizer( - prefill_input_text, - return_tensors="pt", - padding=True, - add_special_tokens=True, - return_token_type_ids=False, - truncation=True, - ) - seq_length = min(batch_encoded.input_ids.shape[-1]) - for req_id in prefill_req_ids: - idx = self.context.cache[req_id]["batch_idx"] - encoded = { - "input_ids": batch_encoded.input_ids[idx, :seq_length], - "attention_mask": batch_encoded.attention_mask[idx, :seq_length], - } - logger.info(f"encoded={encoded}") - self.context.cache[req_id].update( - { - "encoded": encoded, - "prompt_length": encoded["input_ids"].shape[0], - } - ) - - ids = prefill_req_ids + decode_req_ids - return self._prepare_model_inputs(ids) - - def decode(self): - pass - - def clean(self, req_id): - pass - - def _create_stopping_criteria(self, req_id, max_new_tokens=25): - class StoppingCriteria(object): - def __init__( - self, - cache, - batch_empty_ids, - req_id, - stop_token, - max_new_tokens, - ): - self.req_id = req_id - self.cache = cache - self.batch_empty_ids = batch_empty_ids - self.max_new_tokens = max_new_tokens - self.stop_token = stop_token - - def __call__(self, res): - self.max_new_tokens -= 1 - - if self.max_new_tokens == 0 or res["tokens"][-1] == self.stop_token: - self.clean_up() - return True - return False - - def clean_up(self): - self.batch_empty_ids.append(self.cache[self.req_id]["batch_idx"]) - del self.cache[self.req_id] - - return StoppingCriteria( - self.context.cache, - self.batch_empty_ids, - req_id, - self.tokenizer.eos_token_id, - max_new_tokens, - ) - - -class LlamaHandler(BaseHandler, ABC): - """ - Transformers handler class for sequence, token classification and question answering. - """ - - def __init__(self): - super(LlamaHandler, self).__init__() - self.max_length = None - self.max_new_tokens = None + super(LlamaContinuousBatchingHandler, self).__init__() + # the queue of seq_ids which are available for a new request + self.batch_size = 2 + self.max_new_tokens = 25 + self.max_length = 100 self.tokenizer = None - self.batch_size = 1 - self.initialized = False - self.batch_empty_ids = [] + self.decode_next_tokens = None + self.decode_cache_ids = None + self.decode_seq_ids = None + self.empty_seq_ids = [] + # map seq_id to req_id + self.seq_id_to_req_id = {} def initialize(self, ctx: Context): - """In this initialize function, the HF large model is loaded and - partitioned using DeepSpeed. - Args: - ctx (context): It is a JSON Object containing information - pertaining to the model artifacts parameters. - """ + super().initialize(ctx) + logger.info(f"Initialized {self.__class__}") + model_dir = ctx.system_properties.get("model_dir") model_checkpoint_dir = ctx.model_yaml_config.get("handler", {}).get( "model_checkpoint_dir", "" @@ -225,22 +56,25 @@ def initialize(self, ctx: Context): os.environ["NEURON_COMPILE_CACHE_URL"] = f"{model_dir}/neuron_cache" os.environ["NEURON_CC_FLAGS"] = "-O1 --model-type=transformer" + self.max_length = int( + ctx.model_yaml_config.get("handler", {}).get("max_length", self.max_length) + ) + self.max_new_tokens = int( + ctx.model_yaml_config.get("handler", {}).get( + "max_new_tokens", self.max_new_tokens + ) + ) + self.batch_size = int( + ctx.model_yaml_config.get("handler", {}).get("batch_size", self.batch_size) + ) + # settings for model compilation and loading amp = ctx.model_yaml_config.get("handler", {}).get("amp", "fp32") tp_degree = ctx.model_yaml_config.get("handler", {}).get("tp_degree", 6) context_length_estimate = ctx.model_yaml_config.get("handler", {}).get( - "context_length_estimate", None - ) - - self.max_length = int(ctx.model_yaml_config["handler"]["max_length"]) - self.max_new_tokens = ctx.model_yaml_config["handler"]["max_new_tokens"] - self.batch_size = int( - ctx.model_yaml_config.get("handler", {}).get("batch_size", 1) + "context_length_estimate", self.max_length ) - for i in range(self.batch_size): - self.batch_empty_ids.append(i) - # allocate "tp_degree" number of neuron cores to the worker process os.environ["NEURON_RT_NUM_CORES"] = str(tp_degree) try: @@ -269,277 +103,230 @@ def initialize(self, ctx: Context): batch_size=self.batch_size, amp=amp, tp_degree=tp_degree, - n_positions=self.max_length if self.max_length > 2048 else 2048, + n_positions=self.max_length, context_length_estimate=context_length_estimate, ) logger.info("Starting to compile the model") self.model.to_neuron() logger.info("Model has been successfully compiled") - model_config = AutoConfig.from_pretrained(model_checkpoint_path) - self.model = HuggingFaceGenerationModelAdapter(model_config, self.model) + + # 1D: [seq_id] + # an empty slot if seq_id is -1 + self.decode_seq_ids = torch.full([self.batch_size], -1) + # 2D:[batch_size, next_cache_id] + self.decode_cache_ids = torch.zeros(self.batch_size, 1, dtype=torch.int64) + # 2D: [batch_size, next_token] + self.decode_next_tokens = torch.zeros(self.batch_size, 1, dtype=torch.int64) + + for seq_id, batch_id in enumerate(reversed(range(self.batch_size))): + self.empty_seq_ids.append(batch_id) logger.info("Model %s loaded successfully", ctx.model_name) self.initialized = True def preprocess(self, requests): - """ - Basic text preprocessing, based on the user's choice of application mode. - Args: - requests (list): A list of dictionaries with a "data" or "body" field, each - containing the input text to be processed. - Returns: - tuple: A tuple with two tensors: the batch of input ids and the batch of - attention masks. - """ - logger.info(f"requests size={len(requests)}") - - prefill_req_ids, decode_req_ids, prefill_input = [], [], [] + prefill_req_ids, prefill_seq_ids, prefill_input_text, decode_seq_ids = ( + [], + [], + [], + [], + ) for req_id, req_data in zip(self.context.request_ids.values(), requests): - # Tokenizer requests which are not prefilled yet if not req_id in self.context.cache: + prefill_req_ids.append(req_id) + seq_id = self._get_empty_seq_id() + prefill_seq_ids.append(seq_id) + data = req_data["data"] if isinstance(data, (bytes, bytearray)): data = data.decode("utf-8") max_new_tokens = int( - req_data.get("max_new_tokens", str(self.max_new_tokens)) - ) - prefill_input.append( - {"data": data.strip(), "max_new_tokens": max_new_tokens} + req_data.get("max_new_tokens", self.max_new_tokens) ) - logger.info( - "Received text: '%s', max_new_token=%d", data, max_new_tokens - ) - prefill_req_ids.append(req_id) + prefill_input_text.append(data.strip()) + + self.context.cache[req_id] = { + "seq_id": seq_id, + "stopping_criteria": self._create_stopping_criteria( + req_id=req_id, seq_id=seq_id, max_new_tokens=max_new_tokens + ), + } else: - decode_req_ids.append(req_id) + decode_seq_ids.append(self.context.cache[req_id]["seq_id"]) + prefill_tokens = None if len(prefill_req_ids) > 0: - encoded = self._run_tokenize_prefill( - prefill_input, prefill_req_ids, decode_req_ids + prefill_tokens = self.tokenizer( + prefill_input_text, return_tensors="pt", padding=True ) - else: - encoded = self._prepare_model_inputs() + return prefill_tokens, prefill_seq_ids, decode_seq_ids - return encoded + def inference(self, inputs): + prefill_tokens, prefill_seq_ids, req_decode_seq_ids = inputs + results = {} + # Test if this is the beginning of a continuous batching + go_to_decode = True if len(req_decode_seq_ids) > 0 else False + if len(prefill_seq_ids) > 0: + prefill_next_tokens, prefill_cache_ids = self._run_prefill( + prefill_tokens, prefill_seq_ids + ) + for i, prefill_seq_id in enumerate(prefill_seq_ids): + self._update_results( + results, prefill_seq_id, i, prefill_cache_ids, prefill_next_tokens + ) - def inference(self, input_batch): - """ - Predicts the class (or classes) of the received text using the serialized transformers - checkpoint. - Args: - input_batch (tuple): A tuple with two tensors: the batch of input ids and the batch - of attention masks, as returned by the preprocess function. - Returns: - list: A list of strings with the predicted values for each input text in the batch. - """ + if go_to_decode: + decode_seq_ids = torch.where(self.decode_seq_ids > -1) + decode_cache_ids = torch.where(self.decode_cache_ids > 0) + decode_next_tokens = torch.where(self.decode_next_tokens > 0) + next_tokens = self._run_decode( + decode_next_tokens, decode_cache_ids, decode_seq_ids + ) - prefill_req_ids, decode_req_ids, encoded = input_batch + filter_prefill_seq_ids = ( + torch.isin(decode_seq_ids, torch.as_tensor(prefill_seq_ids)) + if len(prefill_seq_ids) > 0 + else torch.full(decode_seq_ids.shape, False) + ) - results = {} - if prefill_req_ids: - # Prefill requests - results.update(self._run_prefill(prefill_req_ids + decode_req_ids, encoded)) - else: - # Decode the rest - results.update(self._run_decode(decode_req_ids)) + decode_cache_ids = decode_cache_ids + 1 + for i, is_prefill_seq_id in enumerate(filter_prefill_seq_ids): + if not is_prefill_seq_id: + seq_id = decode_seq_ids[i] + if seq_id in req_decode_seq_ids: + self._update_results( + results, seq_id, i, decode_cache_ids, next_tokens + ) + else: + req_id = self._get_req_id(seq_id) + logger.warning( + f"Found request id:{req_id} in cache, but not in batch requests. Delete it" + ) + self.clean_up(self, seq_id, req_id) - return [results[i] for i in self.context.request_ids.values()] + return results def postprocess(self, inference_output): - """Post Process Function converts the predicted response into Torchserve readable format. - Args: - inference_output (list): It contains the predicted response of the input text. - Returns: - (list): Returns a list of the Predictions and Explanations. - """ self.context.stopping_criteria = [ - self.context.cache[i]["stopping_criteria"] - for i in self.context.request_ids.values() + self.batch_store[self._get_seq_id(req_id)]["stopping_criteria"]( + req_id, + ) + for req_id in self.context.request_ids.values() ] return inference_output - def _run_tokenize_prefill(self, prefill_input, prefill_req_ids, decode_req_ids): - # Pad input to match compiled model batch size - prefill_input_text = [""] * self.batch_size - for i, req_id in enumerate(prefill_req_ids): - idx = self.batch_empty_ids.pop() - input_data = prefill_input[i] - prefill_input_text[idx] = input_data["data"] - self.context.cache[req_id] = { - "batch_idx": idx, - "stopping_criteria": self._create_stopping_criteria( - req_id, max_new_tokens=input_data["max_new_tokens"] - ), - } - logger.info(f"_run_tokenizer_batch prefill_input_text={prefill_input_text}") - - batch_encoded = self.tokenizer( - prefill_input_text, - return_tensors="pt", - padding=True, - add_special_tokens=True, - return_token_type_ids=False, - truncation=True, - ) - seq_length = min(batch_encoded.input_ids.shape[-1]) - for req_id in prefill_req_ids: - idx = self.context.cache[req_id]["batch_idx"] - encoded = { - "input_ids": batch_encoded.input_ids[idx, :seq_length], - "attention_mask": batch_encoded.attention_mask[idx, :seq_length], - } - logger.info(f"encoded={encoded}") - self.context.cache[req_id].update( - { - "encoded": encoded, - "prompt_length": encoded["input_ids"].shape[0], - } - ) + def _get_empty_seq_id(self): + assert len(self.empty_seq_ids) > 0 + return self.empty_seq_ids.pop() - ids = prefill_req_ids + decode_req_ids - return self._prepare_model_inputs(ids) - - @torch.no_grad() - def _run_prefill(self, prefill_req_ids, prefill_encoded): - self.model.reset_generation() - return self._generate_token(prefill_req_ids, prefill_encoded) - - @torch.no_grad() - def _run_decode(self, ids): - encoded = self._prepare_model_inputs(ids) - return self._generate_token(ids, encoded) - - def _generate_token(self, req_ids, encoded): - outputs = self.model.generate( - encoded["input_ids"], - attention_mask=encoded["attention_mask"], - max_new_tokens=1, - return_dict_in_generate=True, - use_cache=True, - ) - outputs_decoded = self.tokenizer.batch_decode( - outputs.sequences, skip_special_tokens=True - ) + def _add_empty_seq_id(self, seq_id): + self.empty_seq_ids.append(seq_id) - device = next(iter(self.model.parameters())).device - dtype = torch.int64 - config = {"device": device, "dtype": dtype} + def _get_seq_id(self, req_id): + seq_id = None + cache = self.context.cache.get(req_id, None) + if cache: + seq_id = cache["seq_id"] + assert seq_id is not None, "{req_id} must have seq_id" + return seq_id - results = {} - for req_id in req_ids: - idx = self.context.cache[req_id]["batch_idx"] - self.context.cache[req_id]["encoded"]["input_ids"] = outputs.sequences[idx] - attention_mask = encoded["attention_mask"][idx] - attention_mask = torch.concat( - (attention_mask, torch.ones((1), **config)), dim=0 + def _get_req_id(self, seq_id): + req_id = self.seq_id_to_req_id(seq_id, None) + assert req_id is not None + return req_id + + def _pad_to_max(self, x): + for idx, item in enumerate(x): + x[idx] = x[idx] + [0] * (self.max_length - len(x[idx])) + return x + + def _run_prefill(self, tokens, seq_ids): + input_ids, attention_mask = tokens["input_ids"], tokens["attention_mask"] + logger.info( + f"before padding: input_ids={input_ids}, attention_mask={attention_mask}" + ) + input_ids = self._pad_to_max(input_ids) + attention_mask = self._pad_to_max(attention_mask) + logger.info( + f"after padding: input_ids={input_ids}, attention_mask={attention_mask}" + ) + n_active_seqs, context_len = input_ids.shape + cache_ids = ( + torch.arange(context_len) + .reshape(1, context_len) + .expand(n_active_seqs, context_len) + .mul(attention_mask) + ) + with torch.inference_mode(): + logits = self.model( + input_ids, cache_ids=self.cache_ids, start_ids=torch.as_tensor(seq_ids) ) - self.context.cache[req_id]["encoded"]["attention_mask"] = attention_mask - results[req_id] = { - "text": outputs_decoded[idx], - "tokens": outputs.sequences[idx].tolist(), - } - return results + next_tokens = select_tokens(logits) + + return next_tokens, cache_ids.max(dim=1, keepdim=True).values + 1 + + def _run_decode(self, next_tokens, cache_ids, seq_ids): + with torch.inference_mode(): + logits = self.model(next_tokens, cache_ids=cache_ids, start_ids=seq_ids) + next_tokens = select_tokens(logits) + return next_tokens + + def clean_up(self, seq_id, req_id): + # clean up + del self.seq_id_to_req_id[seq_id] + del self.context.cache[req_id] + self.decode_seq_ids[seq_id] = -1 + self.decode_cache_ids[seq_id, :] = torch.zeros( + 1, dtype=torch.int64, device="cpu" + ) + self.decode_next_tokens[seq_id, :] = torch.zeros( + 1, dtype=torch.int64, device="cpu" + ) - def _prepare_model_inputs(self, req_ids): - lengths = [0] * self.batch_size - idx_to_req_id = {} - for req_id in req_ids: - idx = self.context.cache[req_id]["batch_idx"] - lengths[idx] = torch.sum( - self.context.cache[req_id]["encoded"]["attention_mask"] - ).item() - idx_to_req_id[idx] = req_id - - max_len = max(lengths) - logger.info(f"_prepare_model_inputs lengths={lengths}") - - device = next(iter(self.model.parameters())).device - dtype = torch.int64 - config = {"device": device, "dtype": dtype} - - input_ids = [] - attention_mask = [] - - for idx in range(self.batch_size): - seq_len = lengths[idx] - if seq_len > 0: - req_id = idx_to_req_id[idx] - logger.info(f"idx={idx}, seq_len={seq_len}") - input_ids.append(self.context.cache[req_id]["encoded"]["input_ids"]) - attention_mask.append( - self.context.cache[req_id]["encoded"]["attention_mask"] - ) - else: - config = {"device": device, "dtype": dtype} - input_ids.append( - self.tokenizer.pad_token_id + torch.zeros((max_len), **config) - ) - attention_mask.append(torch.zeros((max_len), **config)) - - padded_len = input_ids[-1].size()[-1] - logger.info(f"req_id={req_id}, padded_len={padded_len}, max_len={max_len}") - if padded_len < max_len: - # Apply padding to input_ids, attention_mask and past_key_values - n = max_len - seq_len - input_ids[-1] = torch.concat( - ( - self.tokenizer.pad_token_id + torch.zeros((n), **config), - input_ids[-1], - ) - ) - attention_mask[-1] = torch.concat( - (torch.zeros((n), **config), attention_mask[-1]) - ) - elif padded_len > max_len: - # Truncate padding from input_ids, attention_mask and past_key_values - logger.info(f"padded_len shape={input_ids[-1].size()}") - input_ids[-1] = input_ids[-1][-max_len:] - attention_mask[-1] = attention_mask[-1][-max_len:] - - logger.info(f"input_ids={input_ids}, attention_mask={attention_mask}") - - encoded = { - "input_ids": torch.stack(input_ids), - "attention_mask": torch.stack(attention_mask), - # "input_ids": input_ids, - # "attention_mask": attention_mask, + # add seq_id back to self.empty_seq_ids + self._add_empty_seq_id(seq_id) + + def _update_results(self, results, seq_id, idx, cache_ids, next_tokens): + self.decode_cache_ids[seq_id, :] = cache_ids[idx, :] + self.decode_next_tokens[seq_id, :] = next_tokens[idx, :] + req_id = self._get_req_id(seq_id) + self.seq_id_to_req_id[seq_id] = req_id + results[req_id] = { + "text": self.tokenizer.decode( + next_tokens[idx, -1], skip_special_tokens=True + ), + "tokens": [next_tokens[idx, -1].item()], } - logger.info(f"_prepare_model_inputs encoded={encoded}") - return encoded - def _create_stopping_criteria(self, req_id, max_new_tokens=25): + def _create_stopping_criteria(self, req_id, seq_id, max_new_tokens): class StoppingCriteria(object): def __init__( self, - cache, - batch_empty_ids, + outer, req_id, + seq_id, stop_token, max_new_tokens, ): self.req_id = req_id - self.cache = cache - self.batch_empty_ids = batch_empty_ids + self.seq_id = seq_id + self.outer = outer self.max_new_tokens = max_new_tokens self.stop_token = stop_token def __call__(self, res): self.max_new_tokens -= 1 - if self.max_new_tokens == 0 or res["tokens"][-1] == self.stop_token: - self.clean_up() + if self.max_new_tokens == 0 or res["ids"][-1] == self.stop_token: + self.outer.clean_up(self.req_id, self.seq_id) return True return False - def clean_up(self): - self.batch_empty_ids.append(self.cache[self.req_id]["batch_idx"]) - del self.cache[self.req_id] - return StoppingCriteria( - self.context.cache, - self.batch_empty_ids, - req_id, - self.tokenizer.eos_token_id, - max_new_tokens, + outer=self, + req_id=req_id, + seq_id=seq_id, + stop_token=self.tokenizer.eos_token_id, + max_new_tokens=max_new_tokens, ) From ea28f27f6281355ae63e81345445160dd84eb20a Mon Sep 17 00:00:00 2001 From: lxning Date: Tue, 2 Jan 2024 17:24:29 -0800 Subject: [PATCH 25/49] fmt --- .../inferentia2/llama2/continuous_batching/inf2_handler.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py index ac0233a2d0..bab0b08add 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py +++ b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py @@ -4,6 +4,7 @@ import torch import torch_neuronx from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer +from transformers_neuronx.config import ContinuousBatchingConfig, NeuronConfig from transformers_neuronx.llama.model import LlamaForSampling from transformers_neuronx.module import save_pretrained_split from transformers_neuronx.sampling import select_tokens @@ -98,6 +99,11 @@ def initialize(self, ctx: Context): self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.tokenizer.padding_side = "left" + + continuous_batching_config = ContinuousBatchingConfig( + batch_size_for_shared_caches=self.batch_size + ) + neuron_config = NeuronConfig(continuous_batching=continuous_batching_config) self.model = LlamaForSampling.from_pretrained( model_checkpoint_path, batch_size=self.batch_size, @@ -105,6 +111,7 @@ def initialize(self, ctx: Context): tp_degree=tp_degree, n_positions=self.max_length, context_length_estimate=context_length_estimate, + neuron_config=neuron_config, ) logger.info("Starting to compile the model") self.model.to_neuron() From e54e853c63c707fad24b8b244ec6b8b4a54cec4c Mon Sep 17 00:00:00 2001 From: lxning Date: Fri, 5 Jan 2024 14:05:38 -0800 Subject: [PATCH 26/49] update notebook --- .../continuous_batching/inf2-llama-2-continuous-batching.ipynb | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb b/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb index c262b58d57..3fa6e8da54 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb +++ b/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb @@ -79,6 +79,7 @@ "source": [ "# login in Hugginface hub\n", "!huggingface-cli login --token $HUGGINGFACE_TOKEN\n", + "!python ~/serve/examples/large_models/utils/Download_model.py --model_path model --model_name models--meta-llama--Llama-2-13b-hf --use_auth_token\n", "\n", "# Create TorchServe model artifacts\n", "!torch-model-archiver --model-name llama-2-13b --version 1.0 --handler inf2_handler.py -r requirements.txt --config-file model-config.yaml --archive-format no-archive\n", From 31af68170ffd2b799dd9bf65891365adc1b8fffc Mon Sep 17 00:00:00 2001 From: lxning Date: Fri, 5 Jan 2024 15:58:48 -0800 Subject: [PATCH 27/49] fmt --- .../continuous_batching/inf2_handler.py | 344 +--------------- .../continuous_batching/model-config.yaml | 2 +- ts/handler_utils/utils.py | 11 + ...ase_neuronx_continuous_batching_handler.py | 381 ++++++++++++++++++ 4 files changed, 404 insertions(+), 334 deletions(-) create mode 100644 ts/handler_utils/utils.py create mode 100644 ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py index bab0b08add..f4b8afc711 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py +++ b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py @@ -1,339 +1,17 @@ -import logging -import os +from ts.handler_utils.utils import import_class +from ts.torch_handler.distributed.base_neuronx_continuous_batching_handler import ( + BaseNeuronXContinuousBatchingHandler, +) -import torch -import torch_neuronx -from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer -from transformers_neuronx.config import ContinuousBatchingConfig, NeuronConfig -from transformers_neuronx.llama.model import LlamaForSampling -from transformers_neuronx.module import save_pretrained_split -from transformers_neuronx.sampling import select_tokens -from ts.context import Context -from ts.torch_handler.base_handler import BaseHandler - -logger = logging.getLogger(__name__) - - -class LlamaContinuousBatchingHandler(BaseHandler): +class LlamaContinuousBatchingHandler(BaseNeuronXContinuousBatchingHandler): def __init__(self): - super(LlamaContinuousBatchingHandler, self).__init__() - # the queue of seq_ids which are available for a new request - self.batch_size = 2 - self.max_new_tokens = 25 - self.max_length = 100 - self.tokenizer = None - self.decode_next_tokens = None - self.decode_cache_ids = None - self.decode_seq_ids = None - self.empty_seq_ids = [] - # map seq_id to req_id - self.seq_id_to_req_id = {} - - def initialize(self, ctx: Context): - super().initialize(ctx) - logger.info(f"Initialized {self.__class__}") - - model_dir = ctx.system_properties.get("model_dir") - model_checkpoint_dir = ctx.model_yaml_config.get("handler", {}).get( - "model_checkpoint_dir", "" - ) - model_checkpoint_path = f"{model_dir}/{model_checkpoint_dir}" - model_path = f'{model_dir}/{ctx.model_yaml_config["handler"]["model_path"]}' - - if not os.path.exists(model_checkpoint_path): - # Load and save the CPU model - model_cpu = AutoModelForCausalLM.from_pretrained( - model_path, low_cpu_mem_usage=True - ) - save_pretrained_split(model_cpu, model_checkpoint_path) - # Load and save tokenizer for the model - tokenizer = AutoTokenizer.from_pretrained( - model_path, return_tensors="pt", padding_side="left" - ) - tokenizer.save_pretrained(model_checkpoint_path) - - os.environ["NEURONX_CACHE"] = "on" - os.environ["NEURON_COMPILE_CACHE_URL"] = f"{model_dir}/neuron_cache" - os.environ["NEURON_CC_FLAGS"] = "-O1 --model-type=transformer" - - self.max_length = int( - ctx.model_yaml_config.get("handler", {}).get("max_length", self.max_length) - ) - self.max_new_tokens = int( - ctx.model_yaml_config.get("handler", {}).get( - "max_new_tokens", self.max_new_tokens - ) - ) - self.batch_size = int( - ctx.model_yaml_config.get("handler", {}).get("batch_size", self.batch_size) - ) - - # settings for model compilation and loading - amp = ctx.model_yaml_config.get("handler", {}).get("amp", "fp32") - tp_degree = ctx.model_yaml_config.get("handler", {}).get("tp_degree", 6) - context_length_estimate = ctx.model_yaml_config.get("handler", {}).get( - "context_length_estimate", self.max_length - ) - - # allocate "tp_degree" number of neuron cores to the worker process - os.environ["NEURON_RT_NUM_CORES"] = str(tp_degree) - try: - num_neuron_cores_available = ( - torch_neuronx.xla_impl.data_parallel.device_count() - ) - assert num_neuron_cores_available >= int(tp_degree) - except (RuntimeError, AssertionError) as error: - logger.error( - "Required number of neuron cores for tp_degree " - + str(tp_degree) - + " are not available: " - + str(error) - ) - - raise error - - self.tokenizer = LlamaTokenizer.from_pretrained( - model_checkpoint_path, return_tensors="pt", padding_side="left" - ) - self.tokenizer.pad_token = self.tokenizer.eos_token - self.tokenizer.pad_token_id = self.tokenizer.eos_token_id - self.tokenizer.padding_side = "left" - - continuous_batching_config = ContinuousBatchingConfig( - batch_size_for_shared_caches=self.batch_size - ) - neuron_config = NeuronConfig(continuous_batching=continuous_batching_config) - self.model = LlamaForSampling.from_pretrained( - model_checkpoint_path, - batch_size=self.batch_size, - amp=amp, - tp_degree=tp_degree, - n_positions=self.max_length, - context_length_estimate=context_length_estimate, - neuron_config=neuron_config, - ) - logger.info("Starting to compile the model") - self.model.to_neuron() - logger.info("Model has been successfully compiled") - - # 1D: [seq_id] - # an empty slot if seq_id is -1 - self.decode_seq_ids = torch.full([self.batch_size], -1) - # 2D:[batch_size, next_cache_id] - self.decode_cache_ids = torch.zeros(self.batch_size, 1, dtype=torch.int64) - # 2D: [batch_size, next_token] - self.decode_next_tokens = torch.zeros(self.batch_size, 1, dtype=torch.int64) - - for seq_id, batch_id in enumerate(reversed(range(self.batch_size))): - self.empty_seq_ids.append(batch_id) - - logger.info("Model %s loaded successfully", ctx.model_name) - self.initialized = True - - def preprocess(self, requests): - prefill_req_ids, prefill_seq_ids, prefill_input_text, decode_seq_ids = ( - [], - [], - [], - [], - ) - for req_id, req_data in zip(self.context.request_ids.values(), requests): - if not req_id in self.context.cache: - prefill_req_ids.append(req_id) - seq_id = self._get_empty_seq_id() - prefill_seq_ids.append(seq_id) - - data = req_data["data"] - if isinstance(data, (bytes, bytearray)): - data = data.decode("utf-8") - max_new_tokens = int( - req_data.get("max_new_tokens", self.max_new_tokens) - ) - prefill_input_text.append(data.strip()) - - self.context.cache[req_id] = { - "seq_id": seq_id, - "stopping_criteria": self._create_stopping_criteria( - req_id=req_id, seq_id=seq_id, max_new_tokens=max_new_tokens - ), - } - else: - decode_seq_ids.append(self.context.cache[req_id]["seq_id"]) - - prefill_tokens = None - if len(prefill_req_ids) > 0: - prefill_tokens = self.tokenizer( - prefill_input_text, return_tensors="pt", padding=True - ) - return prefill_tokens, prefill_seq_ids, decode_seq_ids - - def inference(self, inputs): - prefill_tokens, prefill_seq_ids, req_decode_seq_ids = inputs - results = {} - # Test if this is the beginning of a continuous batching - go_to_decode = True if len(req_decode_seq_ids) > 0 else False - if len(prefill_seq_ids) > 0: - prefill_next_tokens, prefill_cache_ids = self._run_prefill( - prefill_tokens, prefill_seq_ids - ) - for i, prefill_seq_id in enumerate(prefill_seq_ids): - self._update_results( - results, prefill_seq_id, i, prefill_cache_ids, prefill_next_tokens - ) - - if go_to_decode: - decode_seq_ids = torch.where(self.decode_seq_ids > -1) - decode_cache_ids = torch.where(self.decode_cache_ids > 0) - decode_next_tokens = torch.where(self.decode_next_tokens > 0) - next_tokens = self._run_decode( - decode_next_tokens, decode_cache_ids, decode_seq_ids - ) - - filter_prefill_seq_ids = ( - torch.isin(decode_seq_ids, torch.as_tensor(prefill_seq_ids)) - if len(prefill_seq_ids) > 0 - else torch.full(decode_seq_ids.shape, False) - ) - - decode_cache_ids = decode_cache_ids + 1 - for i, is_prefill_seq_id in enumerate(filter_prefill_seq_ids): - if not is_prefill_seq_id: - seq_id = decode_seq_ids[i] - if seq_id in req_decode_seq_ids: - self._update_results( - results, seq_id, i, decode_cache_ids, next_tokens - ) - else: - req_id = self._get_req_id(seq_id) - logger.warning( - f"Found request id:{req_id} in cache, but not in batch requests. Delete it" - ) - self.clean_up(self, seq_id, req_id) - - return results - - def postprocess(self, inference_output): - self.context.stopping_criteria = [ - self.batch_store[self._get_seq_id(req_id)]["stopping_criteria"]( - req_id, - ) - for req_id in self.context.request_ids.values() - ] - - return inference_output - - def _get_empty_seq_id(self): - assert len(self.empty_seq_ids) > 0 - return self.empty_seq_ids.pop() - - def _add_empty_seq_id(self, seq_id): - self.empty_seq_ids.append(seq_id) - - def _get_seq_id(self, req_id): - seq_id = None - cache = self.context.cache.get(req_id, None) - if cache: - seq_id = cache["seq_id"] - assert seq_id is not None, "{req_id} must have seq_id" - return seq_id - - def _get_req_id(self, seq_id): - req_id = self.seq_id_to_req_id(seq_id, None) - assert req_id is not None - return req_id - - def _pad_to_max(self, x): - for idx, item in enumerate(x): - x[idx] = x[idx] + [0] * (self.max_length - len(x[idx])) - return x - - def _run_prefill(self, tokens, seq_ids): - input_ids, attention_mask = tokens["input_ids"], tokens["attention_mask"] - logger.info( - f"before padding: input_ids={input_ids}, attention_mask={attention_mask}" - ) - input_ids = self._pad_to_max(input_ids) - attention_mask = self._pad_to_max(attention_mask) - logger.info( - f"after padding: input_ids={input_ids}, attention_mask={attention_mask}" - ) - n_active_seqs, context_len = input_ids.shape - cache_ids = ( - torch.arange(context_len) - .reshape(1, context_len) - .expand(n_active_seqs, context_len) - .mul(attention_mask) - ) - with torch.inference_mode(): - logits = self.model( - input_ids, cache_ids=self.cache_ids, start_ids=torch.as_tensor(seq_ids) - ) - next_tokens = select_tokens(logits) - - return next_tokens, cache_ids.max(dim=1, keepdim=True).values + 1 - - def _run_decode(self, next_tokens, cache_ids, seq_ids): - with torch.inference_mode(): - logits = self.model(next_tokens, cache_ids=cache_ids, start_ids=seq_ids) - next_tokens = select_tokens(logits) - return next_tokens - - def clean_up(self, seq_id, req_id): - # clean up - del self.seq_id_to_req_id[seq_id] - del self.context.cache[req_id] - self.decode_seq_ids[seq_id] = -1 - self.decode_cache_ids[seq_id, :] = torch.zeros( - 1, dtype=torch.int64, device="cpu" + super(BaseNeuronXContinuousBatchingHandler, self).__init__() + self.model_class = import_class( + class_name="llama.model.LlamaForSampling", + module_prefix="transformers_neuronx", ) - self.decode_next_tokens[seq_id, :] = torch.zeros( - 1, dtype=torch.int64, device="cpu" - ) - - # add seq_id back to self.empty_seq_ids - self._add_empty_seq_id(seq_id) - - def _update_results(self, results, seq_id, idx, cache_ids, next_tokens): - self.decode_cache_ids[seq_id, :] = cache_ids[idx, :] - self.decode_next_tokens[seq_id, :] = next_tokens[idx, :] - req_id = self._get_req_id(seq_id) - self.seq_id_to_req_id[seq_id] = req_id - results[req_id] = { - "text": self.tokenizer.decode( - next_tokens[idx, -1], skip_special_tokens=True - ), - "tokens": [next_tokens[idx, -1].item()], - } - - def _create_stopping_criteria(self, req_id, seq_id, max_new_tokens): - class StoppingCriteria(object): - def __init__( - self, - outer, - req_id, - seq_id, - stop_token, - max_new_tokens, - ): - self.req_id = req_id - self.seq_id = seq_id - self.outer = outer - self.max_new_tokens = max_new_tokens - self.stop_token = stop_token - - def __call__(self, res): - self.max_new_tokens -= 1 - - if self.max_new_tokens == 0 or res["ids"][-1] == self.stop_token: - self.outer.clean_up(self.req_id, self.seq_id) - return True - return False - return StoppingCriteria( - outer=self, - req_id=req_id, - seq_id=seq_id, - stop_token=self.tokenizer.eos_token_id, - max_new_tokens=max_new_tokens, + self.tokenizer_class = import_class( + class_name="transformers.LlamaTokenizer", ) diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml b/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml index 897d505c89..11379589bf 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml +++ b/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml @@ -1,6 +1,6 @@ minWorkers: 1 maxWorkers: 1 -maxBatchDelay: 100 +maxBatchDelay: 1 responseTimeout: 10800 batchSize: 8 continuousBatching: true diff --git a/ts/handler_utils/utils.py b/ts/handler_utils/utils.py new file mode 100644 index 0000000000..47e6e006e1 --- /dev/null +++ b/ts/handler_utils/utils.py @@ -0,0 +1,11 @@ +import importlib + + +def import_class(class_name: str, module_prefix=None): + module_name, class_name = class_name.rsplit(".", maxsplit=1) + if module_prefix is not None: + module = importlib.import_module(f"{module_prefix}.{module_name}") + model_class = getattr(module, class_name, None) + if model_class is None: + raise ImportError(f"{class_name} not found in {module_name}.") + return model_class diff --git a/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py b/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py new file mode 100644 index 0000000000..f3b7216be3 --- /dev/null +++ b/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py @@ -0,0 +1,381 @@ +import logging +import os + +import torch +import torch_neuronx +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers_neuronx.config import ContinuousBatchingConfig, NeuronConfig +from transformers_neuronx.module import save_pretrained_split +from transformers_neuronx.sampling import select_tokens + +from ts.context import Context +from ts.torch_handler.base_handler import BaseHandler + +logger = logging.getLogger(__name__) + + +class BaseNeuronXContinuousBatchingHandler(BaseHandler): + def __init__(self): + super(BaseNeuronXContinuousBatchingHandler, self).__init__() + + self.batch_size = 2 + self.max_new_tokens = 25 + self.max_length = 100 + self.tokenizer = None + self.decode_next_tokens = None + self.decode_cache_ids = None + self.decode_seq_ids = None + # the queue of seq_ids which are available for a new request + self.empty_seq_ids = [] + # map seq_id to req_id + self.seq_id_to_req_id = {} + self.model_class = None + self.tokenizer_class = None + + def initialize(self, ctx: Context): + ctx.cache = {} + model_dir = ctx.system_properties.get("model_dir") + model_checkpoint_dir = ctx.model_yaml_config.get("handler", {}).get( + "model_checkpoint_dir", "" + ) + model_checkpoint_path = f"{model_dir}/{model_checkpoint_dir}" + model_path = f'{model_dir}/{ctx.model_yaml_config["handler"]["model_path"]}' + + if not os.path.exists(model_checkpoint_path): + # Load and save the CPU model + model_cpu = AutoModelForCausalLM.from_pretrained( + model_path, low_cpu_mem_usage=True + ) + save_pretrained_split(model_cpu, model_checkpoint_path) + # Load and save tokenizer for the model + tokenizer = AutoTokenizer.from_pretrained( + model_path, return_tensors="pt", padding_side="left" + ) + tokenizer.save_pretrained(model_checkpoint_path) + + os.environ["NEURONX_CACHE"] = "on" + os.environ["NEURON_COMPILE_CACHE_URL"] = f"{model_dir}/neuron_cache" + os.environ["NEURON_CC_FLAGS"] = "-O1 --model-type=transformer" + + self.max_length = int( + ctx.model_yaml_config.get("handler", {}).get("max_length", self.max_length) + ) + self.max_new_tokens = int( + ctx.model_yaml_config.get("handler", {}).get( + "max_new_tokens", self.max_new_tokens + ) + ) + self.batch_size = int( + ctx.model_yaml_config.get("handler", {}).get("batch_size", self.batch_size) + ) + + # settings for model compilation and loading + amp = ctx.model_yaml_config.get("handler", {}).get("amp", "fp32") + tp_degree = ctx.model_yaml_config.get("handler", {}).get("tp_degree", 6) + + # allocate "tp_degree" number of neuron cores to the worker process + os.environ["NEURON_RT_NUM_CORES"] = str(tp_degree) + try: + num_neuron_cores_available = ( + torch_neuronx.xla_impl.data_parallel.device_count() + ) + assert num_neuron_cores_available >= int(tp_degree) + except (RuntimeError, AssertionError) as error: + logger.error( + "Required number of neuron cores for tp_degree " + + str(tp_degree) + + " are not available: " + + str(error) + ) + + raise error + + self.tokenizer = self.tokenizer_class.from_pretrained( + model_checkpoint_path, return_tensors="pt", padding_side="left" + ) + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + self.tokenizer.padding_side = "left" + + continuous_batching_config = ContinuousBatchingConfig( + batch_size_for_shared_caches=self.batch_size + ) + neuron_config = NeuronConfig(continuous_batching=continuous_batching_config) + kwargs = dict( + tp_degree=tp_degree, + amp=amp, + batch_size=self.batch_size, + n_positions=[self.max_length], + context_length_estimate=[self.max_length], + neuron_config=neuron_config, + ) + self.model = self.model_class.from_pretrained(model_checkpoint_path, **kwargs) + logger.info("Starting to compile the model") + self.model.to_neuron() + logger.info("Model has been successfully compiled") + + # 1D: [seq_id] + # an empty slot if seq_id is -1 + self.decode_seq_ids = torch.full([self.batch_size], -1) + # 2D:[batch_size, next_cache_id] + self.decode_cache_ids = torch.zeros(self.batch_size, 1, dtype=torch.int64) + # 2D: [batch_size, next_token] + self.decode_next_tokens = torch.zeros(self.batch_size, 1, dtype=torch.int64) + # self.decode_next_tokens = torch.full(self.batch_size, self.tokenizer.eos_token_id) + + for seq_id, batch_id in enumerate(reversed(range(self.batch_size))): + self.empty_seq_ids.append(batch_id) + + logger.info("Model %s loaded successfully", ctx.model_name) + self.initialized = True + + def preprocess(self, requests): + prefill_req_ids, prefill_seq_ids, prefill_input_text, req_decode_seq_ids = ( + [], + [], + [], + [], + ) + for req_id, req_data in zip(self.context.request_ids.values(), requests): + if not req_id in self.context.cache: + prefill_req_ids.append(req_id) + seq_id = self._get_empty_seq_id() + self.seq_id_to_req_id[seq_id] = req_id + prefill_seq_ids.append(seq_id) + + data = req_data.get("data") or req_data.get("body") + if isinstance(data, (bytes, bytearray)): + data = data.decode("utf-8") + + max_new_tokens = int( + req_data.get("max_new_tokens", self.max_new_tokens) + ) + prefill_input_text.append(data.strip()) + + self.context.cache[req_id] = { + "seq_id": seq_id, + "stopping_criteria": self._create_stopping_criteria( + req_id=req_id, seq_id=seq_id, max_new_tokens=max_new_tokens + ), + } + else: + req_decode_seq_ids.append(self.context.cache[req_id]["seq_id"]) + + prefill_tokens = None + if len(prefill_req_ids) > 0: + prefill_tokens = self.tokenizer( + prefill_input_text, return_tensors="pt", padding=True + ) + return prefill_input_text, prefill_tokens, prefill_seq_ids, req_decode_seq_ids + + def inference(self, inputs): + prefill_input_text, prefill_tokens, prefill_seq_ids, req_decode_seq_ids = inputs + results = {} + # Test if this is the beginning of a continuous batching + go_to_decode = True if len(req_decode_seq_ids) > 0 else False + if len(prefill_seq_ids) > 0: + prefill_next_tokens, prefill_cache_ids = self._run_prefill( + prefill_tokens, prefill_seq_ids + ) + for i, prefill_seq_id in enumerate(prefill_seq_ids): + self._update_results( + results, + prefill_seq_id, + i, + prefill_cache_ids, + prefill_next_tokens, + prefill_tokens=prefill_tokens, + prefill_input_text=prefill_input_text, + ) + + if go_to_decode: + local_decode_seq_ids = torch.cat(torch.where(self.decode_seq_ids > -1)) + local_decode_cache_ids = self.decode_cache_ids[local_decode_seq_ids] + local_decode_next_tokens = self.decode_next_tokens[local_decode_seq_ids] + logger.info( + f"local_decode_seq_ids={local_decode_seq_ids}, local_decode_cache_ids={local_decode_cache_ids}, local_decode_next_tokens={local_decode_next_tokens}" + ) + local_next_tokens = self._run_decode( + local_decode_next_tokens, local_decode_cache_ids, local_decode_seq_ids + ) + + filter_prefill_seq_ids = ( + torch.isin(local_decode_seq_ids, torch.as_tensor(prefill_seq_ids)) + if len(prefill_seq_ids) > 0 + else torch.full(local_decode_seq_ids.shape, False) + ) + + local_decode_cache_ids = local_decode_cache_ids + 1 + for i, is_prefill_seq_id in enumerate(filter_prefill_seq_ids): + if not is_prefill_seq_id: + seq_id = local_decode_seq_ids[i].item() + logger.info( + f"is_prefill_seq_id={is_prefill_seq_id}, seq_id={seq_id}, req_decode_seq_ids={req_decode_seq_ids}" + ) + if seq_id in req_decode_seq_ids: + self._update_results( + results, + seq_id, + i, + local_decode_cache_ids, + local_next_tokens, + ) + else: + req_id = self._get_req_id(seq_id) + logger.warning( + f"Found request id:{req_id} in cache, but not in batch requests. Delete it" + ) + self.clean_up(self, seq_id, req_id) + + return [results[i] for i in self.context.request_ids.values()] + + def postprocess(self, inference_output): + self.context.stopping_criteria = [ + self.context.cache[req_id]["stopping_criteria"] + for req_id in self.context.request_ids.values() + ] + logger.info( + f"inference_output={inference_output}, stopping_criteria={self.context.stopping_criteria}" + ) + + return inference_output + + def _get_empty_seq_id(self): + assert len(self.empty_seq_ids) > 0 + return self.empty_seq_ids.pop() + + def _add_empty_seq_id(self, seq_id): + self.empty_seq_ids.append(seq_id) + + def _get_seq_id(self, req_id): + seq_id = None + cache = self.context.cache.get(req_id, None) + if cache: + seq_id = cache["seq_id"] + assert seq_id is not None, "{req_id} must have seq_id" + return seq_id + + def _get_req_id(self, seq_id): + req_id = self.seq_id_to_req_id.get(seq_id, None) + assert req_id is not None + return req_id + + def _pad_to_max(self, x): + z = torch.empty(x.shape[0], self.max_length, dtype=torch.int64) + for idx, item in enumerate(x): + pad = torch.zeros(self.max_length - len(x[idx]), dtype=torch.int) + z[idx] = torch.cat((x[idx], pad)) + return z + + def _run_prefill(self, tokens, seq_ids): + input_ids, attention_mask = tokens["input_ids"], tokens["attention_mask"] + logger.info( + f"before padding: input_ids={input_ids}, attention_mask={attention_mask}" + ) + input_ids = self._pad_to_max(input_ids) + attention_mask = self._pad_to_max(attention_mask) + logger.info( + f"after padding: input_ids={input_ids}, attention_mask={attention_mask}" + ) + n_active_seqs, context_len = input_ids.shape + cache_ids = ( + torch.arange(context_len) + .reshape(1, context_len) + .expand(n_active_seqs, context_len) + .mul(attention_mask) + ) + with torch.inference_mode(): + logits = self.model( + input_ids, cache_ids=cache_ids, start_ids=torch.as_tensor(seq_ids) + ) + next_tokens = select_tokens(logits) + output_tokens = [[t] for t in next_tokens.flatten().tolist()] + + return next_tokens, cache_ids.max(dim=1, keepdim=True).values + 1 + + def _run_decode(self, next_tokens, cache_ids, seq_ids): + with torch.inference_mode(): + logits = self.model(next_tokens, cache_ids=cache_ids, start_ids=seq_ids) + next_tokens = select_tokens(logits) + output_tokens = [[t] for t in next_tokens.flatten().tolist()] + return next_tokens + + def clean_up(self, seq_id, req_id): + # clean up + del self.seq_id_to_req_id[seq_id] + del self.context.cache[req_id] + self.decode_seq_ids[seq_id] = -1 + self.decode_cache_ids[seq_id, :] = torch.zeros(1, dtype=torch.int64) + # self.decode_next_tokens[seq_id, :] = torch.zeros( + # 1, dtype=torch.int64 + # ) + self.decode_next_tokens[seq_id, :] = torch.tensor( + [self.tokenizer.eos_token_id], dtype=torch.int64 + ) + # add seq_id back to self.empty_seq_ids + self._add_empty_seq_id(seq_id) + + def _update_results( + self, + results, + seq_id, + idx, + cache_ids, + next_tokens, + prefill_tokens=None, + prefill_input_text=None, + ): + self.decode_seq_ids[seq_id] = 0 + self.decode_cache_ids[seq_id, :] = cache_ids[idx, :] + self.decode_next_tokens[seq_id, :] = next_tokens[idx, :] + req_id = self._get_req_id(seq_id) + cur_text = self.tokenizer.decode(next_tokens[idx, :], skip_special_tokens=False) + if not (cur_text.startswith(" ") or cur_text.endswith(" ")): + if prefill_tokens is None: + previous_tokens = self.decode_next_tokens[seq_id, -1] + else: + previous_tokens = prefill_tokens["input_ids"][idx, -1] + + cur_text = self.tokenizer.decode( + torch.cat((torch.tensor([previous_tokens]), next_tokens[idx, :])), + skip_special_tokens=False, + )[len(cur_text) :] + + results[req_id] = { + "text": cur_text + if prefill_input_text is None + else prefill_input_text[idx] + cur_text, + "tokens": [next_tokens[idx, -1].item()], + } + + def _create_stopping_criteria(self, req_id, seq_id, max_new_tokens): + class StoppingCriteria(object): + def __init__( + self, + outer, + req_id, + seq_id, + stop_token, + max_new_tokens, + ): + self.req_id = req_id + self.seq_id = seq_id + self.outer = outer + self.max_new_tokens = max_new_tokens + self.stop_token = stop_token + + def __call__(self, res): + self.max_new_tokens -= 1 + + if self.max_new_tokens == 0 or res["tokens"][-1] == self.stop_token: + self.outer.clean_up(self.seq_id, self.req_id) + return True + return False + + return StoppingCriteria( + outer=self, + req_id=req_id, + seq_id=seq_id, + stop_token=self.tokenizer.eos_token_id, + max_new_tokens=max_new_tokens, + ) From a3ad43a60f3c84259d133bddc9d650fc3c8a20f9 Mon Sep 17 00:00:00 2001 From: lxning Date: Fri, 5 Jan 2024 16:42:28 -0800 Subject: [PATCH 28/49] add handler utils --- ts/handler_utils/utils.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/ts/handler_utils/utils.py b/ts/handler_utils/utils.py index 47e6e006e1..01c0cf8e73 100644 --- a/ts/handler_utils/utils.py +++ b/ts/handler_utils/utils.py @@ -2,10 +2,19 @@ def import_class(class_name: str, module_prefix=None): - module_name, class_name = class_name.rsplit(".", maxsplit=1) + module_name = "" + arr = class_name.rsplit(".", maxsplit=1) + if len(arr) == 2: + module_name, class_name = arr + else: + class_name = arr[0] + if module_prefix is not None: module = importlib.import_module(f"{module_prefix}.{module_name}") + else: + module = importlib.import_module(module_name) + model_class = getattr(module, class_name, None) if model_class is None: - raise ImportError(f"{class_name} not found in {module_name}.") + raise ImportError(f"class:{class_name} not found in module:{module_name}.") return model_class From ae0e7d3a5646a98387c0aa77bd7601c40e1d77df Mon Sep 17 00:00:00 2001 From: lxning Date: Fri, 5 Jan 2024 16:53:04 -0800 Subject: [PATCH 29/49] fix typo --- .../inferentia2/llama2/continuous_batching/inf2_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py index f4b8afc711..6de23d4a29 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py +++ b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py @@ -6,7 +6,7 @@ class LlamaContinuousBatchingHandler(BaseNeuronXContinuousBatchingHandler): def __init__(self): - super(BaseNeuronXContinuousBatchingHandler, self).__init__() + super(LlamaContinuousBatchingHandler, self).__init__() self.model_class = import_class( class_name="llama.model.LlamaForSampling", module_prefix="transformers_neuronx", From 12b34b658c3749c8aa67a0abf72fd698f9c2875c Mon Sep 17 00:00:00 2001 From: lxning Date: Fri, 5 Jan 2024 21:49:23 -0800 Subject: [PATCH 30/49] fmt --- ...ase_neuronx_continuous_batching_handler.py | 30 +++++-------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py b/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py index f3b7216be3..df2cf120b9 100644 --- a/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py +++ b/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py @@ -115,7 +115,7 @@ def initialize(self, ctx: Context): logger.info("Model has been successfully compiled") # 1D: [seq_id] - # an empty slot if seq_id is -1 + # an empty slot if seq_id is -1, otherwise 0 self.decode_seq_ids = torch.full([self.batch_size], -1) # 2D:[batch_size, next_cache_id] self.decode_cache_ids = torch.zeros(self.batch_size, 1, dtype=torch.int64) @@ -137,7 +137,7 @@ def preprocess(self, requests): [], ) for req_id, req_data in zip(self.context.request_ids.values(), requests): - if not req_id in self.context.cache: + if req_id not in self.context.cache: prefill_req_ids.append(req_id) seq_id = self._get_empty_seq_id() self.seq_id_to_req_id[seq_id] = req_id @@ -192,9 +192,7 @@ def inference(self, inputs): local_decode_seq_ids = torch.cat(torch.where(self.decode_seq_ids > -1)) local_decode_cache_ids = self.decode_cache_ids[local_decode_seq_ids] local_decode_next_tokens = self.decode_next_tokens[local_decode_seq_ids] - logger.info( - f"local_decode_seq_ids={local_decode_seq_ids}, local_decode_cache_ids={local_decode_cache_ids}, local_decode_next_tokens={local_decode_next_tokens}" - ) + local_next_tokens = self._run_decode( local_decode_next_tokens, local_decode_cache_ids, local_decode_seq_ids ) @@ -209,9 +207,7 @@ def inference(self, inputs): for i, is_prefill_seq_id in enumerate(filter_prefill_seq_ids): if not is_prefill_seq_id: seq_id = local_decode_seq_ids[i].item() - logger.info( - f"is_prefill_seq_id={is_prefill_seq_id}, seq_id={seq_id}, req_decode_seq_ids={req_decode_seq_ids}" - ) + if seq_id in req_decode_seq_ids: self._update_results( results, @@ -234,9 +230,6 @@ def postprocess(self, inference_output): self.context.cache[req_id]["stopping_criteria"] for req_id in self.context.request_ids.values() ] - logger.info( - f"inference_output={inference_output}, stopping_criteria={self.context.stopping_criteria}" - ) return inference_output @@ -269,14 +262,10 @@ def _pad_to_max(self, x): def _run_prefill(self, tokens, seq_ids): input_ids, attention_mask = tokens["input_ids"], tokens["attention_mask"] - logger.info( - f"before padding: input_ids={input_ids}, attention_mask={attention_mask}" - ) + input_ids = self._pad_to_max(input_ids) attention_mask = self._pad_to_max(attention_mask) - logger.info( - f"after padding: input_ids={input_ids}, attention_mask={attention_mask}" - ) + n_active_seqs, context_len = input_ids.shape cache_ids = ( torch.arange(context_len) @@ -289,7 +278,6 @@ def _run_prefill(self, tokens, seq_ids): input_ids, cache_ids=cache_ids, start_ids=torch.as_tensor(seq_ids) ) next_tokens = select_tokens(logits) - output_tokens = [[t] for t in next_tokens.flatten().tolist()] return next_tokens, cache_ids.max(dim=1, keepdim=True).values + 1 @@ -297,7 +285,7 @@ def _run_decode(self, next_tokens, cache_ids, seq_ids): with torch.inference_mode(): logits = self.model(next_tokens, cache_ids=cache_ids, start_ids=seq_ids) next_tokens = select_tokens(logits) - output_tokens = [[t] for t in next_tokens.flatten().tolist()] + return next_tokens def clean_up(self, seq_id, req_id): @@ -306,9 +294,6 @@ def clean_up(self, seq_id, req_id): del self.context.cache[req_id] self.decode_seq_ids[seq_id] = -1 self.decode_cache_ids[seq_id, :] = torch.zeros(1, dtype=torch.int64) - # self.decode_next_tokens[seq_id, :] = torch.zeros( - # 1, dtype=torch.int64 - # ) self.decode_next_tokens[seq_id, :] = torch.tensor( [self.tokenizer.eos_token_id], dtype=torch.int64 ) @@ -325,6 +310,7 @@ def _update_results( prefill_tokens=None, prefill_input_text=None, ): + # 0: this seq_id is used for decoding if this slot is set 0 self.decode_seq_ids[seq_id] = 0 self.decode_cache_ids[seq_id, :] = cache_ids[idx, :] self.decode_next_tokens[seq_id, :] = next_tokens[idx, :] From c07e8a8b7c57c4f4fb15e6548f588cc1d1bd6529 Mon Sep 17 00:00:00 2001 From: lxning Date: Sat, 6 Jan 2024 14:45:52 -0800 Subject: [PATCH 31/49] fmt --- ...ase_neuronx_continuous_batching_handler.py | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py b/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py index df2cf120b9..95d061c2aa 100644 --- a/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py +++ b/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py @@ -123,8 +123,8 @@ def initialize(self, ctx: Context): self.decode_next_tokens = torch.zeros(self.batch_size, 1, dtype=torch.int64) # self.decode_next_tokens = torch.full(self.batch_size, self.tokenizer.eos_token_id) - for seq_id, batch_id in enumerate(reversed(range(self.batch_size))): - self.empty_seq_ids.append(batch_id) + for _, seq_id in enumerate(reversed(range(self.batch_size))): + self.empty_seq_ids.append(seq_id) logger.info("Model %s loaded successfully", ctx.model_name) self.initialized = True @@ -219,9 +219,9 @@ def inference(self, inputs): else: req_id = self._get_req_id(seq_id) logger.warning( - f"Found request id:{req_id} in cache, but not in batch requests. Delete it" + f"Found request id:{req_id} with seq_id:{seq_id} in local_decode_seq_ids, but not in batch requests. Delete it" ) - self.clean_up(self, seq_id, req_id) + self._clean_up(seq_id, req_id) return [results[i] for i in self.context.request_ids.values()] @@ -234,6 +234,10 @@ def postprocess(self, inference_output): return inference_output def _get_empty_seq_id(self): + if len(self.empty_seq_ids) == 0: + # clean up dead req_ids due to client disconnction + self._clean_dead_reqs() + assert len(self.empty_seq_ids) > 0 return self.empty_seq_ids.pop() @@ -288,7 +292,7 @@ def _run_decode(self, next_tokens, cache_ids, seq_ids): return next_tokens - def clean_up(self, seq_id, req_id): + def _clean_up(self, seq_id, req_id): # clean up del self.seq_id_to_req_id[seq_id] del self.context.cache[req_id] @@ -300,6 +304,14 @@ def clean_up(self, seq_id, req_id): # add seq_id back to self.empty_seq_ids self._add_empty_seq_id(seq_id) + def _clean_dead_reqs(self): + local_decode_seq_ids = torch.cat(torch.where(self.decode_seq_ids > -1)) + for _, seq_id in enumerate(local_decode_seq_ids): + seq_id_value = seq_id.item() + req_id = self._get_req_id(seq_id_value) + if req_id not in self.context.request_ids: + self._clean_up(seq_id_value, req_id) + def _update_results( self, results, @@ -354,7 +366,7 @@ def __call__(self, res): self.max_new_tokens -= 1 if self.max_new_tokens == 0 or res["tokens"][-1] == self.stop_token: - self.outer.clean_up(self.seq_id, self.req_id) + self.outer._clean_up(self.seq_id, self.req_id) return True return False From 6a5867af861eb7a915a6c908a5ba31849e74cd77 Mon Sep 17 00:00:00 2001 From: lxning Date: Sat, 6 Jan 2024 18:11:43 -0800 Subject: [PATCH 32/49] fmt --- .../utils/test_llm_streaming_response.py | 124 ++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 examples/large_models/utils/test_llm_streaming_response.py diff --git a/examples/large_models/utils/test_llm_streaming_response.py b/examples/large_models/utils/test_llm_streaming_response.py new file mode 100644 index 0000000000..e8a88ec114 --- /dev/null +++ b/examples/large_models/utils/test_llm_streaming_response.py @@ -0,0 +1,124 @@ +import argparse +import random +import threading +from queue import Queue + +import requests + + +class Predictor(threading.Thread): + def __init__(self, args, queue): + threading.Thread.__init__(self) + self.args = args + self.queue = queue + + def run(self): + for _ in range(self.args.num_requests_per_thread): + self._predict(self.args, self.queue) + + def _predict(self): + payload = self._format_payload(self.args) + with requests.post(self._get_url(self.args), json=payload) as response: + combined_text = "" + for chunk in response.iter_content(chunk_size=None): + if chunk: + data = chunk.decode("utf-8") + combined_text += data["text"] + + with self.queue.mutex: + self.queue.put_nowait(f'prompt={payload["data"]}\n, output={combined_text}') + + def _get_url(self): + return f"http://localhost:8080/predictions/{self.args.model}" + + def _format_payload(self): + prompt_list = self.args.prompt.split(" ") + r = random.randint(1, len(prompt_list)) + cur_prompt_list = prompt_list[:r] + cur_prompt = " ".join(cur_prompt_list) + return { + "data": cur_prompt, + "max_new_token": random.randint(10, self.args.max_tokens), + } + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-m", + "--model", + type=str, + help="The model to use for generating text. If not specified we will pick the first model from the service as returned by /v1/models", + ) + parser.add_argument( + "-p", + "--prompt-tokens", + env_var="PROMPT_TOKENS", + type=int, + default=512, + help="Length of the prompt in tokens. Default 512", + ) + parser.add_argument( + "--prompt-chars", + env_var="PROMPT_CHARS", + type=int, + help="Length of the prompt in characters.", + ) + parser.add_argument( + "--prompt-text", + env_var="PROMPT_TEXT", + type=str, + help="Prompt text to use instead of generating one. It can be a file reference starting with an ampersand, e.g. `@prompt.txt`", + ) + parser.add_argument( + "--prompt-randomize", + action=argparse.BooleanOptionalAction, + default=False, + help="Include a few random numbers in the generated prompt to avoid caching", + ) + parser.add_argument( + "-o", + "--max-tokens", + env_var="MAX_TOKENS", + type=int, + default=64, + help="Max number of tokens to generate.", + ) + parser.add_argument( + "-t", + "--num-threads", + type=int, + default=1, + help="Enable the number of threads to execute prediction", + ) + parser.add_argument( + "-n", + "--num-requests-per-thread", + type=int, + default=1, + help="Execute the number of prediction in each thread", + ) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + queue = Queue() + predictors = [] + for i in range(args.num_threads): + predictor = Predictor(args, queue) + predictor.start() + predictors.append(predictor) + + for predictor in predictors: + predictor.join() + + while not queue.empty(): + print(queue.get()) + + +if __name__ == "__main__": + main() From e0e8bae05de59b51cba1478480dbc40cb20c15f0 Mon Sep 17 00:00:00 2001 From: lxning Date: Mon, 8 Jan 2024 15:18:12 -0800 Subject: [PATCH 33/49] fmt --- .../utils/test_llm_streaming_response.py | 68 ++++++++++--------- ...ase_neuronx_continuous_batching_handler.py | 16 +++-- 2 files changed, 48 insertions(+), 36 deletions(-) diff --git a/examples/large_models/utils/test_llm_streaming_response.py b/examples/large_models/utils/test_llm_streaming_response.py index e8a88ec114..c1a13eb257 100644 --- a/examples/large_models/utils/test_llm_streaming_response.py +++ b/examples/large_models/utils/test_llm_streaming_response.py @@ -1,8 +1,10 @@ import argparse +import json import random import threading from queue import Queue +import orjson import requests @@ -14,34 +16,53 @@ def __init__(self, args, queue): def run(self): for _ in range(self.args.num_requests_per_thread): - self._predict(self.args, self.queue) + self._predict() def _predict(self): - payload = self._format_payload(self.args) - with requests.post(self._get_url(self.args), json=payload) as response: + payload = self._format_payload() + with requests.post( + self._get_url(), json=json.dumps(payload), stream=True + ) as response: combined_text = "" for chunk in response.iter_content(chunk_size=None): if chunk: - data = chunk.decode("utf-8") + data = orjson.loads(chunk) combined_text += data["text"] - - with self.queue.mutex: - self.queue.put_nowait(f'prompt={payload["data"]}\n, output={combined_text}') + self.queue.put_nowait(f"payload={payload}\n, output={combined_text}\n") def _get_url(self): return f"http://localhost:8080/predictions/{self.args.model}" def _format_payload(self): - prompt_list = self.args.prompt.split(" ") - r = random.randint(1, len(prompt_list)) - cur_prompt_list = prompt_list[:r] + prompt = _load_curl_like_data(self.args.prompt_text) + prompt_list = prompt.split(" ") + rp = len(prompt_list) + rt = self.args.max_tokens + if self.args.prompt_randomize: + rp = random.randint(1, len(prompt_list)) + rt = random.randint(10, self.args.max_tokens) + cur_prompt_list = prompt_list[:rp] cur_prompt = " ".join(cur_prompt_list) return { - "data": cur_prompt, - "max_new_token": random.randint(10, self.args.max_tokens), + "prompt": cur_prompt, + "max_new_tokens": rt, } +def _load_curl_like_data(text): + """ + Either use the passed string or load from a file if the string is `@filename` + """ + if text.startswith("@"): + try: + with open(text[1:], "r") as f: + return f.read() + except Exception as e: + raise ValueError(f"Failed to read file {text[1:]}") from e + else: + return text + + def parse_args(): parser = argparse.ArgumentParser() @@ -51,23 +72,8 @@ def parse_args(): type=str, help="The model to use for generating text. If not specified we will pick the first model from the service as returned by /v1/models", ) - parser.add_argument( - "-p", - "--prompt-tokens", - env_var="PROMPT_TOKENS", - type=int, - default=512, - help="Length of the prompt in tokens. Default 512", - ) - parser.add_argument( - "--prompt-chars", - env_var="PROMPT_CHARS", - type=int, - help="Length of the prompt in characters.", - ) parser.add_argument( "--prompt-text", - env_var="PROMPT_TEXT", type=str, help="Prompt text to use instead of generating one. It can be a file reference starting with an ampersand, e.g. `@prompt.txt`", ) @@ -80,7 +86,6 @@ def parse_args(): parser.add_argument( "-o", "--max-tokens", - env_var="MAX_TOKENS", type=int, default=64, help="Max number of tokens to generate.", @@ -100,8 +105,7 @@ def parse_args(): help="Execute the number of prediction in each thread", ) - args = parser.parse_args() - return args + return parser.parse_args() def main(): @@ -110,12 +114,14 @@ def main(): predictors = [] for i in range(args.num_threads): predictor = Predictor(args, queue) - predictor.start() predictors.append(predictor) + predictor.start() for predictor in predictors: predictor.join() + print("Tasks are completed") + while not queue.empty(): print(queue.get()) diff --git a/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py b/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py index 95d061c2aa..fd21004344 100644 --- a/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py +++ b/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py @@ -1,6 +1,7 @@ import logging import os +import orjson import torch import torch_neuronx from transformers import AutoModelForCausalLM, AutoTokenizer @@ -147,10 +148,13 @@ def preprocess(self, requests): if isinstance(data, (bytes, bytearray)): data = data.decode("utf-8") - max_new_tokens = int( - req_data.get("max_new_tokens", self.max_new_tokens) + data = orjson.loads(data) + prompt = data.get("prompt") + max_new_tokens = int(data.get("max_new_tokens", self.max_new_tokens)) + logger.info( + "preprocess prompt={prompt}, max_new_tokens={max_new_tokens}" ) - prefill_input_text.append(data.strip()) + prefill_input_text.append(prompt.strip()) self.context.cache[req_id] = { "seq_id": seq_id, @@ -334,10 +338,12 @@ def _update_results( else: previous_tokens = prefill_tokens["input_ids"][idx, -1] - cur_text = self.tokenizer.decode( + text = self.tokenizer.decode( torch.cat((torch.tensor([previous_tokens]), next_tokens[idx, :])), skip_special_tokens=False, - )[len(cur_text) :] + ) + if text[: -len(cur_text)].endswith(" "): + cur_text = " " + cur_text results[req_id] = { "text": cur_text From 32adb9073653b8da6339dffec066f501021e547b Mon Sep 17 00:00:00 2001 From: lxning Date: Mon, 8 Jan 2024 16:57:21 -0800 Subject: [PATCH 34/49] fmt --- .../llama2/continuous_batching/Dockerfile | 11 --- .../llama2/continuous_batching/Readme.md | 15 ++- .../inf2-llama-2-continuous-batching.ipynb | 97 ++++++------------- .../llama2/continuous_batching/test.sh | 5 - 4 files changed, 43 insertions(+), 85 deletions(-) delete mode 100644 examples/large_models/inferentia2/llama2/continuous_batching/Dockerfile delete mode 100755 examples/large_models/inferentia2/llama2/continuous_batching/test.sh diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/Dockerfile b/examples/large_models/inferentia2/llama2/continuous_batching/Dockerfile deleted file mode 100644 index 31c7648eec..0000000000 --- a/examples/large_models/inferentia2/llama2/continuous_batching/Dockerfile +++ /dev/null @@ -1,11 +0,0 @@ -FROM 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference-neuronx:1.13.1-neuronx-py310-sdk2.15.0-ubuntu20.04 -# workaround transformers 4.35 failure of downloading HF model bin files -RUN pip install transformers==4.34 -WORKDIR /home/model-server -RUN git clone https://github.com/pytorch/serve.git \ - && cd serve -WORKDIR /home/model-server/serve -RUN git checkout feat/inf2_cb -RUN pip install pygit2 -RUN python ts_scripts/install_from_src.py -ENV TS_INSTALL_PY_DEP_PER_MODEL true diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md b/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md index 3aa764c0b5..cb590e8b14 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md +++ b/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md @@ -1,8 +1,19 @@ # Demo2: Llama-2 Using TorchServe continuous batching on inf2 -This document briefs on serving the [Llama 2](https://huggingface.co/meta-llama) model on [AWS Inferentia2](https://aws.amazon.com/ec2/instance-types/inf2/) for text completion with TorchServe [micro batching](https://github.com/pytorch/serve/tree/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/examples/micro_batching) and [streaming response](https://github.com/pytorch/serve/blob/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/docs/inference_api.md#curl-example-1) support. +This document briefs on serving the [Llama 2](https://huggingface.co/meta-llama) model on [AWS transformers-neuronx continuous batching](https://aws.amazon.com/ec2/instance-types/inf2/). -**Note**: To run the model on an Inf2 instance, the model gets compiled as a preprocessing step. As part of the compilation process, to generate the model graph, a specific batch size is used. Following this, when running inference, we need to pass input which matches the batch size that was used during compilation. Model compilation and input padding to match compiled model batch size is taken care of by the [custom handler](inf2_handler.py) in this example. +This example can also be extended to support the following models. + + +| Model | Model Class | +| :--- | :----: | +| opt | opt.model.OPTForSampling | +| gpt2 | gpt2.model.GPT2ForSampling | +| gptj | gptj.model.GPTJForSampling | +| gpt_neox | gptneox.model.GPTNeoXForSampling | +| llama | lama.model.LlamaForSampling | +| mistral | mistral.model.MistralForSampling | +| bloom | bloom.model.BloomForSampling | The batch size [model-config.yaml](model-config.yaml). The batch size indicates the maximum number of requests torchserve will aggregate and send to the custom handler within the batch delay. It is the batch size used for the Inf2 model compilation. Since compilation batch size can influence compile time and also constrained by the Inf2 instance type, this is chosen to be a relatively smaller value, say 4. diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb b/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb index 3fa6e8da54..672d268bc5 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb +++ b/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "source": [ "## TorchServe Continuous Batching Serve Llama-2 on Inferentia-2\n", - "This notebook demonstrates TorchServe continuous batching serving Llama-2-13b on Inferentia-2 `inf2.24xlarge`." + "This notebook demonstrates TorchServe continuous batching serving Llama-2-13b on Inferentia-2 `inf2.24xlarge` with DLAMI: Deep Learning AMI Neuron PyTorch 1.13 (Ubuntu 20.04) 20231226" ], "metadata": { "collapsed": false @@ -13,8 +13,8 @@ { "cell_type": "markdown", "source": [ - "### Build a customized docker container to install the code changes from this [PR](https://github.com/pytorch/serve/pull/2803).\n", - "This section can be skipped once [Neuron DLC](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#neuron-containers) release TorchServe latest version." + "### Installation\n", + "Note: This section can be skipped once [Neuron DLC](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#neuron-containers) release TorchServe latest version." ], "metadata": { "collapsed": false @@ -25,48 +25,36 @@ "execution_count": null, "outputs": [], "source": [ - "!aws ecr get-login-password --region us-east-1 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-east-1.amazonaws.com\n", - "!docker pull 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference-neuronx:1.13.1-neuronx-py310-sdk2.15.0-ubuntu20.04" - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [ - "!cat Dockerfile" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "start_time": "2023-11-21T22:05:30.551799Z", - "end_time": "2023-11-21T22:05:30.698105Z" - } - } - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [ - "!docker build -t neuron-sdk-215:torchserve-cb ." + "# Install Python venv\n", + "!sudo apt-get install -y python3.9-venv g++\n", + "\n", + "# Create Python venv\n", + "!python3.9 -m venv aws_neuron_venv_pytorch\n", + "\n", + "# Activate Python venv\n", + "!source aws_neuron_venv_pytorch/bin/activate\n", + "!python -m pip install -U pip\n", + "\n", + "# Clone Torchserve git repository\n", + "!git clone https://github.com/pytorch/serve.git\n", + "\n", + "# Install dependencies\n", + "!python ~/serve/ts_scripts/install_dependencies.py --neuronx --environment=dev\n", + "\n", + "# Install torchserve and torch-model-archiver\n", + "python ts_scripts/install_from_src.py" ], "metadata": { "collapsed": false } }, { - "cell_type": "code", - "execution_count": null, - "outputs": [], + "cell_type": "markdown", "source": [ - "# Enter into docker container\n", - "!mkdir model_store\n", + "### Create model artifacts\n", "\n", - "!docker run -it -v model_store:/home/model-server/model_store --device /dev/neuron0:/dev/neuron0 --device /dev/neuron1:/dev/neuron1 --device /dev/neuron2:/dev/neuron2 --device /dev/neuron3:/dev/neuron3 --device /dev/neuron4:/dev/neuron4 --device /dev/neuron5:/dev/neuron5 neuron-sdk-215:torchserve-cb bash" + "Note: run `mv model/models--meta-llama--Llama-2-13b-hf/snapshots/dc1d3b3bfdb69df26f8fc966c16353274b138c55/model.safetensors.index.json model/models--meta-llama--Llama-2-13b-hf/snapshots/dc1d3b3bfdb69df26f8fc966c16353274b138c55/model.safetensors.index.json.bkp`\n", + " if neuron sdk does not support safetensors" ], "metadata": { "collapsed": false @@ -83,13 +71,12 @@ "\n", "# Create TorchServe model artifacts\n", "!torch-model-archiver --model-name llama-2-13b --version 1.0 --handler inf2_handler.py -r requirements.txt --config-file model-config.yaml --archive-format no-archive\n", - "!mkdir -p /home/model-server/model_store\n", - "!mv llama-2-13b /home/model-server/model_store\n", + "!mv model llama-2-13b\n", + "!mkdir -p ~/serve/model_store\n", + "!mv ~/serve/llama-2-13b /home/model-server/model_store\n", "\n", "# Precompile complete once the log \"Model llama-2-13b loaded successfully\"\n", - "torchserve --ncs --start --model-store /home/model-server/model_store --models llama-2-13b --ts-config ../config.properties\n", - "\n", - "# Exit the container" + "torchserve --ncs --start --model-store /home/model-server/model_store --models llama-2-13b --ts-config ../config.properties" ], "metadata": { "collapsed": false @@ -104,37 +91,13 @@ "collapsed": false } }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [ - "# Start the container\n", - "!docker run -it -v model_store:/opt/ml/model --device /dev/neuron0:/dev/neuron0 --device /dev/neuron1:/dev/neuron1 --device /dev/neuron2:/dev/neuron2 --device /dev/neuron3:/dev/neuron3 --device /dev/neuron4:/dev/neuron4 --device /dev/neuron5:/dev/neuron5 -p 8080:8080 -p 8081:8081 -p 8082:8082 neuron-sdk-215:torchserve-cb" - ], - "metadata": { - "collapsed": false - } - }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "# Run single inference request\n", - "!python test_stream_response.py" - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [ - "# Run multiple inference requests concurrently\n", - "!./tesh.sh" + "!python ~/serve/examples/large_models/utils/test_llm_streaming_response.py -m llama-2-13b -o 50 -t 2 -n 4 --prompt-text \"Today the weather is really nice and I am planning on \" --prompt-randomize" ], "metadata": { "collapsed": false diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/test.sh b/examples/large_models/inferentia2/llama2/continuous_batching/test.sh deleted file mode 100755 index 14de9eb993..0000000000 --- a/examples/large_models/inferentia2/llama2/continuous_batching/test.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -for i in {1..64}; do - python ../test_stream_response.py > t_${i} & -done From e83d58c261ce77d3f593454aa6215feb96fcef2e Mon Sep 17 00:00:00 2001 From: lxning Date: Mon, 8 Jan 2024 19:31:49 -0800 Subject: [PATCH 35/49] Fix lint --- ts_scripts/spellcheck_conf/wordlist.txt | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt index 8724dd70a4..178b5c5acf 100644 --- a/ts_scripts/spellcheck_conf/wordlist.txt +++ b/ts_scripts/spellcheck_conf/wordlist.txt @@ -1154,3 +1154,13 @@ compilable nightlies torchexportaotcompile autotune +BloomForSampling +ForSampling +GPTJForSampling +GPTNeoXForSampling +LlamaForSampling +MistralForSampling +OPTForSampling +gptj +gptneox +neox From 3669a0d119425830e92214706e4c8bb53da50a6c Mon Sep 17 00:00:00 2001 From: lxning Date: Mon, 8 Jan 2024 20:11:57 -0800 Subject: [PATCH 36/49] fix typo in notebook example --- .../continuous_batching/inf2-llama-2-continuous-batching.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb b/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb index 672d268bc5..6d5b7d71bc 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb +++ b/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb @@ -67,7 +67,7 @@ "source": [ "# login in Hugginface hub\n", "!huggingface-cli login --token $HUGGINGFACE_TOKEN\n", - "!python ~/serve/examples/large_models/utils/Download_model.py --model_path model --model_name models--meta-llama--Llama-2-13b-hf --use_auth_token\n", + "!python ~/serve/examples/large_models/utils/Download_model.py --model_path model --model_name meta-llama/Llama-2-13b-hf --use_auth_token\n", "\n", "# Create TorchServe model artifacts\n", "!torch-model-archiver --model-name llama-2-13b --version 1.0 --handler inf2_handler.py -r requirements.txt --config-file model-config.yaml --archive-format no-archive\n", From 642d59a871b72c9d876225d69f2606c83fa7f4bd Mon Sep 17 00:00:00 2001 From: lxning Date: Mon, 8 Jan 2024 20:17:50 -0800 Subject: [PATCH 37/49] enable authentication --- .../continuous_batching/inf2-llama-2-continuous-batching.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb b/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb index 6d5b7d71bc..c5f2470594 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb +++ b/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb @@ -67,7 +67,7 @@ "source": [ "# login in Hugginface hub\n", "!huggingface-cli login --token $HUGGINGFACE_TOKEN\n", - "!python ~/serve/examples/large_models/utils/Download_model.py --model_path model --model_name meta-llama/Llama-2-13b-hf --use_auth_token\n", + "!python ~/serve/examples/large_models/utils/Download_model.py --model_path model --model_name meta-llama/Llama-2-13b-hf --use_auth_token True\n", "\n", "# Create TorchServe model artifacts\n", "!torch-model-archiver --model-name llama-2-13b --version 1.0 --handler inf2_handler.py -r requirements.txt --config-file model-config.yaml --archive-format no-archive\n", From 1c6b2119b7cfd77170b994a0a3a94a4ca23a7e7e Mon Sep 17 00:00:00 2001 From: lxning Date: Wed, 10 Jan 2024 15:08:28 -0800 Subject: [PATCH 38/49] fmt --- examples/large_models/utils/test_llm_streaming_response.py | 5 +---- .../distributed/base_neuronx_continuous_batching_handler.py | 2 -- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/examples/large_models/utils/test_llm_streaming_response.py b/examples/large_models/utils/test_llm_streaming_response.py index c1a13eb257..bc31b9da8b 100644 --- a/examples/large_models/utils/test_llm_streaming_response.py +++ b/examples/large_models/utils/test_llm_streaming_response.py @@ -1,5 +1,4 @@ import argparse -import json import random import threading from queue import Queue @@ -20,9 +19,7 @@ def run(self): def _predict(self): payload = self._format_payload() - with requests.post( - self._get_url(), json=json.dumps(payload), stream=True - ) as response: + with requests.post(self._get_url(), json=payload, stream=True) as response: combined_text = "" for chunk in response.iter_content(chunk_size=None): if chunk: diff --git a/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py b/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py index fd21004344..f936a958d3 100644 --- a/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py +++ b/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py @@ -1,7 +1,6 @@ import logging import os -import orjson import torch import torch_neuronx from transformers import AutoModelForCausalLM, AutoTokenizer @@ -148,7 +147,6 @@ def preprocess(self, requests): if isinstance(data, (bytes, bytearray)): data = data.decode("utf-8") - data = orjson.loads(data) prompt = data.get("prompt") max_new_tokens = int(data.get("max_new_tokens", self.max_new_tokens)) logger.info( From 78ef0ae4f54ecb9c68d76914fe1f4a829f03d6ed Mon Sep 17 00:00:00 2001 From: lxning Date: Mon, 22 Jan 2024 13:33:10 -0800 Subject: [PATCH 39/49] fmt --- ...ase_neuronx_continuous_batching_handler.py | 33 ++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py b/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py index f936a958d3..3ca3dc8869 100644 --- a/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py +++ b/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py @@ -9,6 +9,7 @@ from transformers_neuronx.sampling import select_tokens from ts.context import Context +from ts.handler_utils.utils import import_class from ts.torch_handler.base_handler import BaseHandler logger = logging.getLogger(__name__) @@ -101,12 +102,15 @@ def initialize(self, ctx: Context): batch_size_for_shared_caches=self.batch_size ) neuron_config = NeuronConfig(continuous_batching=continuous_batching_config) + self._set_class(ctx) kwargs = dict( tp_degree=tp_degree, amp=amp, batch_size=self.batch_size, n_positions=[self.max_length], - context_length_estimate=[self.max_length], + context_length_estimate=ctx.model_yaml_config.get("handler", {}).get( + "model_checkpoint_dir", [self.max_length] + ), neuron_config=neuron_config, ) self.model = self.model_class.from_pretrained(model_checkpoint_path, **kwargs) @@ -381,3 +385,30 @@ def __call__(self, res): stop_token=self.tokenizer.eos_token_id, max_new_tokens=max_new_tokens, ) + + def _set_class(self, ctx: Context): + model_class_name = ctx.model_yaml_config.get("handler", {}).get( + "model_class_name", None + ) + self.assertIsNotNone(model_class_name, "model_class_name is not defined") + model_module_prefix = ctx.model_yaml_config.get("handler", {}).get( + "model_module_prefix", None + ) + self.model_class = import_class( + class_name=model_class_name, + module_prefix=model_module_prefix, + ) + + tokenizer_class_name = ctx.model_yaml_config.get("handler", {}).get( + "tokenizer_class_name", None + ) + self.assertIsNotNone( + tokenizer_class_name, "tokenizer_class_name is not defined" + ) + tokenizer_module_prefix = ctx.model_yaml_config.get("handler", {}).get( + "tokenizer_module_prefix", None + ) + + self.tokenizer_class = import_class( + class_name=tokenizer_class_name, module_prefix=tokenizer_module_prefix + ) From ded1c26a8d176d21b261b1c067dda37029fb3e04 Mon Sep 17 00:00:00 2001 From: lxning Date: Tue, 23 Jan 2024 15:00:25 -0800 Subject: [PATCH 40/49] fmt --- .../inferentia2/llama2/continuous_batching/Readme.md | 6 +----- .../llama2/continuous_batching/model-config.yaml | 9 ++++++--- .../base_neuronx_continuous_batching_handler.py | 6 ++++-- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md b/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md index cb590e8b14..1d6e8e75eb 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md +++ b/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md @@ -7,13 +7,9 @@ This example can also be extended to support the following models. | Model | Model Class | | :--- | :----: | -| opt | opt.model.OPTForSampling | -| gpt2 | gpt2.model.GPT2ForSampling | -| gptj | gptj.model.GPTJForSampling | -| gpt_neox | gptneox.model.GPTNeoXForSampling | | llama | lama.model.LlamaForSampling | | mistral | mistral.model.MistralForSampling | -| bloom | bloom.model.BloomForSampling | + The batch size [model-config.yaml](model-config.yaml). The batch size indicates the maximum number of requests torchserve will aggregate and send to the custom handler within the batch delay. It is the batch size used for the Inf2 model compilation. Since compilation batch size can influence compile time and also constrained by the Inf2 instance type, this is chosen to be a relatively smaller value, say 4. diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml b/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml index 11379589bf..0210b10d89 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml +++ b/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml @@ -1,6 +1,6 @@ minWorkers: 1 maxWorkers: 1 -maxBatchDelay: 1 +maxBatchDelay: 0 responseTimeout: 10800 batchSize: 8 continuousBatching: true @@ -8,8 +8,11 @@ continuousBatching: true handler: model_path: "model/models--meta-llama--Llama-2-13b-hf/snapshots/dc1d3b3bfdb69df26f8fc966c16353274b138c55" model_checkpoint_dir: "llama-2-13b-split" - amp: "bf16" - tp_degree: 12 + model_module_prefix: "transformers_neuronx" + model_class_name: "llama.model.LlamaForSampling" + tokenizer_module_prefix: "transformers.LlamaTokenizer" + amp: "f16" + tp_degree: 24 max_length: 100 max_new_tokens: 50 batch_size: 8 diff --git a/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py b/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py index 3ca3dc8869..667a09cfb6 100644 --- a/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py +++ b/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py @@ -56,7 +56,9 @@ def initialize(self, ctx: Context): os.environ["NEURONX_CACHE"] = "on" os.environ["NEURON_COMPILE_CACHE_URL"] = f"{model_dir}/neuron_cache" - os.environ["NEURON_CC_FLAGS"] = "-O1 --model-type=transformer" + os.environ[ + "NEURON_CC_FLAGS" + ] = "-O1 --model-type=transformer --enable-mixed-precision-accumulation" self.max_length = int( ctx.model_yaml_config.get("handler", {}).get("max_length", self.max_length) @@ -109,7 +111,7 @@ def initialize(self, ctx: Context): batch_size=self.batch_size, n_positions=[self.max_length], context_length_estimate=ctx.model_yaml_config.get("handler", {}).get( - "model_checkpoint_dir", [self.max_length] + "context_length_estimate", [self.max_length] ), neuron_config=neuron_config, ) From 4bd9b8e5b18da3b77431edca26f698b8fed98bee Mon Sep 17 00:00:00 2001 From: lxning Date: Tue, 23 Jan 2024 17:56:40 -0800 Subject: [PATCH 41/49] update readme --- .../llama2/continuous_batching/Readme.md | 8 +++- .../inf2-llama-2-continuous-batching.ipynb | 48 +++++++++++++------ .../continuous_batching/inf2_handler.py | 17 ------- .../continuous_batching/model-config.yaml | 13 ++--- ...ase_neuronx_continuous_batching_handler.py | 15 ++---- 5 files changed, 52 insertions(+), 49 deletions(-) delete mode 100644 examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md b/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md index 1d6e8e75eb..a863629cc1 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md +++ b/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md @@ -2,8 +2,12 @@ This document briefs on serving the [Llama 2](https://huggingface.co/meta-llama) model on [AWS transformers-neuronx continuous batching](https://aws.amazon.com/ec2/instance-types/inf2/). -This example can also be extended to support the following models. - +This example can also be extended to support Mistral without code changes. Customers only set the following items in model-config.yaml. For example: +* model_path: "model/models--meta-llama--Llama-2-70b-hf/snapshots/90052941a64de02075ca800b09fcea1bdaacb939" +* model_checkpoint_dir: "llama-2-70b-split" +* model_module_prefix: "transformers_neuronx" +* model_class_name: "llama.model.LlamaForSampling" +* tokenizer_class_name: "transformers.LlamaTokenizer" | Model | Model Class | | :--- | :----: | diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb b/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb index c5f2470594..555ef57559 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb +++ b/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb @@ -3,8 +3,8 @@ { "cell_type": "markdown", "source": [ - "## TorchServe Continuous Batching Serve Llama-2 on Inferentia-2\n", - "This notebook demonstrates TorchServe continuous batching serving Llama-2-13b on Inferentia-2 `inf2.24xlarge` with DLAMI: Deep Learning AMI Neuron PyTorch 1.13 (Ubuntu 20.04) 20231226" + "## TorchServe Continuous Batching Serve Llama-2-70B on Inferentia-2\n", + "This notebook demonstrates TorchServe continuous batching serving Llama-2-70b on Inferentia-2 `inf2.48xlarge` with DLAMI: Deep Learning AMI Neuron PyTorch 1.13 (Ubuntu 20.04) 20231226" ], "metadata": { "collapsed": false @@ -14,7 +14,7 @@ "cell_type": "markdown", "source": [ "### Installation\n", - "Note: This section can be skipped once [Neuron DLC](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#neuron-containers) release TorchServe latest version." + "Note: This section can be skipped once Neuron DLC 2.16 with TorchServe latest version is released." ], "metadata": { "collapsed": false @@ -38,8 +38,10 @@ "# Clone Torchserve git repository\n", "!git clone https://github.com/pytorch/serve.git\n", "\n", - "# Install dependencies\n", - "!python ~/serve/ts_scripts/install_dependencies.py --neuronx --environment=dev\n", + "# Install dependencies, now all commands run under serve dir\n", + "!cd serve\n", + "!git checkout feat/inf2_cb\n", + "!python ts_scripts/install_dependencies.py --neuronx --environment=dev\n", "\n", "# Install torchserve and torch-model-archiver\n", "python ts_scripts/install_from_src.py" @@ -53,7 +55,7 @@ "source": [ "### Create model artifacts\n", "\n", - "Note: run `mv model/models--meta-llama--Llama-2-13b-hf/snapshots/dc1d3b3bfdb69df26f8fc966c16353274b138c55/model.safetensors.index.json model/models--meta-llama--Llama-2-13b-hf/snapshots/dc1d3b3bfdb69df26f8fc966c16353274b138c55/model.safetensors.index.json.bkp`\n", + "Note: run `mv model/models--meta-llama--Llama-2-70b-hf/snapshots/90052941a64de02075ca800b09fcea1bdaacb939/model.safetensors.index.json model/models--meta-llama--Llama-2-70b-hf/snapshots/90052941a64de02075ca800b09fcea1bdaacb939/model.safetensors.index.json.bkp`\n", " if neuron sdk does not support safetensors" ], "metadata": { @@ -67,16 +69,34 @@ "source": [ "# login in Hugginface hub\n", "!huggingface-cli login --token $HUGGINGFACE_TOKEN\n", - "!python ~/serve/examples/large_models/utils/Download_model.py --model_path model --model_name meta-llama/Llama-2-13b-hf --use_auth_token True\n", + "!python examples/large_models/utils/Download_model.py --model_path model --model_name meta-llama/Llama-2-13b-hf --use_auth_token True\n", "\n", "# Create TorchServe model artifacts\n", - "!torch-model-archiver --model-name llama-2-13b --version 1.0 --handler inf2_handler.py -r requirements.txt --config-file model-config.yaml --archive-format no-archive\n", - "!mv model llama-2-13b\n", - "!mkdir -p ~/serve/model_store\n", - "!mv ~/serve/llama-2-13b /home/model-server/model_store\n", + "!torch-model-archiver --model-name llama-2-70b --version 1.0 --handler ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py -r requirements.txt --config-file examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml --archive-format no-archive\n", "\n", - "# Precompile complete once the log \"Model llama-2-13b loaded successfully\"\n", - "torchserve --ncs --start --model-store /home/model-server/model_store --models llama-2-13b --ts-config ../config.properties" + "!mkdir -p model_store\n", + "!mv llama-2-70b model_store\n", + "!mv model model_store/llama-2-70b" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "### Start TorchServe" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "torchserve --ncs --start --model-store model_store --models llama-2-70b --ts-config examples/large_models/inferentia2/llama2/config.properties" ], "metadata": { "collapsed": false @@ -97,7 +117,7 @@ "outputs": [], "source": [ "# Run single inference request\n", - "!python ~/serve/examples/large_models/utils/test_llm_streaming_response.py -m llama-2-13b -o 50 -t 2 -n 4 --prompt-text \"Today the weather is really nice and I am planning on \" --prompt-randomize" + "!python examples/large_models/utils/test_llm_streaming_response.py -m llama-2-70b -o 50 -t 2 -n 4 --prompt-text \"Today the weather is really nice and I am planning on \" --prompt-randomize" ], "metadata": { "collapsed": false diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py b/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py deleted file mode 100644 index 6de23d4a29..0000000000 --- a/examples/large_models/inferentia2/llama2/continuous_batching/inf2_handler.py +++ /dev/null @@ -1,17 +0,0 @@ -from ts.handler_utils.utils import import_class -from ts.torch_handler.distributed.base_neuronx_continuous_batching_handler import ( - BaseNeuronXContinuousBatchingHandler, -) - - -class LlamaContinuousBatchingHandler(BaseNeuronXContinuousBatchingHandler): - def __init__(self): - super(LlamaContinuousBatchingHandler, self).__init__() - self.model_class = import_class( - class_name="llama.model.LlamaForSampling", - module_prefix="transformers_neuronx", - ) - - self.tokenizer_class = import_class( - class_name="transformers.LlamaTokenizer", - ) diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml b/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml index 0210b10d89..2a69d4de52 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml +++ b/examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml @@ -1,18 +1,19 @@ minWorkers: 1 maxWorkers: 1 maxBatchDelay: 0 -responseTimeout: 10800 batchSize: 8 +responseTimeout: 10800 +jobQueueSize: 500 continuousBatching: true handler: - model_path: "model/models--meta-llama--Llama-2-13b-hf/snapshots/dc1d3b3bfdb69df26f8fc966c16353274b138c55" - model_checkpoint_dir: "llama-2-13b-split" + model_path: "model/models--meta-llama--Llama-2-70b-hf/snapshots/90052941a64de02075ca800b09fcea1bdaacb939" + model_checkpoint_dir: "llama-2-70b-split" model_module_prefix: "transformers_neuronx" model_class_name: "llama.model.LlamaForSampling" - tokenizer_module_prefix: "transformers.LlamaTokenizer" - amp: "f16" + tokenizer_class_name: "transformers.LlamaTokenizer" + amp: "bf16" tp_degree: 24 - max_length: 100 + max_length: 256 max_new_tokens: 50 batch_size: 8 diff --git a/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py b/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py index 667a09cfb6..09357a9fa9 100644 --- a/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py +++ b/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py @@ -92,7 +92,7 @@ def initialize(self, ctx: Context): ) raise error - + self._set_class(ctx) self.tokenizer = self.tokenizer_class.from_pretrained( model_checkpoint_path, return_tensors="pt", padding_side="left" ) @@ -104,7 +104,6 @@ def initialize(self, ctx: Context): batch_size_for_shared_caches=self.batch_size ) neuron_config = NeuronConfig(continuous_batching=continuous_batching_config) - self._set_class(ctx) kwargs = dict( tp_degree=tp_degree, amp=amp, @@ -155,9 +154,6 @@ def preprocess(self, requests): prompt = data.get("prompt") max_new_tokens = int(data.get("max_new_tokens", self.max_new_tokens)) - logger.info( - "preprocess prompt={prompt}, max_new_tokens={max_new_tokens}" - ) prefill_input_text.append(prompt.strip()) self.context.cache[req_id] = { @@ -388,11 +384,12 @@ def __call__(self, res): max_new_tokens=max_new_tokens, ) - def _set_class(self, ctx: Context): + def _set_class(self, ctx): model_class_name = ctx.model_yaml_config.get("handler", {}).get( "model_class_name", None ) - self.assertIsNotNone(model_class_name, "model_class_name is not defined") + + assert model_class_name is not None model_module_prefix = ctx.model_yaml_config.get("handler", {}).get( "model_module_prefix", None ) @@ -404,9 +401,7 @@ def _set_class(self, ctx: Context): tokenizer_class_name = ctx.model_yaml_config.get("handler", {}).get( "tokenizer_class_name", None ) - self.assertIsNotNone( - tokenizer_class_name, "tokenizer_class_name is not defined" - ) + assert tokenizer_class_name is not None tokenizer_module_prefix = ctx.model_yaml_config.get("handler", {}).get( "tokenizer_module_prefix", None ) From 932e7acfaec3f6fdc257d41817300cea0da5b5f6 Mon Sep 17 00:00:00 2001 From: lxning Date: Tue, 23 Jan 2024 22:03:14 -0800 Subject: [PATCH 42/49] fix lint --- ts_scripts/spellcheck_conf/wordlist.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt index f94a6a4c57..84986005bd 100644 --- a/ts_scripts/spellcheck_conf/wordlist.txt +++ b/ts_scripts/spellcheck_conf/wordlist.txt @@ -1174,3 +1174,7 @@ OPTForSampling gptj gptneox neox +LlamaTokenizer +bdaacb +de +fcea From f7a55312989d5c582f425c6469c22d55ae5be5ef Mon Sep 17 00:00:00 2001 From: lxning Date: Sun, 18 Feb 2024 22:16:50 -0800 Subject: [PATCH 43/49] fmt --- .../large_models/inferentia2/llama2/Readme.md | 110 +----------------- .../llama2/continuous_batching/Readme.md | 2 +- .../inf2-llama-2-continuous-batching.ipynb | 2 +- .../inferentia2/llama2/streamer/Readme.md | 106 ++++++++++++++++- .../utils/test_llm_streaming_response.py | 25 +++- ts/tests/unit_tests/test_handler_utils.py | 16 +++ ...ase_neuronx_continuous_batching_handler.py | 31 ++--- 7 files changed, 157 insertions(+), 135 deletions(-) create mode 100644 ts/tests/unit_tests/test_handler_utils.py diff --git a/examples/large_models/inferentia2/llama2/Readme.md b/examples/large_models/inferentia2/llama2/Readme.md index 614ef04108..d737117acb 100644 --- a/examples/large_models/inferentia2/llama2/Readme.md +++ b/examples/large_models/inferentia2/llama2/Readme.md @@ -1,110 +1,6 @@ # Large model inference on Inferentia2 -This document briefs on serving the [Llama 2](https://huggingface.co/meta-llama) model on [AWS Inferentia2](https://aws.amazon.com/ec2/instance-types/inf2/) for text completion with TorchServe's features: +This folder briefs on serving the [Llama 2](https://huggingface.co/meta-llama) model on [AWS Inferentia2](https://aws.amazon.com/ec2/instance-types/inf2/) for text completion with TorchServe's features: -* demo1: [micro batching](https://github.com/pytorch/serve/tree/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/examples/micro_batching) and [streaming response](https://github.com/pytorch/serve/blob/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/docs/inference_api.md#curl-example-1) support in folder streamer. -* demo2: continuous batching support in folder continuous_batching - -Inferentia2 uses [Neuron SDK](https://aws.amazon.com/machine-learning/neuron/) which is built on top of PyTorch XLA stack. For large model inference [`transformers-neuronx`](https://github.com/aws-neuron/transformers-neuronx) package is used that takes care of model partitioning and running inference. - -This example folder demonstrates the utilization of neuronx cache to store inf2 model compilation artifacts using the `NEURONX_CACHE` and `NEURON_COMPILE_CACHE_URL` environment variables in the custom handler. -When the model is loaded for the first time, the model is compiled for the configured micro batch size and the compilation artifacts are saved to the neuronx cache. -On subsequent model load, the compilation artifacts in the neuronx cache serves as `Ahead of Time(AOT)` compilation artifacts and significantly reduces the model load time. -For convenience, the compiled model artifacts for this example are made available on the Torchserve model zoo: `s3://torchserve/mar_files/llama-2-13b-neuronx-b4`\ -Instructions on how to use the AOT compiled model artifacts is shown below. - -### Step 1: Inf2 instance - -Get an Inf2 instance(Note: This example was tested on instance type:`inf2.24xlarge`), ssh to it, make sure to use the following DLAMI as it comes with PyTorch and necessary packages for AWS Neuron SDK pre-installed. -DLAMI Name: ` Deep Learning AMI Neuron PyTorch 1.13 (Ubuntu 20.04) 20230720 Amazon Machine Image (AMI)` or higher. - -**Note**: The `inf2.24xlarge` instance consists of 6 neuron chips with 2 neuron cores each. The total accelerator memory is 192GB. -Based on the configuration used in [model-config.yaml](streamer/model-config.yaml), with `tp_degree` set to 6, 3 of the 6 neuron chips are used, i.e 6 neuron cores. -On loading the model, the accelerator memory consumed is 38.1GB (12.7GB per chip). - -### Step 2: Package Installations - -Follow the steps below to complete package installations - -```bash -sudo apt-get update -sudo apt-get upgrade - -# Activate Python venv -source /opt/aws_neuron_venv_pytorch/bin/activate - -# Clone Torchserve git repository -git clone https://github.com/pytorch/serve.git -cd serve - -# Install dependencies -python ts_scripts/install_dependencies.py --neuronx --environment=dev - -# Install torchserve and torch-model-archiver -python ts_scripts/install_from_src.py - -# Navigate to `examples/large_models/inferentia2/llama2` directory -cd examples/large_models/inferentia2/llama2/ - -# Install additional necessary packages -python -m pip install -r requirements.txt -``` - -### Step 3: Save the model artifacts compatible with `transformers-neuronx` -In order to use the pre-compiled model artifacts, copy them from the model zoo using the command shown below and skip to **Step 5** -```bash -aws s3 cp s3://torchserve/mar_files/llama-2-13b-neuronx-b4/ llama-2-13b --recursive -``` - -In order to download and compile the Llama2 model from scratch for support on Inf2:\ -Request access to the Llama2 model\ -https://huggingface.co/meta-llama/Llama-2-13b-hf - -Login to Huggingface -```bash -huggingface-cli login -``` - -Run the `inf2_save_split_checkpoints.py` script -```bash -python ../util/inf2_save_split_checkpoints.py --model_name meta-llama/Llama-2-13b-hf --save_path './llama-2-13b-split' -``` - - -### Step 4: Package model artifacts - -```bash -torch-model-archiver --model-name llama-2-13b --version 1.0 --handler /PATH/TO/inf2_handler.py -r requirements.txt --config-file /PATH/TO/model-config.yaml --archive-format no-archive -mv llama-2-13b-split llama-2-13b -``` - -### Step 5: Add the model artifacts to model store - -```bash -mkdir model_store -mv llama-2-13b model_store -``` - -### Step 6: Start torchserve - -```bash -torchserve --ncs --start --model-store model_store --ts-config config.properties -``` - -### Step 7: Register model - -```bash -curl -X POST "http://localhost:8081/models?url=llama-2-13b" -``` - -### Step 8: Run inference - -```bash -python test_stream_response.py -``` - -### Step 9: Stop torchserve - -```bash -torchserve --stop -``` +* demo1: [micro batching](https://github.com/pytorch/serve/tree/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/examples/micro_batching) and [streaming response](https://github.com/pytorch/serve/blob/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/docs/inference_api.md#curl-example-1) support in folder [streamer](streamer). +* demo2: continuous batching support in folder [continuous_batching](continuous_batching) diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md b/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md index a863629cc1..8b935b2288 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md +++ b/examples/large_models/inferentia2/llama2/continuous_batching/Readme.md @@ -15,7 +15,7 @@ This example can also be extended to support Mistral without code changes. Custo | mistral | mistral.model.MistralForSampling | -The batch size [model-config.yaml](model-config.yaml). The batch size indicates the maximum number of requests torchserve will aggregate and send to the custom handler within the batch delay. It is the batch size used for the Inf2 model compilation. +The batch size in [model-config.yaml](model-config.yaml) indicates the maximum number of requests torchserve will aggregate and send to the custom handler within the batch delay. It is the batch size used for the Inf2 model compilation. Since compilation batch size can influence compile time and also constrained by the Inf2 instance type, this is chosen to be a relatively smaller value, say 4. `inf2-llama-2-continuous-batching.ipynb` is the notebook example. diff --git a/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb b/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb index 555ef57559..e6897cca85 100644 --- a/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb +++ b/examples/large_models/inferentia2/llama2/continuous_batching/inf2-llama-2-continuous-batching.ipynb @@ -72,7 +72,7 @@ "!python examples/large_models/utils/Download_model.py --model_path model --model_name meta-llama/Llama-2-13b-hf --use_auth_token True\n", "\n", "# Create TorchServe model artifacts\n", - "!torch-model-archiver --model-name llama-2-70b --version 1.0 --handler ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py -r requirements.txt --config-file examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml --archive-format no-archive\n", + "!torch-model-archiver --model-name llama-2-70b --version 1.0 --handler ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py -r examples/large_models/inferentia2/llama2/requirements.txt --config-file examples/large_models/inferentia2/llama2/continuous_batching/model-config.yaml --archive-format no-archive\n", "\n", "!mkdir -p model_store\n", "!mv llama-2-70b model_store\n", diff --git a/examples/large_models/inferentia2/llama2/streamer/Readme.md b/examples/large_models/inferentia2/llama2/streamer/Readme.md index 684b418e8b..6490420c0f 100644 --- a/examples/large_models/inferentia2/llama2/streamer/Readme.md +++ b/examples/large_models/inferentia2/llama2/streamer/Readme.md @@ -1,9 +1,113 @@ # Demo1: Llama-2 Using TorchServe micro-batching and Streamer on inf2 -This document briefs on serving the [Llama 2](https://huggingface.co/meta-llama) model on [AWS Inferentia2](https://aws.amazon.com/ec2/instance-types/inf2/) for text completion with TorchServe [micro batching](https://github.com/pytorch/serve/tree/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/examples/micro_batching) and [streaming response](https://github.com/pytorch/serve/blob/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/docs/inference_api.md#curl-example-1) support. +This document briefs on serving the [Llama 2](https://huggingface.co/meta-llama) model on [AWS Inferentia2](https://aws.amazon.com/ec2/instance-types/inf2/) for text completion with [micro batching](https://github.com/pytorch/serve/tree/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/examples/micro_batching) and [streaming response](https://github.com/pytorch/serve/blob/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/docs/inference_api.md#curl-example-1) support. + +Inferentia2 uses [Neuron SDK](https://aws.amazon.com/machine-learning/neuron/) which is built on top of PyTorch XLA stack. For large model inference [`transformers-neuronx`](https://github.com/aws-neuron/transformers-neuronx) package is used that takes care of model partitioning and running inference. **Note**: To run the model on an Inf2 instance, the model gets compiled as a preprocessing step. As part of the compilation process, to generate the model graph, a specific batch size is used. Following this, when running inference, we need to pass input which matches the batch size that was used during compilation. Model compilation and input padding to match compiled model batch size is taken care of by the [custom handler](inf2_handler.py) in this example. The batch size and micro batch size configurations are present in [model-config.yaml](model-config.yaml). The batch size indicates the maximum number of requests torchserve will aggregate and send to the custom handler within the batch delay. The batch size is chosen to be a relatively large value, say 16 since micro batching enables running the preprocess(tokenization) and inference steps in parallel on the micro batches. The micro batch size is the batch size used for the Inf2 model compilation. Since compilation batch size can influence compile time and also constrained by the Inf2 instance type, this is chosen to be a relatively smaller value, say 4. + +This example also demonstrates the utilization of neuronx cache to store inf2 model compilation artifacts using the `NEURONX_CACHE` and `NEURONX_DUMP_TO` environment variables in the custom handler. +When the model is loaded for the first time, the model is compiled for the configured micro batch size and the compilation artifacts are saved to the neuronx cache. +On subsequent model load, the compilation artifacts in the neuronx cache serves as `Ahead of Time(AOT)` compilation artifacts and significantly reduces the model load time. +For convenience, the compiled model artifacts for this example are made available on the Torchserve model zoo: `s3://torchserve/mar_files/llama-2-13b-neuronx-b4`\ +Instructions on how to use the AOT compiled model artifacts is shown below. + +### Step 1: Inf2 instance + +Get an Inf2 instance(Note: This example was tested on instance type:`inf2.24xlarge`), ssh to it, make sure to use the following DLAMI as it comes with PyTorch and necessary packages for AWS Neuron SDK pre-installed. +DLAMI Name: ` Deep Learning AMI Neuron PyTorch 1.13 (Ubuntu 20.04) 20230720 Amazon Machine Image (AMI)` or higher. + +**Note**: The `inf2.24xlarge` instance consists of 6 neuron chips with 2 neuron cores each. The total accelerator memory is 192GB. +Based on the configuration used in [model-config.yaml](model-config.yaml), with `tp_degree` set to 6, 3 of the 6 neuron chips are used, i.e 6 neuron cores. +On loading the model, the accelerator memory consumed is 38.1GB (12.7GB per chip). + +### Step 2: Package Installations + +Follow the steps below to complete package installations + +```bash +sudo apt-get update +sudo apt-get upgrade + +# Activate Python venv +source /opt/aws_neuron_venv_pytorch/bin/activate + +# Clone Torchserve git repository +git clone https://github.com/pytorch/serve.git +cd serve + +# Install dependencies +python ts_scripts/install_dependencies.py --neuronx --environment=dev + +# Install torchserve and torch-model-archiver +python ts_scripts/install_from_src.py + +# Navigate to `examples/large_models/inferentia2/llama2` directory +cd examples/large_models/inferentia2/llama2/ + +# Install additional necessary packages +python -m pip install -r requirements.txt +``` + +### Step 3: Save the model artifacts compatible with `transformers-neuronx` +In order to use the pre-compiled model artifacts, copy them from the model zoo using the command shown below and skip to **Step 5** +```bash +aws s3 cp s3://torchserve/mar_files/llama-2-13b-neuronx-b4/ llama-2-13b --recursive +``` + +In order to download and compile the Llama2 model from scratch for support on Inf2:\ +Request access to the Llama2 model\ +https://huggingface.co/meta-llama/Llama-2-13b-hf + +Login to Huggingface +```bash +huggingface-cli login +``` + +Run the `inf2_save_split_checkpoints.py` script +```bash +python ../util/inf2_save_split_checkpoints.py --model_name meta-llama/Llama-2-13b-hf --save_path './llama-2-13b-split' +``` + + +### Step 4: Package model artifacts + +```bash +torch-model-archiver --model-name llama-2-13b --version 1.0 --handler inf2_handler.py -r requirements.txt --config-file model-config.yaml --archive-format no-archive +mv llama-2-13b-split llama-2-13b +``` + +### Step 5: Add the model artifacts to model store + +```bash +mkdir model_store +mv llama-2-13b model_store +``` + +### Step 6: Start torchserve + +```bash +torchserve --ncs --start --model-store model_store --ts-config config.properties +``` + +### Step 7: Register model + +```bash +curl -X POST "http://localhost:8081/models?url=llama-2-13b" +``` + +### Step 8: Run inference + +```bash +python test_stream_response.py +``` + +### Step 9: Stop torchserve + +```bash +torchserve --stop +``` diff --git a/examples/large_models/utils/test_llm_streaming_response.py b/examples/large_models/utils/test_llm_streaming_response.py index bc31b9da8b..b8b7f10ec1 100644 --- a/examples/large_models/utils/test_llm_streaming_response.py +++ b/examples/large_models/utils/test_llm_streaming_response.py @@ -6,10 +6,12 @@ import orjson import requests +max_prompt_random_tokens = 20 + class Predictor(threading.Thread): def __init__(self, args, queue): - threading.Thread.__init__(self) + super().__init__() self.args = args self.queue = queue @@ -36,10 +38,11 @@ def _format_payload(self): rp = len(prompt_list) rt = self.args.max_tokens if self.args.prompt_randomize: - rp = random.randint(1, len(prompt_list)) - rt = random.randint(10, self.args.max_tokens) - cur_prompt_list = prompt_list[:rp] - cur_prompt = " ".join(cur_prompt_list) + rp = random.randint(0, max_prompt_random_tokens) + rt = rp + self.args.max_tokens + for _ in range(rp): + prompt_list.insert(0, chr(ord("a") + random.randint(0, 25))) + cur_prompt = " ".join(prompt_list) return { "prompt": cur_prompt, "max_new_tokens": rt, @@ -66,11 +69,13 @@ def parse_args(): parser.add_argument( "-m", "--model", + required=True, type=str, - help="The model to use for generating text. If not specified we will pick the first model from the service as returned by /v1/models", + help="The model to use for generating text.", ) parser.add_argument( "--prompt-text", + required=True, type=str, help="Prompt text to use instead of generating one. It can be a file reference starting with an ampersand, e.g. `@prompt.txt`", ) @@ -107,6 +112,14 @@ def parse_args(): def main(): args = parse_args() + if len(args.model) == 0: + print("model argument can not be empty.") + exit(1) + + if len(args.prompt_text) == 0: + print("prompt argument can not be empty.") + exit(1) + queue = Queue() predictors = [] for i in range(args.num_threads): diff --git a/ts/tests/unit_tests/test_handler_utils.py b/ts/tests/unit_tests/test_handler_utils.py new file mode 100644 index 0000000000..0d242fc68a --- /dev/null +++ b/ts/tests/unit_tests/test_handler_utils.py @@ -0,0 +1,16 @@ +from ts.handler_utils.utils import import_class + + +def test_import_class_no_module_prefix(): + model_class = import_class( + class_name="ts.torch_handler.base_handler.BaseHandler", + ) + assert "BaseHandler" == model_class.__class__.__name__ + + +def test_import_class_module_prefix(): + model_class = import_class( + class_name="BaseHandler", + module_prefix="ts.torch_handler.base_handler", + ) + assert "BaseHandler" == model_class.__class__.__name__ diff --git a/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py b/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py index 09357a9fa9..d9dc5a54bf 100644 --- a/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py +++ b/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py @@ -17,7 +17,7 @@ class BaseNeuronXContinuousBatchingHandler(BaseHandler): def __init__(self): - super(BaseNeuronXContinuousBatchingHandler, self).__init__() + super().__init__() self.batch_size = 2 self.max_new_tokens = 25 @@ -60,21 +60,16 @@ def initialize(self, ctx: Context): "NEURON_CC_FLAGS" ] = "-O1 --model-type=transformer --enable-mixed-precision-accumulation" - self.max_length = int( - ctx.model_yaml_config.get("handler", {}).get("max_length", self.max_length) - ) + handler_config = ctx.model_yaml_config.get("handler", {}) + self.max_length = int(handler_config.get("max_length", self.max_length)) self.max_new_tokens = int( - ctx.model_yaml_config.get("handler", {}).get( - "max_new_tokens", self.max_new_tokens - ) - ) - self.batch_size = int( - ctx.model_yaml_config.get("handler", {}).get("batch_size", self.batch_size) + handler_config.get("max_new_tokens", self.max_new_tokens) ) + self.batch_size = int(handler_config.get("batch_size", self.batch_size)) # settings for model compilation and loading - amp = ctx.model_yaml_config.get("handler", {}).get("amp", "fp32") - tp_degree = ctx.model_yaml_config.get("handler", {}).get("tp_degree", 6) + amp = handler_config.get("amp", "fp32") + tp_degree = handler_config.get("tp_degree", 6) # allocate "tp_degree" number of neuron cores to the worker process os.environ["NEURON_RT_NUM_CORES"] = str(tp_degree) @@ -98,7 +93,6 @@ def initialize(self, ctx: Context): ) self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token_id = self.tokenizer.eos_token_id - self.tokenizer.padding_side = "left" continuous_batching_config = ContinuousBatchingConfig( batch_size_for_shared_caches=self.batch_size @@ -109,7 +103,7 @@ def initialize(self, ctx: Context): amp=amp, batch_size=self.batch_size, n_positions=[self.max_length], - context_length_estimate=ctx.model_yaml_config.get("handler", {}).get( + context_length_estimate=handler_config.get( "context_length_estimate", [self.max_length] ), neuron_config=neuron_config, @@ -128,7 +122,7 @@ def initialize(self, ctx: Context): self.decode_next_tokens = torch.zeros(self.batch_size, 1, dtype=torch.int64) # self.decode_next_tokens = torch.full(self.batch_size, self.tokenizer.eos_token_id) - for _, seq_id in enumerate(reversed(range(self.batch_size))): + for seq_id in reversed(range(self.batch_size)): self.empty_seq_ids.append(seq_id) logger.info("Model %s loaded successfully", ctx.model_name) @@ -154,7 +148,7 @@ def preprocess(self, requests): prompt = data.get("prompt") max_new_tokens = int(data.get("max_new_tokens", self.max_new_tokens)) - prefill_input_text.append(prompt.strip()) + prefill_input_text.append(prompt) self.context.cache[req_id] = { "seq_id": seq_id, @@ -175,8 +169,7 @@ def preprocess(self, requests): def inference(self, inputs): prefill_input_text, prefill_tokens, prefill_seq_ids, req_decode_seq_ids = inputs results = {} - # Test if this is the beginning of a continuous batching - go_to_decode = True if len(req_decode_seq_ids) > 0 else False + if len(prefill_seq_ids) > 0: prefill_next_tokens, prefill_cache_ids = self._run_prefill( prefill_tokens, prefill_seq_ids @@ -192,7 +185,7 @@ def inference(self, inputs): prefill_input_text=prefill_input_text, ) - if go_to_decode: + if len(req_decode_seq_ids) > 0: local_decode_seq_ids = torch.cat(torch.where(self.decode_seq_ids > -1)) local_decode_cache_ids = self.decode_cache_ids[local_decode_seq_ids] local_decode_next_tokens = self.decode_next_tokens[local_decode_seq_ids] From cbfcec496e808c5c810600f0b02446a9d50077f8 Mon Sep 17 00:00:00 2001 From: lxning Date: Mon, 19 Feb 2024 13:23:40 -0800 Subject: [PATCH 44/49] update test data --- ts/tests/unit_tests/test_handler_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ts/tests/unit_tests/test_handler_utils.py b/ts/tests/unit_tests/test_handler_utils.py index 0d242fc68a..dbf16e0718 100644 --- a/ts/tests/unit_tests/test_handler_utils.py +++ b/ts/tests/unit_tests/test_handler_utils.py @@ -3,14 +3,14 @@ def test_import_class_no_module_prefix(): model_class = import_class( - class_name="ts.torch_handler.base_handler.BaseHandler", + class_name="transformers.LlamaTokenizer", ) - assert "BaseHandler" == model_class.__class__.__name__ + assert "LlamaTokenizer" == model_class.__class__.__name__ def test_import_class_module_prefix(): model_class = import_class( - class_name="BaseHandler", - module_prefix="ts.torch_handler.base_handler", + class_name="LlamaTokenizer", + module_prefix="transformers", ) - assert "BaseHandler" == model_class.__class__.__name__ + assert "LlamaTokenizer" == model_class.__class__.__name__ From da34b53f6d1eb82afd07cb5e03738c033255eac0 Mon Sep 17 00:00:00 2001 From: lxning Date: Mon, 19 Feb 2024 20:06:35 -0800 Subject: [PATCH 45/49] update test --- ts/handler_utils/utils.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/ts/handler_utils/utils.py b/ts/handler_utils/utils.py index 3d1bf47e46..52cb689901 100644 --- a/ts/handler_utils/utils.py +++ b/ts/handler_utils/utils.py @@ -6,6 +6,9 @@ def import_class(class_name: str, module_prefix=None): + if class_name is None or len(class_name) == 0: + raise ImportError(f"class name is not defined") + module_name = "" arr = class_name.rsplit(".", maxsplit=1) if len(arr) == 2: @@ -13,10 +16,16 @@ def import_class(class_name: str, module_prefix=None): else: class_name = arr[0] - if module_prefix is not None: - module = importlib.import_module(f"{module_prefix}.{module_name}") - else: + if module_prefix is not None and len(module_prefix) > 0: + module = ( + importlib.import_module(f"{module_prefix}.{module_name}") + if len(module_name) > 0 + else importlib.import_module(module_prefix) + ) + elif len(module_name) > 0: module = importlib.import_module(module_name) + else: + raise ImportError(f"module name is not defined.") model_class = getattr(module, class_name, None) if model_class is None: From b373077ee9985d4dfe8ae868273aab0629884672 Mon Sep 17 00:00:00 2001 From: lxning Date: Mon, 19 Feb 2024 20:13:37 -0800 Subject: [PATCH 46/49] update test --- ts/tests/unit_tests/test_handler_utils.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/ts/tests/unit_tests/test_handler_utils.py b/ts/tests/unit_tests/test_handler_utils.py index dbf16e0718..dc0d97f384 100644 --- a/ts/tests/unit_tests/test_handler_utils.py +++ b/ts/tests/unit_tests/test_handler_utils.py @@ -1,3 +1,5 @@ +import pytest + from ts.handler_utils.utils import import_class @@ -14,3 +16,17 @@ def test_import_class_module_prefix(): module_prefix="transformers", ) assert "LlamaTokenizer" == model_class.__class__.__name__ + + +def test_import_class_no_module(): + with pytest.raises(ImportError): + model_class = import_class( + class_name="LlamaTokenizer", + ) + + +def test_import_class_no_class(): + with pytest.raises(ImportError): + model_class = import_class( + class_name="", + ) From aa3eafe4fef272aaf81a2f575895c447ee1d92d2 Mon Sep 17 00:00:00 2001 From: lxning Date: Tue, 20 Feb 2024 09:27:41 -0800 Subject: [PATCH 47/49] replace os.path with pathlib --- ...ase_neuronx_continuous_batching_handler.py | 36 ++++++++----------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py b/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py index d9dc5a54bf..ff226322ed 100644 --- a/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py +++ b/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py @@ -1,5 +1,6 @@ import logging import os +import pathlib import torch import torch_neuronx @@ -36,21 +37,23 @@ def __init__(self): def initialize(self, ctx: Context): ctx.cache = {} model_dir = ctx.system_properties.get("model_dir") - model_checkpoint_dir = ctx.model_yaml_config.get("handler", {}).get( - "model_checkpoint_dir", "" + handler_config = ctx.model_yaml_config.get("handler", {}) + model_checkpoint_dir = handler_config.get("model_checkpoint_dir", "") + + model_checkpoint_path = pathlib.Path(model_dir).joinpath(model_checkpoint_dir) + model_path = pathlib.Path(model_dir).joinpath( + handler_config.get("model_path", "") ) - model_checkpoint_path = f"{model_dir}/{model_checkpoint_dir}" - model_path = f'{model_dir}/{ctx.model_yaml_config["handler"]["model_path"]}' - if not os.path.exists(model_checkpoint_path): + if not model_checkpoint_path.exists(): # Load and save the CPU model model_cpu = AutoModelForCausalLM.from_pretrained( - model_path, low_cpu_mem_usage=True + str(model_path), low_cpu_mem_usage=True ) save_pretrained_split(model_cpu, model_checkpoint_path) # Load and save tokenizer for the model tokenizer = AutoTokenizer.from_pretrained( - model_path, return_tensors="pt", padding_side="left" + str(model_path), return_tensors="pt", padding_side="left" ) tokenizer.save_pretrained(model_checkpoint_path) @@ -60,7 +63,6 @@ def initialize(self, ctx: Context): "NEURON_CC_FLAGS" ] = "-O1 --model-type=transformer --enable-mixed-precision-accumulation" - handler_config = ctx.model_yaml_config.get("handler", {}) self.max_length = int(handler_config.get("max_length", self.max_length)) self.max_new_tokens = int( handler_config.get("max_new_tokens", self.max_new_tokens) @@ -120,7 +122,6 @@ def initialize(self, ctx: Context): self.decode_cache_ids = torch.zeros(self.batch_size, 1, dtype=torch.int64) # 2D: [batch_size, next_token] self.decode_next_tokens = torch.zeros(self.batch_size, 1, dtype=torch.int64) - # self.decode_next_tokens = torch.full(self.batch_size, self.tokenizer.eos_token_id) for seq_id in reversed(range(self.batch_size)): self.empty_seq_ids.append(seq_id) @@ -378,26 +379,19 @@ def __call__(self, res): ) def _set_class(self, ctx): - model_class_name = ctx.model_yaml_config.get("handler", {}).get( - "model_class_name", None - ) + handler_config = ctx.model_yaml_config.get("handler", {}) + model_class_name = handler_config.get("model_class_name", None) assert model_class_name is not None - model_module_prefix = ctx.model_yaml_config.get("handler", {}).get( - "model_module_prefix", None - ) + model_module_prefix = handler_config.get("model_module_prefix", None) self.model_class = import_class( class_name=model_class_name, module_prefix=model_module_prefix, ) - tokenizer_class_name = ctx.model_yaml_config.get("handler", {}).get( - "tokenizer_class_name", None - ) + tokenizer_class_name = handler_config.get("tokenizer_class_name", None) assert tokenizer_class_name is not None - tokenizer_module_prefix = ctx.model_yaml_config.get("handler", {}).get( - "tokenizer_module_prefix", None - ) + tokenizer_module_prefix = handler_config.get("tokenizer_module_prefix", None) self.tokenizer_class = import_class( class_name=tokenizer_class_name, module_prefix=tokenizer_module_prefix From 2cb22292a0e5c978ca410d9066829c65950a873b Mon Sep 17 00:00:00 2001 From: lxning Date: Tue, 20 Feb 2024 15:31:32 -0800 Subject: [PATCH 48/49] update test --- ts/tests/unit_tests/test_handler_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ts/tests/unit_tests/test_handler_utils.py b/ts/tests/unit_tests/test_handler_utils.py index dc0d97f384..96a92e7708 100644 --- a/ts/tests/unit_tests/test_handler_utils.py +++ b/ts/tests/unit_tests/test_handler_utils.py @@ -7,7 +7,7 @@ def test_import_class_no_module_prefix(): model_class = import_class( class_name="transformers.LlamaTokenizer", ) - assert "LlamaTokenizer" == model_class.__class__.__name__ + assert "LlamaTokenizer" == model_class.__name__ def test_import_class_module_prefix(): @@ -15,7 +15,7 @@ def test_import_class_module_prefix(): class_name="LlamaTokenizer", module_prefix="transformers", ) - assert "LlamaTokenizer" == model_class.__class__.__name__ + assert "LlamaTokenizer" == model_class.__name__ def test_import_class_no_module(): From a2ba12451b48867ca43e230f22db3d069d0cf91e Mon Sep 17 00:00:00 2001 From: lxning Date: Wed, 21 Feb 2024 16:50:40 -0800 Subject: [PATCH 49/49] fmt --- ts/handler_utils/utils.py | 4 ++-- .../base_neuronx_continuous_batching_handler.py | 9 +++++++-- ts_scripts/spellcheck_conf/wordlist.txt | 1 - 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/ts/handler_utils/utils.py b/ts/handler_utils/utils.py index 52cb689901..3361513c40 100644 --- a/ts/handler_utils/utils.py +++ b/ts/handler_utils/utils.py @@ -6,7 +6,7 @@ def import_class(class_name: str, module_prefix=None): - if class_name is None or len(class_name) == 0: + if not class_name: raise ImportError(f"class name is not defined") module_name = "" @@ -16,7 +16,7 @@ def import_class(class_name: str, module_prefix=None): else: class_name = arr[0] - if module_prefix is not None and len(module_prefix) > 0: + if module_prefix: module = ( importlib.import_module(f"{module_prefix}.{module_name}") if len(module_name) > 0 diff --git a/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py b/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py index ff226322ed..5f6c82fe0d 100644 --- a/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py +++ b/ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py @@ -382,7 +382,9 @@ def _set_class(self, ctx): handler_config = ctx.model_yaml_config.get("handler", {}) model_class_name = handler_config.get("model_class_name", None) - assert model_class_name is not None + assert ( + model_class_name + ), "model_class_name not found in the section of handler in model config yaml file" model_module_prefix = handler_config.get("model_module_prefix", None) self.model_class = import_class( class_name=model_class_name, @@ -390,7 +392,10 @@ def _set_class(self, ctx): ) tokenizer_class_name = handler_config.get("tokenizer_class_name", None) - assert tokenizer_class_name is not None + assert ( + tokenizer_class_name + ), "tokenizer_class_name not found in the section of handler in model config yaml file" + tokenizer_module_prefix = handler_config.get("tokenizer_module_prefix", None) self.tokenizer_class = import_class( diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt index 861824495b..0489f8d908 100644 --- a/ts_scripts/spellcheck_conf/wordlist.txt +++ b/ts_scripts/spellcheck_conf/wordlist.txt @@ -1198,7 +1198,6 @@ Maher's warmup SOTA FxGraphCache -TorchInductor fx locustapache FINhR