-
Notifications
You must be signed in to change notification settings - Fork 863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support inf2 neuronx transformer continuous batching #2803
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Over all this looks good. Left a couple of comments. Please add at least an e2e test for this. We already have some example tests that you can refer to.
@@ -85,7 +82,7 @@ python ../util/inf2_save_split_checkpoints.py --model_name meta-llama/Llama-2-13 | |||
### Step 4: Package model artifacts | |||
|
|||
```bash | |||
torch-model-archiver --model-name llama-2-13b --version 1.0 --handler inf2_handler.py -r requirements.txt --config-file model-config.yaml --archive-format no-archive | |||
torch-model-archiver --model-name llama-2-13b --version 1.0 --handler /PATH/TO/inf2_handler.py -r requirements.txt --config-file /PATH/TO/model-config.yaml --archive-format no-archive |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to set an env variable where users switch between the two choices in a single place and then just copy the commands
"\n", | ||
"# Install dependencies, now all commands run under serve dir\n", | ||
"!cd serve\n", | ||
"!git checkout feat/inf2_cb\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this still valid if merged into master?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added notice at the beginning of this cell. "This notebook demonstrates TorchServe continuous batching serving Llama-2-70b on Inferentia-2 inf2.48xlarge
with DLAMI: Deep Learning AMI Neuron PyTorch 1.13 (Ubuntu 20.04) 20231226". Currently this notebook is needed for SA to present the solution to cx.
"### Create model artifacts\n", | ||
"\n", | ||
"Note: run `mv model/models--meta-llama--Llama-2-70b-hf/snapshots/90052941a64de02075ca800b09fcea1bdaacb939/model.safetensors.index.json model/models--meta-llama--Llama-2-70b-hf/snapshots/90052941a64de02075ca800b09fcea1bdaacb939/model.safetensors.index.json.bkp`\n", | ||
" if neuron sdk does not support safetensors" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"if" neuron sdk ....? On what does this depend?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
neuron sdk is still in beta version to support model safetensors format.
model_class_name: "llama.model.LlamaForSampling" | ||
tokenizer_class_name: "transformers.LlamaTokenizer" | ||
amp: "bf16" | ||
tp_degree: 24 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the minimal number of cores possible for this model size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to inf2 team guidance, set tp_degree as 24 on inf2.48x, 32 for trn1.32xlarge.
prefill_input_text, prefill_tokens, prefill_seq_ids, req_decode_seq_ids = inputs | ||
results = {} | ||
# Test if this is the beginning of a continuous batching | ||
go_to_decode = True if len(req_decode_seq_ids) > 0 else False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to test this in 195 directly for clarity
return prefill_input_text, prefill_tokens, prefill_seq_ids, req_decode_seq_ids | ||
|
||
def inference(self, inputs): | ||
prefill_input_text, prefill_tokens, prefill_seq_ids, req_decode_seq_ids = inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not decode_seq_ids like prefill_seq_ids?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is to highlight the decode_seq_ids from frontend by adding prefix "req_". This is different from the self.decode_seq_ids which include the prefill seq ids.
z = torch.empty(x.shape[0], self.max_length, dtype=torch.int64) | ||
for idx, item in enumerate(x): | ||
pad = torch.zeros(self.max_length - len(x[idx]), dtype=torch.int) | ||
z[idx] = torch.cat((x[idx], pad)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should test if stacking the concatenated tensors is faster
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the pad is added to the same dimension. Why stack is needed at here?
) | ||
else: | ||
req_id = self._get_req_id(seq_id) | ||
logger.warning( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In which cases can this occur?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to prevent from any scenarios in frontend which might delete the request when client is disconnected.
logger = logging.getLogger(__name__) | ||
|
||
|
||
class BaseNeuronXContinuousBatchingHandler(BaseHandler): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to at least write an e2e test for this handler. The e2e test should queue multiple requests and check for results to make sure ids are not mixed up. Better to test methods as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_llm_streaming_response.py is a tool for manual e2e testing, which is a workaround solution for the inf2 environment dependency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or we can add test data in https://github.com/pytorch/serve/blob/master/test/postman/inference_stream_data.json when we have inf2 regression test ci-job. Currently all of inf1/inf2 examples use nightly benchmark dashboard as testing tool.
This PR requires inf2 environment to run e2e test. TorchServe only has inf2 benchmark ci job. That's why I only posted e2e test result which I ran manually on inf2.48x. The notebook is an e2e test. |
Thanks @lxning will be good to have an e2e test under test/pytest so we can repeat the test easily if necessary (we make some changes to the example etc.) You can skip the test if no inf2 hw is detected. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, left a couple of minor comments
### Step 8: Run inference | ||
|
||
```bash | ||
python test_stream_response.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Path and filename have to be ../utils/test_llm_streaming_response.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I keep streamer section as orginal. This PR should update the entire stream section.
ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py
Outdated
Show resolved
Hide resolved
ts/torch_handler/distributed/base_neuronx_continuous_batching_handler.py
Outdated
Show resolved
Hide resolved
* fmt * fmt * fmt * add space * fmt * fmt * fmt * fmt * fix regression test * check key result * fmt * update folder * fmt * update key name * add orjson * update streamer * add key text for streamer iterator * update test_hf_batch_streamer output * integrate split checkpoint in handler * fmt * fmt * fmt * fmt * fmt * fmt * update notebook * fmt * add handler utils * fix typo * fmt * fmt * fmt * fmt * fmt * Fix lint * fix typo in notebook example * enable authentication * fmt * fmt * fmt * update readme * fix lint * fmt * update test data * update test * update test * replace os.path with pathlib * update test * fmt
Description
Please read our CONTRIBUTING.md prior to creating your first pull request.
Please include a summary of the feature or issue being fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Fixes #(issue)
Type of change
Please delete options that are not relevant.
Feature/Issue validation/testing
Please describe the Unit or Integration tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
Please also list any relevant details for your test configuration.
Run inference:
Checklist: