Skip to content

Commit

Permalink
Support JetStream MaxText user guide (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZijunZhou authored Apr 3, 2024
1 parent 426c915 commit 90b2a9d
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 85 deletions.
160 changes: 98 additions & 62 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,21 @@
"""


import tensorflow as tf
import tensorflow_text as tftxt

import argparse
import asyncio

from dataclasses import dataclass
from datetime import datetime
import json
import random
import time
from typing import Any, AsyncGenerator, List, Optional

import grpc
from jetstream.core.proto import jetstream_pb2
from jetstream.core.proto import jetstream_pb2_grpc
import numpy as np
import tensorflow as tf
import tensorflow_text as tftxt
from tqdm.asyncio import tqdm


Expand All @@ -96,6 +95,7 @@ class InputRequest:
output: str = ""
output_len: int = 0


@dataclass
class RequestFuncOutput:
input_request: InputRequest = None
Expand All @@ -109,12 +109,12 @@ class RequestFuncOutput:
# Flatten the structure and return only the necessary results
def to_dict(self):
return {
"prompt": self.input_request.prompt,
"original_output": self.input_request.output,
"generated_text": self.generated_text,
"success": self.success,
"latency": self.latency,
"prompt_len": self.prompt_len
"prompt": self.input_request.prompt,
"original_output": self.input_request.output,
"generated_text": self.generated_text,
"success": self.success,
"latency": self.latency,
"prompt_len": self.prompt_len,
}


Expand All @@ -123,12 +123,14 @@ def get_tokenizer(tokenizer_name: str) -> Any:
if tokenizer_name == "test":
return "test"
else:
with tf.io.gfile.GFile(tokenizer_name, 'rb') as model_fp:
with tf.io.gfile.GFile(tokenizer_name, "rb") as model_fp:
sp_model = model_fp.read()
sp_tokenizer = tftxt.SentencepieceTokenizer(
model=sp_model, add_bos=True, add_eos=False, reverse=False)
model=sp_model, add_bos=True, add_eos=False, reverse=False
)
return sp_tokenizer


def load_sharegpt_dataset(
dataset_path: str,
conversation_starter: str,
Expand All @@ -141,7 +143,11 @@ def load_sharegpt_dataset(

# Filter based on conversation starter
if conversation_starter != "both":
dataset = [data for data in dataset if data["conversations"][0]["from"] == conversation_starter]
dataset = [
data
for data in dataset
if data["conversations"][0]["from"] == conversation_starter
]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
Expand All @@ -151,9 +157,7 @@ def load_sharegpt_dataset(
return dataset


def load_openorca_dataset(
dataset_path: str
) -> List[tuple[str]]:
def load_openorca_dataset(dataset_path: str) -> List[tuple[str]]:
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
Expand Down Expand Up @@ -187,23 +191,31 @@ def tokenize_dataset(
prompt_len = len(prompt_token_ids[i])
output_len = len(outputs_token_ids[i])
tokenized_dataset.append(
(prompts[i], prompt_token_ids[i], outputs[i], prompt_len, output_len)
(prompts[i], prompt_token_ids[i], outputs[i], prompt_len, output_len)
)
return tokenized_dataset


def filter_dataset(
tokenized_dataset: List[tuple[Any]],
max_output_length: Optional[int] = None
tokenized_dataset: List[tuple[Any]], max_output_length: Optional[int] = None
) -> List[InputRequest]:
if max_output_length is None:
print("In InputRequest, pass in actual output_length for each sample")
else:
print(f"In InputRequest, pass in max_output_length: {max_output_length} for each sample")
print(
f"In InputRequest, pass in max_output_length: {max_output_length} for"
" each sample"
)

# Filter out too long sequences.
filtered_dataset: List[InputRequest] = []
for prompt, prompt_token_ids, output, prompt_len, output_len in tokenized_dataset:
for (
prompt,
prompt_token_ids,
output,
prompt_len,
output_len,
) in tokenized_dataset:
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
# This is because TGI causes errors when the input or output length
Expand All @@ -212,7 +224,9 @@ def filter_dataset(
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
request = InputRequest(prompt, prompt_len, output, max_output_length or output_len)
request = InputRequest(
prompt, prompt_len, output, max_output_length or output_len
)
filtered_dataset.append(request)

print(f"The dataset contains {len(tokenized_dataset)} samples.")
Expand All @@ -226,20 +240,26 @@ def sample_requests(
tokenizer: Any,
num_requests: int,
max_output_length: Optional[int] = None,
oversample_multiplier: float=1.2,
) -> List[InputRequest]:
oversample_multiplier: float = 1.2,
) -> List[InputRequest]:

# Original dataset size
n = len(dataset)

# Create necessary number of requests even if bigger than dataset size
sampled_indices = random.sample(
range(n), min(int(num_requests * oversample_multiplier), n))
range(n), min(int(num_requests * oversample_multiplier), n)
)

if num_requests > len(sampled_indices):
print(f"Number of requests {num_requests} is larger than size of dataset {n}.\n",
f"Repeating data to meet number of requests.\n")
sampled_indices = sampled_indices * int(np.ceil(num_requests / len(sampled_indices)))
print(
f"Number of requests {num_requests} is larger than size of dataset"
f" {n}.\n",
f"Repeating data to meet number of requests.\n",
)
sampled_indices = sampled_indices * int(
np.ceil(num_requests / len(sampled_indices))
)

print(f"{len(sampled_indices)=}")
# some of these will be filtered out, so sample more than we need
Expand Down Expand Up @@ -315,7 +335,9 @@ def calculate_metrics(
return metrics


async def grpc_async_request(api_url: str, request: Any) -> tuple[list[str], float, float]:
async def grpc_async_request(
api_url: str, request: Any
) -> tuple[list[str], float, float]:
"""Send grpc synchronous request since the current grpc server is sync."""
options = [("grpc.keepalive_timeout_ms", 10000)]
async with grpc.aio.insecure_channel(api_url, options=options) as channel:
Expand Down Expand Up @@ -351,7 +373,9 @@ async def send_request(
output = RequestFuncOutput()
output.input_request = input_request
output.prompt_len = input_request.prompt_len
generated_token_list, ttft, latency = await grpc_async_request(api_url, request)
generated_token_list, ttft, latency = await grpc_async_request(
api_url, request
)
output.ttft = ttft
output.latency = latency
output.generated_token_list = generated_token_list
Expand Down Expand Up @@ -453,14 +477,15 @@ def mock_requests(total_mock_requests: int):

def sample_warmup_requests(requests):
interesting_buckets = [
0,
16,
32,
64,
128,
256,
512,
1024,]
0,
16,
32,
64,
128,
256,
512,
1024,
]

for start, end in zip(interesting_buckets[:-1], interesting_buckets[1:]):
for request in requests:
Expand All @@ -481,28 +506,30 @@ def main(args: argparse.Namespace):

tokenizer = get_tokenizer(tokenizer_id)
if tokenizer == "test" or args.dataset == "test":
input_requests = mock_requests(args.total_mock_requests) # e.g. [("AB", 2, "AB", 3)]
input_requests = mock_requests(
args.total_mock_requests
) # e.g. [("AB", 2, "AB", 3)]
else:
if args.dataset == "openorca":
dataset = load_openorca_dataset(args.dataset_path)
elif args.dataset == "sharegpt":
dataset = load_sharegpt_dataset(
args.dataset_path,
args.conversation_starter,
args.dataset_path,
args.conversation_starter,
)

# A given args.max_output_length value is the max generation step,
# when the args.max_output_length is default to None, the sample's golden output length
# will be used to decide the generation step
input_requests = sample_requests(
dataset=dataset,
tokenizer=tokenizer,
num_requests=args.num_prompts,
max_output_length=args.max_output_length
dataset=dataset,
tokenizer=tokenizer,
num_requests=args.num_prompts,
max_output_length=args.max_output_length,
)

if args.warmup_first:
print('Warm up start:' )
print("Warm up start:")
warmup_requests = list(sample_warmup_requests(input_requests)) * 2
benchmark_result, request_outputs = asyncio.run(
benchmark(
Expand All @@ -516,7 +543,7 @@ def main(args: argparse.Namespace):
threads=args.threads,
)
)
print('Warm up done')
print("Warm up done")

benchmark_result, request_outputs = asyncio.run(
benchmark(
Expand Down Expand Up @@ -561,7 +588,11 @@ def main(args: argparse.Namespace):
if args.save_request_outputs:
file_path = args.request_outputs_file_path
with open(file_path, "w") as output_file:
json.dump([output.to_dict() for output in request_outputs], output_file, indent=4)
json.dump(
[output.to_dict() for output in request_outputs],
output_file,
indent=4,
)


if __name__ == "__main__":
Expand All @@ -576,11 +607,13 @@ def main(args: argparse.Namespace):
)
parser.add_argument("--port", type=str, default=9000)
parser.add_argument(
"--dataset", type=str, default="test", choices=["test", "sharegpt", "openorca"], help="The dataset name."
)
parser.add_argument(
"--dataset-path", type=str, help="Path to the dataset."
"--dataset",
type=str,
default="test",
choices=["test", "sharegpt", "openorca"],
help="The dataset name.",
)
parser.add_argument("--dataset-path", type=str, help="Path to the dataset.")
parser.add_argument(
"--model",
type=str,
Expand Down Expand Up @@ -637,7 +670,16 @@ def main(args: argparse.Namespace):
"--max-output-length",
type=int,
default=None,
help="The maximum output length for reference request.",
help=(
"The maximum output length for reference request. It would be passed"
" to `max_tokens` parameter of the JetStream's DecodeRequest proto,"
" and used in JetStream to control the output/decode length of a"
" sequence. It would not be used in the engine. We should always set"
" max_tokens <= (max_target_length - max_prefill_predict_length)."
" max_target_length is the maximum length of a sequence;"
" max_prefill_predict_length is the maximum length of the"
" input/prefill of a sequence."
),
)

parser.add_argument("--seed", type=int, default=0)
Expand Down Expand Up @@ -678,26 +720,20 @@ def main(args: argparse.Namespace):
"--request-outputs-file-path",
type=str,
default="/tmp/request-outputs.json",
help=(
"File path to store request outputs"
),
help="File path to store request outputs",
)
parser.add_argument(
"--warmup-first",
type=bool,
default=False,
help=(
"Whether to send warmup req first"
),
help="Whether to send warmup req first",
)
parser.add_argument(
"--conversation-starter",
type=str,
default="human",
choices=["human", "gpt", "both"],
help=(
"What entity should be the one starting the conversations."
),
help="What entity should be the one starting the conversations.",
)

args = parser.parse_args()
Expand Down
7 changes: 6 additions & 1 deletion jetstream/core/proto/jetstream.proto
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.


syntax = "proto3";

package jetstream_proto;
Expand All @@ -29,6 +28,12 @@ message DecodeRequest {
// New text from a user or tool.
string additional_text = 2;
int32 priority = 3;
// The maximum output length of a sequence. It's used in JetStream to control
// the output/decode length of a sequence. It would not be used in the engine.
// We should always set max_tokens <= (max_target_length -
// max_prefill_predict_length). max_target_length is the maximum length of a
// sequence; max_prefill_predict_length is the maximum length of the
// input/prefill of a sequence.
int32 max_tokens = 4;
}
message DecodeResponse {
Expand Down
Loading

0 comments on commit 90b2a9d

Please sign in to comment.