This repository has been archived by the owner on Mar 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 731
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Simple short-form Self-RAG Pack (#907)
- Loading branch information
1 parent
965a254
commit 20c2f59
Showing
6 changed files
with
1,025 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Simple self-RAG short form pack | ||
|
||
This LlamaPack implements (*in short form) the [self-RAG paper by Akari et al.](https://arxiv.org/pdf/2310.11511.pdf). | ||
|
||
This paper presents a novel framework called Self-Reflective Retrieval-Augmented Generation (SELF-RAG). Which aims to enhance the quality and factuality of large language models (LLMs) by combining retrieval and self-reflection mechanisms. | ||
|
||
The implementation is adapted from the author [implementation](https://github.com/AkariAsai/self-rag) | ||
A full notebook guide can be found [here](https://github.com/run-llama/llama-hub/blob/main/llama_hub/llama_packs/self_rag/self_rag.ipynb). | ||
|
||
|
||
## CLI Usage | ||
|
||
You can download llamapacks directly using `llamaindex-cli`, which comes installed with the `llama-index` python package: | ||
|
||
```bash | ||
llamaindex-cli download-llamapack SelfRAGPack --download-dir ./self_rag_pack | ||
``` | ||
|
||
You can then inspect the files at `./self_rag_pack` and use them as a template for your own project! | ||
|
||
## Code Usage | ||
|
||
We will show you how to import the agent from these files! | ||
The implementation uses llama-cpp, to download the relevant models (be sure to replace DIR_PATH) | ||
```bash | ||
pip3 install -q huggingface-hub | ||
huggingface-cli download m4r1/selfrag_llama2_7b-GGUF selfrag_llama2_7b.q4_k_m.gguf --local-dir "<DIR_PATH>" --local-dir-use-symlinks False | ||
``` | ||
|
||
```python | ||
from llama_index.llama_pack import download_llama_pack | ||
|
||
# download and install dependencies | ||
SelfRAGPack = download_llama_pack( | ||
"SelfRAGPack", "./self_rag_pack" | ||
) | ||
|
||
``` | ||
|
||
From here, you can use the pack. You can import the relevant modules from the download folder (in the example below we assume it's a relative import or the directory | ||
has been added to your system path). | ||
|
||
```python | ||
from self_rag_pack.base import SelfRAGQueryEngine | ||
|
||
query_engine = SelfRAGQueryEngine(model_path=model_path, retriever=retriever, verbose=True) | ||
|
||
response = query_engine.query("Who won best Director in the 1972 Academy Awards?") | ||
``` | ||
|
||
You can also use/initialize the pack directly. | ||
|
||
```python | ||
from llm_compiler_agent_pack.base import SelfRAGPack | ||
|
||
agent_pack = SelfRAGPack(model_path=model_path, retriever=retriever, verbose=True) | ||
``` | ||
|
||
The `run()` function is a light wrapper around `agent.chat()`. | ||
|
||
```python | ||
response = pack.run("Who won best Director in the 1972 Academy Awards?") | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,310 @@ | ||
from dataclasses import dataclass | ||
from typing import Any, Dict, List | ||
import numpy as np | ||
|
||
from llama_index import Response | ||
from llama_index.llama_pack.base import BaseLlamaPack | ||
from llama_index.bridge.pydantic import Field | ||
from llama_index.query_engine import CustomQueryEngine | ||
from llama_index.core.base_retriever import BaseRetriever | ||
from llama_index.schema import NodeWithScore, TextNode | ||
from llama_index.utils import print_text | ||
|
||
_IMPORT_ERROR_MSG = ( | ||
"`llama_cpp` package not found, please run `pip install llama_cpp_python`" | ||
) | ||
|
||
_RELEVANCE_TOKENS = ["[Irrelevant]", "[Relevant]"] | ||
|
||
_RETRIEVAL_TOKENS = ["[No Retrieval]", "[Retrieval]", "[Continue to Use Evidence]"] | ||
_UTILITY_TOKENS = [ | ||
"[Utility:1]", | ||
"[Utility:2]", | ||
"[Utility:3]", | ||
"[Utility:4]", | ||
"[Utility:5]", | ||
] | ||
_GROUND_TOKENS = [ | ||
"[Fully supported]", | ||
"[Partially supported]", | ||
"[No support / Contradictory]", | ||
] | ||
_CTRL_TOKENS = [ | ||
"[Fully supported]", | ||
"[Partially supported]", | ||
"[No support / Contradictory]", | ||
"[No Retrieval]", | ||
"[Retrieval]", | ||
"[Irrelevant]", | ||
"[Relevant]", | ||
"[Continue to Use Evidence]", | ||
"<paragraph>", | ||
"</paragraph>", | ||
"[Utility:1]", | ||
"[Utility:2]", | ||
"[Utility:3]", | ||
"[Utility:4]", | ||
"[Utility:5]", | ||
] | ||
|
||
_MODEL_KWARGS = {"logits_all": True, "n_ctx": 2048, "n_gpu_layers": -1} | ||
_GENERATE_KWARGS = { | ||
"temperature": 0.0, | ||
"top_p": 1.0, | ||
"max_tokens": 50, | ||
"logprobs": 32016, | ||
} | ||
|
||
|
||
@dataclass | ||
class CriticOutput: | ||
llm_response_per_paragraph: Dict[int, str] | ||
paragraphs_final_score: Dict[int, float] | ||
source_nodes: List[NodeWithScore] | ||
|
||
|
||
def _format_prompt(input: str, paragraph: str = None) -> str: | ||
prompt = "### Instruction:\n{0}\n\n### Response:\n".format(input) | ||
if paragraph is not None: | ||
prompt += "[Retrieval]<paragraph>{0}</paragraph>".format(paragraph) | ||
return prompt | ||
|
||
|
||
def _postprocess_answer(answer: str) -> str: | ||
for token in _CTRL_TOKENS: | ||
answer = answer.replace(token, "") | ||
|
||
if "</s>" in answer: | ||
answer = answer.replace("</s>", "") | ||
if "\n" in answer: | ||
answer = answer.replace("\n", "") | ||
|
||
if "<|endoftext|>" in answer: | ||
answer = answer.replace("<|endoftext|>", "") | ||
|
||
return answer | ||
|
||
|
||
def _relevance_score(pred_log_probs: Dict[str, float]) -> float: | ||
"""Compute relevance score | ||
Args: | ||
pred_log_probs (Dict[str, float]): log probabilities of tokens | ||
Returns: | ||
float: relevance score | ||
""" | ||
rel_prob = np.exp(float(pred_log_probs["[Relevant]"])) | ||
irel_prob = np.exp(float(pred_log_probs["[Irrelevant]"])) | ||
return rel_prob / (rel_prob + irel_prob) | ||
|
||
|
||
def _is_supported_score( | ||
pred_tokens: List[int], pred_log_probs_dict: List[Dict[str, float]] | ||
) -> float: | ||
"""Compute support score | ||
Args: | ||
pred_tokens (List[int]): List of predicted tokens | ||
pred_log_probs_dict (List[Dict[str, float]]): log probabilities of tokens for each predicted tokens | ||
Returns: | ||
float: support score | ||
""" | ||
isSup_score = 0 | ||
groundness_token_appear_id = -1 | ||
for tok_idx, token in enumerate(pred_tokens): | ||
if token in _GROUND_TOKENS: | ||
groundness_token_appear_id = tok_idx | ||
break | ||
if groundness_token_appear_id > -1: | ||
grd_score_dict = {} | ||
for token in _GROUND_TOKENS: | ||
prob = pred_log_probs_dict[groundness_token_appear_id][token] | ||
grd_score_dict[token] = np.exp(float(prob)) | ||
isSup_score = ( | ||
grd_score_dict["[Fully supported]"] | ||
+ 0.5 * grd_score_dict["[Partially supported]"] | ||
) / np.sum(list(grd_score_dict.values())) | ||
return isSup_score | ||
|
||
|
||
def _is_useful_score( | ||
pred_tokens: List[int], pred_log_probs_dict: List[Dict[str, float]] | ||
) -> float: | ||
"""Compute usefulness score | ||
Args: | ||
pred_tokens (List[int]): List of predicted tokens | ||
pred_log_probs_dict (List[Dict[str, float]]): log probabilities of tokens for each predicted tokens | ||
Returns: | ||
float: relevance score | ||
""" | ||
isUse_score = 0 | ||
utility_token_appear_id = -1 | ||
for tok_idx, tok in enumerate(pred_tokens): | ||
if tok in _UTILITY_TOKENS: | ||
utility_token_appear_id = tok_idx | ||
if utility_token_appear_id > -1: | ||
ut_score_dict = {} | ||
for token in _UTILITY_TOKENS: | ||
prob = pred_log_probs_dict[utility_token_appear_id][token] | ||
ut_score_dict[token] = np.exp(float(prob)) | ||
|
||
ut_sum = np.sum(list(ut_score_dict.values())) | ||
ut_weights = [-1, -0.5, 0, 0.5, 1] | ||
isUse_score = np.sum( | ||
[ | ||
ut_weights[i] * (ut_score_dict["[Utility:{}]".format(i + 1)] / ut_sum) | ||
for i in range(len(ut_weights)) | ||
] | ||
) | ||
return isUse_score | ||
|
||
|
||
class SelfRAGQueryEngine(CustomQueryEngine): | ||
"""Simple short form self RAG query engine.""" | ||
|
||
llm: Any = Field(default=None, description="llm") | ||
retriever: BaseRetriever = Field(default=None, description="retriever") | ||
generate_kwargs: Dict = Field(default=None, description="llm generation arguments") | ||
verbose: bool = Field(default=True, description="Verbose.") | ||
|
||
def __init__( | ||
self, | ||
model_path: str, | ||
retriever: BaseRetriever, | ||
verbose: bool = False, | ||
model_kwargs: Dict = None, | ||
generate_kwargs: Dict = None, | ||
**kwargs: Any, | ||
) -> None: | ||
"""Init params.""" | ||
super().__init__(verbose=verbose, **kwargs) | ||
model_kwargs = model_kwargs or _MODEL_KWARGS | ||
self.generate_kwargs = generate_kwargs or _GENERATE_KWARGS | ||
try: | ||
from llama_cpp import Llama # noqa: F401 | ||
except ImportError: | ||
raise ImportError(_IMPORT_ERROR_MSG) | ||
self.llm = Llama(model_path=model_path, verbose=verbose, **model_kwargs) | ||
self.retriever = retriever | ||
|
||
def _run_critic(self, paragraphs: List[str]) -> CriticOutput: | ||
"""Run Critic component, the llm will generate responses based on the paragraphs and then evaluate them | ||
Args: | ||
paragraphs (List[str]): List of paragraphs to evaluate | ||
Returns: | ||
CriticOutput: Paragraphs final score, LLM predictions and source nodes | ||
""" | ||
paragraphs_final_score = {} | ||
llm_response_text = {} | ||
source_nodes = [] | ||
|
||
for p_idx, paragraph in enumerate(paragraphs): | ||
pred = self.llm(paragraph, **self.generate_kwargs) | ||
# Cache llm answer | ||
llm_response_text[p_idx] = pred["choices"][0]["text"] | ||
|
||
logprobs = pred["choices"][0]["logprobs"] | ||
pred_log_probs = logprobs["top_logprobs"] | ||
# Compute isRel score, on the first predicted token | ||
isRel_score = _relevance_score(pred_log_probs[0]) | ||
|
||
# Compute isSup score | ||
isSup_score = _is_supported_score(logprobs["tokens"], pred_log_probs) | ||
|
||
# Compute isUse score | ||
isUse_score = _is_useful_score(logprobs["tokens"], pred_log_probs) | ||
|
||
paragraphs_final_score[p_idx] = ( | ||
isRel_score + isSup_score + 0.5 * isUse_score | ||
) | ||
# Add the paragraph as source node with its relevance score | ||
source_nodes.append( | ||
NodeWithScore( | ||
node=TextNode(text=paragraph, id_=p_idx), | ||
score=isRel_score, | ||
) | ||
) | ||
|
||
if self.verbose: | ||
print_text( | ||
f"Input: {paragraph}\nPrediction: {llm_response_text[p_idx]}\nScore: {paragraphs_final_score[p_idx]}\n", | ||
color="blue", | ||
) | ||
print_text( | ||
f"{p_idx + 1}/{len(paragraphs)} paragraphs done\n\n", color="blue" | ||
) | ||
|
||
return CriticOutput(llm_response_text, paragraphs_final_score, source_nodes) | ||
|
||
def custom_query(self, query_str: str) -> Response: | ||
"""Run self-RAG.""" | ||
response = self.llm(prompt=_format_prompt(query_str), **_GENERATE_KWARGS) | ||
answer = response["choices"][0]["text"] | ||
source_nodes = [] | ||
|
||
if "[Retrieval]" in answer: | ||
if self.verbose: | ||
print_text("Retreival required\n", color="blue") | ||
documents = self.retriever.retrieve(query_str) | ||
if self.verbose: | ||
print_text(f"Received: {len(documents)} documents\n", color="blue") | ||
paragraphs = [ | ||
_format_prompt(query_str, document.node.text) for document in documents | ||
] | ||
|
||
if self.verbose: | ||
print_text("Start evaluation\n", color="blue") | ||
|
||
critic_output = self._run_critic(paragraphs) | ||
|
||
paragraphs_final_score = critic_output.paragraphs_final_score | ||
llm_response_per_paragraph = critic_output.llm_response_per_paragraph | ||
source_nodes = critic_output.source_nodes | ||
|
||
if self.verbose: | ||
print_text("End evaluation\n", color="blue") | ||
|
||
best_paragraph_id = max( | ||
paragraphs_final_score, key=paragraphs_final_score.get | ||
) | ||
answer = llm_response_per_paragraph[best_paragraph_id] | ||
if self.verbose: | ||
print_text(f"Selected the best answer: {answer}\n", color="blue") | ||
|
||
answer = _postprocess_answer(answer) | ||
if self.verbose: | ||
print_text(f"Final answer: {answer}\n", color="green") | ||
return Response(response=str(answer), source_nodes=source_nodes) | ||
|
||
|
||
class SelfRAGPack(BaseLlamaPack): | ||
"""Simple short form Self-RAG pack.""" | ||
|
||
def __init__( | ||
self, | ||
model_path: str, | ||
retriever: BaseRetriever, | ||
verbose: bool = False, | ||
**kwargs: Any, | ||
) -> None: | ||
"""Init params.""" | ||
|
||
self.query_engine = SelfRAGQueryEngine(model_path, retriever, verbose) | ||
|
||
def get_modules(self) -> Dict[str, Any]: | ||
"""Get modules.""" | ||
return { | ||
"query_engine": self.query_engine, | ||
"llm": self.query_engine.llm, | ||
"retriever": self.query_engine.retriever, | ||
} | ||
|
||
def run(self, *args: Any, **kwargs: Any) -> Any: | ||
"""Run the pipeline.""" | ||
return self.query_engine.query(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
llama_cpp_python |
Oops, something went wrong.