From 6c33a6dca98c95a68f250b895182115c600f762b Mon Sep 17 00:00:00 2001 From: Shiqian Yan Date: Wed, 11 Jan 2023 17:45:26 +0800 Subject: [PATCH] [FEATURE] OPT-175B service authentication and new priority queue (#700) This commit also adds some checks to improve the robustness of the server process against adversarial inputs, since there are plans to open the API to non-key users. --- examples/llm_serving/launch_model_worker.py | 185 +++++++++----- examples/llm_serving/service/constants.py | 18 ++ examples/llm_serving/service/scheduler.py | 261 ++++++++++++++++++++ 3 files changed, 406 insertions(+), 58 deletions(-) create mode 100644 examples/llm_serving/service/scheduler.py diff --git a/examples/llm_serving/launch_model_worker.py b/examples/llm_serving/launch_model_worker.py index 537005083..628b7ed2b 100644 --- a/examples/llm_serving/launch_model_worker.py +++ b/examples/llm_serving/launch_model_worker.py @@ -1,6 +1,6 @@ import asyncio import argparse -from collections import defaultdict, namedtuple +from collections import deque, defaultdict, namedtuple from dataclasses import dataclass, field import json import time @@ -14,17 +14,16 @@ from llm_serving.generator import Generator from llm_serving.service.constants import ( - NUM_BEAMS, NUM_RETURN_SEQ, ALPA_SERVE_PORT, USE_RECAPTCHA, KEYS_FILENAME) + NUM_BEAMS, NUM_RETURN_SEQ, ALPA_SERVE_PORT, USE_RECAPTCHA, USE_API_KEYS, + ALLOW_NON_KEY_ACCESS, KEYS_FILENAME, AuthGroups, AUTH_GROUP_WEIGHTS, + AUTH_GROUP_SCHEDULER_SCALE, API_KEY_SCHEDULER_SCALE, + API_KEY_DEFAULT_WEIGHT) from llm_serving.service.recaptcha import load_recaptcha +from llm_serving.service.scheduler import ( + WeightedRoundRobin, NestedScheduler, FrontQueueScheduler, AsyncWrapper) from llm_serving.service.utils import build_logger -@dataclass(order=True) -class PrioritizedItem: - priority: int - item: Any=field(compare=False) - - GenerateItem = namedtuple("GenerateItem", ["uid", "return_queue", "data"]) LogprobsItem = namedtuple("LogprobsItem", ["uid", "return_queue", "data"]) @@ -38,10 +37,14 @@ def __init__(self, num_beams: int, num_return_sequences: int, use_recaptcha: bool, + use_api_keys: bool, + allow_non_key_access: bool, max_seq_len: int = 1024, max_batch_size: int = 4, logprobs_past_cache_size_limit: int = 4, - batch_timeout: float = 1.0): + batch_wait_size_mult: int = 10, + batch_timeout: float = 1.0, + queue_timeout: float = 0.001): self.logger = build_logger() self.num_beams = num_beams @@ -50,8 +53,9 @@ def __init__(self, # Batch queues self.max_bs = max_batch_size + self.batch_wait_size_mult = batch_wait_size_mult self.batch_timeout = batch_timeout - self.request_queue = asyncio.PriorityQueue() + self.queue_timeout = queue_timeout self.logprobs_past_cache = defaultdict(lambda: (0, None)) self.logprobs_past_cache_size_limit = logprobs_past_cache_size_limit asyncio.get_event_loop().create_task(self.batch_loop()) @@ -75,39 +79,73 @@ def __init__(self, # Authentication self.allowed_api_keys = [] self.recaptcha = load_recaptcha(use_recaptcha) - if use_recaptcha: + self.allow_non_key_access = allow_non_key_access + api_key_weights = {} + if use_api_keys: keys = json.load(open(KEYS_FILENAME, "r")) self.allowed_api_keys = keys["allowed_api_keys"] + if "api_key_weights" in keys: + api_key_weights = keys["api_key_weights"] + + # Scheduling + # Each authentication choice - endpoint pair contains a separate queue, + # and these queues are given fixed weights independent of how many + # requests are within each group. Requests that use API keys are + # further organized based on the API key weights. + inner_schedulers = {} + for auth_group in AuthGroups: + if auth_group == AuthGroups.API_KEY_USER: + inner_schedulers[auth_group] = WeightedRoundRobin( + api_key_weights, + API_KEY_SCHEDULER_SCALE, + API_KEY_DEFAULT_WEIGHT) + else: + inner_schedulers[auth_group] = deque() + self.request_queue = NestedScheduler( + WeightedRoundRobin( + AUTH_GROUP_WEIGHTS, AUTH_GROUP_SCHEDULER_SCALE, None), + inner_schedulers) + # To support batching completion requests without shuffling the order + # of logprob requests, we return the temporarily unqueued logprob + # requests to the front of the queue. + self.request_queue = AsyncWrapper(FrontQueueScheduler( + self.request_queue)) async def batch_loop(self): while True: - pri_item = await self.request_queue.get() + item = (await self.request_queue.get())[1][1] # Get the next batch generate_batch = [] logprobs_item = None non_batch = [] - if isinstance(pri_item.item, GenerateItem): - # Wait for batch opportunity - await asyncio.sleep(self.batch_timeout) - generate_batch.append(pri_item.item) + if isinstance(item, GenerateItem): + batch_wait_size = self.batch_wait_size_mult * self.max_bs + if self.request_queue.qsize() < batch_wait_size: + # Wait for batch opportunity + await asyncio.sleep(self.batch_timeout) + else: + # Yield control until new requests are queued + await asyncio.sleep(self.queue_timeout) + generate_batch.append(item) while (not self.request_queue.empty() and len(generate_batch) < self.max_bs): - pri_item = self.request_queue.get_nowait() - if isinstance(pri_item.item, GenerateItem): - generate_batch.append(pri_item.item) + queue_entry = self.request_queue.get_nowait() + item = queue_entry[1][1] + if isinstance(item, GenerateItem): + generate_batch.append(item) else: - non_batch.append(pri_item) + non_batch.append(queue_entry) break - # Put non-batch items back to request queue - for x in non_batch: - self.request_queue.put_nowait(x) - elif isinstance(pri_item.item, LogprobsItem): - logprobs_item = pri_item.item + # Return non-batch items to the front of the request queue + non_batch.reverse() + self.request_queue.extendleft(non_batch) + elif isinstance(item, LogprobsItem): + logprobs_item = item else: - raise RuntimeError(f"Invalid item: {pri_item.item}") + raise RuntimeError(f"Invalid item: {item}") # Process this batch if generate_batch: @@ -146,7 +184,8 @@ async def batch_loop(self): logits = output.logits[:num_inputs, -1] logprobs = torch.log_softmax(logits, dim=-1) - top_logprobs, top_indices = logprobs.topk(arg["top_k"], dim=1) + top_k = min(arg["top_k"], logprobs.shape[1]) + top_logprobs, top_indices = logprobs.topk(top_k, dim=1) # return at most top_k tokens, e.g. if network limited return_dict = { @@ -158,12 +197,12 @@ async def batch_loop(self): async def handle_request(self, request): args = await request.json() - self.check_authorization(args, request) + authorization = self.get_authorization(args, request) if "completions" in request.url.path: - return await self.completions(args, request) + return await self.completions(args, request, authorization) elif "logprobs" in request.url.path: - return await self.logprobs(args, request) + return await self.logprobs(args, request, authorization) else: raise ValueError("Invalid url: {request.url}") @@ -174,25 +213,34 @@ def normalize_prompts(self, prompts): # - case 3: List[int]. Pretokenized. Return one generation. # - case 4: List[List[int]]. Pretokenized multiple generations. # our approach is to turn everything into the case 4 - if isinstance(prompts, str): # case 1 - prompts = [self.generator.encode(prompts)] - elif isinstance(prompts, list) and isinstance(prompts[0], str): # case 2 - prompts = [self.generator.encode(p) for p in prompts] - elif isinstance(prompts, list) and isinstance(prompts[0], int): # case 3 - prompts = [prompts] - else: # case 4 - assert isinstance(prompts[0], list) - assert isinstance(prompts[0][0], int) - if len(prompts[0]) <= 0: + try: + if isinstance(prompts, str): # case 1 + prompts = [self.generator.encode(prompts)] + elif isinstance(prompts, list) and isinstance(prompts[0], str): + assert all(isinstance(v, str) for v in prompts) + prompts = [self.generator.encode(p) for p in prompts] + elif isinstance(prompts, list) and isinstance(prompts[0], int): + prompts = [prompts] + assert isinstance(prompts, list) + for sublist in prompts: + assert isinstance(sublist, list) + assert all(isinstance(v, int) for v in sublist) + assert all(v + (1 << 63) < (1 << 64) for v in sublist) + except AssertionError: + raise ValueError( + "The prompt must be either a string, a list of strings, a " + "list of integers, or a list of integer lists.") + if len(prompts[0]) <= 0 or \ + any(len(sublist) <= 0 for sublist in prompts): raise ValueError("The prompt must be nonempty.") return prompts - async def completions(self, args, request): + async def completions(self, args, request, authorization): logger = self.logger if "redirect_logprobs" in args: # A redirection to workaround some security settings. - return await self.logprobs(args, request) + return await self.logprobs(args, request, authorization) # Normalize prompts prompts = args["prompt"] @@ -236,9 +284,10 @@ async def completions(self, args, request): return_queue = asyncio.Queue() for i, prompt in enumerate(prompts): data = {"input": prompt, **args} - priority = 0 - self.request_queue.put_nowait(PrioritizedItem( - priority, GenerateItem(i, return_queue, data))) + queue_entry = GenerateItem(i, return_queue, data) + auth_group, api_key = authorization + queue_entry = (auth_group, (api_key, queue_entry)) + self.request_queue.put_nowait(queue_entry) unordered_results = [] for i in range(len(prompts)): @@ -263,7 +312,7 @@ async def completions(self, args, request): ], } - async def logprobs(self, args, request): + async def logprobs(self, args, request, authorization): logger = self.logger # Normalize prompts @@ -297,11 +346,13 @@ async def logprobs(self, args, request): self.check_max_length_limit(cur_len, self.max_seq_len) # Push the request to the batch queue - cache_id = args["cache_id"] if "cache_id" in args else str(uuid.uuid4()) + cache_id = str(args["cache_id"]) if "cache_id" in args else str(uuid.uuid4()) ret_queue = asyncio.Queue() data = {"input": prompts, "cache_id": cache_id, **args} - self.request_queue.put_nowait(PrioritizedItem( - 50, LogprobsItem(0, ret_queue, data))) + queue_entry = LogprobsItem(0, ret_queue, data) + auth_group, api_key = authorization + queue_entry = (auth_group, (api_key, queue_entry)) + self.request_queue.put_nowait(queue_entry) results = await ret_queue.get() return { "cache_id": cache_id, @@ -318,17 +369,31 @@ def check_max_length_limit(self, cur_len, max_len): f"If you want to try longer sequence length, " f"please consider hosting your own service using Alpa.") - def check_authorization(self, args, request): - if args.get("api_key", None) in self.allowed_api_keys: - return + def get_authorization(self, args, request): + api_key = args.get("api_key", None) + if api_key in self.allowed_api_keys: + return (AuthGroups.API_KEY_USER, api_key) + elif api_key is not None: + self.logger.error(f"Rejected a request with an incorrect key.") + raise ValueError("API key is incorrect, please verify that you " + "have passed the right value (as opposed to, " + "say, an OpenAI API key).") + + recaptcha_response = str(args.get("g-recaptcha-response", "")) + if recaptcha_response == "": + if self.allow_non_key_access: + return (AuthGroups.NON_KEY_USER, None) + else: + self.logger.error(f"Rejected a request with no API key.") + raise ValueError("No captcha data found. If you are using " + "client APIs, please contact alpa developers " + "to get an API key.") - if not self.recaptcha.verify( - args.get("g-recaptcha-response", ""), - request.client.host): + if not self.recaptcha.verify(recaptcha_response, request.client.host): self.logger.error(f"Rejected a request with invalid captcha.") raise ValueError("Invalid captcha. If you are using the website, please click the " - "\"I'm not a robot\" button. If you are using client APIs, please " - "contact alpa developers to get an API key.") + "\"I'm not a robot\" button.") + return (AuthGroups.RECAPTCHA_USER, None) def get_remote_ip(self, request): for x in request.scope['headers']: @@ -350,6 +415,8 @@ def get_remote_ip(self, request): parser.add_argument("--torch-device", type=str, default="cpu") parser.add_argument("--tokenizer", type=str) parser.add_argument("--no-recaptcha", action="store_true") + parser.add_argument("--no-api-keys", action="store_true") + parser.add_argument("--block-non-key-access", action="store_true") parser.add_argument("--register-name", type=str, default="default") parser.add_argument("--ssl-keyfile", type=str) parser.add_argument("--ssl-certfile", type=str) @@ -368,7 +435,9 @@ def get_remote_ip(self, request): t = controller.register_model.remote( args.register_name, LangaugeModelWorker, (args.model, args.path, args.torch_device, args.tokenizer, NUM_BEAMS, NUM_RETURN_SEQ, - False if args.no_recaptcha else USE_RECAPTCHA), + not args.no_recaptcha and USE_RECAPTCHA, + not args.no_api_keys and USE_API_KEYS, + not args.block_non_key_access and ALLOW_NON_KEY_ACCESS), override=True) ray.get(t) t = controller.create_replica.remote(args.register_name, group_id) diff --git a/examples/llm_serving/service/constants.py b/examples/llm_serving/service/constants.py index 83404f76d..1452790dc 100644 --- a/examples/llm_serving/service/constants.py +++ b/examples/llm_serving/service/constants.py @@ -1,4 +1,5 @@ """Hyper params for serving Meta's OPT model.""" +from enum import Enum # Alpa serve url ALPA_SERVE_PORT = 20001 @@ -11,7 +12,24 @@ # Authentication params USE_RECAPTCHA = False +USE_API_KEYS = False +ALLOW_NON_KEY_ACCESS = True KEYS_FILENAME = "/home/ubuntu/efs/alpa/examples/llm_serving/keys_file.json" +# Scheduler params +class AuthGroups(Enum): + RECAPTCHA_USER = 1 + API_KEY_USER = 2 + NON_KEY_USER = 3 + +AUTH_GROUP_WEIGHTS = { + AuthGroups.RECAPTCHA_USER: 300, + AuthGroups.API_KEY_USER: 10, + AuthGroups.NON_KEY_USER: 1 +} +AUTH_GROUP_SCHEDULER_SCALE = 300 +API_KEY_SCHEDULER_SCALE = 100 +API_KEY_DEFAULT_WEIGHT = 10 + # Logging params LOGDIR = "weblogs" diff --git a/examples/llm_serving/service/scheduler.py b/examples/llm_serving/service/scheduler.py new file mode 100644 index 000000000..d4074c7f6 --- /dev/null +++ b/examples/llm_serving/service/scheduler.py @@ -0,0 +1,261 @@ +import asyncio +import heapq +from collections import deque, OrderedDict + + +class WeightedRoundRobin: + """ + Scheduler that cycles between queues of different weightings. + The interface is the same as it were a queue implemented using deque(). + This implementation extends the original algorithm by allowing non-integer + priorities. All weights in this class are implicitly divided by a scale + factor - if all the queue weights are integer multiples of the scale + factor, the algorithm behaves just like standard weighted round robin. + Using smaller weights makes the scheduler switch between queues more + frequently, improving latency. + """ + # The scheduling algorithm is implemented using an event list. Each queue + # is associated with an hourglass that fills up a certain fraction every + # time step. When the hourglass is filled, a task is scheduled from the + # corresponding queue. An hourglass is allowed to be filled faster than + # 100% per time step - in this case, tasks are consecutively scheduled + # from the same queue until the hourglass is no longer full. + + class Hourglass: + def __init__(self, update_time, amnt_filled): + self.update_time = update_time + self.amnt_filled = amnt_filled + self.linked_tasks = deque() + + def __repr__(self): + return '({}, {}, {})'.format( + self.update_time, self.amnt_filled, list(self.linked_tasks)) + + def __init__(self, weights, scale, default_weight=None, + max_empty_hourglasses=100): + self.weights = weights + self.default_weight = default_weight + self.scale = scale + self.max_empty_hourglasses = max_empty_hourglasses + self.curr_item_num = 0 + self.curr_simulated_time = 0 + self.tasks = {} + self.hourglasses = {} + self.event_list = [] + self.empty_hourglasses = OrderedDict() + + def __len__(self): + return len(self.tasks) + + def append(self, name_and_item): + queue_name, item = name_and_item + self.tasks[self.curr_item_num] = item + new_event = False + if queue_name in self.empty_hourglasses: + self.hourglasses[queue_name] = self.empty_hourglasses[queue_name] + del self.empty_hourglasses[queue_name] + new_event = True + if queue_name not in self.hourglasses: + self.hourglasses[queue_name] = \ + WeightedRoundRobin.Hourglass(0, 0) + new_event = True + hourglass = self.hourglasses[queue_name] + hourglass.linked_tasks.append(self.curr_item_num) + if new_event: + hourglass.update_time = self.curr_simulated_time + self.__add_new_event(hourglass, queue_name) + self.curr_item_num += 1 + + def extend(self, items): + for item in items: + self.append(item) + + def popleft(self): + event_entry = heapq.heappop(self.event_list) + queue_name = event_entry[2] + hourglass = self.hourglasses[queue_name] + if hourglass.amnt_filled >= self.scale: + hourglass.amnt_filled -= self.scale + else: + self.curr_simulated_time = event_entry[0] + weight = self.weights.get(queue_name, self.default_weight) + if weight is None: + raise KeyError + hourglass.amnt_filled += ( + self.curr_simulated_time - hourglass.update_time) * weight + hourglass.amnt_filled -= self.scale + hourglass.update_time = self.curr_simulated_time + task_num = hourglass.linked_tasks.popleft() + task = self.tasks.pop(task_num) + if len(hourglass.linked_tasks) == 0: + del self.hourglasses[queue_name] + self.empty_hourglasses[queue_name] = hourglass + if len(self.empty_hourglasses) > self.max_empty_hourglasses: + self.empty_hourglasses.popitem(last=False) + else: + self.__add_new_event(hourglass, queue_name) + return (queue_name, task) + + def __add_new_event(self, hourglass, queue_name): + if hourglass.amnt_filled >= self.scale: + event_time = self.curr_simulated_time + event_entry = (event_time, hourglass.linked_tasks[0], queue_name) + heapq.heappush(self.event_list, event_entry) + else: + weight = self.weights.get(queue_name, self.default_weight) + if weight is None: + raise KeyError + time_to_full = ( + self.scale - hourglass.amnt_filled + weight - 1) // weight + event_time = self.curr_simulated_time + time_to_full + event_entry = (event_time, hourglass.linked_tasks[0], queue_name) + heapq.heappush(self.event_list, event_entry) + + def verify_state(self): + """Checks the invariants of the class""" + task_nums = [] + try: + assert len(self.event_list) == 0 or \ + self.curr_simulated_time <= self.event_list[0][0] + for queue_name, hourglass in self.hourglasses.items(): + assert len(hourglass.linked_tasks) > 0 + for task_num in hourglass.linked_tasks: + assert task_num in self.tasks + assert hourglass.amnt_filled >= 0 + assert queue_name not in self.empty_hourglasses + task_nums += list(hourglass.linked_tasks) + if hourglass.amnt_filled >= self.scale: + assert self.event_list[0][0] == self.curr_simulated_time + assert self.curr_simulated_time == hourglass.update_time + for hourglass in self.empty_hourglasses.values(): + assert len(hourglass.linked_tasks) == 0 + assert hourglass.amnt_filled >= 0 + assert sorted(task_nums) == sorted(list(self.tasks.keys())) + except AssertionError as e: + e.args += (repr(self),) + raise e + + def __repr__(self): + return "Tasks: {}\nEvent list: {}\nHourglasses: {}\nTime: {}".format( + self.tasks, self.event_list, self.hourglasses, + self.curr_simulated_time) + + +class NestedScheduler: + """ + Scheduler where each queue is an independent inner scheduler object. + This can be used to implement hierarchies of weights and queues. + """ + def __init__(self, outer_scheduler, inner_schedulers): + self.outer_scheduler = outer_scheduler + self.inner_schedulers = inner_schedulers + + def __len__(self): + return len(self.outer_scheduler) + + def append(self, name_and_item): + name, item = name_and_item + self.outer_scheduler.append((name, None)) + self.inner_schedulers[name].append(item) + + def extend(self, items): + for item in items: + self.append(item) + + def popleft(self): + name = self.outer_scheduler.popleft()[0] + return (name, self.inner_schedulers[name].popleft()) + + def __repr__(self): + return '\n'.join( + ['Outer: ' + repr(self.outer_scheduler)] + + [repr(name) + ': ' + repr(s) + for (name, s) in self.inner_schedulers.items()]) + + +class FrontQueueScheduler: + """ + Scheduler decorator that allows tasks to be placed at the front of the + queue. The front behaves like the front of a deque(), i.e. it is LIFO. + """ + def __init__(self, scheduler): + self.scheduler = scheduler + self.front_queue = deque() + + def __len__(self): + return len(self.front_queue) + len(self.scheduler) + + def append(self, item): + self.scheduler.append(item) + + def extend(self, items): + for item in items: + self.append(item) + + def popleft(self): + if len(self.front_queue) > 0: + return self.front_queue.popleft() + return self.scheduler.popleft() + + def appendleft(self, item): + self.front_queue.appendleft(item) + + def extendleft(self, items): + self.front_queue.extendleft(items) + + def __repr__(self): + return "Front queue:{}\n{}".format(self.front_queue, self.scheduler) + + +class AsyncWrapper: + """ + Decorator that makes a scheduler object behave like an asyncio.Queue(). + """ + def __init__(self, scheduler): + self.schedule_waitlist = asyncio.Queue() + self.scheduler = scheduler + + @property + def maxsize(self): + return 0 + + def qsize(self): + return len(self.scheduler) + self.schedule_waitlist.qsize() + + def empty(self): + return len(self.scheduler) == 0 and self.schedule_waitlist.empty() + + def full(self): + return False + + async def put(self, item): + await self.schedule_waitlist.put(item) + + def put_nowait(self, item): + self.schedule_waitlist.put_nowait(item) + + def extendleft(self, items): + self.scheduler.extendleft(items) + + async def get(self): + if self.empty(): + self.scheduler.append(await self.schedule_waitlist.get()) + while not self.schedule_waitlist.empty(): + self.scheduler.append(self.schedule_waitlist.get_nowait()) + return self.scheduler.popleft() + + def get_nowait(self): + if self.empty(): + raise asyncio.QueueEmpty + while not self.schedule_waitlist.empty(): + self.scheduler.append(self.schedule_waitlist.get_nowait()) + return self.scheduler.popleft() + + def task_done(self): + self.scheduler_waitlist.task_done() + + async def join(self): + await self.scheduler_waitlist.join() + + def __repr__(self): + return repr(self.scheduler)