Skip to content

Commit

Permalink
Fix IPEX embedders performance (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
peteriz authored Jul 8, 2024
1 parent ed85cdb commit dc96dad
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 8 deletions.
15 changes: 12 additions & 3 deletions examples/optimized-embeddings.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,16 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Bi-encoders are implemented as two classes, one encoding the documents and the other encoding the queries. We load our quantized embedding model for both:"
"Bi-encoders are implemented as two classes, one encoding the documents and the other encoding the queries.\n",
"Embedding performance on Intel Hardware depends on the data input strategy. It is recommended to calibrate the batch size and padding strategy to maximize the latency or throughput when embedding.\n",
"\n",
"If the length of the sequences is shorter than the maximum length of the model (for example shorter than 512 for BGE), it is recommended to truncate it to speed up encoding. (via `max_sequence_length` argument)\n",
"Padding can be set to `True` so that each batch is padded to the maximum length (could vary between batches) or to `max_length` that will pad the batch to the maximum set length.\n",
"Varying with batch size and `padding=True` will affect the throughput of the embedding model, as larger batches could be encoded to larger sequences and smaller batches could produce a large number of varying in sizes batches.\n",
"\n",
"Experimentation on your data is key to maximize performance!\n",
"\n",
"We load our quantized embedding model for both:"
]
},
{
Expand All @@ -84,7 +93,7 @@
"metadata": {},
"outputs": [],
"source": [
"query_embedder = IPEXSentenceTransformersTextEmbedder(model=\"Intel/bge-small-en-v1.5-rag-int8-static\")"
"query_embedder = IPEXSentenceTransformersTextEmbedder(model=\"Intel/bge-small-en-v1.5-rag-int8-static\", batch_size=1, max_seq_length=512, padding=True)"
]
},
{
Expand All @@ -93,7 +102,7 @@
"metadata": {},
"outputs": [],
"source": [
"doc_embedder = IPEXSentenceTransformersDocumentEmbedder(model=\"Intel/bge-small-en-v1.5-rag-int8-static\")"
"doc_embedder = IPEXSentenceTransformersDocumentEmbedder(model=\"Intel/bge-small-en-v1.5-rag-int8-static\", batch_size=32, max_seq_length=512, padding=True)"
]
},
{
Expand Down
76 changes: 71 additions & 5 deletions fastrag/embedders/ipex_embedder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Dict, List, Optional, Tuple, Union

from haystack.components.embedders import (
SentenceTransformersDocumentEmbedder,
Expand Down Expand Up @@ -27,6 +27,8 @@ def __init__(
device: Optional[str] = None,
auth_token: Optional[Secret] = None,
trust_remote_code: bool = False,
max_seq_length: Optional[int] = None,
padding: Optional[bool] = True,
):
import sentence_transformers

Expand All @@ -39,6 +41,46 @@ def _load_model(self, model_name_or_path, config, cache_dir, **model_args):
)
self.auto_model.eval()

def tokenize(self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]]):
"""
Override of original st.models.Transformer 'Tokenizes' method to add fixed length tokenization.
"""
output = {}
if isinstance(texts[0], str):
to_tokenize = [texts]
elif isinstance(texts[0], dict):
to_tokenize = []
output["text_keys"] = []
for lookup in texts:
text_key, text = next(iter(lookup.items()))
to_tokenize.append(text)
output["text_keys"].append(text_key)
to_tokenize = [to_tokenize]
else:
batch1, batch2 = [], []
for text_tuple in texts:
batch1.append(text_tuple[0])
batch2.append(text_tuple[1])
to_tokenize = [batch1, batch2]

# strip
to_tokenize = [[str(s).strip() for s in col] for col in to_tokenize]

# Lowercase
if self.do_lower_case:
to_tokenize = [[s.lower() for s in col] for col in to_tokenize]

output.update(
self.tokenizer(
*to_tokenize,
padding=self.padding,
truncation=True,
return_tensors="pt",
max_length=self.max_seq_length,
)
)
return output

class _IPEXSentenceTransformer(sentence_transformers.SentenceTransformer):
def _load_auto_model(
self,
Expand Down Expand Up @@ -81,6 +123,10 @@ def device(self):
trust_remote_code=trust_remote_code,
)

if max_seq_length is not None:
self.model._first_module().max_seq_length = max_seq_length
self.model._first_module().padding = padding


def ipex_model_warm_up(self):
"""
Expand All @@ -91,31 +137,51 @@ def ipex_model_warm_up(self):
model=self.model,
device=self.device.to_torch_str(),
auth_token=self.token,
max_seq_length=self.max_seq_length,
padding=self.padding,
)


class IPEXSentenceTransformersDocumentEmbedder(SentenceTransformersDocumentEmbedder):
"""
A document embedder that uses IPEX for efficient computation.
A document embedder that uses IPEX backend for efficient computation.
This class extends the base `SentenceTransformersDocumentEmbedder` class and provides an implementation
that utilizes IPEX for faster document embedding computation.
Parameters:
max_seq_length (int, optional): The maximum sequence length of the input documents. Defaults to None.
padding (bool or str, optional): Whether to pad the input documents to the maximum sequence length.
If True, padding is enabled. If False, padding is disabled. If "max_length", padding is enabled
and the input documents are padded to the maximum sequence length. Defaults to True.
**kwargs: Additional keyword arguments to be passed to the base class constructor.
"""

def __init__(self, **kwargs):
def __init__(self, max_seq_length=None, padding=True, **kwargs):
super().__init__(**kwargs)
self.max_seq_length = max_seq_length
self.padding = padding


class IPEXSentenceTransformersTextEmbedder(SentenceTransformersTextEmbedder):
"""
A text embedder that uses IPEX for efficient text embedding.
A text embedder that uses IPEX backend for efficient text embedding.
This class extends the `SentenceTransformersTextEmbedder` class and provides
an implementation that utilizes IPEX for faster and more efficient text embedding.
Parameters:
max_seq_length (int, optional): The maximum sequence length of the input text. Defaults to None.
padding (bool or str, optional): Whether to pad the input documents to the maximum sequence length.
If True, padding is enabled. If False, padding is disabled. If "max_length", padding is enabled
and the input documents are padded to the maximum sequence length. Defaults to True.
**kwargs: Additional keyword arguments to be passed to the parent class.
"""

def __init__(self, **kwargs):
def __init__(self, max_seq_length=None, padding=True, **kwargs):
super().__init__(**kwargs)
self.max_seq_length = max_seq_length
self.padding = padding


IPEXSentenceTransformersDocumentEmbedder.warm_up = ipex_model_warm_up
Expand Down

0 comments on commit dc96dad

Please sign in to comment.