Skip to content

Commit

Permalink
Merge pull request #32 from qhjqhj00/tommy-dev-lite
Browse files Browse the repository at this point in the history
Tommy dev lite
  • Loading branch information
qhjqhj00 authored Sep 24, 2024
2 parents 3a9501f + 8168655 commit 6411dd3
Show file tree
Hide file tree
Showing 6 changed files with 680 additions and 12 deletions.
21 changes: 20 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ Afterwards, you can view the demo as below:
</div>

## :page_with_curl: Changelog
[21/09/24] MemoRAG introduces Lite mode, enabling memory-augmented RAG processing for millions of tokens with just a few lines of code. For more details, refer to the [`examples`](https://github.com/qhjqhj00/MemoRAG/blob/tommy-dev-lite/examples/memorag_lite.ipynb) notebook.

[13/09/24] MemoRAG adds `Meta-Llama-3.1-8B-Instruct` and `Llama3.1-8B-Chinese-Chat` as the Memory Model, see [`examples`](https://github.com/qhjqhj00/MemoRAG/blob/main/examples/longllm_as_memory.ipynb).

[10/09/24] We release MemoRAG's [`Technical Report`](https://arxiv.org/pdf/2409.05591).
Expand Down Expand Up @@ -115,9 +117,26 @@ We provide a notebook to illustrate all functions of MemoRAG [here](https://gith

## :notebook: Usage

MemoRAG is easy to use and can be initialized with HuggingFace models directly. By using the `MemoRAG.memorize()` method, the memory model builds a global memory over a long input context. Empirically, with default parameter settings, `TommyChien/memorag-qwen2-7b-inst` can handle contexts of up to 400K tokens, while `TommyChien/memorag-mistral-7b-inst` can manage contexts up to 128K tokens. By increasing the `beacon_ratio` parameter, the model’s capacity to handle longer contexts can be extended. For example, `TommyChien/memorag-qwen2-7b-inst` can process up to one million tokens with `beacon_ratio=16`.
### Lite Mode of MemoRAG :new: :new: :new:
We introduce the Lite Mode of MemoRAG, designed to offer a quick and user-friendly experience with the MemoRAG pipeline. With just a few lines of code, you can easily try MemoRAG. While we recommend starting with a GPU that has 24GiB of memory, a 16GiB GPU can also handle the pipeline under default settings in most cases.


```python
from memorag import MemoRAGLite
pipe = MemoRAGLite()
context = open("examples/harry_potter.txt").read()
pipe.memorize(context, save_dir="harry_potter", print_stats=True)

query = "What's the book's main theme?"
print(pipe(query))
```
MemoRAG Lite is simple to use, supporting English or Chinese contexts of up to millions of tokens. Although it may work with other languages, performance could degrade since the default prompts are in English. For more details on MemoRAG Lite, please refer to the [`examples`](https://github.com/qhjqhj00/MemoRAG/blob/tommy-dev-lite/examples/memorag_lite.ipynb) notebook."



### Basic Usage of MemoRAG
MemoRAG is easy to use and can be initialized with HuggingFace models directly. By using the `MemoRAG.memorize()` method, the memory model builds a global memory over a long input context. Empirically, with default parameter settings, `TommyChien/memorag-qwen2-7b-inst` can handle contexts of up to 400K tokens, while `TommyChien/memorag-mistral-7b-inst` can manage contexts up to 128K tokens. By increasing the `beacon_ratio` parameter, the model’s capacity to handle longer contexts can be extended. For example, `TommyChien/memorag-qwen2-7b-inst` can process up to one million tokens with `beacon_ratio=16`.


```python
from memorag import MemoRAG
Expand Down
266 changes: 266 additions & 0 deletions examples/memorag_lite.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "3bc35793-f2b8-4809-b682-61165ba10454",
"metadata": {},
"source": [
"## Initialize the MemoRAGLite"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "29e565b0-b0ef-4ab1-8d06-2c00438b7bd6",
"metadata": {
"ExecutionIndicator": {
"show": false
},
"execution": {
"iopub.execute_input": "2024-09-21T12:16:19.630888Z",
"iopub.status.busy": "2024-09-21T12:16:19.630333Z",
"iopub.status.idle": "2024-09-21T12:16:23.550291Z",
"shell.execute_reply": "2024-09-21T12:16:23.549592Z",
"shell.execute_reply.started": "2024-09-21T12:16:19.630866Z"
},
"tags": []
},
"outputs": [],
"source": [
"from memorag import MemoRAGLite\n",
"pipe = MemoRAGLite()\n",
"\n",
"## All args of MemoRAGLite are:\n",
"\n",
"# gen_model_name_or_path: str=\"Qwen/Qwen2.5-1.5B-Instruct\",\n",
"# ret_model_name_or_path: str=\"BAAI/bge-m3\",\n",
"# customized_gen_model=None,\n",
"# ret_hit: int = 3,\n",
"# retrieval_chunk_size: int = 512,\n",
"# cache_dir: Optional[str] = None,\n",
"# access_token: Optional[str] = None,\n",
"# load_in_4bit: bool = False,\n",
"# enable_flash_attn: bool = True"
]
},
{
"cell_type": "markdown",
"id": "0d734fb7-f58a-4f7c-b9cd-9622f4ccbfce",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"execution": {
"iopub.execute_input": "2024-09-21T11:34:29.574781Z",
"iopub.status.busy": "2024-09-21T11:34:29.574631Z",
"iopub.status.idle": "2024-09-21T11:34:33.560111Z",
"shell.execute_reply": "2024-09-21T11:34:33.559484Z",
"shell.execute_reply.started": "2024-09-21T11:34:29.574765Z"
},
"tags": []
},
"source": [
"### By default, MemoRAGLite automatically detects languages, available resources and perform memroy formation."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "639f4475-9e30-4ce2-976d-60f769490366",
"metadata": {
"ExecutionIndicator": {
"show": false
},
"execution": {
"iopub.execute_input": "2024-09-21T12:02:35.498041Z",
"iopub.status.busy": "2024-09-21T12:02:35.496837Z",
"iopub.status.idle": "2024-09-21T12:05:46.764405Z",
"shell.execute_reply": "2024-09-21T12:05:46.763824Z",
"shell.execute_reply.started": "2024-09-21T12:02:35.498016Z"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Detected language: en\n",
"Context length: 122591 tokens\n",
"Forming memory of the context...\n",
"Progress: 25.64% of the context memorized...\n",
"Progress: 51.28% of the context memorized...\n",
"Progress: 76.92% of the context memorized...\n",
"Context memorization completed successfully.\n",
"Dense retrieval index has been built.\n",
"Memory file size: 0.29 GB\n",
"Number of chunks in retrieval corpus: 268\n"
]
}
],
"source": [
"context = open(\"harry_potter.txt\").read()\n",
"pipe.memorize(context, save_dir=\"harry_potter\", print_stats=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6f94c6fe-2a38-4e07-9144-d7e41e8711b4",
"metadata": {
"execution": {
"iopub.execute_input": "2024-09-21T12:06:58.765444Z",
"iopub.status.busy": "2024-09-21T12:06:58.764740Z",
"iopub.status.idle": "2024-09-21T12:06:59.126161Z",
"shell.execute_reply": "2024-09-21T12:06:59.125567Z",
"shell.execute_reply.started": "2024-09-21T12:06:58.765422Z"
},
"tags": []
},
"outputs": [],
"source": [
"## You can also load a pre-cached memory from a directory\n",
"pipe.load(\"harry_potter\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "18b9fe71-57ec-4471-9445-fbfe5e383a79",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"execution": {
"iopub.execute_input": "2024-09-21T12:07:50.600705Z",
"iopub.status.busy": "2024-09-21T12:07:50.600256Z",
"iopub.status.idle": "2024-09-21T12:08:03.414794Z",
"shell.execute_reply": "2024-09-21T12:08:03.414207Z",
"shell.execute_reply.started": "2024-09-21T12:07:50.600684Z"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"**Book:** *Harry Potter and the Chamber of Secrets*\n",
" \n",
"**Main Theme:** The story revolves around Harry discovering the mysterious Chamber of Secrets hidden within Hogwarts Castle, leading to conflicts involving students, staff members, and supernatural entities such as ghosts and creatures. The plot centers around Harry’s struggle to protect others from the Basilisk and confronts threats posed by various characters including Tom Riddle (Lord Voldemort). Additionally, themes encompass friendship, loyalty among peers, personal growth, and overcoming obstacles—highlighting Harry's journey towards maturity and understanding the complexities of the wizarding world.\n"
]
}
],
"source": [
"query = \"What's the book's main theme?\"\n",
"print(pipe(query))"
]
},
{
"cell_type": "markdown",
"id": "85bdafb1-88c9-4a6a-ab82-95ffc7057307",
"metadata": {},
"source": [
"### Chinese Case"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "9061166a-a36d-4231-8455-e0a09320fd49",
"metadata": {
"execution": {
"iopub.execute_input": "2024-09-21T12:20:02.777508Z",
"iopub.status.busy": "2024-09-21T12:20:02.776711Z",
"iopub.status.idle": "2024-09-21T12:22:42.862054Z",
"shell.execute_reply": "2024-09-21T12:22:42.860595Z",
"shell.execute_reply.started": "2024-09-21T12:20:02.777484Z"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Detected language: zh-cn\n",
"Context length: 284878 tokens\n",
"Forming memory of the context...\n",
"Progress: 20.25% of the context memorized...\n",
"Progress: 40.51% of the context memorized...\n",
"Progress: 60.76% of the context memorized...\n",
"Progress: 81.01% of the context memorized...\n",
"Context memorization completed successfully.\n",
"Dense retrieval index has been built.\n",
"Memory file size: 0.35 GB\n",
"Number of chunks in retrieval corpus: 699\n"
]
}
],
"source": [
"context = open(\"fortress_besieged.txt\").read()\n",
"pipe.memorize(context, save_dir=\"fortress_besieged\", print_stats=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "8de250c2-d36e-4bf8-9eb9-4e4f2aad8ff8",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"execution": {
"iopub.execute_input": "2024-09-21T12:19:38.413060Z",
"iopub.status.busy": "2024-09-21T12:19:38.412232Z",
"iopub.status.idle": "2024-09-21T12:19:50.839744Z",
"shell.execute_reply": "2024-09-21T12:19:50.839029Z",
"shell.execute_reply.started": "2024-09-21T12:19:38.413037Z"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"根据提供的信息,小说的故事并没有明确给出最终的结局。然而可以推测,在那个夜晚之后,由于夫妻之间的争吵以及对方的冷漠态度,他们之间的关系变得非常紧张,并且可能已经破裂。同时,文中提到\"人生的惨淡滋味\",暗示生活在这个阶段可能会更加艰难。最后,通过描述主人公在深夜中的沉思与疲惫状态,可以看出尽管经历了许多情感起伏,但他们仍然处于一种相对消极的生活状态下。因此,总体来看,这是一个充满矛盾和挫折感的情节发展过程。\n"
]
}
],
"source": [
"query = \"故事的结局是什么?\"\n",
"print(pipe(query))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "feb431bf-3db8-4e88-9cc3-ee516c469c23",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
3 changes: 2 additions & 1 deletion memorag/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .memorag import Memory, MemoRAG
from .memorag import Memory, MemoRAG, Model
from .memorag_lite import MemoRAGLite
from .agent import Agent
14 changes: 8 additions & 6 deletions memorag/memorag.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def generate(
temperature: float = None,
top_p: float = None,
do_sample: bool = False,
repetition_penalty:float=1.0
) -> Union[str, List[str]]:

if isinstance(prompts, str):
Expand All @@ -167,6 +168,7 @@ def generate(
"do_sample": do_sample,
"temperature": temperature,
"top_p": top_p,
"repetition_penalty": repetition_penalty
}

all_outputs = []
Expand Down Expand Up @@ -387,9 +389,9 @@ def load(self, save_dir: str, print_stats: bool = False):

def __call__(
self,
context: str,
query: str = None,
task_type: str = "rag",
context: str = None,
task_type: str = "memorag",
prompt_template: str = None,
max_new_tokens: int = 256,
reset_each_call: bool = False,
Expand All @@ -402,6 +404,8 @@ def __call__(
self.retriever.remove_all()

if not self.mem_model.memory:
if not context:
raise ValueError("Please provide your input context...")
self.memorize(context)

if task_type == 'qa':
Expand Down Expand Up @@ -462,12 +466,10 @@ def _generate_response(self, task_key: str, query: str, knowledge: str, prompt_t
prompt = self.prompts[task_key].format(input=query, context=knowledge) if query else self.prompts[task_key].format(context=knowledge)

if self.gen_model.__class__.__name__ == "Memory" and self.mem_model.memo_type == "beacon":
self.gen_model._enable_beacon = False
# self.gen_model._enable_beacon = False
output = self.gen_model.generate(prompt, max_new_tokens=max_new_tokens)[0]
self.gen_model._enable_beacon = True
# self.gen_model._enable_beacon = True
else:
output = self.gen_model.generate(prompt, max_new_tokens=max_new_tokens, with_cache=False)[0]
torch.cuda.empty_cache()
return output


Loading

0 comments on commit 6411dd3

Please sign in to comment.