Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
[FEATURE] OPT-175B service authentication and new priority queue (#700)
Browse files Browse the repository at this point in the history
This commit also adds some checks to improve the robustness of the server
process against adversarial inputs, since there are plans to open the API
to non-key users.
  • Loading branch information
Shiqian Yan committed Jan 11, 2023
1 parent cd1afee commit 6c33a6d
Show file tree
Hide file tree
Showing 3 changed files with 406 additions and 58 deletions.
185 changes: 127 additions & 58 deletions examples/llm_serving/launch_model_worker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import argparse
from collections import defaultdict, namedtuple
from collections import deque, defaultdict, namedtuple
from dataclasses import dataclass, field
import json
import time
Expand All @@ -14,17 +14,16 @@

from llm_serving.generator import Generator
from llm_serving.service.constants import (
NUM_BEAMS, NUM_RETURN_SEQ, ALPA_SERVE_PORT, USE_RECAPTCHA, KEYS_FILENAME)
NUM_BEAMS, NUM_RETURN_SEQ, ALPA_SERVE_PORT, USE_RECAPTCHA, USE_API_KEYS,
ALLOW_NON_KEY_ACCESS, KEYS_FILENAME, AuthGroups, AUTH_GROUP_WEIGHTS,
AUTH_GROUP_SCHEDULER_SCALE, API_KEY_SCHEDULER_SCALE,
API_KEY_DEFAULT_WEIGHT)
from llm_serving.service.recaptcha import load_recaptcha
from llm_serving.service.scheduler import (
WeightedRoundRobin, NestedScheduler, FrontQueueScheduler, AsyncWrapper)
from llm_serving.service.utils import build_logger


@dataclass(order=True)
class PrioritizedItem:
priority: int
item: Any=field(compare=False)


GenerateItem = namedtuple("GenerateItem", ["uid", "return_queue", "data"])
LogprobsItem = namedtuple("LogprobsItem", ["uid", "return_queue", "data"])

Expand All @@ -38,10 +37,14 @@ def __init__(self,
num_beams: int,
num_return_sequences: int,
use_recaptcha: bool,
use_api_keys: bool,
allow_non_key_access: bool,
max_seq_len: int = 1024,
max_batch_size: int = 4,
logprobs_past_cache_size_limit: int = 4,
batch_timeout: float = 1.0):
batch_wait_size_mult: int = 10,
batch_timeout: float = 1.0,
queue_timeout: float = 0.001):

self.logger = build_logger()
self.num_beams = num_beams
Expand All @@ -50,8 +53,9 @@ def __init__(self,

# Batch queues
self.max_bs = max_batch_size
self.batch_wait_size_mult = batch_wait_size_mult
self.batch_timeout = batch_timeout
self.request_queue = asyncio.PriorityQueue()
self.queue_timeout = queue_timeout
self.logprobs_past_cache = defaultdict(lambda: (0, None))
self.logprobs_past_cache_size_limit = logprobs_past_cache_size_limit
asyncio.get_event_loop().create_task(self.batch_loop())
Expand All @@ -75,39 +79,73 @@ def __init__(self,
# Authentication
self.allowed_api_keys = []
self.recaptcha = load_recaptcha(use_recaptcha)
if use_recaptcha:
self.allow_non_key_access = allow_non_key_access
api_key_weights = {}
if use_api_keys:
keys = json.load(open(KEYS_FILENAME, "r"))
self.allowed_api_keys = keys["allowed_api_keys"]
if "api_key_weights" in keys:
api_key_weights = keys["api_key_weights"]

# Scheduling
# Each authentication choice - endpoint pair contains a separate queue,
# and these queues are given fixed weights independent of how many
# requests are within each group. Requests that use API keys are
# further organized based on the API key weights.
inner_schedulers = {}
for auth_group in AuthGroups:
if auth_group == AuthGroups.API_KEY_USER:
inner_schedulers[auth_group] = WeightedRoundRobin(
api_key_weights,
API_KEY_SCHEDULER_SCALE,
API_KEY_DEFAULT_WEIGHT)
else:
inner_schedulers[auth_group] = deque()
self.request_queue = NestedScheduler(
WeightedRoundRobin(
AUTH_GROUP_WEIGHTS, AUTH_GROUP_SCHEDULER_SCALE, None),
inner_schedulers)
# To support batching completion requests without shuffling the order
# of logprob requests, we return the temporarily unqueued logprob
# requests to the front of the queue.
self.request_queue = AsyncWrapper(FrontQueueScheduler(
self.request_queue))

async def batch_loop(self):
while True:
pri_item = await self.request_queue.get()
item = (await self.request_queue.get())[1][1]

# Get the next batch
generate_batch = []
logprobs_item = None
non_batch = []
if isinstance(pri_item.item, GenerateItem):
# Wait for batch opportunity
await asyncio.sleep(self.batch_timeout)
generate_batch.append(pri_item.item)
if isinstance(item, GenerateItem):
batch_wait_size = self.batch_wait_size_mult * self.max_bs
if self.request_queue.qsize() < batch_wait_size:
# Wait for batch opportunity
await asyncio.sleep(self.batch_timeout)
else:
# Yield control until new requests are queued
await asyncio.sleep(self.queue_timeout)
generate_batch.append(item)

while (not self.request_queue.empty() and
len(generate_batch) < self.max_bs):
pri_item = self.request_queue.get_nowait()
if isinstance(pri_item.item, GenerateItem):
generate_batch.append(pri_item.item)
queue_entry = self.request_queue.get_nowait()
item = queue_entry[1][1]
if isinstance(item, GenerateItem):
generate_batch.append(item)
else:
non_batch.append(pri_item)
non_batch.append(queue_entry)
break

# Put non-batch items back to request queue
for x in non_batch:
self.request_queue.put_nowait(x)
elif isinstance(pri_item.item, LogprobsItem):
logprobs_item = pri_item.item
# Return non-batch items to the front of the request queue
non_batch.reverse()
self.request_queue.extendleft(non_batch)
elif isinstance(item, LogprobsItem):
logprobs_item = item
else:
raise RuntimeError(f"Invalid item: {pri_item.item}")
raise RuntimeError(f"Invalid item: {item}")

# Process this batch
if generate_batch:
Expand Down Expand Up @@ -146,7 +184,8 @@ async def batch_loop(self):

logits = output.logits[:num_inputs, -1]
logprobs = torch.log_softmax(logits, dim=-1)
top_logprobs, top_indices = logprobs.topk(arg["top_k"], dim=1)
top_k = min(arg["top_k"], logprobs.shape[1])
top_logprobs, top_indices = logprobs.topk(top_k, dim=1)

# return at most top_k tokens, e.g. if network limited
return_dict = {
Expand All @@ -158,12 +197,12 @@ async def batch_loop(self):

async def handle_request(self, request):
args = await request.json()
self.check_authorization(args, request)
authorization = self.get_authorization(args, request)

if "completions" in request.url.path:
return await self.completions(args, request)
return await self.completions(args, request, authorization)
elif "logprobs" in request.url.path:
return await self.logprobs(args, request)
return await self.logprobs(args, request, authorization)
else:
raise ValueError("Invalid url: {request.url}")

Expand All @@ -174,25 +213,34 @@ def normalize_prompts(self, prompts):
# - case 3: List[int]. Pretokenized. Return one generation.
# - case 4: List[List[int]]. Pretokenized multiple generations.
# our approach is to turn everything into the case 4
if isinstance(prompts, str): # case 1
prompts = [self.generator.encode(prompts)]
elif isinstance(prompts, list) and isinstance(prompts[0], str): # case 2
prompts = [self.generator.encode(p) for p in prompts]
elif isinstance(prompts, list) and isinstance(prompts[0], int): # case 3
prompts = [prompts]
else: # case 4
assert isinstance(prompts[0], list)
assert isinstance(prompts[0][0], int)
if len(prompts[0]) <= 0:
try:
if isinstance(prompts, str): # case 1
prompts = [self.generator.encode(prompts)]
elif isinstance(prompts, list) and isinstance(prompts[0], str):
assert all(isinstance(v, str) for v in prompts)
prompts = [self.generator.encode(p) for p in prompts]
elif isinstance(prompts, list) and isinstance(prompts[0], int):
prompts = [prompts]
assert isinstance(prompts, list)
for sublist in prompts:
assert isinstance(sublist, list)
assert all(isinstance(v, int) for v in sublist)
assert all(v + (1 << 63) < (1 << 64) for v in sublist)
except AssertionError:
raise ValueError(
"The prompt must be either a string, a list of strings, a "
"list of integers, or a list of integer lists.")
if len(prompts[0]) <= 0 or \
any(len(sublist) <= 0 for sublist in prompts):
raise ValueError("The prompt must be nonempty.")
return prompts

async def completions(self, args, request):
async def completions(self, args, request, authorization):
logger = self.logger

if "redirect_logprobs" in args:
# A redirection to workaround some security settings.
return await self.logprobs(args, request)
return await self.logprobs(args, request, authorization)

# Normalize prompts
prompts = args["prompt"]
Expand Down Expand Up @@ -236,9 +284,10 @@ async def completions(self, args, request):
return_queue = asyncio.Queue()
for i, prompt in enumerate(prompts):
data = {"input": prompt, **args}
priority = 0
self.request_queue.put_nowait(PrioritizedItem(
priority, GenerateItem(i, return_queue, data)))
queue_entry = GenerateItem(i, return_queue, data)
auth_group, api_key = authorization
queue_entry = (auth_group, (api_key, queue_entry))
self.request_queue.put_nowait(queue_entry)

unordered_results = []
for i in range(len(prompts)):
Expand All @@ -263,7 +312,7 @@ async def completions(self, args, request):
],
}

async def logprobs(self, args, request):
async def logprobs(self, args, request, authorization):
logger = self.logger

# Normalize prompts
Expand Down Expand Up @@ -297,11 +346,13 @@ async def logprobs(self, args, request):
self.check_max_length_limit(cur_len, self.max_seq_len)

# Push the request to the batch queue
cache_id = args["cache_id"] if "cache_id" in args else str(uuid.uuid4())
cache_id = str(args["cache_id"]) if "cache_id" in args else str(uuid.uuid4())
ret_queue = asyncio.Queue()
data = {"input": prompts, "cache_id": cache_id, **args}
self.request_queue.put_nowait(PrioritizedItem(
50, LogprobsItem(0, ret_queue, data)))
queue_entry = LogprobsItem(0, ret_queue, data)
auth_group, api_key = authorization
queue_entry = (auth_group, (api_key, queue_entry))
self.request_queue.put_nowait(queue_entry)
results = await ret_queue.get()
return {
"cache_id": cache_id,
Expand All @@ -318,17 +369,31 @@ def check_max_length_limit(self, cur_len, max_len):
f"If you want to try longer sequence length, "
f"please consider hosting your own service using Alpa.")

def check_authorization(self, args, request):
if args.get("api_key", None) in self.allowed_api_keys:
return
def get_authorization(self, args, request):
api_key = args.get("api_key", None)
if api_key in self.allowed_api_keys:
return (AuthGroups.API_KEY_USER, api_key)
elif api_key is not None:
self.logger.error(f"Rejected a request with an incorrect key.")
raise ValueError("API key is incorrect, please verify that you "
"have passed the right value (as opposed to, "
"say, an OpenAI API key).")

recaptcha_response = str(args.get("g-recaptcha-response", ""))
if recaptcha_response == "":
if self.allow_non_key_access:
return (AuthGroups.NON_KEY_USER, None)
else:
self.logger.error(f"Rejected a request with no API key.")
raise ValueError("No captcha data found. If you are using "
"client APIs, please contact alpa developers "
"to get an API key.")

if not self.recaptcha.verify(
args.get("g-recaptcha-response", ""),
request.client.host):
if not self.recaptcha.verify(recaptcha_response, request.client.host):
self.logger.error(f"Rejected a request with invalid captcha.")
raise ValueError("Invalid captcha. If you are using the website, please click the "
"\"I'm not a robot\" button. If you are using client APIs, please "
"contact alpa developers to get an API key.")
"\"I'm not a robot\" button.")
return (AuthGroups.RECAPTCHA_USER, None)

def get_remote_ip(self, request):
for x in request.scope['headers']:
Expand All @@ -350,6 +415,8 @@ def get_remote_ip(self, request):
parser.add_argument("--torch-device", type=str, default="cpu")
parser.add_argument("--tokenizer", type=str)
parser.add_argument("--no-recaptcha", action="store_true")
parser.add_argument("--no-api-keys", action="store_true")
parser.add_argument("--block-non-key-access", action="store_true")
parser.add_argument("--register-name", type=str, default="default")
parser.add_argument("--ssl-keyfile", type=str)
parser.add_argument("--ssl-certfile", type=str)
Expand All @@ -368,7 +435,9 @@ def get_remote_ip(self, request):
t = controller.register_model.remote(
args.register_name, LangaugeModelWorker,
(args.model, args.path, args.torch_device, args.tokenizer, NUM_BEAMS, NUM_RETURN_SEQ,
False if args.no_recaptcha else USE_RECAPTCHA),
not args.no_recaptcha and USE_RECAPTCHA,
not args.no_api_keys and USE_API_KEYS,
not args.block_non_key_access and ALLOW_NON_KEY_ACCESS),
override=True)
ray.get(t)
t = controller.create_replica.remote(args.register_name, group_id)
Expand Down
18 changes: 18 additions & 0 deletions examples/llm_serving/service/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Hyper params for serving Meta's OPT model."""
from enum import Enum

# Alpa serve url
ALPA_SERVE_PORT = 20001
Expand All @@ -11,7 +12,24 @@

# Authentication params
USE_RECAPTCHA = False
USE_API_KEYS = False
ALLOW_NON_KEY_ACCESS = True
KEYS_FILENAME = "/home/ubuntu/efs/alpa/examples/llm_serving/keys_file.json"

# Scheduler params
class AuthGroups(Enum):
RECAPTCHA_USER = 1
API_KEY_USER = 2
NON_KEY_USER = 3

AUTH_GROUP_WEIGHTS = {
AuthGroups.RECAPTCHA_USER: 300,
AuthGroups.API_KEY_USER: 10,
AuthGroups.NON_KEY_USER: 1
}
AUTH_GROUP_SCHEDULER_SCALE = 300
API_KEY_SCHEDULER_SCALE = 100
API_KEY_DEFAULT_WEIGHT = 10

# Logging params
LOGDIR = "weblogs"
Loading

0 comments on commit 6c33a6d

Please sign in to comment.