Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tnx] neuron rolling batch test suite #2172

Merged
merged 1 commit into from
Jul 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
#!/usr/bin/env python
#
# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

from collections import defaultdict
from typing import List, Dict
from dataclasses import dataclass
from djl_python import test_model, Input
from djl_python.request import Request
from djl_python.input_parser import format_input


@dataclass
class SimulationSchedule:
prompts: List[str]
params: List[Dict]
reqs_to_prefill: List[int]
wait_steps: List[int]


class NeuronRollingBatchGenerator:

def __init__(self):
self.rolling_batch = None
self._req_id = 0
# Store the results
self.output_all = defaultdict(list)
self.input_all = {}
self.data_collector = []
self.responses = []

# Status variables
self.input_str = []
self.params = []
self.req_ids = []

# Spec_dec
self.token_numbers = defaultdict(list)

def init_neuron_service(self, properties: dict):
from djl_python.transformers_neuronx import TransformersNeuronXService
_service = TransformersNeuronXService()
_service.initialize(properties)
self.rolling_batch = _service.rolling_batch

def get_req_id(self):
req_id = self._req_id
self._req_id = self._req_id + 1
return req_id

def collect_data(self, result):
done_requests_indices = []
for idx, item in enumerate(result):
if len(self.data_collector) <= idx:
self.data_collector.append(item["data"])
else:
self.data_collector[idx] += item["data"]
if item['last']:
done_requests_indices.append(idx)
for idx in sorted(done_requests_indices, reverse=True):
value = self.data_collector.pop(idx)
self.responses.append(value)
print(f"\nFinished request: {value}\n")
return done_requests_indices

def build_request(self, raw_input):
inputs = test_model.create_json_request(raw_input)
parsed_inputs = format_input(inputs)
request = Request(parsed_inputs)
request.id = self.get_req_id()
return request

def simulator(self, schedule: SimulationSchedule):
assert len(schedule.prompts) == len(schedule.params)
assert len(schedule.reqs_to_prefill) == len(schedule.wait_steps)
zipped_requests = zip(schedule.prompts, schedule.params)
all_requests = [{
"inputs": prompt,
"parameters": params
} for prompt, params in zipped_requests]
current_requests = []
new_requests = []
for batch_size, step in zip(schedule.reqs_to_prefill,
schedule.wait_steps):
for _ in range(batch_size):
request = self.build_request(all_requests.pop(0))
new_requests = [request] + new_requests
current_requests.append(request)

for i in range(step):
if len(current_requests) == 0:
break
generated_tokens = self.rolling_batch.inference(new_requests)
new_requests.clear()
finished_indices = self.collect_data(generated_tokens)
for idx in sorted(finished_indices, reverse=True):
current_requests.pop(idx)
while len(current_requests) > 0:
generated_tokens = self.rolling_batch.inference(new_requests)
finished_indices = self.collect_data(generated_tokens)
for idx in sorted(finished_indices, reverse=True):
current_requests.pop(idx)

def step(self, step=20, input_str_delta=None, params_delta=None):
if input_str_delta:
begin_id = max(self.input_all.keys(), default=0) + 1
req_ids_delta = list(
range(begin_id, begin_id + len(input_str_delta)))

self.input_str += input_str_delta
self.params += params_delta
self.req_ids += req_ids_delta
for req_id, input_s, param in zip(req_ids_delta, input_str_delta,
params_delta):
self.input_all[req_id] = (input_s, param)

iterator = range(step)
for i in iterator:
result = self.rolling_batch.inference(self.input_str, self.params)
for res, req_id in zip(result, self.req_ids):
self.output_all[req_id].append(res['data'])
self.token_numbers[req_id].append(res.get('step_token_num', 1))
self.req_ids = [
req_id for req_id, res in zip(self.req_ids, result)
if not res['last']
]
self.input_str = [
s for s, res in zip(self.input_str, result) if not res['last']
]
self.params = [
p for p, res in zip(self.params, result) if not res['last']
]
if not self.req_ids:
break

def is_empty(self):
return not self.req_ids

def reset(self):
self.data_collector = []
self.rolling_batch = None
# Store the results
self.output_all = defaultdict(list)
self.input_all = {}

# Status variables, the remaining
self.input_str = []
self.params = []
self.req_ids = []

self.token_numbers = defaultdict(list)
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#!/usr/bin/env python
#
# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

import unittest
import json
import os

try:
import transformers_neuronx
from djl_python.transformers_neuronx import TransformersNeuronXService
SKIP_TEST = False
except ImportError:
SKIP_TEST = True

expected_text_30 = {
"TinyLlama/TinyLlama-1.1B-Chat-v0.6": {
0:
"Hello, my name is [Your Name] and I am a [Your Job Title] at [Your Company Name]. I am interested in learning more about your company'",
1:
'The president of the United States is a man named Donald Trump.\n\n2. The president of the United States is a man named Donald Trump.\n\n3. The president',
2:
'The capital of France is Paris.\n\n2. The capital of the United States is Washington, D.C.\n\n3. The capital of Canada is Ott',
3:
"The future of AI is bright, and it's already here. With the help of AI, we can create more personalized experiences, automate repetitive tasks, and even predict the future.",
}
}


@unittest.skipIf(SKIP_TEST, "Neuron dependencies are not available")
class TestNeuronRollingBatch(unittest.TestCase):

def test_models(self):
# === Preparation ===
from djl_python.tests.neuron_test_scripts.neuron_rb_generator import NeuronRollingBatchGenerator, SimulationSchedule

# --- Models ---
model_names = [
"TinyLlama/TinyLlama-1.1B-Chat-v0.6",
]

# === Test ===
for model_id in model_names:
properties = {
"tensor_parallel_degree": 2,
"n_positions": "128",
"rolling_batch": "tnx",
"max_rolling_batch_size": 4,
"model_id": model_id
}

# ===================== neuron-tnx ============================
gen = NeuronRollingBatchGenerator()
gen.init_neuron_service(properties)

print('========== init inference ===========')
input_str = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

params = [{
"max_new_tokens": 100,
"do_sample": False,
}.copy() for _ in range(len(input_str))]

test_input = SimulationSchedule(prompts=input_str,
params=params,
reqs_to_prefill=[1, 2, 1],
wait_steps=[1, 4, 5])

gen.simulator(test_input)

for i, out in enumerate(gen.responses):
out_dict = json.loads(''.join(out))
out_str = out_dict["generated_text"]
test_generation = input_str[i] + " " + out_str
print(f"\n====req_id: {i}=====\n{test_generation}\n")
if model_id in expected_text_30 and i in expected_text_30[
model_id]:
expected_prefix_30_req_id = expected_text_30[model_id][i]
assert expected_prefix_30_req_id == test_generation[:len(
expected_prefix_30_req_id)]

gen.reset()
del gen
import gc
gc.collect()

def test_tiny_models(self):
# === Preparation ===
from djl_python.tests.neuron_test_scripts.neuron_rb_generator import NeuronRollingBatchGenerator, SimulationSchedule
from djl_python.tests.neuron_test_scripts.tiny_models import artifacts
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# --- Models ---
model_names = [
"llama",
"gpt2",
"gptneox",
"bloom",
]

# === Test ===
for model_id in model_names:
properties = {
"tensor_parallel_degree": 2,
"n_positions": "128",
"rolling_batch": "tnx",
"max_rolling_batch_size": 4,
"model_loading_timeout": 3600,
"model_id": artifacts(model_id)
}

# ===================== neuron-tnx ============================
gen = NeuronRollingBatchGenerator()
gen.init_neuron_service(properties)

print('========== init inference ===========')
input_str = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

params = [{
"max_new_tokens": 100,
"do_sample": False,
"ignore_eos": True,
}.copy() for _ in range(len(input_str))]

test_input = SimulationSchedule(prompts=input_str,
params=params,
reqs_to_prefill=[1, 2, 1],
wait_steps=[1, 4, 5])

gen.simulator(test_input)
gen.reset()
del gen
import gc
gc.collect()


if __name__ == '__main__':
unittest.main()
Loading
Loading