-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add FaqGen Accuracy scripts & Refine Ragas (#91)
* fix ragas to align latest code Signed-off-by: Xinyao Wang <xinyao.wang@intel.com> * add FaqGen Accuracy scripts Signed-off-by: Xinyao Wang <xinyao.wang@intel.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix bug Signed-off-by: Xinyao Wang <xinyao.wang@intel.com> --------- Signed-off-by: Xinyao Wang <xinyao.wang@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
514a6d6
commit 4df6438
Showing
8 changed files
with
227 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
## Dataset | ||
We evaluate performance on QA dataset [Squad_v2](https://huggingface.co/datasets/rajpurkar/squad_v2). Generate FAQs on "context" columns in validation dataset, which contains 1204 unique records. | ||
|
||
First download dataset and put at "./data". | ||
|
||
Extract unique "context" columns, which will be save to 'data/sqv2_context.json': | ||
``` | ||
python get_context.py | ||
``` | ||
|
||
## Generate FAQs | ||
|
||
### Launch FaQGen microservice | ||
Please refer to [FaQGen microservice](https://github.com/opea-project/GenAIComps/tree/main/comps/llms/faq-generation/tgi), set up an microservice endpoint. | ||
``` | ||
export FAQ_ENDPOINT = "http://${your_ip}:9000/v1/faqgen" | ||
``` | ||
|
||
### Generate FAQs with microservice | ||
Use the microservice endpoint to generate FAQs for dataset. | ||
``` | ||
python generate_FAQ.py | ||
``` | ||
|
||
Post-process the output to get the right data, which will be save to 'data/sqv2_faq.json'. | ||
``` | ||
python post_process_FAQ.py | ||
``` | ||
|
||
## Evaluate with Ragas | ||
|
||
### Launch TGI service | ||
We use "mistralai/Mixtral-8x7B-Instruct-v0.1" as LLM referee to evaluate the model. First we need to launch a LLM endpoint on Gaudi. | ||
``` | ||
export HUGGING_FACE_HUB_TOKEN="your_huggingface_token" | ||
bash launch_tgi.sh | ||
``` | ||
Get the endpoint: | ||
``` | ||
export LLM_ENDPOINT = "http://${ip_address}:8082" | ||
``` | ||
|
||
Verify the service: | ||
```bash | ||
curl http://${ip_address}:8082/generate \ | ||
-X POST \ | ||
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":128}}' \ | ||
-H 'Content-Type: application/json' | ||
``` | ||
|
||
### Evaluate | ||
evaluate the performance with the LLM: | ||
``` | ||
python evaluate.py | ||
``` | ||
|
||
### Performance Result | ||
Here is the tested result for your reference | ||
| answer_relevancy | faithfulness | context_utilization | reference_free_rubrics_score | | ||
| ---- | ---- |---- |---- | | ||
| 0.7191 | 0.9681 | 0.8964 | 4.4125| |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import json | ||
import os | ||
|
||
from langchain_community.embeddings import HuggingFaceBgeEmbeddings | ||
|
||
from evals.metrics.ragas import RagasMetric | ||
|
||
llm_endpoint = os.getenv("LLM_ENDPOINT", "http://0.0.0.0:8082") | ||
|
||
f = open("data/sqv2_context.json", "r") | ||
sqv2_context = json.load(f) | ||
|
||
f = open("data/sqv2_faq.json", "r") | ||
sqv2_faq = json.load(f) | ||
|
||
templ = """Create a concise FAQs (frequently asked questions and answers) for following text: | ||
TEXT: {text} | ||
Do not use any prefix or suffix to the FAQ. | ||
""" | ||
|
||
number = 1204 | ||
question = [] | ||
answer = [] | ||
ground_truth = ["None"] * number | ||
contexts = [] | ||
for i in range(number): | ||
inputs = sqv2_context[str(i)] | ||
inputs_faq = templ.format_map({"text": inputs}) | ||
actual_output = sqv2_faq[str(i)] | ||
|
||
question.append(inputs_faq) | ||
answer.append(actual_output) | ||
contexts.append([inputs_faq]) | ||
|
||
embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5") | ||
metrics_faq = ["answer_relevancy", "faithfulness", "context_utilization", "reference_free_rubrics_score"] | ||
metric = RagasMetric(threshold=0.5, model=llm_endpoint, embeddings=embeddings, metrics=metrics_faq) | ||
|
||
test_case = {"question": question, "answer": answer, "ground_truth": ground_truth, "contexts": contexts} | ||
|
||
metric.measure(test_case) | ||
print(metric.score) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import json | ||
import os | ||
import time | ||
|
||
import requests | ||
|
||
llm_endpoint = os.getenv("FAQ_ENDPOINT", "http://0.0.0.0:9000/v1/faqgen") | ||
|
||
f = open("data/sqv2_context.json", "r") | ||
sqv2_context = json.load(f) | ||
|
||
start_time = time.time() | ||
headers = {"Content-Type": "application/json"} | ||
for i in range(1204): | ||
start_time_tmp = time.time() | ||
print(i) | ||
inputs = sqv2_context[str(i)] | ||
data = {"query": inputs, "max_new_tokens": 128} | ||
response = requests.post(llm_endpoint, json=data, headers=headers) | ||
f = open(f"data/result/sqv2_faq_{i}", "w") | ||
f.write(inputs) | ||
f.write(str(response.content, encoding="utf-8")) | ||
f.close() | ||
print(f"Cost {time.time()-start_time_tmp} seconds") | ||
print(f"\n Finished! \n Totally Cost {time.time()-start_time} seconds\n") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import json | ||
import os | ||
|
||
import pandas as pd | ||
|
||
data_path = "./data" | ||
data = pd.read_parquet(os.path.join(data_path, "squad_v2/squad_v2/validation-00000-of-00001.parquet")) | ||
sq_context = list(data["context"].unique()) | ||
sq_context_d = dict() | ||
for i in range(len(sq_context)): | ||
sq_context_d[i] = sq_context[i] | ||
|
||
with open(os.path.join(data_path, "sqv2_context.json"), "w") as outfile: | ||
json.dump(sq_context_d, outfile) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
max_input_tokens=3072 | ||
max_total_tokens=4096 | ||
port_number=8082 | ||
model_name="mistralai/Mixtral-8x7B-Instruct-v0.1" | ||
volume="./data" | ||
docker run -it --rm \ | ||
--name="tgi_Mixtral" \ | ||
-p $port_number:80 \ | ||
-v $volume:/data \ | ||
--runtime=habana \ | ||
--restart always \ | ||
-e HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN \ | ||
-e HABANA_VISIBLE_DEVICES=all \ | ||
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \ | ||
-e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \ | ||
--cap-add=sys_nice \ | ||
--ipc=host \ | ||
-e HTTPS_PROXY=$https_proxy \ | ||
-e HTTP_PROXY=$https_proxy \ | ||
ghcr.io/huggingface/tgi-gaudi:2.0.1 \ | ||
--model-id $model_name \ | ||
--max-input-tokens $max_input_tokens \ | ||
--max-total-tokens $max_total_tokens \ | ||
--sharded true \ | ||
--num-shard 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import json | ||
|
||
faq_dict = {} | ||
fails = [] | ||
for i in range(1204): | ||
data = open(f"data/result/sqv2_faq_{i}", "r").readlines() | ||
result = data[-6][6:] | ||
# print(result) | ||
if "LLMChain/final_output" not in result: | ||
print(f"error1: fail for {i}") | ||
fails.append(i) | ||
continue | ||
try: | ||
result2 = json.loads(result) | ||
result3 = result2["ops"][0]["value"]["text"] | ||
faq_dict[str(i)] = result3 | ||
except: | ||
print(f"error2: fail for {i}") | ||
fails.append(i) | ||
continue | ||
with open("data/sqv2_faq.json", "w") as outfile: | ||
json.dump(faq_dict, outfile) | ||
print("Failure index:") | ||
print(fails) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters