Skip to content

Commit

Permalink
Merge branch 'main' into nemotron
Browse files Browse the repository at this point in the history
  • Loading branch information
hrishi121 committed Dec 17, 2024
2 parents a81596d + 88074ea commit 25d1cf4
Show file tree
Hide file tree
Showing 10 changed files with 575 additions and 32 deletions.
36 changes: 21 additions & 15 deletions python/mlc_llm/bench/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,18 @@ def run_pipeline(
args.output_len_std,
)
request_records = pipeline(request_records)
assert len(request_records) == args.num_requests * args.num_gpus
sorted_requests: List[RequestRecord] = [None] * args.num_requests * args.num_gpus
num_total_requests = (
args.num_requests if not args.per_gpu_workload else args.num_requests * args.num_gpus
)
assert len(request_records) == num_total_requests
sorted_requests: List[RequestRecord] = [None] * num_total_requests
for request_record in request_records:
assert request_record.request_id is not None
assert sorted_requests[request_record.request_id] is None
sorted_requests[request_record.request_id] = request_record

request_records = MetricAnalyzer(tokenizer)(request_records)
report = generate_metrics_summary(
request_records, args.num_requests * args.num_gpus, args.num_gpus
)
report = generate_metrics_summary(request_records, num_total_requests, args.num_gpus)
return report, sorted_requests


Expand Down Expand Up @@ -221,6 +222,15 @@ def _main():
help="The number of requests for warmup. "
"It is optional when fixing the number of concurrent requests, and is required otherwise.",
)
parser.add_argument(
"--per-gpu-workload",
default=False,
action="store_true",
help='When set to True, the specified "num_concurrent_requests"/"request_rate" '
"denote the workload **per GPU**, which means that the real values of "
'"num_concurrent_requests"/"request_rate" used in benchmark'
'will be multiplied by "num_gpus".',
)
parser.add_argument(
"--num-concurrent-requests",
type=_parse_num_concurrent_requests,
Expand Down Expand Up @@ -354,13 +364,6 @@ def _main():
type=_parse_mlc_engine_config,
help="The engine config used when launch MLC server.",
)
parser.add_argument(
"--output",
"-o",
type=str,
default="mlc_benchmark.csv",
help="The path of the output file where to dump the benchmark results.",
)
parser.add_argument(
"--cuda-profile",
default=False,
Expand All @@ -378,13 +381,16 @@ def _main():
"--multi-round",
default=False,
action="store_true",
help="Whether to chat like mulit round conversion with history log each request. "
help="Whether to chat like multi round conversion with history log each request. "
"Only enabled when benchmarked with fixed concurrent request mode."
"The --num-concurrent-requests should be provided when enabling this option.",
)

parser.add_argument(
"--testset-name", type=str, help="The name of the testset. Only used for Loogle dataset"
"--output",
"-o",
type=str,
default="mlc_benchmark.csv",
help="The path of the output file where to dump the benchmark results.",
)

main(parser.parse_args())
5 changes: 2 additions & 3 deletions python/mlc_llm/bench/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,11 @@ class LoogleDataset(Dataset): # pylint: disable=too-few-public-methods
# pylint: enable=line-too-long
require_fake_warmup: bool = True

def __init__(self, tokenizer: AutoTokenizer, testset_name) -> None:
def __init__(self, tokenizer: AutoTokenizer, testset_name: str) -> None:
raw_dataset = load_dataset("bigainlco/LooGLE", testset_name, split="test")
self.tokenizer = tokenizer
self.dataset = []
self.prompt_format = self.task2prompt[testset_name]
# self.max_gen = self.task2maxlen[testset_name]
prompts = []
generate_lens = []
questions = []
Expand Down Expand Up @@ -806,7 +805,7 @@ def create_dataset(args: argparse.Namespace, tokenizer: AutoTokenizer) -> "Datas
assert (
args.apply_chat_template is False
), "Loogle dataset does not support applying chat template"
return LoogleDataset(tokenizer, args.testset_name)
return LoogleDataset(tokenizer, testset_name=args.dataset_path)
if args.dataset == "react":
assert (
args.apply_chat_template is False
Expand Down
13 changes: 9 additions & 4 deletions python/mlc_llm/bench/request_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,22 +622,27 @@ def create_pipelines(
"Please specify the number of warmup requests via "
'"--num-warmup-requests" when fixing request rate.'
)
num_total_requests = int(
args.num_requests if not args.per_gpu_workload else args.num_requests * args.num_gpus
)
if dataset.require_fake_warmup:
num_samples = int(args.num_requests * args.num_gpus)
num_samples = num_total_requests
else:
num_samples = int(args.num_requests * args.num_gpus) + args.num_warmup_requests
num_samples = num_total_requests + args.num_warmup_requests
return [
SequentialProcessor(
LogMessage(f"Fixing request rate: {request_rate}"),
SampleRequests(num_samples),
AttachModelName(args.tokenizer),
AttachRequestRateTimestamp(request_rate * args.num_gpus),
AttachRequestRateTimestamp(
request_rate if not args.per_gpu_workload else request_rate * args.num_gpus
),
AttachStreamFlag(args.stream),
AttachSamplingOptions(args.temperature, args.top_p, args.ignore_eos),
AttachExecutionFeature({"request_rate": float(request_rate)}),
WarmupAndRun(
num_warmup_requests=args.num_warmup_requests,
num_benchmark_requests=int(args.num_requests * args.num_gpus),
num_benchmark_requests=num_total_requests,
pipeline=FixTimestampExecutor(
f_create_api_endpoint,
args.num_process_workers,
Expand Down
Empty file.
72 changes: 72 additions & 0 deletions python/mlc_llm/model/gpt_j/gpt_j_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
This file specifies how MLC's GPTJ parameter maps from other formats, for example HuggingFace
PyTorch, HuggingFace safetensors.
"""

import functools

import numpy as np

from mlc_llm.loader import ExternMapping
from mlc_llm.quantization import Quantization

from .gpt_j_model import GPTJConfig, GPTJForCausalLM


def huggingface(model_config: GPTJConfig, quantization: Quantization) -> ExternMapping:
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
the names of HuggingFace PyTorch parameters.
Parameters
----------
model_config : GPTJConfig
The configuration of the GPTJ model.
quantization : Quantization
The quantization configuration.
Returns
-------
param_map : ExternMapping
The parameter mapping from MLC to HuggingFace PyTorch.
"""
model = GPTJForCausalLM(model_config)
if quantization is not None:
model.to(quantization.model_dtype)
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
spec=model.get_default_spec(),
allow_extern=True,
)
named_parameters = dict(_named_params)

mapping = ExternMapping()

for i in range(model_config.n_layer):
# Add gates in MLP
attn = f"transformer.h.{i}.attn"
mlc_name = f"{attn}.c_attn.weight"
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
f"{attn}.q_proj.weight",
f"{attn}.k_proj.weight",
f"{attn}.v_proj.weight",
],
functools.partial(
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
dtype=mlc_param.dtype,
),
)

for mlc_name, mlc_param in named_parameters.items():
if mlc_name not in mapping.param_map:
mapping.add_mapping(
mlc_name,
[mlc_name],
functools.partial(
lambda x, dtype: x.astype(dtype),
dtype=mlc_param.dtype,
),
)
return mapping
Loading

0 comments on commit 25d1cf4

Please sign in to comment.