diff --git a/main.py b/main.py index 504135c..40c38e7 100644 --- a/main.py +++ b/main.py @@ -6,6 +6,7 @@ import falcon import msgspec import torch # type: ignore +import torch.nn.functional as F import transformers from falcon.asgi import App, Request, Response from llmspec import ( @@ -14,6 +15,9 @@ ChatMessage, CompletionChoice, CompletionResponse, + EmbeddingData, + EmbeddingRequest, + EmbeddingResponse, ErrorResponse, LanguageModels, PromptCompletionRequest, @@ -22,8 +26,10 @@ ) DEFAULT_MODEL = "THUDM/chatglm-6b-int4" +DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" TOKENIZER = os.environ.get("MODELZ_TOKENIZER", DEFAULT_MODEL) MODEL = os.environ.get("MODELZ_MODEL", DEFAULT_MODEL) +EMBEDDING_MODEL = os.environ.get("MODELZ_EMBEDDING_MODEL", DEFAULT_EMBEDDING_MODEL) logger = logging.getLogger() @@ -153,14 +159,92 @@ async def on_post(self, req: Request, resp: Response): resp.data = completion.to_json() +class Embeddings: + def __init__(self, model_name: str) -> None: + self.model_name = model_name + self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model) + self.model = transformers.AutoModel.from_pretrained(self.model) + + def embed_and_get_token_count(self, sentences): + # Mean Pooling - Take attention mask into account for correct averaging + def mean_pooling(model_output, attention_mask): + token_embeddings = model_output[ + 0 + ] # First element of model_output contains all token embeddings + input_mask_expanded = ( + attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + ) + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( + input_mask_expanded.sum(1), min=1e-9 + ) + + # Tokenize sentences + encoded_input = self.tokenizer( + sentences, padding=True, truncation=True, return_tensors="pt" + ) + token_count = encoded_input["attention_mask"].sum(dim=1) + + # Compute token embeddings + with torch.no_grad(): + model_output = self.model(**encoded_input) + + # Perform pooling + sentence_embeddings = mean_pooling( + model_output, encoded_input["attention_mask"] + ) + + # Normalize embeddings + sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) + + return token_count, sentence_embeddings + + async def on_post(self, req: Request, resp: Response, engine: str = ""): + buf = await req.stream.readall() + try: + embedding_req = EmbeddingRequest.from_bytes(buf=buf) + except msgspec.ValidationError as err: + logger.info(f"Failed to parse request: {err}") + resp.status = falcon.HTTP_400 + resp.data = ErrorResponse.from_validation_err(err, str(buf)).to_json() + return + + token_count, embeddings = self.embed_and_get_token_count(embedding_req.input) + # convert embeddings of type list[Tensor] | ndarray to list[float] + if isinstance(embeddings, list): + embeddings = [e.tolist() for e in embeddings] + elif isinstance(embeddings, torch.Tensor): + embeddings = embeddings.tolist() + else: + embeddings = embeddings.tolist() + + embedding_data = EmbeddingData(embedding=embeddings, index=0) + embedding_resp = EmbeddingResponse( + data=embedding_data, + model=self.model_name, + usage=TokenUsage( + prompt_tokens=token_count, + # No completions performed, only embeddings generated. + completion_tokens=0, + total_tokens=token_count, + ), + ) + resp.data = embedding_resp.to_json() + + +embeddings = Embeddings(EMBEDDING_MODEL) + app = App() app.add_route("/", Ping()) app.add_route("/completions", Completions(model_name=MODEL)) app.add_route("/chat/completions", ChatCompletions(model_name=MODEL)) +app.add_route("/embeddings", embeddings) +app.add_route("/engines/{engine}/embeddings", embeddings) # refer to https://platform.openai.com/docs/api-reference/chat # make it fully compatible with the current OpenAI API endpoints app.add_route("/v1/completions", Completions(model_name=MODEL)) app.add_route("/v1/chat/completions", ChatCompletions(model_name=MODEL)) +app.add_route("/v1/embeddings", embeddings) +app.add_route("/v1/engines/{engine}/embeddings", embeddings) if __name__ == "__main__": diff --git a/requirements-cpu.txt b/requirements-cpu.txt new file mode 100644 index 0000000..fd38477 --- /dev/null +++ b/requirements-cpu.txt @@ -0,0 +1,8 @@ +msgpack +mosec +torch --extra-index-url https://download.pytorch.org/whl/cpu +diffusers[torch] +transformers +llmspec +falcon +uvicorn \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index e37681b..8158169 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ accelerate llmspec falcon uvicorn +sentence_transformers \ No newline at end of file