From b0e44cd8b0617792994cc9c9cb1c44fe939a295d Mon Sep 17 00:00:00 2001 From: circargs Date: Thu, 13 Apr 2023 21:18:46 -0400 Subject: [PATCH 1/8] llama cpp server --- src/lmql/model/serve_llama_cpp.py | 421 ++++++++++++++++++++++++++++++ 1 file changed, 421 insertions(+) create mode 100644 src/lmql/model/serve_llama_cpp.py diff --git a/src/lmql/model/serve_llama_cpp.py b/src/lmql/model/serve_llama_cpp.py new file mode 100644 index 00000000..7c58d6d3 --- /dev/null +++ b/src/lmql/model/serve_llama_cpp.py @@ -0,0 +1,421 @@ +""" +Serves a transformers model as LMQL inference API. +""" + +from dataclasses import dataclass, field +from collections import defaultdict + +import json +from http.server import BaseHTTPRequestHandler, HTTPServer +from multiprocessing import Queue as MPQueue +from queue import Empty +from queue import Queue +import multiprocessing +from typing import Dict +import atexit +import argparse +import time +from llama_cpp.llama import Llama, llama_cpp +import inspect + +@dataclass +class InferenceServerState: + model_identifier : str + tokenizer_descriptor : str + dtype: str + + queue: Queue + tokenize_queue: Queue + all_results_queue : Queue + + sample_count: int = 0 + client_results_queues: Dict[str,Queue] = field(default_factory=dict) + + exit: bool = False + +class TokenizerProcessor: + def __init__(self, state: InferenceServerState, processor: "ModelProcessor"): + self.model_identifier = state.tokenizer_descriptor + self.model = processor + self.queue = state.tokenize_queue + self.state = state + + def shutdown(self): + self.state.exit = True + + def tokenize(self, tokenizer, sample_id, client_id, item): + text = item["text"] + + if text == "": + input_ids = [tokenizer.token_eos()] + elif text == "": + input_ids = [tokenizer.token_bos()] + else: + input_ids = tokenizer.tokenize(b" " + text.encode("utf-8")) + + self.state.all_results_queue.put({ + "sample_id": sample_id, + "client_id": client_id, + "input_ids": input_ids + }) + + def detokenize(self, tokenizer, sample_id, client_id, item): + input_ids = item["input_ids"] + + text = tokenizer.detokenize(input_ids).decode('utf-8') + self.state.all_results_queue.put({ + "sample_id": sample_id, + "client_id": client_id, + "text": text + }) + + def run(self): + # load tokenizer + tokenizer = self.model + print("Tokenizer {} ready!".format(self.model_identifier)) + + while not self.state.exit: + item = self.queue.get() + if item is None: + time.sleep(0.1) + continue + + sample_id = item["sample_id"] + client_id = item["client_id"] + action = item["action"] + + if action == "tokenize": + self.tokenize(tokenizer, sample_id, client_id, item) + elif action == "detokenize": + self.detokenize(tokenizer, sample_id, client_id, item) + else: + print("error: unknown TokenizerProcessor action {}".format(action)) + + print("Tokenizer shut down.") + +class ModelProcessor: + def __init__(self, state: InferenceServerState, llama_kwargs: dict, cache: str = None): + self.model_identifier = state.model_identifier + self.llama_kwargs = llama_kwargs + self.queue = state.queue + self.state = state + + self.cache = None + if cache is not None: + from rocksdict import Rdict + self.cache = Rdict(cache) + + self.request_count = 0 + self.requests_cached = 0 + self.last_report = time.time() + self.last_request_count = 0 + + + def shutdown(self): + self.state.exit = True + + def __del__(self): + if self.cache is not None: + self.cache.close() + + def print_stats(self): + # fancy unicode based terminal spinner + terminal_spinner_chars = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] + throughput = (self.request_count - self.last_request_count) / (time.time() - self.last_report) + self.last_report = time.time() + self.last_request_count = self.request_count + + # format throughput to two decimal places + print("{} {:.2f} calls/s, Requests Served: {}, Queue: {}".format( + terminal_spinner_chars[self.request_count % len(terminal_spinner_chars)], + throughput, + self.request_count, + self.state.queue.qsize()), end="\r") + + def run(self): + + + # load model + print("Loading {} (CPU)".format(self.model_identifier)) + self.model = Llama(**{**self.llama_kwargs, 'logits_all': True}) + + print("Ready!".format(self.model_identifier)) + + while not self.state.exit: + self.print_stats() + # wait for self.queue to have an item + try: + item = self.queue.get(timeout=1.0) + except Empty: + continue + except KeyboardInterrupt: + break + + if item is None: + time.sleep(0.1) + continue + + self.request_count += 1 + + + + sample_id = item["sample_id"] + client_id = item["client_id"] + input_ids = item["input_ids"] + + if self.cache is not None: + key = str(input_ids) + if key in self.cache: + self.requests_cached += 1 + self.state.all_results_queue.put({ + "client_id": client_id, + "sample_id": sample_id, + "next_token_logits": self.cache[key] + }) + continue + + res = self.model.eval(input_ids) + + next_token_logits = self.model.all_logits[-1] + + if self.cache is not None: + key = str(input_ids.tolist()) + self.cache[key] = next_token_logits.tolist() + + self.state.all_results_queue.put({ + "client_id": client_id, + "sample_id": sample_id, + "next_token_logits": next_token_logits.detach().tolist() + }) + + print("Processor shut down") + + +class LMQLInferenceAPIHandler(BaseHTTPRequestHandler): + def __init__(self, *args, **kwargs): + self._client_id = None + super().__init__(*args, **kwargs) + + # disable logging + def log_message(self, format, *args): + return + + @property + def state(self) -> InferenceServerState: + return self.server.state + + @property + def client_id(self) -> str: + if self._client_id is None: + self.error_bad_request(msg="client_id not set (please provide the client_id in POST payload or in query.") + raise Exception("client_id not set") + return self._client_id + + def error_bad_request(self, msg="Bad request."): + self.send_response(400) + self.send_header('Content-type', 'text/plain') + self.end_headers() + self.wfile.write(msg.encode("utf-8")) + + def process_some_all_results(self, max=10): + all_results_queue = self.state.all_results_queue + i = 0 + + while not all_results_queue.empty() and i < max: + result = all_results_queue.get() + + result_client_id = result["client_id"] + + if result_client_id not in self.state.client_results_queues: + self.state.client_results_queues[result_client_id] = Queue() + self.state.client_results_queues[result_client_id].put(result) + + i += 1 + + def do_GET_results(self): + request_client_id = self.path.split("/")[2] + self._client_id = f"{self.client_address[0]}-{request_client_id}" + + self.send_response(200) + self.send_header('Content-type', 'application/json') + self.end_headers() + + # process some results from the model and group them by client_id + self.process_some_all_results() + + if self.client_id not in self.state.client_results_queues.keys(): + self.wfile.write(b'[]') + return + + # return all results for self.client_id currently available + self.wfile.write("[".encode()) + while not self.state.client_results_queues[self.client_id].empty(): + result = self.state.client_results_queues[self.client_id].get() + self.wfile.write(json.dumps(result).encode()) + + # omit last colon + if self.state.client_results_queues[self.client_id].empty(): break + else: self.wfile.write(b",") + self.wfile.write(b"]") + + def do_queue_forward(self, payload, sample_id): + try: + input_ids = payload['input_ids'] + attention_mask = payload.get('attention_mask', None) + model_identifier = payload['model_identifier'] + + if model_identifier != self.state.model_identifier: + self.error_bad_request("The inference API serves model {} not {}.".format(self.state.model_identifier, model_identifier)) + return + + self._client_id = f"{self.client_address[0]}-{payload['client_id']}" + + assert type(input_ids) == list # and all([type(i) == int for i in input_ids]) + + self.state.queue.put({ + 'client_id': self.client_id, + 'sample_id': sample_id, + 'input_ids': input_ids, + 'attention_mask': attention_mask + }) + + self.send_response(200) + self.send_header('Content-type', 'application/json') + self.end_headers() + self.wfile.write(json.dumps({'sample_id': sample_id}).encode()) + except Exception as e: + self.error_bad_request() + + def do_queue_tokenize(self, payload, sample_id): + try: + text = payload['text'] + self._client_id = f"{self.client_address[0]}-{payload['client_id']}" + + assert type(text) == str + + self.state.tokenize_queue.put({ + 'client_id': self.client_id, + 'sample_id': sample_id, + 'action': 'tokenize', + 'text': text + }) + + self.send_response(200) + self.send_header('Content-type', 'application/json') + self.end_headers() + self.wfile.write(json.dumps({'sample_id': sample_id}).encode()) + except Exception as e: + self.error_bad_request() + + def do_queue_detokenize(self, payload, sample_id): + try: + input_ids = payload['input_ids'] + self._client_id = f"{self.client_address[0]}-{payload['client_id']}" + + assert type(input_ids) == list and all([type(i) == int for i in input_ids]) + + self.state.tokenize_queue.put({ + 'client_id': self.client_id, + 'sample_id': sample_id, + 'action': 'detokenize', + 'input_ids': input_ids + }) + + self.send_response(200) + self.send_header('Content-type', 'application/json') + self.end_headers() + self.wfile.write(json.dumps({'sample_id': sample_id}).encode()) + except Exception as e: + self.error_bad_request() + + def do_POST_queue(self): + # handle POST to /queue + payload = self.rfile.read(int(self.headers['Content-Length'])) + payload = json.loads(payload) + # get client address and port + sample_id = payload["sample_id"] + + action = payload['action'] + + if action == "forward": + self.do_queue_forward(payload, sample_id) + elif action == "tokenize": + self.do_queue_tokenize(payload, sample_id) + elif action == "detokenize": + self.do_queue_detokenize(payload, sample_id) + else: + self.error_bad_request("Unknown action: {}".format(action)) + + def do_POST(self): + if self.path == "/queue": + self.do_POST_queue() + return + else: + self.send_error(404) + return + + def do_GET(self): + if self.path.startswith("/results"): + self.do_GET_results() + return + else: + self.send_error(404) + return + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("model", type=str) + parser.add_argument("--port", type=int, default=8080) + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--cache", type=str, default=None) + llama_kwargs = {} + sig = inspect.signature(Llama.__init__) + for name, param in sig.parameters.items(): + if name == 'self': + continue + llama_kwargs[name] = None + if param.default == inspect.Parameter.empty: + parser.add_argument(name) + else: + parser.add_argument(f'--{name}', default=param.default) + + args = parser.parse_args() + + llama_kwargs = {kwarg: getattr(args, kwarg) for kwarg in llama_kwargs.keys()} + + manager = multiprocessing.Manager() + + # prepare configuration + model_descriptor = args.model + state = InferenceServerState(model_descriptor, + model_descriptor, + "", + queue=manager.Queue(), + tokenize_queue=manager.Queue(), + all_results_queue=manager.Queue()) + + # run model in separate process + processor = ModelProcessor(state, cache=args.cache, llama_kwargs=llama_kwargs) + processor.run() + + # run tokenizers in separate process + tokenizer_processor = TokenizerProcessor(state, processor) + tokenizer_processor.run() + + # run inference API server in this process + server_address = (args.host, args.port) + httpd = HTTPServer(server_address, LMQLInferenceAPIHandler) + httpd.state = state + + try: + print("Serving LMQL inference API on {}:{}".format(args.host, args.port)) + httpd.serve_forever() + except KeyboardInterrupt: + # terminate server + httpd.shutdown() + httpd.server_close() + print("Server stopped") + + # terminate processors + processor.shutdown() + tokenizer_processor.shutdown() From 1ce96022acabc2a4cf03c54b516620d728fdbe3b Mon Sep 17 00:00:00 2001 From: circargs Date: Thu, 13 Apr 2023 23:11:38 -0400 Subject: [PATCH 2/8] hacking and whacking --- src/lmql/model/serve_llama_cpp.py | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/src/lmql/model/serve_llama_cpp.py b/src/lmql/model/serve_llama_cpp.py index 7c58d6d3..ca110464 100644 --- a/src/lmql/model/serve_llama_cpp.py +++ b/src/lmql/model/serve_llama_cpp.py @@ -93,6 +93,13 @@ def run(self): print("Tokenizer shut down.") + def run_in_parallel(self): + atexit.register(self.shutdown) + + p = multiprocessing.Process(target=self.run) + p.start() + return p + class ModelProcessor: def __init__(self, state: InferenceServerState, llama_kwargs: dict, cache: str = None): self.model_identifier = state.model_identifier @@ -174,22 +181,28 @@ def run(self): }) continue - res = self.model.eval(input_ids) + res = self.model.eval(input_ids[0]) next_token_logits = self.model.all_logits[-1] if self.cache is not None: key = str(input_ids.tolist()) - self.cache[key] = next_token_logits.tolist() + self.cache[key] = next_token_logits self.state.all_results_queue.put({ "client_id": client_id, "sample_id": sample_id, - "next_token_logits": next_token_logits.detach().tolist() + "next_token_logits": [next_token_logits] }) print("Processor shut down") + + def run_in_parallel(self): + atexit.register(self.shutdown) + p = multiprocessing.Process(target=self.run) + p.start() + return p class LMQLInferenceAPIHandler(BaseHTTPRequestHandler): def __init__(self, *args, **kwargs): @@ -374,10 +387,13 @@ def do_GET(self): if name == 'self': continue llama_kwargs[name] = None + _type = param.annotation + if hasattr(_type, '__args__'): + _type=_type.__args__[0] if param.default == inspect.Parameter.empty: - parser.add_argument(name) + parser.add_argument(name, type=_type) else: - parser.add_argument(f'--{name}', default=param.default) + parser.add_argument(f'--{name}', default=param.default, type=_type) args = parser.parse_args() @@ -396,11 +412,11 @@ def do_GET(self): # run model in separate process processor = ModelProcessor(state, cache=args.cache, llama_kwargs=llama_kwargs) - processor.run() + processor.run_in_parallel() # run tokenizers in separate process tokenizer_processor = TokenizerProcessor(state, processor) - tokenizer_processor.run() + tokenizer_processor.run_in_parallel() # run inference API server in this process server_address = (args.host, args.port) From aabd60854dd8a9fcc9f31b1f1d51de08a2edd376 Mon Sep 17 00:00:00 2001 From: circargs Date: Thu, 13 Apr 2023 23:35:14 -0400 Subject: [PATCH 3/8] reset model on each eval --- src/lmql/model/serve_llama_cpp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lmql/model/serve_llama_cpp.py b/src/lmql/model/serve_llama_cpp.py index ca110464..2cdae044 100644 --- a/src/lmql/model/serve_llama_cpp.py +++ b/src/lmql/model/serve_llama_cpp.py @@ -180,7 +180,7 @@ def run(self): "next_token_logits": self.cache[key] }) continue - + self.model.reset() res = self.model.eval(input_ids[0]) next_token_logits = self.model.all_logits[-1] From 33b3e05bac6d7366727c1935be34d7b2dd1565ee Mon Sep 17 00:00:00 2001 From: circargs Date: Fri, 14 Apr 2023 13:21:28 -0400 Subject: [PATCH 4/8] separate servers for llamacpp and hf --- src/lmql/model/serve.py | 437 ++---------------------------- src/lmql/model/serve_hf.py | 173 ++++++++++++ src/lmql/model/serve_llama_cpp.py | 320 ++-------------------- src/lmql/model/serve_types.py | 328 ++++++++++++++++++++++ 4 files changed, 541 insertions(+), 717 deletions(-) create mode 100644 src/lmql/model/serve_hf.py create mode 100644 src/lmql/model/serve_types.py diff --git a/src/lmql/model/serve.py b/src/lmql/model/serve.py index fba4b968..4d486ffa 100644 --- a/src/lmql/model/serve.py +++ b/src/lmql/model/serve.py @@ -24,427 +24,25 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline import torch -@dataclass -class InferenceServerState: - model_identifier : str - tokenizer_descriptor : str - dtype: str - - queue: Queue - tokenize_queue: Queue - all_results_queue : Queue - - sample_count: int = 0 - client_results_queues: Dict[str,Queue] = field(default_factory=dict) - - exit: bool = False - -class TokenizerProcessor: - def __init__(self, state: InferenceServerState): - self.model_identifier = state.tokenizer_descriptor - self.queue = state.tokenize_queue - self.state = state - - def shutdown(self): - self.state.exit = True - - def tokenize(self, tokenizer, sample_id, client_id, item): - text = item["text"] - - if text == "": - input_ids = [tokenizer.eos_token_id] - elif text == "": - input_ids = [tokenizer.bos_token_id] - else: - input_ids = tokenizer(text)["input_ids"] - - self.state.all_results_queue.put({ - "sample_id": sample_id, - "client_id": client_id, - "input_ids": input_ids - }) - - def detokenize(self, tokenizer, sample_id, client_id, item): - input_ids = item["input_ids"] - - text = tokenizer.decode(input_ids) - self.state.all_results_queue.put({ - "sample_id": sample_id, - "client_id": client_id, - "text": text - }) - - def run(self, index): - # load tokenizer - tokenizer = AutoTokenizer.from_pretrained(self.model_identifier) - print("Tokenizer #{} {} ready!".format(index, self.model_identifier)) - - while not self.state.exit: - item = self.queue.get() - if item is None: - time.sleep(0.1) - continue - - sample_id = item["sample_id"] - client_id = item["client_id"] - action = item["action"] - - if action == "tokenize": - self.tokenize(tokenizer, sample_id, client_id, item) - elif action == "detokenize": - self.detokenize(tokenizer, sample_id, client_id, item) - else: - print("error: unknown TokenizerProcessor action {}".format(action)) - - print("Tokenizer #{} shut down.".format(index)) - - def run_in_parallel(self, n=2): - atexit.register(self.shutdown) - - workers = [] - - for i in range(n): - p = multiprocessing.Process(target=self.run, args=(i,)) - p.start() - workers.append(p) - - return workers - -class ModelProcessor: - def __init__(self, state: InferenceServerState, cuda: bool = False, cache: str = None): - self.model_identifier = state.model_identifier - self.queue = state.queue - self.state = state - self.cuda = cuda - - self.cache = None - if cache is not None: - from rocksdict import Rdict - self.cache = Rdict(cache) - - self.request_count = 0 - self.requests_cached = 0 - self.last_report = time.time() - self.last_request_count = 0 - - try: - self.nvidia_logging = subprocess.Popen(["nvidia-smi"], stdout=subprocess.PIPE).wait() == 0 - except: - self.nvidia_logging = False - - def shutdown(self): - self.state.exit = True - - def __del__(self): - if self.cache is not None: - self.cache.close() - - def print_stats(self): - if self.nvidia_logging: - visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) - cmds = ["nvidia-smi"] - if visible_devices is not None: - cmds.append("-i={}".format(visible_devices)) - cmds += ["--query-gpu=name,memory.used,memory.total,utilization.gpu", "--format=csv,noheader"] - output = [l.split(", ") for l in subprocess.check_output(cmds).decode("utf-8").split("\n") if l.strip() != ""] - gpu_usage = ["GPU {} {}, util {}".format(i, row[1] + "/" + row[2], row[3]) for i, row in enumerate(output)] - else: - gpu_usage = ["GPU monitoring not available on non-CUDA systems"] - - print(" " * 100, end="\r") - # fancy unicode based terminal spinner - terminal_spinner_chars = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] - throughput = (self.request_count - self.last_request_count) / (time.time() - self.last_report) - self.last_report = time.time() - self.last_request_count = self.request_count - - # format throughput to two decimal places - print("{} {:.2f} calls/s, Requests Served: {}, Queue: {} [{}]".format( - terminal_spinner_chars[self.request_count % len(terminal_spinner_chars)], - throughput, - self.request_count, - self.state.queue.qsize(), - ", ".join(gpu_usage)), end="\r") - - def run(self): - dtype = self.state.dtype - if dtype == "float16": - dtype = torch.float16 - else: - dtype = None - - # load model - if not self.cuda: - print("Loading {} (CPU)".format(self.model_identifier)) - self.model = AutoModelForCausalLM.from_pretrained(self.model_identifier, torch_dtype=dtype, resume_download=True) - else: - print("Loading {} (Multi-GPU)".format(self.model_identifier)) - self.model = AutoModelForCausalLM.from_pretrained(self.model_identifier, torch_dtype=dtype, resume_download=True, device_map="auto") - self.model.eval() - - print("Ready!".format(self.model_identifier)) - - while not self.state.exit: - self.print_stats() - # wait for self.queue to have an item - try: - item = self.queue.get(timeout=1.0) - except Empty: - continue - except KeyboardInterrupt: - break - - if item is None: - time.sleep(0.1) - continue - - self.request_count += 1 - - device = "cuda" if self.cuda else "cpu" - - sample_id = item["sample_id"] - client_id = item["client_id"] - input_ids = torch.tensor(item["input_ids"], dtype=torch.long).to(device) - attention_mask = item.get("attention_mask", None) - - if attention_mask is None: - attention_mask = torch.ones_like(input_ids).to(device) - else: - attention_mask = torch.tensor(attention_mask, dtype=torch.long).to(device) - - if self.cache is not None: - key = "IDs:" + str(input_ids.tolist()) + " MASK:" + str(attention_mask.tolist()) - if key in self.cache: - self.requests_cached += 1 - self.state.all_results_queue.put({ - "client_id": client_id, - "sample_id": sample_id, - "next_token_logits": self.cache[key] - }) - continue - - res = self.model.forward(input_ids=input_ids, attention_mask=attention_mask) - - if input_ids.ndimension() == 2: - next_token_logits = res.logits[:,-1] - else: - next_token_logits = res.logits[-1] - - if self.cache is not None: - key = "IDs:" + str(input_ids.tolist()) + " MASK:" + str(attention_mask.tolist()) - self.cache[key] = next_token_logits.tolist() - - self.state.all_results_queue.put({ - "client_id": client_id, - "sample_id": sample_id, - "next_token_logits": next_token_logits.detach().tolist() - }) - - print("Processor shut down") - - def oom_reloading_run(self): - while True: - try: - self.run() - return - except RuntimeError as e: - if "CUDA out of memory" in str(e): - print("Crashed due to OOM, reloading model.") - continue - else: - import traceback - traceback.print_exc() - print("Crashed with", e, "reloading model...") - continue - - def run_in_parallel(self): - atexit.register(self.shutdown) - - p = multiprocessing.Process(target=self.oom_reloading_run) - p.start() - return p - -class LMQLInferenceAPIHandler(BaseHTTPRequestHandler): - def __init__(self, *args, **kwargs): - self._client_id = None - super().__init__(*args, **kwargs) - - # disable logging - def log_message(self, format, *args): - return - - @property - def state(self) -> InferenceServerState: - return self.server.state - - @property - def client_id(self) -> str: - if self._client_id is None: - self.error_bad_request(msg="client_id not set (please provide the client_id in POST payload or in query.") - raise Exception("client_id not set") - return self._client_id - - def error_bad_request(self, msg="Bad request."): - self.send_response(400) - self.send_header('Content-type', 'text/plain') - self.end_headers() - self.wfile.write(msg.encode("utf-8")) - - def process_some_all_results(self, max=10): - all_results_queue = self.state.all_results_queue - i = 0 - - while not all_results_queue.empty() and i < max: - result = all_results_queue.get() - - result_client_id = result["client_id"] - - if result_client_id not in self.state.client_results_queues: - self.state.client_results_queues[result_client_id] = Queue() - self.state.client_results_queues[result_client_id].put(result) - - i += 1 - - def do_GET_results(self): - request_client_id = self.path.split("/")[2] - self._client_id = f"{self.client_address[0]}-{request_client_id}" - - self.send_response(200) - self.send_header('Content-type', 'application/json') - self.end_headers() - - # process some results from the model and group them by client_id - self.process_some_all_results() - - if self.client_id not in self.state.client_results_queues.keys(): - self.wfile.write(b'[]') - return - - # return all results for self.client_id currently available - self.wfile.write("[".encode()) - while not self.state.client_results_queues[self.client_id].empty(): - result = self.state.client_results_queues[self.client_id].get() - self.wfile.write(json.dumps(result).encode()) - - # omit last colon - if self.state.client_results_queues[self.client_id].empty(): break - else: self.wfile.write(b",") - self.wfile.write(b"]") - - def do_queue_forward(self, payload, sample_id): - try: - input_ids = payload['input_ids'] - attention_mask = payload.get('attention_mask', None) - model_identifier = payload['model_identifier'] - - if model_identifier != self.state.model_identifier: - self.error_bad_request("The inference API serves model {} not {}.".format(self.state.model_identifier, model_identifier)) - return - - self._client_id = f"{self.client_address[0]}-{payload['client_id']}" - - assert type(input_ids) == list # and all([type(i) == int for i in input_ids]) - - self.state.queue.put({ - 'client_id': self.client_id, - 'sample_id': sample_id, - 'input_ids': input_ids, - 'attention_mask': attention_mask - }) - - self.send_response(200) - self.send_header('Content-type', 'application/json') - self.end_headers() - self.wfile.write(json.dumps({'sample_id': sample_id}).encode()) - except Exception as e: - self.error_bad_request() - - def do_queue_tokenize(self, payload, sample_id): - try: - text = payload['text'] - self._client_id = f"{self.client_address[0]}-{payload['client_id']}" - - assert type(text) == str - - self.state.tokenize_queue.put({ - 'client_id': self.client_id, - 'sample_id': sample_id, - 'action': 'tokenize', - 'text': text - }) - - self.send_response(200) - self.send_header('Content-type', 'application/json') - self.end_headers() - self.wfile.write(json.dumps({'sample_id': sample_id}).encode()) - except Exception as e: - self.error_bad_request() - - def do_queue_detokenize(self, payload, sample_id): - try: - input_ids = payload['input_ids'] - self._client_id = f"{self.client_address[0]}-{payload['client_id']}" - - assert type(input_ids) == list and all([type(i) == int for i in input_ids]) - - self.state.tokenize_queue.put({ - 'client_id': self.client_id, - 'sample_id': sample_id, - 'action': 'detokenize', - 'input_ids': input_ids - }) - - self.send_response(200) - self.send_header('Content-type', 'application/json') - self.end_headers() - self.wfile.write(json.dumps({'sample_id': sample_id}).encode()) - except Exception as e: - self.error_bad_request() - - def do_POST_queue(self): - # handle POST to /queue - payload = self.rfile.read(int(self.headers['Content-Length'])) - payload = json.loads(payload) - # get client address and port - sample_id = payload["sample_id"] - - action = payload['action'] - - if action == "forward": - self.do_queue_forward(payload, sample_id) - elif action == "tokenize": - self.do_queue_tokenize(payload, sample_id) - elif action == "detokenize": - self.do_queue_detokenize(payload, sample_id) - else: - self.error_bad_request("Unknown action: {}".format(action)) - - def do_POST(self): - if self.path == "/queue": - self.do_POST_queue() - return - else: - self.send_error(404) - return - - def do_GET(self): - if self.path.startswith("/results"): - self.do_GET_results() - return - else: - self.send_error(404) - return +from lmql.model import serve_hf, serve_llama_cpp +from lmql.model.serve_types import InferenceServerState, LMQLInferenceAPIHandler +import lmql.model.serve_hf +import lmql.model.serve_llama_cpp if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("model", type=str) - parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument("model", type=str, help = "the huggingface model to use if not llama.cpp else the huggingface tokenizer to proxy the llama.cpp one.") + parser.add_argument("--llama.cpp", action='store_true', dest="llama_cpp", help="flag determining whether to use llama.cpp server or not.") parser.add_argument("--port", type=int, default=8080) parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--cuda", action="store_true", default=False) parser.add_argument("--cache", type=str, default=None) - parser.add_argument("--dtype", type=str, default="none") parser.add_argument("--num-tokenizer-processes", type=int, default=2, dest="num_tokenizer_processes") + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument("--dtype", type=str, default="none") + + serve_hf.add_parser(parser) + serve_llama_cpp.add_parser(parser) args = parser.parse_args() @@ -455,6 +53,7 @@ def do_GET(self): tokenizer_descriptor = args.tokenizer if tokenizer_descriptor is None: tokenizer_descriptor = model_descriptor + state = InferenceServerState(model_descriptor, tokenizer_descriptor, args.dtype, @@ -462,14 +61,10 @@ def do_GET(self): tokenize_queue=manager.Queue(), all_results_queue=manager.Queue()) - # run model in separate process - processor = ModelProcessor(state, cuda=args.cuda, cache=args.cache) - processor.run_in_parallel() - - # run tokenizers in separate process - tokenizer_processor = TokenizerProcessor(state) - tokenizer_processor.run_in_parallel(n=args.num_tokenizer_processes) - + if args.llama_cpp: + processor, tokenizer_processor = serve_llama_cpp.get_serve(state, args) + else: + processor, tokenizer_processor = serve_hf.get_serve(state, args) # run inference API server in this process server_address = (args.host, args.port) httpd = HTTPServer(server_address, LMQLInferenceAPIHandler) diff --git a/src/lmql/model/serve_hf.py b/src/lmql/model/serve_hf.py new file mode 100644 index 00000000..3201b07a --- /dev/null +++ b/src/lmql/model/serve_hf.py @@ -0,0 +1,173 @@ +""" +Serves a transformers model as LMQL inference API. +""" + +from dataclasses import dataclass, field +from collections import defaultdict + +import json +from http.server import BaseHTTPRequestHandler, HTTPServer +from multiprocessing import Queue as MPQueue +from queue import Empty +from queue import Queue +import multiprocessing +from typing import Dict, Tuple +import requests +import asyncio +import sys +import atexit +import argparse +import time +import os +import subprocess + +from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline +import torch +from lmql.model.serve_types import TokenizerProcessor, ModelProcessor, InferenceServerState + +class HFTokenizerProcessor(TokenizerProcessor): + + def tokenize(self, tokenizer, sample_id, client_id, item): + text = item["text"] + + if text == "": + input_ids = [tokenizer.eos_token_id] + elif text == "": + input_ids = [tokenizer.bos_token_id] + else: + input_ids = tokenizer(text)["input_ids"] + + self.state.all_results_queue.put({ + "sample_id": sample_id, + "client_id": client_id, + "input_ids": input_ids + }) + + def detokenize(self, tokenizer, sample_id, client_id, item): + input_ids = item["input_ids"] + + text = tokenizer.decode(input_ids) + self.state.all_results_queue.put({ + "sample_id": sample_id, + "client_id": client_id, + "text": text + }) + + def run(self, index): + # load tokenizer + tokenizer = AutoTokenizer.from_pretrained(self.model_identifier) + print("Tokenizer #{} {} ready!".format(index, self.model_identifier)) + + while not self.state.exit: + item = self.queue.get() + if item is None: + time.sleep(0.1) + continue + + sample_id = item["sample_id"] + client_id = item["client_id"] + action = item["action"] + + if action == "tokenize": + self.tokenize(tokenizer, sample_id, client_id, item) + elif action == "detokenize": + self.detokenize(tokenizer, sample_id, client_id, item) + else: + print("error: unknown TokenizerProcessor action {}".format(action)) + + print("Tokenizer #{} shut down.".format(index)) + +class HFModelProcessor(ModelProcessor): + + def run(self): + dtype = self.state.dtype + if dtype == "float16": + dtype = torch.float16 + else: + dtype = None + + # load model + if not self.cuda: + print("Loading {} (CPU)".format(self.model_identifier)) + self.model = AutoModelForCausalLM.from_pretrained(self.model_identifier, torch_dtype=dtype, resume_download=True) + else: + print("Loading {} (Multi-GPU)".format(self.model_identifier)) + self.model = AutoModelForCausalLM.from_pretrained(self.model_identifier, torch_dtype=dtype, resume_download=True, device_map="auto") + self.model.eval() + + print("Ready!".format(self.model_identifier)) + + while not self.state.exit: + self.print_stats() + # wait for self.queue to have an item + try: + item = self.queue.get(timeout=1.0) + except Empty: + continue + except KeyboardInterrupt: + break + + if item is None: + time.sleep(0.1) + continue + + self.request_count += 1 + + device = "cuda" if self.cuda else "cpu" + + sample_id = item["sample_id"] + client_id = item["client_id"] + input_ids = torch.tensor(item["input_ids"], dtype=torch.long).to(device) + attention_mask = item.get("attention_mask", None) + + if attention_mask is None: + attention_mask = torch.ones_like(input_ids).to(device) + else: + attention_mask = torch.tensor(attention_mask, dtype=torch.long).to(device) + + if self.cache is not None: + key = "IDs:" + str(input_ids.tolist()) + " MASK:" + str(attention_mask.tolist()) + if key in self.cache: + self.requests_cached += 1 + self.state.all_results_queue.put({ + "client_id": client_id, + "sample_id": sample_id, + "next_token_logits": self.cache[key] + }) + continue + + res = self.model.forward(input_ids=input_ids, attention_mask=attention_mask) + + if input_ids.ndimension() == 2: + next_token_logits = res.logits[:,-1] + else: + next_token_logits = res.logits[-1] + + if self.cache is not None: + key = "IDs:" + str(input_ids.tolist()) + " MASK:" + str(attention_mask.tolist()) + self.cache[key] = next_token_logits.tolist() + + self.state.all_results_queue.put({ + "client_id": client_id, + "sample_id": sample_id, + "next_token_logits": next_token_logits.detach().tolist() + }) + + print("Processor shut down") + + + + +def get_serve(state: InferenceServerState, args: argparse.Namespace)->Tuple[ModelProcessor, TokenizerProcessor]: + # run model in separate process + processor = HFModelProcessor(state, cuda=args.cuda, cache=args.cache) + processor.run_in_parallel() + + # run tokenizers in separate process + tokenizer_processor = HFTokenizerProcessor(state) + tokenizer_processor.run_in_parallel(n=args.num_tokenizer_processes) + return processor, tokenizer_processor + +def add_parser(base_parser): + ... + diff --git a/src/lmql/model/serve_llama_cpp.py b/src/lmql/model/serve_llama_cpp.py index 2cdae044..c70238db 100644 --- a/src/lmql/model/serve_llama_cpp.py +++ b/src/lmql/model/serve_llama_cpp.py @@ -1,5 +1,5 @@ """ -Serves a transformers model as LMQL inference API. +Serves a llama.cpp model as LMQL inference API. """ from dataclasses import dataclass, field @@ -11,37 +11,18 @@ from queue import Empty from queue import Queue import multiprocessing -from typing import Dict +from typing import Dict, Optional, Tuple import atexit import argparse import time from llama_cpp.llama import Llama, llama_cpp import inspect +from lmql.model.serve_types import TokenizerProcessor, ModelProcessor, InferenceServerState -@dataclass -class InferenceServerState: - model_identifier : str - tokenizer_descriptor : str - dtype: str - - queue: Queue - tokenize_queue: Queue - all_results_queue : Queue - - sample_count: int = 0 - client_results_queues: Dict[str,Queue] = field(default_factory=dict) - - exit: bool = False - -class TokenizerProcessor: +class LlamaCPPTokenizerProcessor(TokenizerProcessor): def __init__(self, state: InferenceServerState, processor: "ModelProcessor"): - self.model_identifier = state.tokenizer_descriptor + super().__init__(state) self.model = processor - self.queue = state.tokenize_queue - self.state = state - - def shutdown(self): - self.state.exit = True def tokenize(self, tokenizer, sample_id, client_id, item): text = item["text"] @@ -92,56 +73,14 @@ def run(self): print("error: unknown TokenizerProcessor action {}".format(action)) print("Tokenizer shut down.") - - def run_in_parallel(self): - atexit.register(self.shutdown) - - p = multiprocessing.Process(target=self.run) - p.start() - return p -class ModelProcessor: - def __init__(self, state: InferenceServerState, llama_kwargs: dict, cache: str = None): - self.model_identifier = state.model_identifier +class LlamaCPPModelProcessor(ModelProcessor): + def __init__(self, state: InferenceServerState, cuda: bool = False, cache: Optional[str] = None, llama_kwargs: Optional[dict] = None): + super().__init__(state, cuda, cache) + assert llama_kwargs is not None self.llama_kwargs = llama_kwargs - self.queue = state.queue - self.state = state - - self.cache = None - if cache is not None: - from rocksdict import Rdict - self.cache = Rdict(cache) - - self.request_count = 0 - self.requests_cached = 0 - self.last_report = time.time() - self.last_request_count = 0 - - - def shutdown(self): - self.state.exit = True - - def __del__(self): - if self.cache is not None: - self.cache.close() - - def print_stats(self): - # fancy unicode based terminal spinner - terminal_spinner_chars = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] - throughput = (self.request_count - self.last_request_count) / (time.time() - self.last_report) - self.last_report = time.time() - self.last_request_count = self.request_count - - # format throughput to two decimal places - print("{} {:.2f} calls/s, Requests Served: {}, Queue: {}".format( - terminal_spinner_chars[self.request_count % len(terminal_spinner_chars)], - throughput, - self.request_count, - self.state.queue.qsize()), end="\r") - def run(self): - - + def run(self): # load model print("Loading {} (CPU)".format(self.model_identifier)) self.model = Llama(**{**self.llama_kwargs, 'logits_all': True}) @@ -197,241 +136,30 @@ def run(self): print("Processor shut down") - def run_in_parallel(self): - atexit.register(self.shutdown) - - p = multiprocessing.Process(target=self.run) - p.start() - return p +base_llama_kwargs = {} -class LMQLInferenceAPIHandler(BaseHTTPRequestHandler): - def __init__(self, *args, **kwargs): - self._client_id = None - super().__init__(*args, **kwargs) +def get_serve(state: InferenceServerState, args: argparse.Namespace)->Tuple[ModelProcessor, TokenizerProcessor]: + # run model in separate process + llama_kwargs = {kwarg: getattr(args, kwarg) for kwarg in base_llama_kwargs.keys()} + processor = LlamaCPPModelProcessor(state, cuda=False, cache=args.cache, llama_kwargs=llama_kwargs) + processor.run_in_parallel() - # disable logging - def log_message(self, format, *args): - return + # run tokenizers in separate process + tokenizer_processor = LlamaCPPTokenizerProcessor(state, processor=processor) + tokenizer_processor.run_in_parallel(1) + return processor, tokenizer_processor - @property - def state(self) -> InferenceServerState: - return self.server.state +def add_parser(base_parser): - @property - def client_id(self) -> str: - if self._client_id is None: - self.error_bad_request(msg="client_id not set (please provide the client_id in POST payload or in query.") - raise Exception("client_id not set") - return self._client_id - - def error_bad_request(self, msg="Bad request."): - self.send_response(400) - self.send_header('Content-type', 'text/plain') - self.end_headers() - self.wfile.write(msg.encode("utf-8")) - - def process_some_all_results(self, max=10): - all_results_queue = self.state.all_results_queue - i = 0 - - while not all_results_queue.empty() and i < max: - result = all_results_queue.get() - - result_client_id = result["client_id"] - - if result_client_id not in self.state.client_results_queues: - self.state.client_results_queues[result_client_id] = Queue() - self.state.client_results_queues[result_client_id].put(result) - - i += 1 - - def do_GET_results(self): - request_client_id = self.path.split("/")[2] - self._client_id = f"{self.client_address[0]}-{request_client_id}" - - self.send_response(200) - self.send_header('Content-type', 'application/json') - self.end_headers() - - # process some results from the model and group them by client_id - self.process_some_all_results() - - if self.client_id not in self.state.client_results_queues.keys(): - self.wfile.write(b'[]') - return - - # return all results for self.client_id currently available - self.wfile.write("[".encode()) - while not self.state.client_results_queues[self.client_id].empty(): - result = self.state.client_results_queues[self.client_id].get() - self.wfile.write(json.dumps(result).encode()) - - # omit last colon - if self.state.client_results_queues[self.client_id].empty(): break - else: self.wfile.write(b",") - self.wfile.write(b"]") - - def do_queue_forward(self, payload, sample_id): - try: - input_ids = payload['input_ids'] - attention_mask = payload.get('attention_mask', None) - model_identifier = payload['model_identifier'] - - if model_identifier != self.state.model_identifier: - self.error_bad_request("The inference API serves model {} not {}.".format(self.state.model_identifier, model_identifier)) - return - - self._client_id = f"{self.client_address[0]}-{payload['client_id']}" - - assert type(input_ids) == list # and all([type(i) == int for i in input_ids]) - - self.state.queue.put({ - 'client_id': self.client_id, - 'sample_id': sample_id, - 'input_ids': input_ids, - 'attention_mask': attention_mask - }) - - self.send_response(200) - self.send_header('Content-type', 'application/json') - self.end_headers() - self.wfile.write(json.dumps({'sample_id': sample_id}).encode()) - except Exception as e: - self.error_bad_request() - - def do_queue_tokenize(self, payload, sample_id): - try: - text = payload['text'] - self._client_id = f"{self.client_address[0]}-{payload['client_id']}" - - assert type(text) == str - - self.state.tokenize_queue.put({ - 'client_id': self.client_id, - 'sample_id': sample_id, - 'action': 'tokenize', - 'text': text - }) - - self.send_response(200) - self.send_header('Content-type', 'application/json') - self.end_headers() - self.wfile.write(json.dumps({'sample_id': sample_id}).encode()) - except Exception as e: - self.error_bad_request() - - def do_queue_detokenize(self, payload, sample_id): - try: - input_ids = payload['input_ids'] - self._client_id = f"{self.client_address[0]}-{payload['client_id']}" - - assert type(input_ids) == list and all([type(i) == int for i in input_ids]) - - self.state.tokenize_queue.put({ - 'client_id': self.client_id, - 'sample_id': sample_id, - 'action': 'detokenize', - 'input_ids': input_ids - }) - - self.send_response(200) - self.send_header('Content-type', 'application/json') - self.end_headers() - self.wfile.write(json.dumps({'sample_id': sample_id}).encode()) - except Exception as e: - self.error_bad_request() - - def do_POST_queue(self): - # handle POST to /queue - payload = self.rfile.read(int(self.headers['Content-Length'])) - payload = json.loads(payload) - # get client address and port - sample_id = payload["sample_id"] - - action = payload['action'] - - if action == "forward": - self.do_queue_forward(payload, sample_id) - elif action == "tokenize": - self.do_queue_tokenize(payload, sample_id) - elif action == "detokenize": - self.do_queue_detokenize(payload, sample_id) - else: - self.error_bad_request("Unknown action: {}".format(action)) - - def do_POST(self): - if self.path == "/queue": - self.do_POST_queue() - return - else: - self.send_error(404) - return - - def do_GET(self): - if self.path.startswith("/results"): - self.do_GET_results() - return - else: - self.send_error(404) - return - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("model", type=str) - parser.add_argument("--port", type=int, default=8080) - parser.add_argument("--host", type=str, default="localhost") - parser.add_argument("--cache", type=str, default=None) - llama_kwargs = {} sig = inspect.signature(Llama.__init__) for name, param in sig.parameters.items(): if name == 'self': continue - llama_kwargs[name] = None + base_llama_kwargs[name] = None _type = param.annotation if hasattr(_type, '__args__'): _type=_type.__args__[0] if param.default == inspect.Parameter.empty: - parser.add_argument(name, type=_type) + base_parser.add_argument(f'--{name}', default=_type(), type=_type, help="required for llama.cpp") else: - parser.add_argument(f'--{name}', default=param.default, type=_type) - - args = parser.parse_args() - - llama_kwargs = {kwarg: getattr(args, kwarg) for kwarg in llama_kwargs.keys()} - - manager = multiprocessing.Manager() - - # prepare configuration - model_descriptor = args.model - state = InferenceServerState(model_descriptor, - model_descriptor, - "", - queue=manager.Queue(), - tokenize_queue=manager.Queue(), - all_results_queue=manager.Queue()) - - # run model in separate process - processor = ModelProcessor(state, cache=args.cache, llama_kwargs=llama_kwargs) - processor.run_in_parallel() - - # run tokenizers in separate process - tokenizer_processor = TokenizerProcessor(state, processor) - tokenizer_processor.run_in_parallel() - - # run inference API server in this process - server_address = (args.host, args.port) - httpd = HTTPServer(server_address, LMQLInferenceAPIHandler) - httpd.state = state - - try: - print("Serving LMQL inference API on {}:{}".format(args.host, args.port)) - httpd.serve_forever() - except KeyboardInterrupt: - # terminate server - httpd.shutdown() - httpd.server_close() - print("Server stopped") - - # terminate processors - processor.shutdown() - tokenizer_processor.shutdown() + base_parser.add_argument(f'--{name}', default=param.default, type=_type, help="optional for llama.cpp") diff --git a/src/lmql/model/serve_types.py b/src/lmql/model/serve_types.py new file mode 100644 index 00000000..9778e919 --- /dev/null +++ b/src/lmql/model/serve_types.py @@ -0,0 +1,328 @@ +""" +Serves a transformers model as LMQL inference API. +""" + +from dataclasses import dataclass, field +from collections import defaultdict + +import json +from http.server import BaseHTTPRequestHandler, HTTPServer +from multiprocessing import Queue as MPQueue +from queue import Empty +from queue import Queue +import multiprocessing +from typing import Dict, Optional +import requests +import asyncio +import sys +import atexit +import argparse +import time +import os +import subprocess + +from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline +import torch +from abc import ABC, abstractmethod + +@dataclass +class InferenceServerState: + model_identifier : str + tokenizer_descriptor : str + dtype: str + + queue: Queue + tokenize_queue: Queue + all_results_queue : Queue + + sample_count: int = 0 + client_results_queues: Dict[str,Queue] = field(default_factory=dict) + + exit: bool = False + + +class TokenizerProcessor(ABC): + def __init__(self, state: InferenceServerState): + self.model_identifier = state.tokenizer_descriptor + self.queue = state.tokenize_queue + self.state = state + + def shutdown(self): + self.state.exit = True + + @abstractmethod + def tokenize(self, tokenizer, sample_id, client_id, item): + """Tokenize item input""" + + @abstractmethod + def detokenize(self, tokenizer, sample_id, client_id, item): + """Detokenize item input""" + + @abstractmethod + def run(self, index): + """Evaluate items from the queue""" + + def run_in_parallel(self, n=2): + atexit.register(self.shutdown) + + workers = [] + + for i in range(n): + p = multiprocessing.Process(target=self.run, args=(i,)) + p.start() + workers.append(p) + + return workers + +class ModelProcessor(ABC): + def __init__(self, state: InferenceServerState, cuda: bool = False, cache: Optional[str] = None): + self.model_identifier = state.model_identifier + self.queue = state.queue + self.state = state + self.cuda = cuda + + self.cache = None + if cache is not None: + from rocksdict import Rdict + self.cache = Rdict(cache) + + self.request_count = 0 + self.requests_cached = 0 + self.last_report = time.time() + self.last_request_count = 0 + + try: + self.nvidia_logging = subprocess.Popen(["nvidia-smi"], stdout=subprocess.PIPE).wait() == 0 + except: + self.nvidia_logging = False + + def shutdown(self): + self.state.exit = True + + def __del__(self): + if self.cache is not None: + self.cache.close() + + def print_stats(self): + if self.nvidia_logging: + visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) + cmds = ["nvidia-smi"] + if visible_devices is not None: + cmds.append("-i={}".format(visible_devices)) + cmds += ["--query-gpu=name,memory.used,memory.total,utilization.gpu", "--format=csv,noheader"] + output = [l.split(", ") for l in subprocess.check_output(cmds).decode("utf-8").split("\n") if l.strip() != ""] + gpu_usage = ["GPU {} {}, util {}".format(i, row[1] + "/" + row[2], row[3]) for i, row in enumerate(output)] + else: + gpu_usage = ["GPU monitoring not available on non-CUDA systems"] + + print(" " * 100, end="\r") + # fancy unicode based terminal spinner + terminal_spinner_chars = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] + throughput = (self.request_count - self.last_request_count) / (time.time() - self.last_report) + self.last_report = time.time() + self.last_request_count = self.request_count + + # format throughput to two decimal places + print("{} {:.2f} calls/s, Requests Served: {}, Queue: {} [{}]".format( + terminal_spinner_chars[self.request_count % len(terminal_spinner_chars)], + throughput, + self.request_count, + self.state.queue.qsize(), + ", ".join(gpu_usage)), end="\r") + + @abstractmethod + def run(self): + """Infer on the queue""" + + def oom_reloading_run(self): + while True: + try: + self.run() + return + except RuntimeError as e: + if "CUDA out of memory" in str(e): + print("Crashed due to OOM, reloading model.") + continue + else: + import traceback + traceback.print_exc() + print("Crashed with", e, "reloading model...") + continue + + def run_in_parallel(self): + atexit.register(self.shutdown) + + p = multiprocessing.Process(target=self.oom_reloading_run) + p.start() + return p + +class LMQLInferenceAPIHandler(BaseHTTPRequestHandler): + def __init__(self, *args, **kwargs): + self._client_id = None + super().__init__(*args, **kwargs) + + # disable logging + def log_message(self, format, *args): + return + + @property + def state(self) -> InferenceServerState: + return self.server.state + + @property + def client_id(self) -> str: + if self._client_id is None: + self.error_bad_request(msg="client_id not set (please provide the client_id in POST payload or in query.") + raise Exception("client_id not set") + return self._client_id + + def error_bad_request(self, msg="Bad request."): + self.send_response(400) + self.send_header('Content-type', 'text/plain') + self.end_headers() + self.wfile.write(msg.encode("utf-8")) + + def process_some_all_results(self, max=10): + all_results_queue = self.state.all_results_queue + i = 0 + + while not all_results_queue.empty() and i < max: + result = all_results_queue.get() + + result_client_id = result["client_id"] + + if result_client_id not in self.state.client_results_queues: + self.state.client_results_queues[result_client_id] = Queue() + self.state.client_results_queues[result_client_id].put(result) + + i += 1 + + def do_GET_results(self): + request_client_id = self.path.split("/")[2] + self._client_id = f"{self.client_address[0]}-{request_client_id}" + + self.send_response(200) + self.send_header('Content-type', 'application/json') + self.end_headers() + + # process some results from the model and group them by client_id + self.process_some_all_results() + + if self.client_id not in self.state.client_results_queues.keys(): + self.wfile.write(b'[]') + return + + # return all results for self.client_id currently available + self.wfile.write("[".encode()) + while not self.state.client_results_queues[self.client_id].empty(): + result = self.state.client_results_queues[self.client_id].get() + self.wfile.write(json.dumps(result).encode()) + + # omit last colon + if self.state.client_results_queues[self.client_id].empty(): break + else: self.wfile.write(b",") + self.wfile.write(b"]") + + def do_queue_forward(self, payload, sample_id): + try: + input_ids = payload['input_ids'] + attention_mask = payload.get('attention_mask', None) + model_identifier = payload['model_identifier'] + + if model_identifier != self.state.model_identifier: + self.error_bad_request("The inference API serves model {} not {}.".format(self.state.model_identifier, model_identifier)) + return + + self._client_id = f"{self.client_address[0]}-{payload['client_id']}" + + assert type(input_ids) == list # and all([type(i) == int for i in input_ids]) + + self.state.queue.put({ + 'client_id': self.client_id, + 'sample_id': sample_id, + 'input_ids': input_ids, + 'attention_mask': attention_mask + }) + + self.send_response(200) + self.send_header('Content-type', 'application/json') + self.end_headers() + self.wfile.write(json.dumps({'sample_id': sample_id}).encode()) + except Exception as e: + self.error_bad_request() + + def do_queue_tokenize(self, payload, sample_id): + try: + text = payload['text'] + self._client_id = f"{self.client_address[0]}-{payload['client_id']}" + + assert type(text) == str + + self.state.tokenize_queue.put({ + 'client_id': self.client_id, + 'sample_id': sample_id, + 'action': 'tokenize', + 'text': text + }) + + self.send_response(200) + self.send_header('Content-type', 'application/json') + self.end_headers() + self.wfile.write(json.dumps({'sample_id': sample_id}).encode()) + except Exception as e: + self.error_bad_request() + + def do_queue_detokenize(self, payload, sample_id): + try: + input_ids = payload['input_ids'] + self._client_id = f"{self.client_address[0]}-{payload['client_id']}" + + assert type(input_ids) == list and all([type(i) == int for i in input_ids]) + + self.state.tokenize_queue.put({ + 'client_id': self.client_id, + 'sample_id': sample_id, + 'action': 'detokenize', + 'input_ids': input_ids + }) + + self.send_response(200) + self.send_header('Content-type', 'application/json') + self.end_headers() + self.wfile.write(json.dumps({'sample_id': sample_id}).encode()) + except Exception as e: + self.error_bad_request() + + def do_POST_queue(self): + # handle POST to /queue + payload = self.rfile.read(int(self.headers['Content-Length'])) + payload = json.loads(payload) + # get client address and port + sample_id = payload["sample_id"] + + action = payload['action'] + + if action == "forward": + self.do_queue_forward(payload, sample_id) + elif action == "tokenize": + self.do_queue_tokenize(payload, sample_id) + elif action == "detokenize": + self.do_queue_detokenize(payload, sample_id) + else: + self.error_bad_request("Unknown action: {}".format(action)) + + def do_POST(self): + if self.path == "/queue": + self.do_POST_queue() + return + else: + self.send_error(404) + return + + def do_GET(self): + if self.path.startswith("/results"): + self.do_GET_results() + return + else: + self.send_error(404) + return \ No newline at end of file From 93f3947e3c097567fc4d935d2f0077e7a95093a8 Mon Sep 17 00:00:00 2001 From: circargs Date: Fri, 14 Apr 2023 13:52:13 -0400 Subject: [PATCH 5/8] small serve fixes --- src/lmql/model/serve.py | 2 +- src/lmql/model/serve_llama_cpp.py | 18 ++++++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/lmql/model/serve.py b/src/lmql/model/serve.py index 4d486ffa..35abfee1 100644 --- a/src/lmql/model/serve.py +++ b/src/lmql/model/serve.py @@ -1,5 +1,5 @@ """ -Serves a transformers model as LMQL inference API. +Serves a model as LMQL inference API. """ from dataclasses import dataclass, field diff --git a/src/lmql/model/serve_llama_cpp.py b/src/lmql/model/serve_llama_cpp.py index c70238db..dba2cd7a 100644 --- a/src/lmql/model/serve_llama_cpp.py +++ b/src/lmql/model/serve_llama_cpp.py @@ -73,6 +73,13 @@ def run(self): print("error: unknown TokenizerProcessor action {}".format(action)) print("Tokenizer shut down.") + + def run_in_parallel(self): + atexit.register(self.shutdown) + + p = multiprocessing.Process(target=self.run) + p.start() + return p class LlamaCPPModelProcessor(ModelProcessor): def __init__(self, state: InferenceServerState, cuda: bool = False, cache: Optional[str] = None, llama_kwargs: Optional[dict] = None): @@ -135,7 +142,14 @@ def run(self): }) print("Processor shut down") - + + def run_in_parallel(self): + atexit.register(self.shutdown) + + p = multiprocessing.Process(target=self.run) + p.start() + return p + base_llama_kwargs = {} def get_serve(state: InferenceServerState, args: argparse.Namespace)->Tuple[ModelProcessor, TokenizerProcessor]: @@ -146,7 +160,7 @@ def get_serve(state: InferenceServerState, args: argparse.Namespace)->Tuple[Mode # run tokenizers in separate process tokenizer_processor = LlamaCPPTokenizerProcessor(state, processor=processor) - tokenizer_processor.run_in_parallel(1) + tokenizer_processor.run_in_parallel() return processor, tokenizer_processor def add_parser(base_parser): From 21a806dba3e5b8054d0de7a37033138373e3e01d Mon Sep 17 00:00:00 2001 From: circargs Date: Fri, 14 Apr 2023 13:54:37 -0400 Subject: [PATCH 6/8] update requirements --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 21978d13..e14b20bc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ astunparse==1.6.3 openai termcolor pydot -transformers==4.25.1 +transformers==4.28.1 pandas accelerate +llama-cpp-python From 9aa76dd79a2e3cdb79a5061a40956497eaeb054e Mon Sep 17 00:00:00 2001 From: circargs Date: Mon, 17 Apr 2023 19:35:29 -0400 Subject: [PATCH 7/8] clean_up_tokenization_spaces --- src/lmql/runtime/tokenizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lmql/runtime/tokenizer.py b/src/lmql/runtime/tokenizer.py index e13625ff..60494548 100644 --- a/src/lmql/runtime/tokenizer.py +++ b/src/lmql/runtime/tokenizer.py @@ -108,7 +108,7 @@ def decode(self, input_ids): if input_ids[-1] >= self.tokenizer_impl.vocab_size: extended = self.detokenizer_cache[n-1][key] + "<" + reverse_special_token_mappings[input_ids[-1]] + "/>" else: - extended = self.detokenizer_cache[n-1][key] + self.tokenizer_impl.decode([input_ids[-1]]) + extended = self.detokenizer_cache[n-1][key] + self.tokenizer_impl.decode([input_ids[-1]], clean_up_tokenization_spaces=True) if not n in self.detokenizer_cache.keys(): self.detokenizer_cache[n] = {} self.detokenizer_cache[n][str(input_ids)] = extended @@ -119,7 +119,7 @@ def decode(self, input_ids): if type(chunk) is str: s += chunk else: - s += self.tokenizer_impl.decode(chunk) + s += self.tokenizer_impl.decode(chunk, clean_up_tokenization_spaces=True) if not n in self.detokenizer_cache.keys(): self.detokenizer_cache[n] = {} From 07190a30a71066a40493079f836e00ee04363ac3 Mon Sep 17 00:00:00 2001 From: circargs Date: Mon, 17 Apr 2023 19:47:06 -0400 Subject: [PATCH 8/8] cleanup serve files --- src/lmql/model/serve.py | 62 ++++----- src/lmql/model/serve_hf.py | 129 ++++++++++-------- src/lmql/model/serve_llama_cpp.py | 118 +++++++++------- src/lmql/model/serve_types.py | 219 +++++++++++++++++------------- 4 files changed, 294 insertions(+), 234 deletions(-) diff --git a/src/lmql/model/serve.py b/src/lmql/model/serve.py index 35abfee1..f5c01d94 100644 --- a/src/lmql/model/serve.py +++ b/src/lmql/model/serve.py @@ -2,27 +2,12 @@ Serves a model as LMQL inference API. """ -from dataclasses import dataclass, field -from collections import defaultdict +from http.server import HTTPServer -import json -from http.server import BaseHTTPRequestHandler, HTTPServer -from multiprocessing import Queue as MPQueue -from queue import Empty -from queue import Queue import multiprocessing -from typing import Dict -import requests -import asyncio -import sys -import atexit + import argparse -import time -import os -import subprocess -from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline -import torch from lmql.model import serve_hf, serve_llama_cpp from lmql.model.serve_types import InferenceServerState, LMQLInferenceAPIHandler @@ -31,35 +16,48 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("model", type=str, help = "the huggingface model to use if not llama.cpp else the huggingface tokenizer to proxy the llama.cpp one.") - parser.add_argument("--llama.cpp", action='store_true', dest="llama_cpp", help="flag determining whether to use llama.cpp server or not.") + parser.add_argument( + "model", + type=str, + help="the huggingface model to use if not llama.cpp else the huggingface tokenizer to proxy the llama.cpp one.", + ) + parser.add_argument( + "--llama.cpp", + action="store_true", + dest="llama_cpp", + help="flag determining whether to use llama.cpp server or not.", + ) parser.add_argument("--port", type=int, default=8080) parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--cuda", action="store_true", default=False) parser.add_argument("--cache", type=str, default=None) - parser.add_argument("--num-tokenizer-processes", type=int, default=2, dest="num_tokenizer_processes") + parser.add_argument( + "--num-tokenizer-processes", type=int, default=2, dest="num_tokenizer_processes" + ) parser.add_argument("--tokenizer", type=str, default=None) parser.add_argument("--dtype", type=str, default="none") - + serve_hf.add_parser(parser) serve_llama_cpp.add_parser(parser) - + args = parser.parse_args() - + manager = multiprocessing.Manager() - + # prepare configuration model_descriptor = args.model tokenizer_descriptor = args.tokenizer if tokenizer_descriptor is None: tokenizer_descriptor = model_descriptor - - state = InferenceServerState(model_descriptor, - tokenizer_descriptor, - args.dtype, - queue=manager.Queue(), - tokenize_queue=manager.Queue(), - all_results_queue=manager.Queue()) + + state = InferenceServerState( + model_descriptor, + tokenizer_descriptor, + args.dtype, + queue=manager.Queue(), + tokenize_queue=manager.Queue(), + all_results_queue=manager.Queue(), + ) if args.llama_cpp: processor, tokenizer_processor = serve_llama_cpp.get_serve(state, args) @@ -69,7 +67,7 @@ server_address = (args.host, args.port) httpd = HTTPServer(server_address, LMQLInferenceAPIHandler) httpd.state = state - + try: print("Serving LMQL inference API on {}:{}".format(args.host, args.port)) httpd.serve_forever() diff --git a/src/lmql/model/serve_hf.py b/src/lmql/model/serve_hf.py index 3201b07a..b752aba8 100644 --- a/src/lmql/model/serve_hf.py +++ b/src/lmql/model/serve_hf.py @@ -2,31 +2,21 @@ Serves a transformers model as LMQL inference API. """ -from dataclasses import dataclass, field -from collections import defaultdict - -import json -from http.server import BaseHTTPRequestHandler, HTTPServer -from multiprocessing import Queue as MPQueue from queue import Empty -from queue import Queue -import multiprocessing -from typing import Dict, Tuple -import requests -import asyncio -import sys -import atexit -import argparse +from typing import Tuple +import argparse import time -import os -import subprocess -from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline +from transformers import AutoModelForCausalLM, AutoTokenizer import torch -from lmql.model.serve_types import TokenizerProcessor, ModelProcessor, InferenceServerState +from lmql.model.serve_types import ( + TokenizerProcessor, + ModelProcessor, + InferenceServerState, +) -class HFTokenizerProcessor(TokenizerProcessor): +class HFTokenizerProcessor(TokenizerProcessor): def tokenize(self, tokenizer, sample_id, client_id, item): text = item["text"] @@ -37,21 +27,17 @@ def tokenize(self, tokenizer, sample_id, client_id, item): else: input_ids = tokenizer(text)["input_ids"] - self.state.all_results_queue.put({ - "sample_id": sample_id, - "client_id": client_id, - "input_ids": input_ids - }) + self.state.all_results_queue.put( + {"sample_id": sample_id, "client_id": client_id, "input_ids": input_ids} + ) def detokenize(self, tokenizer, sample_id, client_id, item): input_ids = item["input_ids"] text = tokenizer.decode(input_ids) - self.state.all_results_queue.put({ - "sample_id": sample_id, - "client_id": client_id, - "text": text - }) + self.state.all_results_queue.put( + {"sample_id": sample_id, "client_id": client_id, "text": text} + ) def run(self, index): # load tokenizer @@ -74,27 +60,34 @@ def run(self, index): self.detokenize(tokenizer, sample_id, client_id, item) else: print("error: unknown TokenizerProcessor action {}".format(action)) - + print("Tokenizer #{} shut down.".format(index)) -class HFModelProcessor(ModelProcessor): +class HFModelProcessor(ModelProcessor): def run(self): dtype = self.state.dtype if dtype == "float16": dtype = torch.float16 else: dtype = None - + # load model if not self.cuda: print("Loading {} (CPU)".format(self.model_identifier)) - self.model = AutoModelForCausalLM.from_pretrained(self.model_identifier, torch_dtype=dtype, resume_download=True) + self.model = AutoModelForCausalLM.from_pretrained( + self.model_identifier, torch_dtype=dtype, resume_download=True + ) else: print("Loading {} (Multi-GPU)".format(self.model_identifier)) - self.model = AutoModelForCausalLM.from_pretrained(self.model_identifier, torch_dtype=dtype, resume_download=True, device_map="auto") + self.model = AutoModelForCausalLM.from_pretrained( + self.model_identifier, + torch_dtype=dtype, + resume_download=True, + device_map="auto", + ) self.model.eval() - + print("Ready!".format(self.model_identifier)) while not self.state.exit: @@ -107,58 +100,74 @@ def run(self): except KeyboardInterrupt: break - if item is None: + if item is None: time.sleep(0.1) continue self.request_count += 1 - + device = "cuda" if self.cuda else "cpu" sample_id = item["sample_id"] client_id = item["client_id"] input_ids = torch.tensor(item["input_ids"], dtype=torch.long).to(device) attention_mask = item.get("attention_mask", None) - + if attention_mask is None: attention_mask = torch.ones_like(input_ids).to(device) else: - attention_mask = torch.tensor(attention_mask, dtype=torch.long).to(device) + attention_mask = torch.tensor(attention_mask, dtype=torch.long).to( + device + ) if self.cache is not None: - key = "IDs:" + str(input_ids.tolist()) + " MASK:" + str(attention_mask.tolist()) + key = ( + "IDs:" + + str(input_ids.tolist()) + + " MASK:" + + str(attention_mask.tolist()) + ) if key in self.cache: self.requests_cached += 1 - self.state.all_results_queue.put({ - "client_id": client_id, - "sample_id": sample_id, - "next_token_logits": self.cache[key] - }) + self.state.all_results_queue.put( + { + "client_id": client_id, + "sample_id": sample_id, + "next_token_logits": self.cache[key], + } + ) continue - + res = self.model.forward(input_ids=input_ids, attention_mask=attention_mask) - + if input_ids.ndimension() == 2: - next_token_logits = res.logits[:,-1] + next_token_logits = res.logits[:, -1] else: next_token_logits = res.logits[-1] - + if self.cache is not None: - key = "IDs:" + str(input_ids.tolist()) + " MASK:" + str(attention_mask.tolist()) + key = ( + "IDs:" + + str(input_ids.tolist()) + + " MASK:" + + str(attention_mask.tolist()) + ) self.cache[key] = next_token_logits.tolist() - self.state.all_results_queue.put({ - "client_id": client_id, - "sample_id": sample_id, - "next_token_logits": next_token_logits.detach().tolist() - }) - - print("Processor shut down") - + self.state.all_results_queue.put( + { + "client_id": client_id, + "sample_id": sample_id, + "next_token_logits": next_token_logits.detach().tolist(), + } + ) + print("Processor shut down") -def get_serve(state: InferenceServerState, args: argparse.Namespace)->Tuple[ModelProcessor, TokenizerProcessor]: +def get_serve( + state: InferenceServerState, args: argparse.Namespace +) -> Tuple[ModelProcessor, TokenizerProcessor]: # run model in separate process processor = HFModelProcessor(state, cuda=args.cuda, cache=args.cache) processor.run_in_parallel() @@ -168,6 +177,6 @@ def get_serve(state: InferenceServerState, args: argparse.Namespace)->Tuple[Mode tokenizer_processor.run_in_parallel(n=args.num_tokenizer_processes) return processor, tokenizer_processor + def add_parser(base_parser): ... - diff --git a/src/lmql/model/serve_llama_cpp.py b/src/lmql/model/serve_llama_cpp.py index dba2cd7a..c6a7fb7a 100644 --- a/src/lmql/model/serve_llama_cpp.py +++ b/src/lmql/model/serve_llama_cpp.py @@ -2,22 +2,20 @@ Serves a llama.cpp model as LMQL inference API. """ -from dataclasses import dataclass, field -from collections import defaultdict - -import json -from http.server import BaseHTTPRequestHandler, HTTPServer -from multiprocessing import Queue as MPQueue from queue import Empty -from queue import Queue import multiprocessing -from typing import Dict, Optional, Tuple +from typing import Optional, Tuple import atexit import argparse import time -from llama_cpp.llama import Llama, llama_cpp +from llama_cpp.llama import Llama import inspect -from lmql.model.serve_types import TokenizerProcessor, ModelProcessor, InferenceServerState +from lmql.model.serve_types import ( + TokenizerProcessor, + ModelProcessor, + InferenceServerState, +) + class LlamaCPPTokenizerProcessor(TokenizerProcessor): def __init__(self, state: InferenceServerState, processor: "ModelProcessor"): @@ -34,21 +32,17 @@ def tokenize(self, tokenizer, sample_id, client_id, item): else: input_ids = tokenizer.tokenize(b" " + text.encode("utf-8")) - self.state.all_results_queue.put({ - "sample_id": sample_id, - "client_id": client_id, - "input_ids": input_ids - }) + self.state.all_results_queue.put( + {"sample_id": sample_id, "client_id": client_id, "input_ids": input_ids} + ) def detokenize(self, tokenizer, sample_id, client_id, item): input_ids = item["input_ids"] - text = tokenizer.detokenize(input_ids).decode('utf-8') - self.state.all_results_queue.put({ - "sample_id": sample_id, - "client_id": client_id, - "text": text - }) + text = tokenizer.detokenize(input_ids).decode("utf-8") + self.state.all_results_queue.put( + {"sample_id": sample_id, "client_id": client_id, "text": text} + ) def run(self): # load tokenizer @@ -71,7 +65,7 @@ def run(self): self.detokenize(tokenizer, sample_id, client_id, item) else: print("error: unknown TokenizerProcessor action {}".format(action)) - + print("Tokenizer shut down.") def run_in_parallel(self): @@ -80,17 +74,24 @@ def run_in_parallel(self): p = multiprocessing.Process(target=self.run) p.start() return p - + + class LlamaCPPModelProcessor(ModelProcessor): - def __init__(self, state: InferenceServerState, cuda: bool = False, cache: Optional[str] = None, llama_kwargs: Optional[dict] = None): + def __init__( + self, + state: InferenceServerState, + cuda: bool = False, + cache: Optional[str] = None, + llama_kwargs: Optional[dict] = None, + ): super().__init__(state, cuda, cache) assert llama_kwargs is not None self.llama_kwargs = llama_kwargs - def run(self): + def run(self): # load model print("Loading {} (CPU)".format(self.model_identifier)) - self.model = Llama(**{**self.llama_kwargs, 'logits_all': True}) + self.model = Llama(**{**self.llama_kwargs, "logits_all": True}) print("Ready!".format(self.model_identifier)) @@ -104,13 +105,11 @@ def run(self): except KeyboardInterrupt: break - if item is None: + if item is None: time.sleep(0.1) continue self.request_count += 1 - - sample_id = item["sample_id"] client_id = item["client_id"] @@ -120,27 +119,31 @@ def run(self): key = str(input_ids) if key in self.cache: self.requests_cached += 1 - self.state.all_results_queue.put({ - "client_id": client_id, - "sample_id": sample_id, - "next_token_logits": self.cache[key] - }) + self.state.all_results_queue.put( + { + "client_id": client_id, + "sample_id": sample_id, + "next_token_logits": self.cache[key], + } + ) continue self.model.reset() res = self.model.eval(input_ids[0]) - + next_token_logits = self.model.all_logits[-1] - + if self.cache is not None: key = str(input_ids.tolist()) self.cache[key] = next_token_logits - self.state.all_results_queue.put({ - "client_id": client_id, - "sample_id": sample_id, - "next_token_logits": [next_token_logits] - }) - + self.state.all_results_queue.put( + { + "client_id": client_id, + "sample_id": sample_id, + "next_token_logits": [next_token_logits], + } + ) + print("Processor shut down") def run_in_parallel(self): @@ -149,13 +152,19 @@ def run_in_parallel(self): p = multiprocessing.Process(target=self.run) p.start() return p - + + base_llama_kwargs = {} -def get_serve(state: InferenceServerState, args: argparse.Namespace)->Tuple[ModelProcessor, TokenizerProcessor]: + +def get_serve( + state: InferenceServerState, args: argparse.Namespace +) -> Tuple[ModelProcessor, TokenizerProcessor]: # run model in separate process llama_kwargs = {kwarg: getattr(args, kwarg) for kwarg in base_llama_kwargs.keys()} - processor = LlamaCPPModelProcessor(state, cuda=False, cache=args.cache, llama_kwargs=llama_kwargs) + processor = LlamaCPPModelProcessor( + state, cuda=False, cache=args.cache, llama_kwargs=llama_kwargs + ) processor.run_in_parallel() # run tokenizers in separate process @@ -163,17 +172,24 @@ def get_serve(state: InferenceServerState, args: argparse.Namespace)->Tuple[Mode tokenizer_processor.run_in_parallel() return processor, tokenizer_processor + def add_parser(base_parser): - sig = inspect.signature(Llama.__init__) for name, param in sig.parameters.items(): - if name == 'self': + if name == "self": continue base_llama_kwargs[name] = None _type = param.annotation - if hasattr(_type, '__args__'): - _type=_type.__args__[0] + if hasattr(_type, "__args__"): + _type = _type.__args__[0] if param.default == inspect.Parameter.empty: - base_parser.add_argument(f'--{name}', default=_type(), type=_type, help="required for llama.cpp") + base_parser.add_argument( + f"--{name}", default=_type(), type=_type, help="required for llama.cpp" + ) else: - base_parser.add_argument(f'--{name}', default=param.default, type=_type, help="optional for llama.cpp") + base_parser.add_argument( + f"--{name}", + default=param.default, + type=_type, + help="optional for llama.cpp", + ) diff --git a/src/lmql/model/serve_types.py b/src/lmql/model/serve_types.py index 9778e919..e20d9479 100644 --- a/src/lmql/model/serve_types.py +++ b/src/lmql/model/serve_types.py @@ -3,44 +3,36 @@ """ from dataclasses import dataclass, field -from collections import defaultdict import json -from http.server import BaseHTTPRequestHandler, HTTPServer -from multiprocessing import Queue as MPQueue -from queue import Empty +from http.server import BaseHTTPRequestHandler from queue import Queue import multiprocessing from typing import Dict, Optional -import requests -import asyncio -import sys import atexit -import argparse import time import os import subprocess -from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline -import torch from abc import ABC, abstractmethod + @dataclass class InferenceServerState: - model_identifier : str - tokenizer_descriptor : str + model_identifier: str + tokenizer_descriptor: str dtype: str queue: Queue tokenize_queue: Queue - all_results_queue : Queue - + all_results_queue: Queue + sample_count: int = 0 - client_results_queues: Dict[str,Queue] = field(default_factory=dict) - + client_results_queues: Dict[str, Queue] = field(default_factory=dict) + exit: bool = False - + class TokenizerProcessor(ABC): def __init__(self, state: InferenceServerState): self.model_identifier = state.tokenizer_descriptor @@ -49,7 +41,7 @@ def __init__(self, state: InferenceServerState): def shutdown(self): self.state.exit = True - + @abstractmethod def tokenize(self, tokenizer, sample_id, client_id, item): """Tokenize item input""" @@ -57,78 +49,105 @@ def tokenize(self, tokenizer, sample_id, client_id, item): @abstractmethod def detokenize(self, tokenizer, sample_id, client_id, item): """Detokenize item input""" - + @abstractmethod def run(self, index): """Evaluate items from the queue""" def run_in_parallel(self, n=2): atexit.register(self.shutdown) - + workers = [] for i in range(n): p = multiprocessing.Process(target=self.run, args=(i,)) p.start() workers.append(p) - + return workers + class ModelProcessor(ABC): - def __init__(self, state: InferenceServerState, cuda: bool = False, cache: Optional[str] = None): + def __init__( + self, + state: InferenceServerState, + cuda: bool = False, + cache: Optional[str] = None, + ): self.model_identifier = state.model_identifier self.queue = state.queue self.state = state self.cuda = cuda - + self.cache = None if cache is not None: from rocksdict import Rdict + self.cache = Rdict(cache) - + self.request_count = 0 self.requests_cached = 0 self.last_report = time.time() self.last_request_count = 0 - + try: - self.nvidia_logging = subprocess.Popen(["nvidia-smi"], stdout=subprocess.PIPE).wait() == 0 + self.nvidia_logging = ( + subprocess.Popen(["nvidia-smi"], stdout=subprocess.PIPE).wait() == 0 + ) except: self.nvidia_logging = False def shutdown(self): self.state.exit = True - + def __del__(self): if self.cache is not None: self.cache.close() - + def print_stats(self): if self.nvidia_logging: visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) cmds = ["nvidia-smi"] if visible_devices is not None: cmds.append("-i={}".format(visible_devices)) - cmds += ["--query-gpu=name,memory.used,memory.total,utilization.gpu", "--format=csv,noheader"] - output = [l.split(", ") for l in subprocess.check_output(cmds).decode("utf-8").split("\n") if l.strip() != ""] - gpu_usage = ["GPU {} {}, util {}".format(i, row[1] + "/" + row[2], row[3]) for i, row in enumerate(output)] + cmds += [ + "--query-gpu=name,memory.used,memory.total,utilization.gpu", + "--format=csv,noheader", + ] + output = [ + l.split(", ") + for l in subprocess.check_output(cmds).decode("utf-8").split("\n") + if l.strip() != "" + ] + gpu_usage = [ + "GPU {} {}, util {}".format(i, row[1] + "/" + row[2], row[3]) + for i, row in enumerate(output) + ] else: gpu_usage = ["GPU monitoring not available on non-CUDA systems"] - + print(" " * 100, end="\r") # fancy unicode based terminal spinner terminal_spinner_chars = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] - throughput = (self.request_count - self.last_request_count) / (time.time() - self.last_report) + throughput = (self.request_count - self.last_request_count) / ( + time.time() - self.last_report + ) self.last_report = time.time() self.last_request_count = self.request_count - + # format throughput to two decimal places - print("{} {:.2f} calls/s, Requests Served: {}, Queue: {} [{}]".format( - terminal_spinner_chars[self.request_count % len(terminal_spinner_chars)], - throughput, - self.request_count, - self.state.queue.qsize(), - ", ".join(gpu_usage)), end="\r") + print( + "{} {:.2f} calls/s, Requests Served: {}, Queue: {} [{}]".format( + terminal_spinner_chars[ + self.request_count % len(terminal_spinner_chars) + ], + throughput, + self.request_count, + self.state.queue.qsize(), + ", ".join(gpu_usage), + ), + end="\r", + ) @abstractmethod def run(self): @@ -145,6 +164,7 @@ def oom_reloading_run(self): continue else: import traceback + traceback.print_exc() print("Crashed with", e, "reloading model...") continue @@ -156,6 +176,7 @@ def run_in_parallel(self): p.start() return p + class LMQLInferenceAPIHandler(BaseHTTPRequestHandler): def __init__(self, *args, **kwargs): self._client_id = None @@ -168,29 +189,31 @@ def log_message(self, format, *args): @property def state(self) -> InferenceServerState: return self.server.state - + @property def client_id(self) -> str: if self._client_id is None: - self.error_bad_request(msg="client_id not set (please provide the client_id in POST payload or in query.") + self.error_bad_request( + msg="client_id not set (please provide the client_id in POST payload or in query." + ) raise Exception("client_id not set") return self._client_id def error_bad_request(self, msg="Bad request."): self.send_response(400) - self.send_header('Content-type', 'text/plain') + self.send_header("Content-type", "text/plain") self.end_headers() self.wfile.write(msg.encode("utf-8")) def process_some_all_results(self, max=10): all_results_queue = self.state.all_results_queue i = 0 - + while not all_results_queue.empty() and i < max: result = all_results_queue.get() result_client_id = result["client_id"] - + if result_client_id not in self.state.client_results_queues: self.state.client_results_queues[result_client_id] = Queue() self.state.client_results_queues[result_client_id].put(result) @@ -202,14 +225,14 @@ def do_GET_results(self): self._client_id = f"{self.client_address[0]}-{request_client_id}" self.send_response(200) - self.send_header('Content-type', 'application/json') + self.send_header("Content-type", "application/json") self.end_headers() # process some results from the model and group them by client_id self.process_some_all_results() if self.client_id not in self.state.client_results_queues.keys(): - self.wfile.write(b'[]') + self.wfile.write(b"[]") return # return all results for self.client_id currently available @@ -217,90 +240,104 @@ def do_GET_results(self): while not self.state.client_results_queues[self.client_id].empty(): result = self.state.client_results_queues[self.client_id].get() self.wfile.write(json.dumps(result).encode()) - + # omit last colon - if self.state.client_results_queues[self.client_id].empty(): break - else: self.wfile.write(b",") + if self.state.client_results_queues[self.client_id].empty(): + break + else: + self.wfile.write(b",") self.wfile.write(b"]") def do_queue_forward(self, payload, sample_id): - try: - input_ids = payload['input_ids'] - attention_mask = payload.get('attention_mask', None) - model_identifier = payload['model_identifier'] + try: + input_ids = payload["input_ids"] + attention_mask = payload.get("attention_mask", None) + model_identifier = payload["model_identifier"] if model_identifier != self.state.model_identifier: - self.error_bad_request("The inference API serves model {} not {}.".format(self.state.model_identifier, model_identifier)) + self.error_bad_request( + "The inference API serves model {} not {}.".format( + self.state.model_identifier, model_identifier + ) + ) return self._client_id = f"{self.client_address[0]}-{payload['client_id']}" - assert type(input_ids) == list # and all([type(i) == int for i in input_ids]) + assert ( + type(input_ids) == list + ) # and all([type(i) == int for i in input_ids]) - self.state.queue.put({ - 'client_id': self.client_id, - 'sample_id': sample_id, - 'input_ids': input_ids, - 'attention_mask': attention_mask - }) + self.state.queue.put( + { + "client_id": self.client_id, + "sample_id": sample_id, + "input_ids": input_ids, + "attention_mask": attention_mask, + } + ) self.send_response(200) - self.send_header('Content-type', 'application/json') + self.send_header("Content-type", "application/json") self.end_headers() - self.wfile.write(json.dumps({'sample_id': sample_id}).encode()) - except Exception as e: + self.wfile.write(json.dumps({"sample_id": sample_id}).encode()) + except Exception as e: self.error_bad_request() def do_queue_tokenize(self, payload, sample_id): - try: - text = payload['text'] + try: + text = payload["text"] self._client_id = f"{self.client_address[0]}-{payload['client_id']}" assert type(text) == str - self.state.tokenize_queue.put({ - 'client_id': self.client_id, - 'sample_id': sample_id, - 'action': 'tokenize', - 'text': text - }) + self.state.tokenize_queue.put( + { + "client_id": self.client_id, + "sample_id": sample_id, + "action": "tokenize", + "text": text, + } + ) self.send_response(200) - self.send_header('Content-type', 'application/json') + self.send_header("Content-type", "application/json") self.end_headers() - self.wfile.write(json.dumps({'sample_id': sample_id}).encode()) - except Exception as e: + self.wfile.write(json.dumps({"sample_id": sample_id}).encode()) + except Exception as e: self.error_bad_request() - + def do_queue_detokenize(self, payload, sample_id): - try: - input_ids = payload['input_ids'] + try: + input_ids = payload["input_ids"] self._client_id = f"{self.client_address[0]}-{payload['client_id']}" assert type(input_ids) == list and all([type(i) == int for i in input_ids]) - self.state.tokenize_queue.put({ - 'client_id': self.client_id, - 'sample_id': sample_id, - 'action': 'detokenize', - 'input_ids': input_ids - }) + self.state.tokenize_queue.put( + { + "client_id": self.client_id, + "sample_id": sample_id, + "action": "detokenize", + "input_ids": input_ids, + } + ) self.send_response(200) - self.send_header('Content-type', 'application/json') + self.send_header("Content-type", "application/json") self.end_headers() - self.wfile.write(json.dumps({'sample_id': sample_id}).encode()) - except Exception as e: + self.wfile.write(json.dumps({"sample_id": sample_id}).encode()) + except Exception as e: self.error_bad_request() def do_POST_queue(self): # handle POST to /queue - payload = self.rfile.read(int(self.headers['Content-Length'])) + payload = self.rfile.read(int(self.headers["Content-Length"])) payload = json.loads(payload) # get client address and port sample_id = payload["sample_id"] - action = payload['action'] + action = payload["action"] if action == "forward": self.do_queue_forward(payload, sample_id) @@ -325,4 +362,4 @@ def do_GET(self): return else: self.send_error(404) - return \ No newline at end of file + return