Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add embedding support in main.py #22

Merged
merged 15 commits into from
May 25, 2023
84 changes: 84 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import falcon
import msgspec
import torch # type: ignore
import torch.nn.functional as F
import transformers
from falcon.asgi import App, Request, Response
from llmspec import (
Expand All @@ -14,6 +15,9 @@
ChatMessage,
CompletionChoice,
CompletionResponse,
EmbeddingData,
EmbeddingRequest,
EmbeddingResponse,
ErrorResponse,
LanguageModels,
PromptCompletionRequest,
Expand All @@ -22,8 +26,10 @@
)

DEFAULT_MODEL = "THUDM/chatglm-6b-int4"
DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
TOKENIZER = os.environ.get("MODELZ_TOKENIZER", DEFAULT_MODEL)
MODEL = os.environ.get("MODELZ_MODEL", DEFAULT_MODEL)
EMBEDDING_MODEL = os.environ.get("MODELZ_EMBEDDING_MODEL", DEFAULT_EMBEDDING_MODEL)


logger = logging.getLogger()
Expand Down Expand Up @@ -153,14 +159,92 @@ async def on_post(self, req: Request, resp: Response):
resp.data = completion.to_json()


class Embeddings:
def __init__(self, model_name: str) -> None:
self.model_name = model_name
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model)
self.model = transformers.AutoModel.from_pretrained(self.model)

def embed_and_get_token_count(self, sentences):
# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[
0
] # First element of model_output contains all token embeddings
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)

# Tokenize sentences
encoded_input = self.tokenizer(
sentences, padding=True, truncation=True, return_tensors="pt"
)
token_count = encoded_input["attention_mask"].sum(dim=1)

# Compute token embeddings
with torch.no_grad():
model_output = self.model(**encoded_input)

# Perform pooling
sentence_embeddings = mean_pooling(
model_output, encoded_input["attention_mask"]
)

# Normalize embeddings
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

return token_count, sentence_embeddings

async def on_post(self, req: Request, resp: Response, engine: str = ""):
buf = await req.stream.readall()
try:
embedding_req = EmbeddingRequest.from_bytes(buf=buf)
except msgspec.ValidationError as err:
logger.info(f"Failed to parse request: {err}")
resp.status = falcon.HTTP_400
resp.data = ErrorResponse.from_validation_err(err, str(buf)).to_json()
return

token_count, embeddings = self.embed_and_get_token_count(embedding_req.input)
# convert embeddings of type list[Tensor] | ndarray to list[float]
if isinstance(embeddings, list):
embeddings = [e.tolist() for e in embeddings]
elif isinstance(embeddings, torch.Tensor):
embeddings = embeddings.tolist()
else:
embeddings = embeddings.tolist()

embedding_data = EmbeddingData(embedding=embeddings, index=0)
embedding_resp = EmbeddingResponse(
data=embedding_data,
model=self.model_name,
usage=TokenUsage(
prompt_tokens=token_count,
# No completions performed, only embeddings generated.
completion_tokens=0,
total_tokens=token_count,
),
)
resp.data = embedding_resp.to_json()


embeddings = Embeddings(EMBEDDING_MODEL)

app = App()
app.add_route("/", Ping())
app.add_route("/completions", Completions(model_name=MODEL))
app.add_route("/chat/completions", ChatCompletions(model_name=MODEL))
app.add_route("/embeddings", embeddings)
app.add_route("/engines/{engine}/embeddings", embeddings)
# refer to https://platform.openai.com/docs/api-reference/chat
# make it fully compatible with the current OpenAI API endpoints
app.add_route("/v1/completions", Completions(model_name=MODEL))
app.add_route("/v1/chat/completions", ChatCompletions(model_name=MODEL))
app.add_route("/v1/embeddings", embeddings)
app.add_route("/v1/engines/{engine}/embeddings", embeddings)


if __name__ == "__main__":
Expand Down
8 changes: 8 additions & 0 deletions requirements-cpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
msgpack
mosec
torch --extra-index-url https://download.pytorch.org/whl/cpu
diffusers[torch]
transformers
llmspec
falcon
uvicorn
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ accelerate
llmspec
falcon
uvicorn
sentence_transformers