Skip to content

Commit

Permalink
FEAT: support sparse vector for bge-m3 (#2540)
Browse files Browse the repository at this point in the history
Co-authored-by: pengjunfeng11 <179464367@qq.com>
Co-authored-by: Junfeng Peng <pengjunfeng@hyperchain.cn>
  • Loading branch information
3 people authored Nov 22, 2024
1 parent c456ef9 commit f2b22bb
Show file tree
Hide file tree
Showing 6 changed files with 335 additions and 5 deletions.
45 changes: 45 additions & 0 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,16 @@ async def internal_exception_handler(request: Request, exc: Exception):
else None
),
)
self._router.add_api_route(
"/v1/convert_ids_to_tokens",
self.convert_ids_to_tokens,
methods=["POST"],
dependencies=(
[Security(self._auth_service, scopes=["models:read"])]
if self.is_authenticated()
else None
),
)
self._router.add_api_route(
"/v1/rerank",
self.rerank,
Expand Down Expand Up @@ -1312,6 +1322,41 @@ async def create_embedding(self, request: Request) -> Response:
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

async def convert_ids_to_tokens(self, request: Request) -> Response:
payload = await request.json()
body = CreateEmbeddingRequest.parse_obj(payload)
model_uid = body.model
exclude = {
"model",
"input",
"user",
}
kwargs = {key: value for key, value in payload.items() if key not in exclude}

try:
model = await (await self._get_supervisor_ref()).get_model(model_uid)
except ValueError as ve:
logger.error(str(ve), exc_info=True)
await self._report_error_event(model_uid, str(ve))
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

try:
decoded_texts = await model.convert_ids_to_tokens(body.input, **kwargs)
return Response(decoded_texts, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
self.handle_request_limit_error(re)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

async def rerank(self, request: Request) -> Response:
payload = await request.json()
body = RerankRequest.parse_obj(payload)
Expand Down
37 changes: 37 additions & 0 deletions xinference/client/restful/restful_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,43 @@ def create_embedding(self, input: Union[str, List[str]], **kwargs) -> "Embedding
response_data = response.json()
return response_data

def convert_ids_to_tokens(
self, input: Union[List, List[List]], **kwargs
) -> List[str]:
"""
Convert token IDs to human readable tokens via RESTful APIs.
Parameters
----------
input: Union[List, List[List]]
Input token IDs to convert, can be a single list of token IDs or a list of token ID lists.
To convert multiple sequences in a single request, pass a list of token ID lists.
Returns
-------
list
A list of decoded tokens in human readable format.
Raises
------
RuntimeError
Report the failure of token conversion and provide the error message.
"""
url = f"{self._base_url}/v1/convert_ids_to_tokens"
request_body = {
"model": self._model_uid,
"input": input,
}
request_body.update(kwargs)
response = requests.post(url, json=request_body, headers=self.auth_headers)
if response.status_code != 200:
raise RuntimeError(
f"Failed to decode token ids, detail: {_get_error_string(response)}"
)
response_data = response.json()
return response_data


class RESTfulRerankModelHandle(RESTfulModelHandle):
def rerank(
Expand Down
13 changes: 13 additions & 0 deletions xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,19 @@ async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs):
f"Model {self._model.model_spec} is not for creating embedding."
)

@request_limit
@log_async(logger=logger)
async def convert_ids_to_tokens(
self, input: Union[List, List[List]], *args, **kwargs
):
kwargs.pop("request_id", None)
if hasattr(self._model, "convert_ids_to_tokens"):
return await self._call_wrapper_json(
self._model.convert_ids_to_tokens, input, *args, **kwargs
)

raise AttributeError(f"Model {self._model.model_spec} can convert token id.")

@request_limit
@log_async(logger=logger)
async def rerank(
Expand Down
226 changes: 222 additions & 4 deletions xinference/model/embedding/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,27 @@ def to(self, *args, **kwargs):
device=self._device,
model_kwargs=model_kwargs,
)
elif (
self._kwargs.get("hybrid_mode")
and "m3" in self._model_spec.model_name.lower()
):
try:
from FlagEmbedding import BGEM3FlagModel
except ImportError:
error_message = "Failed to import module 'BGEM3FlagModel'"
installation_guide = [
"Please make sure 'FlagEmbedding' is installed. ",
"You can install it by `pip install FlagEmbedding`\n",
]
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")

model_kwargs = {"torch_dtype": torch_dtype} if torch_dtype else None
self._model = BGEM3FlagModel(
self._model_path,
device=self._device,
model_kwargs=model_kwargs,
trust_remote_code=True,
)
else:
model_kwargs = {"torch_dtype": torch_dtype} if torch_dtype else None
self._model = SentenceTransformer(
Expand All @@ -203,10 +224,155 @@ def to(self, *args, **kwargs):
)

def create_embedding(self, sentences: Union[str, List[str]], **kwargs):
from FlagEmbedding import BGEM3FlagModel
from sentence_transformers import SentenceTransformer

kwargs.setdefault("normalize_embeddings", True)

@no_type_check
def _encode_bgem3(
model: Union[SentenceTransformer, BGEM3FlagModel],
sentences: Union[str, List[str]],
batch_size: int = 32,
show_progress_bar: bool = None,
output_value: str = "sparse_embedding",
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
device: str = None,
normalize_embeddings: bool = False,
**kwargs,
):
"""
Computes sentence embeddings with bge-m3 model
Nothing special here, just replace sentence-transformer with FlagEmbedding
TODO: think about how to solve the redundant code of encode method in the future
:param sentences: the sentences to embed
:param batch_size: the batch size used for the computation
:param show_progress_bar: Output a progress bar when encode sentences
:param output_value: Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values
:param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors.
:param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy
:param device: Which torch.device to use for the computation
:param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
:return:
By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned.
"""
import torch
from tqdm.autonotebook import trange

if show_progress_bar is None:
show_progress_bar = (
logger.getEffectiveLevel() == logging.INFO
or logger.getEffectiveLevel() == logging.DEBUG
)

if convert_to_tensor:
convert_to_numpy = False

if output_value != "sparse_embedding":
convert_to_tensor = False
convert_to_numpy = False

input_was_string = False
if isinstance(sentences, str) or not hasattr(
sentences, "__len__"
): # Cast an individual sentence to a list with length 1
sentences = [sentences]
input_was_string = True

if device is None:
# Same as SentenceTransformer.py
from sentence_transformers.util import get_device_name

device = get_device_name()
logger.info(f"Use pytorch device_name: {device}")

all_embeddings = []
all_token_nums = 0

# The original code does not support other inference engines
def _text_length(text):
if isinstance(text, dict): # {key: value} case
return len(next(iter(text.values())))
elif not hasattr(text, "__len__"): # Object has no len() method
return 1
elif len(text) == 0 or isinstance(
text[0], int
): # Empty string or list of ints
return len(text)
else:
return sum(
[len(t) for t in text]
) # Sum of length of individual strings

length_sorted_idx = np.argsort([-_text_length(sen) for sen in sentences])
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]

for start_index in trange(
0,
len(sentences),
batch_size,
desc="Batches",
disable=not show_progress_bar,
):
sentences_batch = sentences_sorted[
start_index : start_index + batch_size
]

with torch.no_grad():
out_features = model.encode(sentences_batch, **kwargs)

if output_value == "token_embeddings":
embeddings = []
for token_emb, attention in zip(
out_features[output_value], out_features["attention_mask"]
):
last_mask_id = len(attention) - 1
while (
last_mask_id > 0 and attention[last_mask_id].item() == 0
):
last_mask_id -= 1

embeddings.append(token_emb[0 : last_mask_id + 1])
elif output_value is None: # Return all outputs
embeddings = []
for sent_idx in range(len(out_features["sentence_embedding"])):
row = {
name: out_features[name][sent_idx]
for name in out_features
}
embeddings.append(row)
# for sparse embedding
else:
if kwargs.get("return_sparse"):
embeddings = out_features["lexical_weights"]
else:
embeddings = out_features["dense_vecs"]

if convert_to_numpy:
embeddings = embeddings.cpu()

all_embeddings.extend(embeddings)

all_embeddings = [
all_embeddings[idx] for idx in np.argsort(length_sorted_idx)
]

if convert_to_tensor:
if len(all_embeddings):
all_embeddings = torch.stack(all_embeddings)
else:
all_embeddings = torch.Tensor()
elif convert_to_numpy:
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])

if input_was_string:
all_embeddings = all_embeddings[0]

return all_embeddings, all_token_nums

# copied from sentence-transformers, and modify it to return tokens num
@no_type_check
def encode(
Expand Down Expand Up @@ -390,6 +556,10 @@ def encode(
convert_to_numpy=False,
**kwargs,
)
elif isinstance(self._model, BGEM3FlagModel):
all_embeddings, all_token_nums = _encode_bgem3(
self._model, sentences, convert_to_numpy=False, **kwargs
)
else:
all_embeddings, all_token_nums = encode(
self._model,
Expand All @@ -401,14 +571,30 @@ def encode(
all_embeddings = [all_embeddings]
embedding_list = []
for index, data in enumerate(all_embeddings):
embedding_list.append(
EmbeddingData(index=index, object="embedding", embedding=data.tolist())
)
if kwargs.get("return_sparse") and isinstance(self._model, BGEM3FlagModel):
embedding_list.append(
EmbeddingData(
index=index,
object="embedding",
embedding={k: float(v) for k, v in data.items()},
)
)
else:
embedding_list.append(
EmbeddingData(
index=index, object="embedding", embedding=data.tolist()
)
)
usage = EmbeddingUsage(
prompt_tokens=all_token_nums, total_tokens=all_token_nums
)
result = Embedding(
object="list",
object=(
"list" # type: ignore
if not isinstance(self._model, BGEM3FlagModel)
and not kwargs.get("return_sparse")
else "dict"
),
model=self._model_uid,
data=embedding_list,
usage=usage,
Expand All @@ -430,6 +616,38 @@ def encode(

return result

def convert_ids_to_tokens(
self,
batch_token_ids: Union[List[Union[int, str]], List[List[Union[int, str]]]],
**kwargs,
) -> Union[List[str]]:
batch_decoded_texts: List[str] = []

assert self._model is not None

if isinstance(batch_token_ids, (int, str)):
return self._model.tokenizer.convert_ids_to_tokens(
[int(str(batch_token_ids))]
)[0]

# check if it's a nested list
if (
isinstance(batch_token_ids, list)
and batch_token_ids
and isinstance(batch_token_ids[0], list)
):
for token_ids in batch_token_ids:
token_ids = [int(token_id) for token_id in token_ids]
batch_decoded_texts.append(
self._model.tokenizer.convert_ids_to_tokens(token_ids)
)
else:
batch_token_ids = [int(token_id) for token_id in batch_token_ids]
batch_decoded_texts = self._model.tokenizer.convert_ids_to_tokens(
batch_token_ids
)
return batch_decoded_texts


def match_embedding(
model_name: str,
Expand Down
Loading

0 comments on commit f2b22bb

Please sign in to comment.