Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add city doc benchmark mode #129

Merged
merged 6 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 39 additions & 11 deletions benchmark/json_fast_forward/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,72 @@
### Dependencies

```
llama_cpp_python 0.2.32
llama_cpp_python 0.2.38
guidance 0.1.10
vllm 0.2.7
outlines 0.0.24
outlines 0.0.25
```

### Build dataset

When benchmarking long document information retrieval, run the following command to build the dataset:

```bash
pip install wikipedia
python3 build_dataset.py
```

### Benchmark sglang

Run Llama-7B

```
```bash
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```

Benchmark
Benchmark Character Generation

```bash
python3 bench_sglang.py --mode character
```
python3 bench_sglang.py

Benchmark City Information Retrieval

```bash
python3 bench_sglang.py --mode city
```


### Benchmark vllm

Run Llama-7B

```
```bash
python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```

Benchmark
Benchmark Character Generation

```bash
python3 bench_other.py --mode character --backend vllm
```
python3 bench_other.py --backend vllm

Benchmark City Information Retrieval

```bash
python3 bench_other.py --mode city --backend vllm
```

### Benchmark guidance (seems not supported)
### Benchmark guidance

Run Llama-7B and benchmark
Run Llama-7B and benchmark character generation

```bash
python3 bench_other.py --mode character --backend guidance --parallel 1
```
python3 bench_other.py --backend guidance --parallel 1

Run Llama-7B and benchmark city information retrieval

```bash
python3 bench_other.py --mode city --backend guidance --parallel 1
```
119 changes: 114 additions & 5 deletions benchmark/json_fast_forward/bench_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
add_common_other_args_and_parse,
call_generate_outlines,
)
from sglang.utils import dump_state_text
from sglang.utils import dump_state_text, read_jsonl
from tqdm import tqdm

# there are some FSM bugs with json regex converted from pydantic model
Expand All @@ -32,13 +32,32 @@
+ r"""\}"""
)

city_regex = (
r"""\{\n"""
+ r""" "name": "[\w\d\s]{1,16}",\n"""
+ r""" "country": "[\w\d\s]{1,16}",\n"""
+ r""" "latitude": [-+]?[0-9]*\.?[0-9]{0,2},\n"""
+ r""" "population": [-+]?[0-9]{1,9},\n"""
+ r""" "top 3 landmarks": \["[\w\d\s]{1,16}", "[\w\d\s]{1,16}", "[\w\d\s]{1,16}"\]\n"""
+ r"""\}"""
)

# fmt: off
def character_gen(name, generate):
s = name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
s += generate(s, max_tokens=256, regex=character_regex)
return s
# fmt: on

# fmt: off
def city_gen(document, generate):
s = "Please extract the information of a city from the following wikipedia page.\n"
s += "Page begin.\n" + document + "Page end.\n"
s += "Here is the name, country, and symbol of the city in JSON format.\n"
s += generate(s, max_tokens=256, regex=city_regex)
return s
# fmt: on


@guidance
def character_maker(lm, name):
Expand All @@ -65,7 +84,31 @@ def character_maker(lm, name):
return lm


def main(args):
@guidance
def city_maker(lm, document):
regex_str_no_quote = r"[\w\d\s]+"
regex_float = r"[0-9]+\.[0-9]+"
lm += f"""\
Please extract the information of a city from the following wikipedia page.
Page begin.
{document}
Page end.
Here is the name, country, and symbol of the city in JSON format.
{{
"name": "{guidance.gen("name", max_tokens=16, regex=regex_str_no_quote)}",
"country": "{guidance.gen("country", max_tokens=16, regex=regex_str_no_quote)}",
"latitude": {guidance.gen("latitude", max_tokens=10, regex=regex_float)},
"population": {guidance.gen("population", max_tokens=10, regex=r"[0-9]+")},
"top 3 landmarks": [
"{guidance.gen("landmark1", max_tokens=16, regex=regex_str_no_quote)}", "{guidance.gen("landmark2", max_tokens=16, regex=regex_str_no_quote)}", "{guidance.gen("landmark3", max_tokens=16, regex=regex_str_no_quote)}"
]
}}
"""

return lm


def bench_character(args):
arguments = []
with open(args.data_path, "r") as f:
for line in f:
Expand All @@ -85,7 +128,7 @@ def func(i):
get_one_answer = func
elif args.backend == "guidance":
model = guidance.models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf",
args.llama_cpp_model_path,
n_gpu_layers=-1,
n_ctx=4096,
)
Expand All @@ -110,11 +153,69 @@ def func(i):

latency = time.time() - tic

return states, latency


def bench_city_doc(args):
arguments = []
for line in read_jsonl(args.data_path):
arguments.append({"document": line["document"]})
arguments = arguments[: args.num_jsons]

states = [None] * len(arguments)

# Select backend
if args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_outlines, url=url, temperature=0)

def func(i):
states[i] = city_gen(**arguments[i], generate=generate)

get_one_answer = func
elif args.backend == "guidance":
model = guidance.models.LlamaCpp(
args.llama_cpp_model_path,
n_gpu_layers=-1,
n_ctx=4096,
)

def func(i):
lm = model + city_maker(**arguments[i])
states[i] = lm

get_one_answer = func
else:
raise ValueError(f"Invalid backend: {args.backend}")

tic = time.time()
if args.parallel == 1:
for i in tqdm(range(len(arguments))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
rets = executor.map(get_one_answer, list(range(len(arguments))))
for _ in rets:
pass

latency = time.time() - tic

return states, latency


def main(args):
if args.mode == "character":
args.data_path = "dataset.txt"
states, latency = bench_character(args)
elif args.mode == "city":
args.data_path = "questions.jsonl"
states, latency = bench_city_doc(args)

# Compute accuracy
print(f"Latency: {latency:.3f}")

# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
dump_state_text(f"tmp_output_{args.backend}_{args.mode}.txt", states)

with open(args.result_file, "a") as fout:
value = {
Expand All @@ -129,7 +230,15 @@ def func(i):

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="dataset.txt")
parser.add_argument("--data-path", type=str)
parser.add_argument("--num-jsons", type=int, default=50)
parser.add_argument(
"--mode", type=str, default="character", choices=["character", "city"]
)
parser.add_argument(
"--llama-cpp-model-path",
type=str,
default="/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf",
)
args = add_common_other_args_and_parse(parser)
main(args)
61 changes: 56 additions & 5 deletions benchmark/json_fast_forward/bench_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text
from sglang.utils import dump_state_text, read_jsonl

# there are some FSM bugs with json regex converted from pydantic model
# here use a string regex instead
Expand All @@ -29,13 +29,55 @@
+ r"""\}"""
)

city_regex = (
r"""\{\n"""
+ r""" "name": "[\w\d\s]{1,16}",\n"""
+ r""" "country": "[\w\d\s]{1,16}",\n"""
+ r""" "latitude": [-+]?[0-9]*\.?[0-9]{0,2},\n"""
+ r""" "population": [-+]?[0-9]{1,9},\n"""
+ r""" "top 3 landmarks": \["[\w\d\s]{1,16}", "[\w\d\s]{1,16}", "[\w\d\s]{1,16}"\]\n"""
+ r"""\}"""
)

# fmt: off
@sgl.function
def character_gen(s, name):
s += name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
# fmt: on

# fmt: off
@sgl.function
def city_gen(s, document):
s += "Please extract the information of a city from the following wikipedia page.\n"
s += "Page begin.\n" + document + "Page end.\n"
s += "Here is the name, country, and symbol of the city in JSON format.\n"
s += sgl.gen("json_output",max_tokens=256, regex=city_regex)
# fmt: on


def bench_city_doc(args):
arguments = []
for line in read_jsonl(args.data_path):
arguments.append({"document": line["document"]})
arguments = arguments[: args.num_jsons]

# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)

# Run requests
tic = time.time()
states = city_gen.run_batch(
arguments,
temperature=0,
num_threads=args.parallel,
progress_bar=(args.parallel == 1),
)
latency = time.time() - tic

return states, latency


def bench_character(args):
arguments = []
Expand All @@ -62,14 +104,19 @@ def bench_character(args):


def main(args):
states, latency = bench_character(args)
if args.mode == "character":
args.data_path = "dataset.txt"
states, latency = bench_character(args)
elif args.mode == "city":
args.data_path = "questions.jsonl"
states, latency = bench_city_doc(args)

# Compute accuracy
print(f"Latency: {latency:.3f}")

# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(f"{args.backend}.json", "w") as fout:
dump_state_text(f"tmp_output_{args.backend}_{args.mode}.txt", states)
with open(f"{args.backend}_{args.mode}.json", "w") as fout:
for state in states:
fout.write(state["json_output"] + "\n")

Expand All @@ -79,14 +126,18 @@ def main(args):
"backend": args.backend,
"latency": round(latency, 3),
"num_jsons": args.num_jsons,
"mode": args.mode,
"parallel": args.parallel,
}
fout.write(json.dumps(value) + "\n")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="dataset.txt")
parser.add_argument("--data-path", type=str)
parser.add_argument("--num-jsons", type=int, default=50)
parser.add_argument(
"--mode", type=str, default="character", choices=["character", "city"]
)
args = add_common_sglang_args_and_parse(parser)
main(args)
Loading