diff --git a/scripts/optimizations/embedders/embedders.py b/scripts/optimizations/embedders/embedders.py index 5583624..f9f319e 100644 --- a/scripts/optimizations/embedders/embedders.py +++ b/scripts/optimizations/embedders/embedders.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Union +from typing import Any, Dict, List, Union import numpy as np import torch @@ -115,6 +115,16 @@ def _mean_pooling(outputs, attention_mask): class EmbedderModelMTEB(EmbedderModel): + def encode( + self, sentences: list[str], batch_size=32, **kwargs: Any + ) -> torch.Tensor | np.ndarray: + return self.encode_sentences( + sentences=sentences, + batch_size=batch_size, + normalize=True, + convert_to_numpy=True, + ) + def encode_queries(self, queries: List[str], batch_size=32, **kwargs): if self.query_prompt: sentences = [self.query_prompt + q for q in queries] @@ -133,9 +143,11 @@ def encode_corpus( sep = " " if type(corpus[0]) is dict: sentences = [ - (doc["title"].strip() + sep + doc["text"]).strip() - if "title" in doc - else doc["text"].strip() + ( + (doc["title"].strip() + sep + doc["text"]).strip() + if "title" in doc + else doc["text"].strip() + ) for doc in corpus ] else: diff --git a/scripts/optimizations/embedders/quantize_embedder.py b/scripts/optimizations/embedders/quantize_embedder.py index 87ab897..ba2b473 100644 --- a/scripts/optimizations/embedders/quantize_embedder.py +++ b/scripts/optimizations/embedders/quantize_embedder.py @@ -10,7 +10,6 @@ from neural_compressor.config import PostTrainingQuantConfig from optimum.intel import INCQuantizer, IPEXModel from sentence_transformers import SentenceTransformer -from simple_parsing import field from transformers import AutoModel, AutoTokenizer @@ -125,10 +124,41 @@ def preprocess_function(examples): ], "retrieval": [ "ArguAna", - # "ClimateFEVER", - # "FEVER", - # "FiQA2018", - # "HotpotQA", + "ClimateFEVER", + "CQADupstackAndroidRetrieval", + "CQADupstackEnglishRetrieval", + "CQADupstackGamingRetrieval", + "CQADupstackGisRetrieval", + "CQADupstackMathematicaRetrieval", + "CQADupstackPhysicsRetrieval", + "CQADupstackProgrammersRetrieval", + "CQADupstackStatsRetrieval", + "CQADupstackTexRetrieval", + "CQADupstackUnixRetrieval", + "CQADupstackWebmastersRetrieval", + "CQADupstackWordpressRetrieval", + "DBPedia", + "FaithDial", + "FeedbackQARetrieval", + "FEVER", + "FiQA2018", + "HagridRetrieval", + "HotpotQA", + "LegalBenchConsumerContractsQA", + "LegalBenchCorporateLobbying", + "LegalSummarization", + "MLQuestions", + "MSMARCO", + "NarrativeQARetrieval", + "NFCorpus", + "NQ", + "RARbCode", + "RARbMath", + "SCIDOCS", + "SciFact", + "TopiOCQA", + "Touche2020", + "TRECCOVID", ], } @@ -137,8 +167,8 @@ def _gather_rerank_results(results): res = {} total = 0.0 for task in results: - res[task] = results[task]["test"]["map"] - total += res[task] + res[task.task_name] = task.scores["test"][0]["map"] + total += res[task.task_name] res["avg"] = total / len(results) return res @@ -147,7 +177,7 @@ def _gather_retrieval_results(results): res = {} total = 0.0 for task in results: - res[task] = results[task]["test"]["ndcg_at_10"] + res[task.task_name] = task.scores["test"][0]["ndcg_at_10"] total += res[task] res["avg"] = total / len(results) return res @@ -158,14 +188,8 @@ def _gather_retrieval_results(results): "retrieval": _gather_retrieval_results, } -TASK_TYPES = { - "rerank": "Reranking", - "retrieval": "Retrieval" -} - def _run_validation(model, task, model_path): - tasks = mteb.get_tasks(task_types=TASK_TYPES[task], languages=["eng"]) - evaluation = MTEB(tasks=tasks) + evaluation = MTEB(tasks=mteb.get_tasks(tasks=benchmarks[task])) results = evaluation.run( model, overwrite_results=True, output_folder=model_path, eval_splits=["test"] )