Skip to content

Commit

Permalink
add embedding support in main.py (#22)
Browse files Browse the repository at this point in the history
* add embedding support in main.py

* use llmspec structs, add engine route

Signed-off-by: Teddy Xinyuan Chen <45612704+tddschn@users.noreply.github.com>

* use sentence_transformers

Signed-off-by: Teddy Xinyuan Chen <45612704+tddschn@users.noreply.github.com>

* fix EmbeddingData import

Signed-off-by: Teddy Xinyuan Chen <45612704+tddschn@users.noreply.github.com>

* fix EmbeddingData

Signed-off-by: Teddy Xinyuan Chen <45612704+tddschn@users.noreply.github.com>

* fix engine routes

Signed-off-by: Teddy Xinyuan Chen <45612704+tddschn@users.noreply.github.com>

* remove duplicated line

Co-authored-by: Keming <kemingy94@gmail.com>

* fixes

Signed-off-by: Teddy Xinyuan Chen <45612704+tddschn@users.noreply.github.com>

* add token count support

Signed-off-by: Teddy Xinyuan Chen <45612704+tddschn@users.noreply.github.com>

* merge embedding classes

Signed-off-by: Teddy Xinyuan Chen <45612704+tddschn@users.noreply.github.com>

* fix imports

Signed-off-by: Teddy Xinyuan Chen <45612704+tddschn@users.noreply.github.com>

* remove SentenceTransformer code

Signed-off-by: Teddy Xinyuan Chen <45612704+tddschn@users.noreply.github.com>

* add requirements-cpu.txt

Signed-off-by: Teddy Xinyuan Chen <45612704+tddschn@users.noreply.github.com>

---------

Signed-off-by: Teddy Xinyuan Chen <45612704+tddschn@users.noreply.github.com>
Signed-off-by: Keming <kemingyang@tensorchord.ai>
Co-authored-by: Keming <kemingy94@gmail.com>
Co-authored-by: Keming <kemingyang@tensorchord.ai>
  • Loading branch information
3 people authored May 25, 2023
1 parent c528ffe commit e64293d
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 0 deletions.
84 changes: 84 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import falcon
import msgspec
import torch # type: ignore
import torch.nn.functional as F
import transformers
from falcon.asgi import App, Request, Response
from llmspec import (
Expand All @@ -14,6 +15,9 @@
ChatMessage,
CompletionChoice,
CompletionResponse,
EmbeddingData,
EmbeddingRequest,
EmbeddingResponse,
ErrorResponse,
LanguageModels,
PromptCompletionRequest,
Expand All @@ -22,8 +26,10 @@
)

DEFAULT_MODEL = "THUDM/chatglm-6b-int4"
DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
TOKENIZER = os.environ.get("MODELZ_TOKENIZER", DEFAULT_MODEL)
MODEL = os.environ.get("MODELZ_MODEL", DEFAULT_MODEL)
EMBEDDING_MODEL = os.environ.get("MODELZ_EMBEDDING_MODEL", DEFAULT_EMBEDDING_MODEL)


logger = logging.getLogger()
Expand Down Expand Up @@ -157,14 +163,92 @@ async def on_post(self, req: Request, resp: Response):
resp.data = completion.to_json()


class Embeddings:
def __init__(self, model_name: str) -> None:
self.model_name = model_name
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model)
self.model = transformers.AutoModel.from_pretrained(self.model)

def embed_and_get_token_count(self, sentences):
# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[
0
] # First element of model_output contains all token embeddings
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)

# Tokenize sentences
encoded_input = self.tokenizer(
sentences, padding=True, truncation=True, return_tensors="pt"
)
token_count = encoded_input["attention_mask"].sum(dim=1)

# Compute token embeddings
with torch.no_grad():
model_output = self.model(**encoded_input)

# Perform pooling
sentence_embeddings = mean_pooling(
model_output, encoded_input["attention_mask"]
)

# Normalize embeddings
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

return token_count, sentence_embeddings

async def on_post(self, req: Request, resp: Response, engine: str = ""):
buf = await req.stream.readall()
try:
embedding_req = EmbeddingRequest.from_bytes(buf=buf)
except msgspec.ValidationError as err:
logger.info(f"Failed to parse request: {err}")
resp.status = falcon.HTTP_400
resp.data = ErrorResponse.from_validation_err(err, str(buf)).to_json()
return

token_count, embeddings = self.embed_and_get_token_count(embedding_req.input)
# convert embeddings of type list[Tensor] | ndarray to list[float]
if isinstance(embeddings, list):
embeddings = [e.tolist() for e in embeddings]
elif isinstance(embeddings, torch.Tensor):
embeddings = embeddings.tolist()
else:
embeddings = embeddings.tolist()

embedding_data = EmbeddingData(embedding=embeddings, index=0)
embedding_resp = EmbeddingResponse(
data=embedding_data,
model=self.model_name,
usage=TokenUsage(
prompt_tokens=token_count,
# No completions performed, only embeddings generated.
completion_tokens=0,
total_tokens=token_count,
),
)
resp.data = embedding_resp.to_json()


embeddings = Embeddings(EMBEDDING_MODEL)

app = App()
app.add_route("/", Ping())
app.add_route("/completions", Completions(model_name=MODEL))
app.add_route("/chat/completions", ChatCompletions(model_name=MODEL))
app.add_route("/embeddings", embeddings)
app.add_route("/engines/{engine}/embeddings", embeddings)
# refer to https://platform.openai.com/docs/api-reference/chat
# make it fully compatible with the current OpenAI API endpoints
app.add_route("/v1/completions", Completions(model_name=MODEL))
app.add_route("/v1/chat/completions", ChatCompletions(model_name=MODEL))
app.add_route("/v1/embeddings", embeddings)
app.add_route("/v1/engines/{engine}/embeddings", embeddings)


if __name__ == "__main__":
Expand Down
8 changes: 8 additions & 0 deletions requirements-cpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
msgpack
mosec
torch --extra-index-url https://download.pytorch.org/whl/cpu
diffusers[torch]
transformers
llmspec
falcon
uvicorn
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ accelerate
llmspec
falcon
uvicorn
sentence_transformers

0 comments on commit e64293d

Please sign in to comment.