diff --git a/docs/benchmark-dspy-entity-extraction.md b/docs/benchmark-dspy-entity-extraction.md new file mode 100644 index 0000000..e98f2f2 --- /dev/null +++ b/docs/benchmark-dspy-entity-extraction.md @@ -0,0 +1,267 @@ +# Main Takeaways +- Time difference: 156.99 seconds +- Execution time with DSPy-AI: 304.38 seconds +- Execution time without DSPy-AI: 147.39 seconds +- Entities extracted: 22 (without DSPy-AI) vs 37 (with DSPy-AI) +- Relationships extracted: 21 (without DSPy-AI) vs 36 (with DSPy-AI) + + +# Results +```markdown +> python examples/benchmarks/dspy_entity.py + +Running benchmark with DSPy-AI: +INFO:httpx:HTTP Request: POST https://api.deepseek.com/chat/completions "HTTP/1.1 200 OK" +INFO:httpx:HTTP Request: POST https://api.deepseek.com/chat/completions "HTTP/1.1 200 OK" +INFO:httpx:HTTP Request: POST https://api.deepseek.com/chat/completions "HTTP/1.1 200 OK" +DEBUG:nano-graphrag:Entities: 14 | Missed Entities: 23 | Total Entities: 37 +DEBUG:nano-graphrag:Relationships: 13 | Missed Relationships: 23 | Total Relationships: 36 +DEBUG:nano-graphrag:Direct Relationships: 31 | Second-order: 5 | Third-order: 0 | Total Relationships: 36 +⠙ Processed 1 chunks, 37 entities(duplicated), 36 relations(duplicated) +Execution time with DSPy-AI: 304.38 seconds + +Entities: +- 朱元璋 (PERSON): + 明朝开国皇帝,原名朱重八,后改名朱元璋。他出身贫农,经历了从放牛娃到皇帝的传奇人生。在元朝末年,他参加了红巾军起义,最终推翻元朝,建立了明朝。 +- 朱五四 (PERSON): + 朱元璋的父亲,农民出身,家境贫寒。他在朱元璋幼年时去世,对朱元璋的成长和人生选择产生了深远影响。 +- 陈氏 (PERSON): + 朱元璋的母亲,农民出身,家境贫寒。她在朱元璋幼年时去世,对朱元璋的成长和人生选择产生了深远影响。 +- 汤和 (PERSON): + 朱元璋的幼年朋友,后来成为朱元璋起义军中的重要将领。他在朱元璋早期的发展中起到了关键作用。 +- 郭子兴 (PERSON): + 红巾军起义的领导人之一,朱元璋的岳父。他在朱元璋早期的发展中起到了重要作用,但后来与朱元璋产生了矛盾。 +- 马姑娘 (PERSON): + 郭子兴的义女,朱元璋的妻子。她在朱元璋最困难的时候给予了极大的支持,是朱元璋成功的重要因素之一。 +- 元朝 (ORGANIZATION): + 中国历史上的一个朝代,由蒙古族建立。元朝末年,社会矛盾激化,最终导致了红巾军起义和明朝的建立。 +- 红巾军 (ORGANIZATION): + 元朝末年起义军的一支,主要由农民组成。朱元璋最初加入的就是红巾军,并在其中逐渐崭露头角。 +- 皇觉寺 (LOCATION): + 朱元璋早年出家的地方,位于安徽凤阳。他在寺庙中度过了几年的时光,这段经历对他的人生观和价值观产生了深远影响。 +- 濠州 (LOCATION): + 朱元璋早期活动的重要地点,也是红巾军的重要据点之一。朱元璋在这里经历了许多重要事件,包括与郭子兴的矛盾和最终的离开。 +- 1328年 (DATE): + 朱元璋出生的年份。这一年标志着明朝开国皇帝传奇人生的开始。 +- 1344年 (DATE): + 朱元璋家庭遭遇重大变故的年份,他的父母在这一年相继去世。这一事件对朱元璋的人生选择产生了深远影响。 +- 1352年 (DATE): + 朱元璋正式加入红巾军起义的年份。这一年标志着朱元璋从农民到起义军领袖的转变。 +- 1368年 (DATE): + 朱元璋推翻元朝,建立明朝的年份。这一年标志着朱元璋从起义军领袖到皇帝的转变。 +- 朱百六 (PERSON): + 朱元璋的高祖,名字具有元朝时期老百姓命名的特点,即以数字命名。 +- 朱四九 (PERSON): + 朱元璋的曾祖,名字同样具有元朝时期老百姓命名的特点,即以数字命名。 +- 朱初一 (PERSON): + 朱元璋的祖父,名字具有元朝时期老百姓命名的特点,即以数字命名。 +- 刘德 (PERSON): + 朱元璋早年为其放牛的地主,对朱元璋的童年生活有重要影响。 +- 韩山童 (PERSON): + 红巾军起义的早期领导人之一,与刘福通共同起义,对朱元璋的起义选择有间接影响。 +- 刘福通 (PERSON): + 红巾军起义的早期领导人之一,与韩山童共同起义,对朱元璋的起义选择有间接影响。 +- 脱脱 (PERSON): + 元朝末年的著名宰相,主张治理黄河,但他的政策间接导致了红巾军起义的爆发。 +- 元顺帝 (PERSON): + 元朝末代皇帝,他在位期间元朝社会矛盾激化,最终导致了红巾军起义和明朝的建立。 +- 孙德崖 (PERSON): + 红巾军起义的领导人之一,与郭子兴有矛盾,曾绑架郭子兴,对朱元璋的早期发展有重要影响。 +- 周德兴 (PERSON): + 朱元璋的早期朋友,曾为朱元璋算卦,对朱元璋的人生选择有一定影响。 +- 徐达 (PERSON): + 朱元璋早期的重要将领,后来成为明朝的开国功臣之一。 +- 明教 (RELIGION): + 朱元璋在起义过程中接触到的宗教信仰,对他的思想和行动有一定影响。 +- 弥勒佛 (RELIGION): + 明教中的重要神祇,朱元璋相信弥勒佛会降世,对他的信仰和行动有一定影响。 +- 颖州 (LOCATION): + 朱元璋早年讨饭的地方,也是红巾军起义的重要地点之一。 +- 定远 (LOCATION): + 朱元璋早期攻打的地点之一,是他军事生涯的起点。 +- 怀远 (LOCATION): + 朱元璋早期攻打的地点之一,是他军事生涯的起点。 +- 安奉 (LOCATION): + 朱元璋早期攻打的地点之一,是他军事生涯的起点。 +- 含山 (LOCATION): + 朱元璋早期攻打的地点之一,是他军事生涯的起点。 +- 虹县 (LOCATION): + 朱元璋早期攻打的地点之一,是他军事生涯的起点。 +- 钟离 (LOCATION): + 朱元璋的家乡,他在此地召集了二十四位重要将领。 +- 黄河 (LOCATION): + 元朝末年黄河泛滥,导致了严重的社会问题,间接引发了红巾军起义。 +- 淮河 (LOCATION): + 元朝末年淮河沿岸遭遇严重瘟疫和旱灾,加剧了社会矛盾。 +- 1351年 (DATE): + 红巾军起义爆发的年份,对朱元璋的人生选择产生了重要影响。 + +Relationships: +- 朱元璋 -> 朱五四: + 朱元璋是朱五四的儿子,朱五四的去世对朱元璋的成长和人生选择产生了深远影响。 +- 朱元璋 -> 陈氏: + 朱元璋是陈氏的儿子,陈氏的去世对朱元璋的成长和人生选择产生了深远影响。 +- 朱元璋 -> 汤和: + 汤和是朱元璋的幼年朋友,后来成为朱元璋起义军中的重要将领,对朱元璋早期的发展起到了关键作用。 +- 朱元璋 -> 郭子兴: + 郭子兴是朱元璋的岳父,也是红巾军起义的领导人之一。他在朱元璋早期的发展中起到了重要作用,但后来与朱元璋产生了矛盾。 +- 朱元璋 -> 马姑娘: + 马姑娘是朱元璋的妻子,她在朱元璋最困难的时候给予了极大的支持,是朱元璋成功的重要因素之一。 +- 朱元璋 -> 元朝: + 朱元璋在元朝末年参加了红巾军起义,最终推翻了元朝,建立了明朝。 +- 朱元璋 -> 红巾军: + 朱元璋最初加入的是红巾军,并在其中逐渐崭露头角,最终成为起义军的重要领导人。 +- 朱元璋 -> 皇觉寺: + 朱元璋早年出家的地方是皇觉寺,这段经历对他的人生观和价值观产生了深远影响。 +- 朱元璋 -> 濠州: + 濠州是朱元璋早期活动的重要地点,也是红巾军的重要据点之一。朱元璋在这里经历了许多重要事件,包括与郭子兴的矛盾和最终的离开。 +- 朱元璋 -> 1328年: + 1328年是朱元璋出生的年份,这一年标志着明朝开国皇帝传奇人生的开始。 +- 朱元璋 -> 1344年: + 1344年是朱元璋家庭遭遇重大变故的年份,他的父母在这一年相继去世,这一事件对朱元璋的人生选择产生了深远影响。 +- 朱元璋 -> 1352年: + 1352年是朱元璋正式加入红巾军起义的年份,这一年标志着朱元璋从农民到起义军领袖的转变。 +- 朱元璋 -> 1368年: + 1368年是朱元璋推翻元朝,建立明朝的年份,这一年标志着朱元璋从起义军领袖到皇帝的转变。 +- 朱元璋 -> 朱百六: + 朱百六是朱元璋的高祖,对朱元璋的家族背景有重要影响。 +- 朱元璋 -> 朱四九: + 朱四九是朱元璋的曾祖,对朱元璋的家族背景有重要影响。 +- 朱元璋 -> 朱初一: + 朱初一是朱元璋的祖父,对朱元璋的家族背景有重要影响。 +- 朱元璋 -> 刘德: + 刘德是朱元璋早年为其放牛的地主,对朱元璋的童年生活有重要影响。 +- 朱元璋 -> 韩山童: + 韩山童是红巾军起义的早期领导人之一,对朱元璋的起义选择有间接影响。 +- 朱元璋 -> 刘福通: + 刘福通是红巾军起义的早期领导人之一,对朱元璋的起义选择有间接影响。 +- 朱元璋 -> 脱脱: + 脱脱是元朝末年的著名宰相,他的政策间接导致了红巾军起义的爆发,对朱元璋的起义选择有间接影响。 +- 朱元璋 -> 元顺帝: + 元顺帝是元朝末代皇帝,他在位期间社会矛盾激化,最终导致了红巾军起义和明朝的建立,对朱元璋的起义选择有重要影响。 +- 朱元璋 -> 孙德崖: + 孙德崖是红巾军起义的领导人之一,与郭子兴有矛盾,曾绑架郭子兴,对朱元璋的早期发展有重要影响。 +- 朱元璋 -> 周德兴: + 周德兴是朱元璋的早期朋友,曾为朱元璋算卦,对朱元璋的人生选择有一定影响。 +- 朱元璋 -> 徐达: + 徐达是朱元璋早期的重要将领,后来成为明朝的开国功臣之一,对朱元璋的军事生涯有重要影响。 +- 朱元璋 -> 明教: + 朱元璋在起义过程中接触到的宗教信仰,对他的思想和行动有一定影响。 +- 朱元璋 -> 弥勒佛: + 朱元璋相信弥勒佛会降世,对他的信仰和行动有一定影响。 +- 朱元璋 -> 颖州: + 颖州是朱元璋早年讨饭的地方,也是红巾军起义的重要地点之一,对朱元璋的早期生活有重要影响。 +- 朱元璋 -> 定远: + 定远是朱元璋早期攻打的地点之一,是他军事生涯的起点,对朱元璋的军事发展有重要影响。 +- 朱元璋 -> 怀远: + 怀远是朱元璋早期攻打的地点之一,是他军事生涯的起点,对朱元璋的军事发展有重要影响。 +- 朱元璋 -> 安奉: + 安奉是朱元璋早期攻打的地点之一,是他军事生涯的起点,对朱元璋的军事发展有重要影响。 +- 朱元璋 -> 含山: + 含山是朱元璋早期攻打的地点之一,是他军事生涯的起点,对朱元璋的军事发展有重要影响。 +- 朱元璋 -> 虹县: + 虹县是朱元璋早期攻打的地点之一,是他军事生涯的起点,对朱元璋的军事发展有重要影响。 +- 朱元璋 -> 钟离: + 钟离是朱元璋的家乡,他在此地召集了二十四位重要将领,对朱元璋的军事发展有重要影响。 +- 朱元璋 -> 黄河: + 元朝末年黄河泛滥,导致了严重的社会问题,间接引发了红巾军起义,对朱元璋的起义选择有重要影响。 +- 朱元璋 -> 淮河: + 元朝末年淮河沿岸遭遇严重瘟疫和旱灾,加剧了社会矛盾,对朱元璋的起义选择有重要影响。 +- 朱元璋 -> 1351年: + 1351年是红巾军起义爆发的年份,对朱元璋的人生选择产生了重要影响。 +Running benchmark without DSPy-AI: +INFO:httpx:HTTP Request: POST https://api.deepseek.com/chat/completions "HTTP/1.1 200 OK" +INFO:httpx:HTTP Request: POST https://api.deepseek.com/chat/completions "HTTP/1.1 200 OK" +⠙ Processed 1 chunks, 22 entities(duplicated), 21 relations(duplicated) +Execution time without DSPy-AI: 147.39 seconds + +Entities: +- "朱元璋" ("PERSON"): + "朱元璋,原名朱重八,后改名朱元璋,是明朝的开国皇帝。他出身贫农,经历了从放牛娃到和尚,再到起义军领袖,最终成为皇帝的传奇人生。" +- "朱五四" ("PERSON"): + "朱五四,朱元璋的父亲,是一个农民,为地主种地,家境贫寒。" +- "陈氏" ("PERSON"): + "陈氏,朱元璋的母亲,是一个农民,与丈夫朱五四一起辛勤劳作,家境贫寒。" +- "汤和" ("PERSON"): + "汤和,朱元璋的幼年朋友,后来成为朱元璋的战友,在朱元璋的崛起过程中起到了重要作用。" +- "郭子兴" ("PERSON"): + "郭子兴,濠州城的守卫者,是朱元璋的岳父,也是朱元璋早期的重要支持者。" +- "韩山童" ("PERSON"): + "韩山童,与刘福通一起起义反抗元朝统治,是元末农民起义的重要领袖之一。""韩山童,元末农民起义的领袖之一,自称宋朝皇室后裔,与刘福通一起起义。" +- "刘福通" ("PERSON"): + "刘福通,与韩山童一起起义反抗元朝统治,是元末农民起义的重要领袖之一。""刘福通,元末农民起义的领袖之一,自称刘光世大将的后人,与韩山童一起起义。" +- "元朝" ("ORGANIZATION"): + "元朝,由蒙古族建立的王朝,统治中国时期实行了严格的等级制度,导致社会矛盾激化,最终被朱元璋领导的起义军推翻。" +- "皇觉寺" ("ORGANIZATION"): + "皇觉寺,朱元璋曾经在此当和尚,从事杂役工作,后来因饥荒严重,和尚们都被派出去化缘。" +- "白莲教" ("ORGANIZATION"): + "白莲教,元末农民起义中的一种宗教组织,韩山童和刘福通起义时利用了这一宗教信仰。" +- "濠州城" ("GEO"): + "濠州城,位于今安徽省,是朱元璋早期活动的重要地点,也是郭子兴的驻地。" +- "定远" ("GEO"): + "定远,朱元璋奉命攻击的地方,成功攻克后在元军回援前撤出,显示了其军事才能。" +- "钟离" ("GEO"): + "钟离,朱元璋的家乡,他在此招收了二十四名壮丁,这些人后来成为明朝的高级干部。" +- "元末农民起义" ("EVENT"): + "元末农民起义,是元朝末年由韩山童、刘福通等人领导的反抗元朝统治的大规模起义,最终导致了元朝的灭亡。" +- "马姑娘" ("PERSON"): + "马姑娘,郭子兴的义女,后来成为朱元璋的妻子,在朱元璋被关押时,她冒着危险送饭给朱元璋,表现出深厚的感情。" +- "孙德崖" ("PERSON"): + "孙德崖,与郭子兴有矛盾的起义军领袖之一,曾参与绑架郭子兴。" +- "徐达" ("PERSON"): + "徐达,朱元璋的二十四名亲信之一,后来成为明朝的重要将领。" +- "周德兴" ("PERSON"): + "周德兴,朱元璋的二十四名亲信之一,曾为朱元璋算过命。" +- "脱脱" ("PERSON"): + "脱脱,元朝的著名宰相,主张治理黄河,但他的政策间接导致了元朝的灭亡。" +- "元顺帝" ("PERSON"): + "元顺帝,元朝的最后一位皇帝,统治时期元朝社会矛盾激化,最终导致了元朝的灭亡。" +- "刘德" ("PERSON"): + "刘德,地主,朱元璋早年为其放牛。" +- "吴老太" ("PERSON"): + "吴老太,村口的媒人,朱元璋曾希望托她找一个媳妇。" + +Relationships: +- "朱元璋" -> "朱五四": + "朱元璋的父亲,对他的成长和早期生活有重要影响。" +- "朱元璋" -> "陈氏": + "朱元璋的母亲,对他的成长和早期生活有重要影响。" +- "朱元璋" -> "汤和": + "朱元璋的幼年朋友,后来成为他的战友,在朱元璋的崛起过程中起到了重要作用。" +- "朱元璋" -> "郭子兴": + "朱元璋的岳父,是他在起义军中的重要支持者。" +- "朱元璋" -> "韩山童": + "朱元璋在起义过程中与韩山童有间接联系,韩山童的起义对朱元璋的崛起有重要影响。" +- "朱元璋" -> "刘福通": + "朱元璋在起义过程中与刘福通有间接联系,刘福通的起义对朱元璋的崛起有重要影响。" +- "朱元璋" -> "元朝": + "朱元璋最终推翻了元朝的统治,建立了明朝。" +- "朱元璋" -> "皇觉寺": + "朱元璋曾经在此当和尚,这段经历对他的成长有重要影响。" +- "朱元璋" -> "白莲教": + "朱元璋在起义过程中接触到了白莲教,虽然他本人可能并不信仰,但白莲教的起义对他有重要影响。" +- "朱元璋" -> "濠州城": + "朱元璋在濠州城的活动对其早期军事和政治生涯有重要影响。" +- "朱元璋" -> "定远": + "朱元璋成功攻克定远,显示了其军事才能。" +- "朱元璋" -> "钟离": + "朱元璋的家乡,他在此招收了二十四名壮丁,这些人后来成为明朝的高级干部。" +- "朱元璋" -> "元末农民起义": + "朱元璋参与并最终领导了元末农民起义,推翻了元朝的统治。" +- "朱元璋" -> "马姑娘": + "朱元璋的妻子,在朱元璋被关押时,她冒着危险送饭给朱元璋,表现出深厚的感情。" +- "朱元璋" -> "孙德崖": + "朱元璋在孙德崖与郭子兴的矛盾中起到了调解作用,显示了其政治智慧。" +- "朱元璋" -> "徐达": + "朱元璋的二十四名亲信之一,后来成为明朝的重要将领。" +- "朱元璋" -> "周德兴": + "朱元璋的二十四名亲信之一,曾为朱元璋算过命。" +- "朱元璋" -> "脱脱": + "朱元璋在起义过程中间接受到脱脱政策的影响,脱脱的政策间接导致了元朝的灭亡。" +- "朱元璋" -> "元顺帝": + "朱元璋最终推翻了元顺帝的统治,建立了明朝。" +- "朱元璋" -> "刘德": + "朱元璋早年为刘德放牛,这段经历对他的成长有重要影响。" +- "朱元璋" -> "吴老太": + "朱元璋曾希望托吴老太找一个媳妇,显示了他对家庭的渴望。" +``` \ No newline at end of file diff --git a/examples/benchmarks/dspy_entity.py b/examples/benchmarks/dspy_entity.py new file mode 100644 index 0000000..bc3dd54 --- /dev/null +++ b/examples/benchmarks/dspy_entity.py @@ -0,0 +1,153 @@ +import dspy +import os +from dotenv import load_dotenv +from openai import AsyncOpenAI +import logging +import asyncio +import time +import shutil +from nano_graphrag.entity_extraction.extract import extract_entities_dspy +from nano_graphrag._storage import NetworkXStorage, BaseKVStorage +from nano_graphrag._utils import compute_mdhash_id, compute_args_hash +from nano_graphrag._op import extract_entities + +WORKING_DIR = "./nano_graphrag_cache_dspy_entity" + +load_dotenv() + +logger = logging.getLogger("nano-graphrag") +logger.setLevel(logging.DEBUG) + + +async def deepseepk_model_if_cache( + prompt: str, model: str = "deepseek-chat", system_prompt : str = None, history_messages: list = [], **kwargs +) -> str: + openai_async_client = AsyncOpenAI( + api_key=os.environ.get("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com" + ) + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + + # Get the cached response if having------------------- + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + if hashing_kv is not None: + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + # ----------------------------------------------------- + + response = await openai_async_client.chat.completions.create( + model=model, messages=messages, **kwargs + ) + + # Cache the response if having------------------- + if hashing_kv is not None: + await hashing_kv.upsert( + {args_hash: {"return": response.choices[0].message.content, "model": model}} + ) + # ----------------------------------------------------- + return response.choices[0].message.content + + +async def benchmark_entity_extraction(text: str, system_prompt: str, use_dspy: bool = False): + working_dir = os.path.join(WORKING_DIR, f"use_dspy={use_dspy}") + if os.path.exists(working_dir): + shutil.rmtree(working_dir) + + start_time = time.time() + graph_storage = NetworkXStorage(namespace="test", global_config={ + "working_dir": working_dir, + "entity_summary_to_max_tokens": 500, + "cheap_model_func": lambda *args, **kwargs: deepseepk_model_if_cache(*args, system_prompt=system_prompt, **kwargs), + "best_model_func": lambda *args, **kwargs: deepseepk_model_if_cache(*args, system_prompt=system_prompt, **kwargs), + "cheap_model_max_token_size": 4096, + "best_model_max_token_size": 4096, + "tiktoken_model_name": "gpt-4o", + "hashing_kv": BaseKVStorage(namespace="test", global_config={"working_dir": working_dir}), + "entity_extract_max_gleaning": 1, + "entity_extract_max_tokens": 4096, + "entity_extract_max_entities": 100, + "entity_extract_max_relationships": 100, + }) + chunks = {compute_mdhash_id(text, prefix="chunk-"): {"content": text}} + + if use_dspy: + graph_storage = await extract_entities_dspy(chunks, graph_storage, None, graph_storage.global_config) + else: + graph_storage = await extract_entities(chunks, graph_storage, None, graph_storage.global_config) + + end_time = time.time() + execution_time = end_time - start_time + + return graph_storage, execution_time + + +def print_extraction_results(graph_storage: NetworkXStorage): + print("\nEntities:") + entities = [] + for node, data in graph_storage._graph.nodes(data=True): + entity_type = data.get('entity_type', 'Unknown') + description = data.get('description', 'No description') + entities.append(f"- {node} ({entity_type}):\n {description}") + print("\n".join(entities)) + + print("\nRelationships:") + relationships = [] + for source, target, data in graph_storage._graph.edges(data=True): + description = data.get('description', 'No description') + relationships.append(f"- {source} -> {target}:\n {description}") + print("\n".join(relationships)) + + +async def run_benchmark(text: str): + print("\nRunning benchmark with DSPy-AI:") + system_prompt = """ + You are a world-class AI system, capable of complex rationale and reflection. + Reason through the query, and then provide your final response. + If you detect that you made a mistake in your rationale at any point, correct yourself. + Think carefully. + """ + system_prompt_dspy = f"{system_prompt} Time: {time.time()}." + lm = dspy.OpenAI( + model="deepseek-chat", + model_type="chat", + api_key=os.environ["DEEPSEEK_API_KEY"], + base_url=os.environ["DEEPSEEK_BASE_URL"], + system_prompt=system_prompt_dspy, + temperature=1.0, + top_p=1, + max_tokens=4096 + ) + dspy.settings.configure(lm=lm) + graph_storage_with_dspy, time_with_dspy = await benchmark_entity_extraction(text, system_prompt_dspy, use_dspy=True) + print(f"Execution time with DSPy-AI: {time_with_dspy:.2f} seconds") + print_extraction_results(graph_storage_with_dspy) + + print("Running benchmark without DSPy-AI:") + system_prompt_no_dspy = f"{system_prompt} Time: {time.time()}." + graph_storage_without_dspy, time_without_dspy = await benchmark_entity_extraction(text, system_prompt_no_dspy, use_dspy=False) + print(f"Execution time without DSPy-AI: {time_without_dspy:.2f} seconds") + print_extraction_results(graph_storage_without_dspy) + + print("\nComparison:") + print(f"Time difference: {abs(time_with_dspy - time_without_dspy):.2f} seconds") + print(f"DSPy-AI is {'faster' if time_with_dspy < time_without_dspy else 'slower'}") + + entities_without_dspy = len(graph_storage_without_dspy._graph.nodes()) + entities_with_dspy = len(graph_storage_with_dspy._graph.nodes()) + relationships_without_dspy = len(graph_storage_without_dspy._graph.edges()) + relationships_with_dspy = len(graph_storage_with_dspy._graph.edges()) + + print(f"Entities extracted: {entities_without_dspy} (without DSPy-AI) vs {entities_with_dspy} (with DSPy-AI)") + print(f"Relationships extracted: {relationships_without_dspy} (without DSPy-AI) vs {relationships_with_dspy} (with DSPy-AI)") + + +if __name__ == "__main__": + with open("./examples/data/test.txt", encoding="utf-8-sig") as f: + text = f.read() + + asyncio.run(run_benchmark(text=text)) diff --git a/examples/finetune_entity_relationship_dspy.ipynb b/examples/finetune_entity_relationship_dspy.ipynb new file mode 100644 index 0000000..f637125 --- /dev/null +++ b/examples/finetune_entity_relationship_dspy.ipynb @@ -0,0 +1,711 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "import nest_asyncio\n", + "nest_asyncio.apply()" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [], + "source": [ + "import dspy\n", + "from dspy.teleprompt.random_search import BootstrapFewShotWithRandomSearch\n", + "from dspy.teleprompt.mipro_optimizer_v2 import MIPROv2\n", + "from dspy.evaluate import Evaluate\n", + "import asyncio\n", + "import os\n", + "import numpy as np\n", + "from dotenv import load_dotenv\n", + "from datasets import load_dataset\n", + "import logging\n", + "import pickle\n", + "\n", + "from nano_graphrag._utils import compute_mdhash_id\n", + "from nano_graphrag.entity_extraction.extract import generate_dataset\n", + "from nano_graphrag.entity_extraction.module import EntityRelationshipExtractor\n", + "from nano_graphrag.entity_extraction.metric import relationship_similarity_metric, entity_recall_metric" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "WORKING_DIR = \"./nano_graphrag_cache_finetune_entity_relationship_dspy\"\n", + "\n", + "load_dotenv()\n", + "\n", + "logging.basicConfig(level=logging.WARNING)\n", + "logging.getLogger(\"nano-graphrag\").setLevel(logging.DEBUG)\n", + "\n", + "np.random.seed(1337)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "system_prompt = \"\"\"\n", + " You are a world-class AI system, capable of complex reasoning and reflection. \n", + " Reason through the query, and then provide your final response. \n", + " If you detect that you made a mistake in your reasoning at any point, correct yourself.\n", + " Think carefully.\n", + "\"\"\"\n", + "lm = dspy.OpenAI(\n", + " model=\"deepseek-chat\", \n", + " model_type=\"chat\", \n", + " api_key=os.environ[\"DEEPSEEK_API_KEY\"], \n", + " base_url=os.environ[\"DEEPSEEK_BASE_URL\"], \n", + " system_prompt=system_prompt, \n", + " temperature=1.0,\n", + " top_p=1.0,\n", + " max_tokens=4096\n", + ")\n", + "llama_lm = dspy.OllamaLocal(\n", + " model=\"llama3.1\", \n", + " model_type=\"chat\",\n", + " system=system_prompt,\n", + " max_tokens=4096\n", + ")\n", + "dspy.settings.configure(lm=lm)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/homebrew/Caskroom/miniconda/base/envs/nano-graphrag/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by promote_options='default'.\n", + " table = cls._concat_blocks(blocks, axis=0)\n" + ] + } + ], + "source": [ + "os.makedirs(WORKING_DIR, exist_ok=True)\n", + "train_len = 20\n", + "val_len = 2\n", + "dev_len = 3\n", + "entity_relationship_trainset_path = os.path.join(WORKING_DIR, \"entity_relationship_extraction_news_trainset.pkl\")\n", + "entity_relationship_valset_path = os.path.join(WORKING_DIR, \"entity_relationship_extraction_news_valset.pkl\")\n", + "entity_relationship_devset_path = os.path.join(WORKING_DIR, \"entity_relationship_extraction_news_devset.pkl\")\n", + "entity_relationship_module_path = os.path.join(WORKING_DIR, \"entity_relationship_extraction_news.json\")\n", + "fin_news = load_dataset(\"ashraq/financial-news-articles\")\n", + "cnn_news = load_dataset(\"AyoubChLin/CNN_News_Articles_2011-2022\")\n", + "fin_shuffled_indices = np.random.permutation(len(fin_news['train']))\n", + "cnn_train_shuffled_indices = np.random.permutation(len(cnn_news['train']))\n", + "cnn_test_shuffled_indices = np.random.permutation(len(cnn_news['test']))\n", + "train_data = cnn_news['train'].select(cnn_train_shuffled_indices[:train_len])\n", + "val_data = cnn_news['test'].select(cnn_test_shuffled_indices[:val_len])\n", + "dev_data = fin_news['train'].select(fin_shuffled_indices[:dev_len])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_data['text'][:2]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "val_data['text']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dev_data['text'][:2]" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "DEBUG:nano-graphrag:Entities: 17 | Missed Entities: 15 | Total Entities: 32\n", + "DEBUG:nano-graphrag:Entities: 9 | Missed Entities: 7 | Total Entities: 16\n", + "DEBUG:nano-graphrag:Entities: 27 | Missed Entities: 21 | Total Entities: 48\n", + "DEBUG:nano-graphrag:Entities: 18 | Missed Entities: 10 | Total Entities: 28\n", + "DEBUG:nano-graphrag:Entities: 9 | Missed Entities: 9 | Total Entities: 18\n", + "DEBUG:nano-graphrag:Entities: 13 | Missed Entities: 6 | Total Entities: 19\n", + "DEBUG:nano-graphrag:Entities: 14 | Missed Entities: 7 | Total Entities: 21\n", + "DEBUG:nano-graphrag:Entities: 8 | Missed Entities: 10 | Total Entities: 18\n", + "DEBUG:nano-graphrag:Entities: 28 | Missed Entities: 6 | Total Entities: 34\n", + "DEBUG:nano-graphrag:Entities: 13 | Missed Entities: 5 | Total Entities: 18\n", + "DEBUG:nano-graphrag:Entities: 15 | Missed Entities: 8 | Total Entities: 23\n", + "DEBUG:nano-graphrag:Entities: 14 | Missed Entities: 5 | Total Entities: 19\n", + "DEBUG:nano-graphrag:Entities: 21 | Missed Entities: 5 | Total Entities: 26\n", + "DEBUG:nano-graphrag:Entities: 11 | Missed Entities: 6 | Total Entities: 17\n", + "DEBUG:nano-graphrag:Entities: 16 | Missed Entities: 9 | Total Entities: 25\n", + "DEBUG:nano-graphrag:Entities: 25 | Missed Entities: 10 | Total Entities: 35\n", + "DEBUG:nano-graphrag:Relationships: 27 | Missed Relationships: 22 | Total Relationships: 49\n", + "DEBUG:nano-graphrag:Relationships: 11 | Missed Relationships: 9 | Total Relationships: 20\n", + "DEBUG:nano-graphrag:Relationships: 18 | Missed Relationships: 20 | Total Relationships: 38\n", + "DEBUG:nano-graphrag:Relationships: 15 | Missed Relationships: 7 | Total Relationships: 22\n", + "DEBUG:nano-graphrag:Relationships: 12 | Missed Relationships: 10 | Total Relationships: 22\n", + "DEBUG:nano-graphrag:Relationships: 15 | Missed Relationships: 9 | Total Relationships: 24\n", + "DEBUG:nano-graphrag:Relationships: 12 | Missed Relationships: 9 | Total Relationships: 21\n", + "DEBUG:nano-graphrag:Relationships: 7 | Missed Relationships: 8 | Total Relationships: 15\n", + "DEBUG:nano-graphrag:Relationships: 17 | Missed Relationships: 6 | Total Relationships: 23\n", + "DEBUG:nano-graphrag:Relationships: 10 | Missed Relationships: 6 | Total Relationships: 16\n", + "DEBUG:nano-graphrag:Relationships: 16 | Missed Relationships: 8 | Total Relationships: 24\n", + "DEBUG:nano-graphrag:Relationships: 15 | Missed Relationships: 5 | Total Relationships: 20\n", + "DEBUG:nano-graphrag:Relationships: 19 | Missed Relationships: 5 | Total Relationships: 24\n", + "DEBUG:nano-graphrag:Relationships: 10 | Missed Relationships: 8 | Total Relationships: 18\n", + "DEBUG:nano-graphrag:Relationships: 13 | Missed Relationships: 9 | Total Relationships: 22\n", + "DEBUG:nano-graphrag:Relationships: 22 | Missed Relationships: 10 | Total Relationships: 32\n", + "DEBUG:nano-graphrag:Direct Relationships: 44 | Second-order: 5 | Third-order: 0 | Total Relationships: 49\n", + "DEBUG:nano-graphrag:Direct Relationships: 16 | Second-order: 4 | Third-order: 0 | Total Relationships: 20\n", + "DEBUG:nano-graphrag:Direct Relationships: 38 | Second-order: 0 | Third-order: 0 | Total Relationships: 38\n", + "DEBUG:nano-graphrag:Direct Relationships: 22 | Second-order: 0 | Third-order: 0 | Total Relationships: 22\n", + "DEBUG:nano-graphrag:Direct Relationships: 22 | Second-order: 0 | Third-order: 0 | Total Relationships: 22\n", + "DEBUG:nano-graphrag:Direct Relationships: 24 | Second-order: 0 | Third-order: 0 | Total Relationships: 24\n", + "DEBUG:nano-graphrag:Direct Relationships: 21 | Second-order: 0 | Third-order: 0 | Total Relationships: 21\n", + "DEBUG:nano-graphrag:Direct Relationships: 15 | Second-order: 0 | Third-order: 0 | Total Relationships: 15\n", + "DEBUG:nano-graphrag:Direct Relationships: 23 | Second-order: 0 | Third-order: 0 | Total Relationships: 23\n", + "DEBUG:nano-graphrag:Direct Relationships: 16 | Second-order: 0 | Third-order: 0 | Total Relationships: 16\n", + "DEBUG:nano-graphrag:Direct Relationships: 24 | Second-order: 0 | Third-order: 0 | Total Relationships: 24\n", + "DEBUG:nano-graphrag:Direct Relationships: 17 | Second-order: 3 | Third-order: 0 | Total Relationships: 20\n", + "DEBUG:nano-graphrag:Direct Relationships: 24 | Second-order: 0 | Third-order: 0 | Total Relationships: 24\n", + "DEBUG:nano-graphrag:Direct Relationships: 13 | Second-order: 5 | Third-order: 0 | Total Relationships: 18\n", + "DEBUG:nano-graphrag:Direct Relationships: 22 | Second-order: 0 | Third-order: 0 | Total Relationships: 22\n", + "DEBUG:nano-graphrag:Direct Relationships: 32 | Second-order: 0 | Third-order: 0 | Total Relationships: 32\n", + "DEBUG:nano-graphrag:Entities: 10 | Missed Entities: 5 | Total Entities: 15\n", + "DEBUG:nano-graphrag:Entities: 6 | Missed Entities: 5 | Total Entities: 11\n", + "DEBUG:nano-graphrag:Entities: 18 | Missed Entities: 15 | Total Entities: 33\n", + "DEBUG:nano-graphrag:Entities: 15 | Missed Entities: 10 | Total Entities: 25\n", + "DEBUG:nano-graphrag:Relationships: 11 | Missed Relationships: 5 | Total Relationships: 16\n", + "DEBUG:nano-graphrag:Relationships: 5 | Missed Relationships: 5 | Total Relationships: 10\n", + "DEBUG:nano-graphrag:Relationships: 13 | Missed Relationships: 15 | Total Relationships: 28\n", + "DEBUG:nano-graphrag:Relationships: 16 | Missed Relationships: 10 | Total Relationships: 26\n", + "DEBUG:nano-graphrag:Direct Relationships: 11 | Second-order: 5 | Third-order: 0 | Total Relationships: 16\n", + "DEBUG:nano-graphrag:Direct Relationships: 9 | Second-order: 1 | Third-order: 0 | Total Relationships: 10\n", + "DEBUG:nano-graphrag:Direct Relationships: 28 | Second-order: 0 | Third-order: 0 | Total Relationships: 28\n", + "DEBUG:nano-graphrag:Direct Relationships: 26 | Second-order: 0 | Third-order: 0 | Total Relationships: 26\n", + "INFO:nano-graphrag:Saved 20 examples with keys: ['input_text', 'entities', 'relationships']\n" + ] + } + ], + "source": [ + "train_chunks = {compute_mdhash_id(text, prefix=f\"chunk-\"): {\"content\": text} for text in train_data[\"text\"]}\n", + "trainset = asyncio.run(generate_dataset(chunks=train_chunks, filepath=entity_relationship_trainset_path))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for example in trainset:\n", + " for relationship in example.relationships.context:\n", + " if relationship.order == 2:\n", + " print(relationship)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for example in trainset:\n", + " for relationship in example.relationships.context:\n", + " if relationship.order == 3:\n", + " print(relationship)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainset[0].relationships.context[:2]" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "DEBUG:nano-graphrag:Entities: 21 | Missed Entities: 14 | Total Entities: 35\n", + "DEBUG:nano-graphrag:Entities: 10 | Missed Entities: 5 | Total Entities: 15\n", + "DEBUG:nano-graphrag:Relationships: 22 | Missed Relationships: 14 | Total Relationships: 36\n", + "DEBUG:nano-graphrag:Relationships: 10 | Missed Relationships: 5 | Total Relationships: 15\n", + "DEBUG:nano-graphrag:Direct Relationships: 36 | Second-order: 0 | Third-order: 0 | Total Relationships: 36\n", + "DEBUG:nano-graphrag:Direct Relationships: 12 | Second-order: 3 | Third-order: 0 | Total Relationships: 15\n", + "INFO:nano-graphrag:Saved 2 examples with keys: ['input_text', 'entities', 'relationships']\n" + ] + } + ], + "source": [ + "val_chunks = {compute_mdhash_id(text, prefix=f\"chunk-\"): {\"content\": text} for text in val_data[\"text\"]}\n", + "valset = asyncio.run(generate_dataset(chunks=val_chunks, filepath=entity_relationship_valset_path))" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Relationship(src_id='PORTUGAL', tgt_id='EURO 2016', description='Portugal qualified for the final of Euro 2016.', weight=0.9, order=1),\n", + " Relationship(src_id='PORTUGAL', tgt_id='WALES', description='Portugal defeated Wales in the semifinal of Euro 2016.', weight=0.9, order=1)]" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "valset[0].relationships.context[:2]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for example in valset:\n", + " for relationship in example.relationships.context:\n", + " if relationship.order == 2:\n", + " print(relationship)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for example in valset:\n", + " for relationship in example.relationships.context:\n", + " if relationship.order == 3:\n", + " print(relationship)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "DEBUG:nano-graphrag:Entities: 27 | Missed Entities: 9 | Total Entities: 36\n", + "DEBUG:nano-graphrag:Entities: 14 | Missed Entities: 7 | Total Entities: 21\n", + "DEBUG:nano-graphrag:Entities: 7 | Missed Entities: 4 | Total Entities: 11\n", + "DEBUG:nano-graphrag:Relationships: 19 | Missed Relationships: 8 | Total Relationships: 27\n", + "DEBUG:nano-graphrag:Relationships: 14 | Missed Relationships: 8 | Total Relationships: 22\n", + "DEBUG:nano-graphrag:Relationships: 8 | Missed Relationships: 8 | Total Relationships: 16\n", + "DEBUG:nano-graphrag:Direct Relationships: 27 | Second-order: 0 | Third-order: 0 | Total Relationships: 27\n", + "DEBUG:nano-graphrag:Direct Relationships: 18 | Second-order: 4 | Third-order: 0 | Total Relationships: 22\n", + "DEBUG:nano-graphrag:Direct Relationships: 12 | Second-order: 4 | Third-order: 0 | Total Relationships: 16\n", + "INFO:nano-graphrag:Saved 3 examples with keys: ['input_text', 'entities', 'relationships']\n" + ] + } + ], + "source": [ + "dev_chunks = {compute_mdhash_id(text, prefix=f\"chunk-\"): {\"content\": text} for text in dev_data[\"text\"]}\n", + "devset = asyncio.run(generate_dataset(chunks=dev_chunks, filepath=entity_relationship_devset_path))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "devset[0].relationships.context[:2]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for example in devset:\n", + " for relationship in example.relationships.context:\n", + " if relationship.order == 2:\n", + " print(relationship)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for example in devset:\n", + " for relationship in example.relationships.context:\n", + " if relationship.order == 3:\n", + " print(relationship)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "extractor.predictor = Predict(CombinedExtraction(input_text, entity_types -> entities, relationships\n", + " instructions='Signature for extracting both entities and relationships from input text.'\n", + " input_text = Field(annotation=str required=True json_schema_extra={'desc': 'The text to extract entities and relationships from.', '__dspy_field_type': 'input', 'prefix': 'Input Text:'})\n", + " entity_types = Field(annotation=EntityTypes required=True json_schema_extra={'__dspy_field_type': 'input', 'prefix': 'Entity Types:', 'desc': '${entity_types}'})\n", + " entities = Field(annotation=Entities required=True json_schema_extra={'desc': '\\n Format:\\n {\\n \"context\": [\\n {\\n \"entity_name\": \"ENTITY NAME\",\\n \"entity_type\": \"ENTITY TYPE\",\\n \"description\": \"Detailed description\",\\n \"importance_score\": 0.8\\n },\\n ...\\n ]\\n }\\n Each entity name should be an actual atomic word from the input text. Avoid duplicates and generic terms.\\n Make sure descriptions are concise and specific, and all entity types are included from the text. \\n Entities must have an importance score greater than 0.5.\\n IMPORTANT: Only use entity types from the provided \\'entity_types\\' list. Do not introduce new entity types.\\n Ensure the output is strictly JSON formatted without any trailing text or comments.\\n ', '__dspy_field_type': 'output', 'prefix': 'Entities:'})\n", + " relationships = Field(annotation=Relationships required=True json_schema_extra={'desc': '\\n Format:\\n {\\n \"context\": [\\n {\\n \"src_id\": \"SOURCE ENTITY\",\\n \"tgt_id\": \"TARGET ENTITY\",\\n \"description\": \"Detailed description of the relationship\",\\n \"weight\": 0.7,\\n \"order\": 1 # 1 for direct relationships, 2 for second-order, 3 for third-order, etc.\\n },\\n ...\\n ]\\n }\\n Make sure relationships are detailed and specific.\\n Include direct relationships (order 1) as well as higher-order relationships (order 2 and 3):\\n - Direct relationships: Immediate connections between entities.\\n - Second-order relationships: Indirect effects or connections that result from direct relationships.\\n - Third-order relationships: Further indirect effects that result from second-order relationships.\\n IMPORTANT: Only include relationships between existing entities from the extracted entities. Do not introduce new entities here.\\n The \"src_id\" and \"tgt_id\" fields must exactly match entity names from the extracted entities list.\\n Ensure the output is strictly JSON formatted without any trailing text or comments.\\n ', '__dspy_field_type': 'output', 'prefix': 'Relationships:'})\n", + "))\n", + "self_reflection.predictor = Predict(CombinedSelfReflection(input_text, entity_types, entities, relationships -> missing_entities, missing_relationships\n", + " instructions='Signature for combined self-reflection on extracted entities and relationships.\\nSelf-reflection is on the completeness and quality of both the extracted entities and relationships.'\n", + " input_text = Field(annotation=str required=True json_schema_extra={'desc': 'The original input text.', '__dspy_field_type': 'input', 'prefix': 'Input Text:'})\n", + " entity_types = Field(annotation=EntityTypes required=True json_schema_extra={'__dspy_field_type': 'input', 'prefix': 'Entity Types:', 'desc': '${entity_types}'})\n", + " entities = Field(annotation=Entities required=True json_schema_extra={'desc': 'List of extracted entities.', '__dspy_field_type': 'input', 'prefix': 'Entities:'})\n", + " relationships = Field(annotation=Relationships required=True json_schema_extra={'desc': 'List of extracted relationships.', '__dspy_field_type': 'input', 'prefix': 'Relationships:'})\n", + " missing_entities = Field(annotation=Entities required=True json_schema_extra={'desc': '\\n Format:\\n {\\n \"context\": [\\n {\\n \"entity_name\": \"ENTITY NAME\",\\n \"entity_type\": \"ENTITY TYPE\",\\n \"description\": \"Detailed description\",\\n \"importance_score\": 0.8\\n },\\n ...\\n ]\\n }\\n More specifically:\\n 1. Entities mentioned in the text but not captured in the initial extraction.\\n 2. Implicit entities that are crucial to the context but not explicitly mentioned.\\n 3. Entities that belong to the identified entity types but were overlooked.\\n 4. Subtypes or more specific instances of the already extracted entities.\\n Ensure the output is strictly JSON formatted without any trailing text or comments.\\n ', '__dspy_field_type': 'output', 'prefix': 'Missing Entities:'})\n", + " missing_relationships = Field(annotation=Relationships required=True json_schema_extra={'desc': '\\n Format:\\n {\\n \"context\": [\\n {\\n \"src_id\": \"SOURCE ENTITY\",\\n \"tgt_id\": \"TARGET ENTITY\",\\n \"description\": \"Detailed description of the relationship\",\\n \"weight\": 0.7,\\n \"order\": 1 # 1 for direct, 2 for second-order, 3 for third-order\\n },\\n ...\\n ]\\n }\\n More specifically:\\n 1. Direct relationships (order 1) between entities that were not captured initially.\\n 2. Second-order relationships (order 2): Indirect effects or connections resulting from direct relationships.\\n 3. Third-order relationships (order 3): Further indirect effects resulting from second-order relationships.\\n 4. Implicit relationships that can be inferred from the context.\\n 5. Hierarchical, causal, or temporal relationships that may have been overlooked.\\n 6. Relationships involving the newly identified missing entities.\\n Only include relationships between entities in the combined entities list (extracted + missing).\\n Ensure the output is strictly JSON formatted without any trailing text or comments.\\n ', '__dspy_field_type': 'output', 'prefix': 'Missing Relationships:'})\n", + "))" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = EntityRelationshipExtractor()\n", + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "DEBUG:nano-graphrag:Entities: 13 | Missed Entities: 10 | Total Entities: 23\n", + "DEBUG:nano-graphrag:Entities: 22 | Missed Entities: 14 | Total Entities: 36\n", + " 0%| | 0/3 [00:00\n", + "#T_34e56 th {\n", + " text-align: left;\n", + "}\n", + "#T_34e56 td {\n", + " text-align: left;\n", + "}\n", + "#T_34e56_row0_col0, #T_34e56_row0_col1, #T_34e56_row0_col2, #T_34e56_row0_col3, #T_34e56_row0_col4, #T_34e56_row0_col5, #T_34e56_row1_col0, #T_34e56_row1_col1, #T_34e56_row1_col2, #T_34e56_row1_col3, #T_34e56_row1_col4, #T_34e56_row1_col5, #T_34e56_row2_col0, #T_34e56_row2_col1, #T_34e56_row2_col2, #T_34e56_row2_col3, #T_34e56_row2_col4, #T_34e56_row2_col5 {\n", + " text-align: left;\n", + " white-space: pre-wrap;\n", + " word-wrap: break-word;\n", + " max-width: 400px;\n", + "}\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
 input_textexample_entitiesexample_relationshipspred_entitiespred_relationshipsentity_recall_metric
0As students from Marjory Stoneman Douglas High School confront lawmakers with demands to restrict sales of assault rifles, there were warnings by the president of...context=[Entity(entity_name='MARJORY STONEMAN DOUGLAS HIGH SCHOOL', entity_type='ORGANIZATION', description='A high school in Florida where a mass shooting occurred.', importance_score=0.9), Entity(entity_name='NIKOLAS CRUZ', entity_type='PERSON', description='The gunman who carried out...context=[Relationship(src_id='MARJORY STONEMAN DOUGLAS HIGH SCHOOL', tgt_id='NIKOLAS CRUZ', description='Nikolas Cruz carried out a mass shooting at Marjory Stoneman Douglas High School.', weight=0.9, order=1), Relationship(src_id='NIKOLAS CRUZ', tgt_id='FLORIDA',...context=[Entity(entity_name='MARJORY STONEMAN DOUGLAS HIGH SCHOOL', entity_type='ORGANIZATION', description='A high school in Florida where a mass shooting occurred.', importance_score=0.9), Entity(entity_name='NIKOLAS CRUZ', entity_type='PERSON', description='The gunman who carried out...context=[Relationship(src_id='NIKOLAS CRUZ', tgt_id='MARJORY STONEMAN DOUGLAS HIGH SCHOOL', description='Nikolas Cruz carried out a mass shooting at Marjory Stoneman Douglas High School.', weight=0.9, order=1), Relationship(src_id='LAURENZO PRADO', tgt_id='MARJORY...✔️ [0.8055555555555556]
1From ferrying people to and from their place of work to transporting nuclear waste and coal, railways are not only an integral part of 21st...context=[Entity(entity_name='RAILWAYS', entity_type='VEHICLE', description='Transportation system used for ferrying people and transporting nuclear waste and coal.', importance_score=0.9), Entity(entity_name='BELGIUM', entity_type='LOCATION', description='Country where a business is looking to innovate...context=[Relationship(src_id='RAILNOVA', tgt_id='BRUSSELS', description='Railnova is based in Brussels.', weight=0.9, order=1), Relationship(src_id='RAILNOVA', tgt_id='DEUTSCHE BAHN', description='Railnova serves Deutsche Bahn as a client.', weight=0.8, order=1), Relationship(src_id='RAILNOVA', tgt_id='SNCF', description='Railnova serves...context=[Entity(entity_name='RAILWAYS', entity_type='VEHICLE', description='A mode of transportation that involves trains running on tracks, used for various purposes including passenger and cargo transport.', importance_score=0.9), Entity(entity_name='BELGIUM', entity_type='LOCATION', description='A...context=[Relationship(src_id='RAILNOVA', tgt_id='DEUTSCHE BAHN', description='Railnova provides innovative technology solutions to Deutsche Bahn, a German railway company.', weight=0.8, order=1), Relationship(src_id='RAILNOVA', tgt_id='SNCF', description='Railnova offers its technology services to...✔️ [0.8095238095238095]
2Jan 22 (Reuters) - Shanghai Stock Exchange Filing * SHOWS BLOCK TRADE OF YONGHUI SUPERSTORES Co LTd's 166.3 MILLION SHARES INVOLVING 1.63 BILLION YUAN ($254.68...context=[Entity(entity_name='YONGHUI SUPERSTORES', entity_type='ORGANIZATION', description='A company involved in a block trade of its shares.', importance_score=0.9), Entity(entity_name='SHANGHAI STOCK EXCHANGE', entity_type='ORGANIZATION', description='The stock exchange where the block trade...context=[Relationship(src_id='YONGHUI SUPERSTORES', tgt_id='SHANGHAI STOCK EXCHANGE', description=\"YONGHUI SUPERSTORES' shares were traded on the SHANGHAI STOCK EXCHANGE.\", weight=0.9, order=1), Relationship(src_id='YONGHUI SUPERSTORES', tgt_id='166.3 MILLION SHARES', description='YONGHUI SUPERSTORES was...context=[Entity(entity_name='YONGHUI SUPERSTORES', entity_type='ORGANIZATION', description='A company involved in a block trade of its shares.', importance_score=0.9), Entity(entity_name='SHANGHAI STOCK EXCHANGE', entity_type='ORGANIZATION', description='The stock exchange where the block trade...context=[Relationship(src_id='YONGHUI SUPERSTORES', tgt_id='166.3 MILLION SHARES', description='YONGHUI SUPERSTORES was involved in a block trade of 166.3 million shares.', weight=0.9, order=1), Relationship(src_id='166.3 MILLION SHARES', tgt_id='1.63 BILLION YUAN',...✔️ [0.7272727272727273]
\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "DEBUG:nano-graphrag:Entities: 22 | Missed Entities: 14 | Total Entities: 36\n", + "DEBUG:nano-graphrag:Entities: 13 | Missed Entities: 10 | Total Entities: 23\n", + " 0%| | 0/3 [00:00\n", + "#T_465ae th {\n", + " text-align: left;\n", + "}\n", + "#T_465ae td {\n", + " text-align: left;\n", + "}\n", + "#T_465ae_row0_col0, #T_465ae_row0_col1, #T_465ae_row0_col2, #T_465ae_row0_col3, #T_465ae_row0_col4, #T_465ae_row0_col5, #T_465ae_row1_col0, #T_465ae_row1_col1, #T_465ae_row1_col2, #T_465ae_row1_col3, #T_465ae_row1_col4, #T_465ae_row1_col5, #T_465ae_row2_col0, #T_465ae_row2_col1, #T_465ae_row2_col2, #T_465ae_row2_col3, #T_465ae_row2_col4, #T_465ae_row2_col5 {\n", + " text-align: left;\n", + " white-space: pre-wrap;\n", + " word-wrap: break-word;\n", + " max-width: 400px;\n", + "}\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
 input_textexample_entitiesexample_relationshipspred_entitiespred_relationshipsrelationship_similarity_metric
0As students from Marjory Stoneman Douglas High School confront lawmakers with demands to restrict sales of assault rifles, there were warnings by the president of...context=[Entity(entity_name='MARJORY STONEMAN DOUGLAS HIGH SCHOOL', entity_type='ORGANIZATION', description='A high school in Florida where a mass shooting occurred.', importance_score=0.9), Entity(entity_name='NIKOLAS CRUZ', entity_type='PERSON', description='The gunman who carried out...context=[Relationship(src_id='MARJORY STONEMAN DOUGLAS HIGH SCHOOL', tgt_id='NIKOLAS CRUZ', description='Nikolas Cruz carried out a mass shooting at Marjory Stoneman Douglas High School.', weight=0.9, order=1), Relationship(src_id='NIKOLAS CRUZ', tgt_id='FLORIDA',...context=[Entity(entity_name='MARJORY STONEMAN DOUGLAS HIGH SCHOOL', entity_type='ORGANIZATION', description='A high school in Florida where a mass shooting occurred.', importance_score=0.9), Entity(entity_name='NIKOLAS CRUZ', entity_type='PERSON', description='The gunman who carried out...context=[Relationship(src_id='NIKOLAS CRUZ', tgt_id='MARJORY STONEMAN DOUGLAS HIGH SCHOOL', description='Nikolas Cruz carried out a mass shooting at Marjory Stoneman Douglas High School.', weight=0.9, order=1), Relationship(src_id='LAURENZO PRADO', tgt_id='MARJORY...✔️ [0.946203351020813]
1From ferrying people to and from their place of work to transporting nuclear waste and coal, railways are not only an integral part of 21st...context=[Entity(entity_name='RAILWAYS', entity_type='VEHICLE', description='Transportation system used for ferrying people and transporting nuclear waste and coal.', importance_score=0.9), Entity(entity_name='BELGIUM', entity_type='LOCATION', description='Country where a business is looking to innovate...context=[Relationship(src_id='RAILNOVA', tgt_id='BRUSSELS', description='Railnova is based in Brussels.', weight=0.9, order=1), Relationship(src_id='RAILNOVA', tgt_id='DEUTSCHE BAHN', description='Railnova serves Deutsche Bahn as a client.', weight=0.8, order=1), Relationship(src_id='RAILNOVA', tgt_id='SNCF', description='Railnova serves...context=[Entity(entity_name='RAILWAYS', entity_type='VEHICLE', description='A mode of transportation that involves trains running on tracks, used for various purposes including passenger and cargo transport.', importance_score=0.9), Entity(entity_name='BELGIUM', entity_type='LOCATION', description='A...context=[Relationship(src_id='RAILNOVA', tgt_id='DEUTSCHE BAHN', description='Railnova provides innovative technology solutions to Deutsche Bahn, a German railway company.', weight=0.8, order=1), Relationship(src_id='RAILNOVA', tgt_id='SNCF', description='Railnova offers its technology services to...✔️ [0.9310485124588013]
2Jan 22 (Reuters) - Shanghai Stock Exchange Filing * SHOWS BLOCK TRADE OF YONGHUI SUPERSTORES Co LTd's 166.3 MILLION SHARES INVOLVING 1.63 BILLION YUAN ($254.68...context=[Entity(entity_name='YONGHUI SUPERSTORES', entity_type='ORGANIZATION', description='A company involved in a block trade of its shares.', importance_score=0.9), Entity(entity_name='SHANGHAI STOCK EXCHANGE', entity_type='ORGANIZATION', description='The stock exchange where the block trade...context=[Relationship(src_id='YONGHUI SUPERSTORES', tgt_id='SHANGHAI STOCK EXCHANGE', description=\"YONGHUI SUPERSTORES' shares were traded on the SHANGHAI STOCK EXCHANGE.\", weight=0.9, order=1), Relationship(src_id='YONGHUI SUPERSTORES', tgt_id='166.3 MILLION SHARES', description='YONGHUI SUPERSTORES was...context=[Entity(entity_name='YONGHUI SUPERSTORES', entity_type='ORGANIZATION', description='A company involved in a block trade of its shares.', importance_score=0.9), Entity(entity_name='SHANGHAI STOCK EXCHANGE', entity_type='ORGANIZATION', description='The stock exchange where the block trade...context=[Relationship(src_id='YONGHUI SUPERSTORES', tgt_id='166.3 MILLION SHARES', description='YONGHUI SUPERSTORES was involved in a block trade of 166.3 million shares.', weight=0.9, order=1), Relationship(src_id='166.3 MILLION SHARES', tgt_id='1.63 BILLION YUAN',...✔️ [0.9334976673126221]
\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "metrics = [entity_recall_metric, relationship_similarity_metric]\n", + "for metric in metrics:\n", + " evaluate = Evaluate(\n", + " devset=devset, \n", + " metric=metric, \n", + " num_threads=os.cpu_count(), \n", + " display_progress=True,\n", + " display_table=5,\n", + " )\n", + " evaluate(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = BootstrapFewShotWithRandomSearch(\n", + " metric=relationship_similarity_metric, \n", + " num_threads=os.cpu_count(),\n", + " num_candidate_programs=4,\n", + " max_labeled_demos=5,\n", + " max_bootstrapped_demos=3,\n", + ")\n", + "rs_model = optimizer.compile(model, trainset=trainset, valset=valset)\n", + "rs_model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metrics = [entity_recall_metric, relationship_similarity_metric]\n", + "for metric in metrics:\n", + " evaluate = Evaluate(\n", + " devset=devset, \n", + " metric=metric, \n", + " num_threads=os.cpu_count(), \n", + " display_progress=True,\n", + " display_table=5,\n", + " )\n", + " evaluate(rs_model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = MIPROv2(\n", + " prompt_model=lm,\n", + " task_model=llama_lm,\n", + " metric=relationship_similarity_metric,\n", + " init_temperature=1.0,\n", + " num_candidates=4\n", + ")\n", + "miprov2_model = optimizer.compile(model, trainset=trainset, valset=valset, num_batches=5, max_labeled_demos=5, max_bootstrapped_demos=3)\n", + "miprov2_model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metrics = [entity_recall_metric, relationship_similarity_metric]\n", + "for metric in metrics:\n", + " evaluate = Evaluate(\n", + " devset=devset, \n", + " metric=metric, \n", + " num_threads=os.cpu_count(), \n", + " display_progress=True,\n", + " display_table=5,\n", + " )\n", + " evaluate(miprov2_model)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "nano-graphrag", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/using_dspy_entity_extraction.py b/examples/using_dspy_entity_extraction.py new file mode 100644 index 0000000..7e5a29c --- /dev/null +++ b/examples/using_dspy_entity_extraction.py @@ -0,0 +1,151 @@ +import os +from openai import AsyncOpenAI +from dotenv import load_dotenv +import logging +import numpy as np +import dspy +from sentence_transformers import SentenceTransformer +from nano_graphrag import GraphRAG, QueryParam +from nano_graphrag._llm import gpt_4o_mini_complete +from nano_graphrag._storage import HNSWVectorStorage +from nano_graphrag.base import BaseKVStorage +from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs +from nano_graphrag.entity_extraction.extract import extract_entities_dspy + +logging.basicConfig(level=logging.WARNING) +logging.getLogger("nano-graphrag").setLevel(logging.DEBUG) + +WORKING_DIR = "./nano_graphrag_cache_using_dspy_entity_extraction" + +load_dotenv() + + +EMBED_MODEL = SentenceTransformer( + "sentence-transformers/all-MiniLM-L6-v2", cache_folder=WORKING_DIR, device="cpu" +) + + +@wrap_embedding_func_with_attrs( + embedding_dim=EMBED_MODEL.get_sentence_embedding_dimension(), + max_token_size=EMBED_MODEL.max_seq_length, +) +async def local_embedding(texts: list[str]) -> np.ndarray: + return EMBED_MODEL.encode(texts, normalize_embeddings=True) + + +async def deepseepk_model_if_cache( + prompt, model: str = "deepseek-chat", system_prompt=None, history_messages=[], **kwargs +) -> str: + openai_async_client = AsyncOpenAI( + api_key=os.environ.get("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com" + ) + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + + # Get the cached response if having------------------- + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + if hashing_kv is not None: + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + # ----------------------------------------------------- + + response = await openai_async_client.chat.completions.create( + model=model, messages=messages, **kwargs + ) + + # Cache the response if having------------------- + if hashing_kv is not None: + await hashing_kv.upsert( + {args_hash: {"return": response.choices[0].message.content, "model": model}} + ) + # ----------------------------------------------------- + return response.choices[0].message.content + + + +def remove_if_exist(file): + if os.path.exists(file): + os.remove(file) + + +def insert(): + from time import time + + with open("./tests/mock_data.txt", encoding="utf-8-sig") as f: + FAKE_TEXT = f.read() + + remove_if_exist(f"{WORKING_DIR}/vdb_entities.json") + remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json") + remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json") + remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json") + remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml") + rag = GraphRAG( + working_dir=WORKING_DIR, + enable_llm_cache=True, + vector_db_storage_cls=HNSWVectorStorage, + vector_db_storage_cls_kwargs={"max_elements": 1000000, "ef_search": 200, "M": 50}, + best_model_max_async=10, + cheap_model_max_async=10, + best_model_func=deepseepk_model_if_cache, + cheap_model_func=deepseepk_model_if_cache, + embedding_func=local_embedding, + entity_extraction_func=extract_entities_dspy + ) + start = time() + rag.insert(FAKE_TEXT) + print("indexing time:", time() - start) + + +def query(): + rag = GraphRAG( + working_dir=WORKING_DIR, + enable_llm_cache=True, + vector_db_storage_cls=HNSWVectorStorage, + vector_db_storage_cls_kwargs={"max_elements": 1000000, "ef_search": 200, "M": 50}, + best_model_max_token_size=8196, + cheap_model_max_token_size=8196, + best_model_max_async=4, + cheap_model_max_async=4, + best_model_func=gpt_4o_mini_complete, + cheap_model_func=gpt_4o_mini_complete, + embedding_func=local_embedding, + entity_extraction_func=extract_entities_dspy + + ) + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="global") + ) + ) + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="local") + ) + ) + + +if __name__ == "__main__": + system_prompt = """ + You are a world-class AI system, capable of complex rationale and reflection. + Reason through the query, and then provide your final response. + If you detect that you made a mistake in your rationale at any point, correct yourself. + Think carefully. + """ + lm = dspy.OpenAI( + model="deepseek-chat", + model_type="chat", + api_key=os.environ["DEEPSEEK_API_KEY"], + base_url=os.environ["DEEPSEEK_BASE_URL"], + system_prompt=system_prompt, + temperature=1.0, + top_p=1, + max_tokens=4096 + ) + dspy.settings.configure(lm=lm) + insert() + query() diff --git a/examples/using_hnsw_as_vectorDB.py b/examples/using_hnsw_as_vectorDB.py index 8914da5..0320d1e 100644 --- a/examples/using_hnsw_as_vectorDB.py +++ b/examples/using_hnsw_as_vectorDB.py @@ -2,12 +2,13 @@ from openai import AsyncOpenAI from dotenv import load_dotenv import logging - +import numpy as np +from sentence_transformers import SentenceTransformer from nano_graphrag import GraphRAG, QueryParam from nano_graphrag._llm import gpt_4o_mini_complete from nano_graphrag._storage import HNSWVectorStorage from nano_graphrag.base import BaseKVStorage -from nano_graphrag._utils import compute_args_hash +from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs logging.basicConfig(level=logging.WARNING) logging.getLogger("nano-graphrag").setLevel(logging.DEBUG) @@ -17,6 +18,19 @@ load_dotenv() +EMBED_MODEL = SentenceTransformer( + "sentence-transformers/all-MiniLM-L6-v2", cache_folder=WORKING_DIR, device="cpu" +) + + +@wrap_embedding_func_with_attrs( + embedding_dim=EMBED_MODEL.get_sentence_embedding_dimension(), + max_token_size=EMBED_MODEL.max_seq_length, +) +async def local_embedding(texts: list[str]) -> np.ndarray: + return EMBED_MODEL.encode(texts, normalize_embeddings=True) + + async def deepseepk_model_if_cache( prompt, model: str = "deepseek-chat", system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -77,6 +91,7 @@ def insert(): cheap_model_max_async=10, best_model_func=deepseepk_model_if_cache, cheap_model_func=deepseepk_model_if_cache, + embedding_func=local_embedding ) start = time() rag.insert(FAKE_TEXT) @@ -95,7 +110,7 @@ def query(): cheap_model_max_async=4, best_model_func=gpt_4o_mini_complete, cheap_model_func=gpt_4o_mini_complete, - + embedding_func=local_embedding ) print( rag.query( diff --git a/nano_graphrag/_llm.py b/nano_graphrag/_llm.py index 21ffe82..66c6f51 100644 --- a/nano_graphrag/_llm.py +++ b/nano_graphrag/_llm.py @@ -15,7 +15,7 @@ @retry( - stop=stop_after_attempt(3), + stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((RateLimitError, APIConnectionError)), ) @@ -72,7 +72,7 @@ async def gpt_4o_mini_complete( @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @retry( - stop=stop_after_attempt(3), + stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((RateLimitError, APIConnectionError)), ) diff --git a/nano_graphrag/_op.py b/nano_graphrag/_op.py index ac762af..691e5d8 100644 --- a/nano_graphrag/_op.py +++ b/nano_graphrag/_op.py @@ -14,7 +14,7 @@ list_of_list_to_csv, pack_user_ass_to_openai_messages, split_string_by_multi_markers, - truncate_list_by_token_size, + truncate_list_by_token_size ) from .base import ( BaseGraphStorage, @@ -177,6 +177,7 @@ async def _merge_edges_then_upsert( already_weights = [] already_source_ids = [] already_description = [] + already_order = [] if await knwoledge_graph_inst.has_edge(src_id, tgt_id): already_edge = await knwoledge_graph_inst.get_edge(src_id, tgt_id) already_weights.append(already_edge["weight"]) @@ -184,7 +185,10 @@ async def _merge_edges_then_upsert( split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP]) ) already_description.append(already_edge["description"]) + already_order.append(already_edge.get("order", 1)) + # [numberchiffre]: `Relationship.order` is only returned from DSPy's predictions + order = min([dp.get("order", 1) for dp in edges_data] + already_order) weight = sum([dp["weight"] for dp in edges_data] + already_weights) description = GRAPH_FIELD_SEP.join( sorted(set([dp["description"] for dp in edges_data] + already_description)) @@ -212,6 +216,7 @@ async def _merge_edges_then_upsert( weight=weight, description=description, source_id=source_id, + order=order ), ) diff --git a/nano_graphrag/entity_extraction/__init__.py b/nano_graphrag/entity_extraction/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nano_graphrag/entity_extraction/extract.py b/nano_graphrag/entity_extraction/extract.py new file mode 100644 index 0000000..0462b59 --- /dev/null +++ b/nano_graphrag/entity_extraction/extract.py @@ -0,0 +1,137 @@ +from typing import Union +import pickle +import asyncio +from collections import defaultdict +import dspy +from nano_graphrag._storage import BaseGraphStorage +from nano_graphrag.base import ( + BaseGraphStorage, + BaseVectorStorage, + TextChunkSchema, +) +from nano_graphrag.prompt import PROMPTS +from nano_graphrag._utils import logger, compute_mdhash_id +from nano_graphrag.entity_extraction.module import EntityRelationshipExtractor +from nano_graphrag._op import _merge_edges_then_upsert, _merge_nodes_then_upsert + + +async def generate_dataset( + chunks: dict[str, TextChunkSchema], + filepath: str, + save_dataset: bool = True +) -> list[dspy.Example]: + entity_extractor = EntityRelationshipExtractor() + ordered_chunks = list(chunks.items()) + + async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]) -> dspy.Example: + chunk_dp = chunk_key_dp[1] + content = chunk_dp["content"] + prediction = await asyncio.to_thread( + entity_extractor, input_text=content + ) + example = dspy.Example( + input_text=content, + entities=prediction.entities, + relationships=prediction.relationships + ).with_inputs("input_text") + return example + + examples = await asyncio.gather( + *[_process_single_content(c) for c in ordered_chunks] + ) + if save_dataset: + with open(filepath, 'wb') as f: + pickle.dump(examples, f) + logger.info(f"Saved {len(examples)} examples with keys: {examples[0].keys()}") + + return examples + + +async def extract_entities_dspy( + chunks: dict[str, TextChunkSchema], + knwoledge_graph_inst: BaseGraphStorage, + entity_vdb: BaseVectorStorage, + global_config: dict, +) -> Union[BaseGraphStorage, None]: + entity_extractor = EntityRelationshipExtractor() + + if global_config.get("use_compiled_dspy_entity_relationship", False): + entity_extractor.load(global_config["entity_relationship_module_path"]) + + ordered_chunks = list(chunks.items()) + already_processed = 0 + already_entities = 0 + already_relations = 0 + + async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): + nonlocal already_processed, already_entities, already_relations + chunk_key = chunk_key_dp[0] + chunk_dp = chunk_key_dp[1] + content = chunk_dp["content"] + prediction = await asyncio.to_thread( + entity_extractor, input_text=content + ) + + maybe_nodes = defaultdict(list) + maybe_edges = defaultdict(list) + + for entity in prediction.entities.context: + entity_dict = entity.dict() + entity_dict["source_id"] = chunk_key + maybe_nodes[entity_dict['entity_name']].append(entity_dict) + already_entities += 1 + + for relationship in prediction.relationships.context: + relationship_dict = relationship.dict() + relationship_dict["source_id"] = chunk_key + maybe_edges[(relationship_dict['src_id'], relationship_dict['tgt_id'])].append(relationship_dict) + already_relations += 1 + + already_processed += 1 + now_ticks = PROMPTS["process_tickers"][ + already_processed % len(PROMPTS["process_tickers"]) + ] + print( + f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r", + end="", + flush=True, + ) + return dict(maybe_nodes), dict(maybe_edges) + + results = await asyncio.gather( + *[_process_single_content(c) for c in ordered_chunks] + ) + print() + maybe_nodes = defaultdict(list) + maybe_edges = defaultdict(list) + for m_nodes, m_edges in results: + for k, v in m_nodes.items(): + maybe_nodes[k].extend(v) + for k, v in m_edges.items(): + maybe_edges[k].extend(v) + all_entities_data = await asyncio.gather( + *[ + _merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config) + for k, v in maybe_nodes.items() + ] + ) + await asyncio.gather( + *[ + _merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config) + for k, v in maybe_edges.items() + ] + ) + if not len(all_entities_data): + logger.warning("Didn't extract any entities, maybe your LLM is not working") + return None + if entity_vdb is not None: + data_for_vdb = { + compute_mdhash_id(dp["entity_name"], prefix="ent-"): { + "content": dp["entity_name"] + dp["description"], + "entity_name": dp["entity_name"], + } + for dp in all_entities_data + } + await entity_vdb.upsert(data_for_vdb) + + return knwoledge_graph_inst diff --git a/nano_graphrag/entity_extraction/metric.py b/nano_graphrag/entity_extraction/metric.py new file mode 100644 index 0000000..7b9712c --- /dev/null +++ b/nano_graphrag/entity_extraction/metric.py @@ -0,0 +1,36 @@ +import dspy +import numpy as np + + +class AssessRelationship(dspy.Signature): + """Assess the similarity of two relationships.""" + gold_relationship = dspy.InputField() + predicted_relationship = dspy.InputField() + similarity_score = dspy.OutputField(desc="Similarity score between 0 and 1") + + +def relationship_similarity_metric(gold: dspy.Example, pred: dspy.Prediction, trace=None) -> float: + similarity_scores = [] + + for gold_rel, pred_rel in zip(gold.relationships.context, pred.relationships.context): + assessment = dspy.Predict(AssessRelationship)( + gold_relationship=gold_rel, + predicted_relationship=pred_rel + ) + + try: + score = float(assessment.similarity_score) + similarity_scores.append(score) + except ValueError: + similarity_scores.append(0.0) + + return np.mean(similarity_scores) if similarity_scores else 0.0 + + +def entity_recall_metric(gold: dspy.Example, pred: dspy.Prediction, trace=None) -> float: + true_set = set(item.entity_name for item in gold.entities.context) + pred_set = set(item.entity_name for item in pred.entities.context) + true_positives = len(pred_set.intersection(true_set)) + false_negatives = len(true_set - pred_set) + recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0 + return recall diff --git a/nano_graphrag/entity_extraction/module.py b/nano_graphrag/entity_extraction/module.py new file mode 100644 index 0000000..fbcd31e --- /dev/null +++ b/nano_graphrag/entity_extraction/module.py @@ -0,0 +1,218 @@ +import dspy +from pydantic import BaseModel, Field +from nano_graphrag._utils import logger, clean_str + + +class EntityTypes(BaseModel): + """ + Obtained from: + https://github.com/SciPhi-AI/R2R/blob/6e958d1e451c1cb10b6fc868572659785d1091cb/r2r/providers/prompts/defaults.jsonl + """ + context: list[str] = Field( + default=[ + "PERSON", "ORGANIZATION", "LOCATION", "DATE", "TIME", "MONEY", + "PERCENTAGE", "PRODUCT", "EVENT", "LANGUAGE", "NATIONALITY", + "RELIGION", "TITLE", "PROFESSION", "ANIMAL", "PLANT", "DISEASE", + "MEDICATION", "CHEMICAL", "MATERIAL", "COLOR", "SHAPE", + "MEASUREMENT", "WEATHER", "NATURAL_DISASTER", "AWARD", "LAW", + "CRIME", "TECHNOLOGY", "SOFTWARE", "HARDWARE", "VEHICLE", + "FOOD", "DRINK", "SPORT", "MUSIC_GENRE", "INSTRUMENT", + "ARTWORK", "BOOK", "MOVIE", "TV_SHOW", "ACADEMIC_SUBJECT", + "SCIENTIFIC_THEORY", "POLITICAL_PARTY", "CURRENCY", + "STOCK_SYMBOL", "FILE_TYPE", "PROGRAMMING_LANGUAGE", + "MEDICAL_PROCEDURE", "CELESTIAL_BODY" + ], + description="List of entity types used for extraction." + ) + + +class Entity(BaseModel): + entity_name: str = Field(..., description="Cleaned and uppercased entity name, strictly upper case") + entity_type: str = Field(..., description="Cleaned and uppercased entity type, strictly upper case") + description: str = Field(..., description="Detailed and specific description of the entity") + importance_score: float = Field(ge=0.0, le=1.0, description="0 to 1, with 1 being most important") + + +class Relationship(BaseModel): + src_id: str = Field(..., description="Cleaned and uppercased source entity, strictly upper case") + tgt_id: str = Field(..., description="Cleaned and uppercased target entity, strictly upper case") + description: str = Field(..., description="Detailed and specific description of the relationship") + weight: float = Field(ge=0.0, le=1.0, description="0 to 1, with 1 being most important") + order: int = Field(..., description="1 for direct relationships, 2 for second-order, 3 for third-order, etc") + + +class Entities(BaseModel): + context: list[Entity] + + +class Relationships(BaseModel): + context: list[Relationship] + + +class CombinedExtraction(dspy.Signature): + """Signature for extracting both entities and relationships from input text.""" + + input_text: str = dspy.InputField(desc="The text to extract entities and relationships from.") + entity_types: EntityTypes = dspy.InputField() + entities: Entities = dspy.OutputField( + desc=""" + Format: + { + "context": [ + { + "entity_name": "ENTITY NAME", + "entity_type": "ENTITY TYPE", + "description": "Detailed description", + "importance_score": 0.8 + }, + ... + ] + } + Each entity name should be an actual atomic word from the input text. Avoid duplicates and generic terms. + Make sure descriptions are detailed and comprehensive, including: + 1. The entity's role or significance in the context + 2. Key attributes or characteristics + 3. Relationships to other entities (if applicable) + 4. Historical or cultural relevance (if applicable) + 5. Any notable actions or events associated with the entity + All entity types from the text must be included. + Entities must have an importance score greater than 0.5. + IMPORTANT: Only use entity types from the provided 'entity_types' list. Do not introduce new entity types. + Ensure the output is strictly JSON formatted without any trailing text or comments. + """ + ) + relationships: Relationships = dspy.OutputField( + desc=""" + Format: + { + "context": [ + { + "src_id": "SOURCE ENTITY", + "tgt_id": "TARGET ENTITY", + "description": "Detailed description of the relationship", + "weight": 0.7, + "order": 1 # 1 for direct relationships, 2 for second-order, 3 for third-order, etc. + }, + ... + ] + } + Make sure relationship descriptions are detailed and comprehensive, including: + 1. The nature of the relationship (e.g., familial, professional, causal) + 2. The impact or significance of the relationship on both entities + 3. Any historical or contextual information relevant to the relationship + 4. How the relationship evolved over time (if applicable) + 5. Any notable events or actions that resulted from this relationship + Include direct relationships (order 1) as well as higher-order relationships (order 2 and 3): + - Direct relationships: Immediate connections between entities. + - Second-order relationships: Indirect effects or connections that result from direct relationships. + - Third-order relationships: Further indirect effects that result from second-order relationships. + IMPORTANT: Only include relationships between existing entities from the extracted entities. Do not introduce new entities here. + The "src_id" and "tgt_id" fields must exactly match entity names from the extracted entities list. + Ensure the output is strictly JSON formatted without any trailing text or comments. + """ + ) + + +class CombinedSelfReflection(dspy.Signature): + """ + Signature for combined self-reflection on extracted entities and relationships. + Self-reflection is on the completeness and quality of both the extracted entities and relationships. + """ + + input_text: str = dspy.InputField(desc="The original input text.") + entity_types: EntityTypes = dspy.InputField() + entities: Entities = dspy.InputField(desc="List of extracted entities.") + relationships: Relationships = dspy.InputField(desc="List of extracted relationships.") + missing_entities: Entities = dspy.OutputField( + desc=""" + Format: + { + "context": [ + { + "entity_name": "ENTITY NAME", + "entity_type": "ENTITY TYPE", + "description": "Detailed description", + "importance_score": 0.8 + }, + ... + ] + } + More specifically: + 1. Entities mentioned in the text but not captured in the initial extraction. + 2. Implicit entities that are crucial to the context but not explicitly mentioned. + 3. Entities that belong to the identified entity types but were overlooked. + 4. Subtypes or more specific instances of the already extracted entities. + Ensure the output is strictly JSON formatted without any trailing text or comments. + """ + ) + missing_relationships: Relationships = dspy.OutputField( + desc=""" + Format: + { + "context": [ + { + "src_id": "SOURCE ENTITY", + "tgt_id": "TARGET ENTITY", + "description": "Detailed description of the relationship", + "weight": 0.7, + "order": 1 # 1 for direct, 2 for second-order, 3 for third-order + }, + ... + ] + } + More specifically: + 1. Direct relationships (order 1) between entities that were not captured initially. + 2. Second-order relationships (order 2): Indirect effects or connections resulting from direct relationships. + 3. Third-order relationships (order 3): Further indirect effects resulting from second-order relationships. + 4. Implicit relationships that can be inferred from the context. + 5. Hierarchical, causal, or temporal relationships that may have been overlooked. + 6. Relationships involving the newly identified missing entities. + Only include relationships between entities in the combined entities list (extracted + missing). + Ensure the output is strictly JSON formatted without any trailing text or comments. + """ + ) + + +class EntityRelationshipExtractor(dspy.Module): + def __init__(self): + super().__init__() + self.entity_types = EntityTypes() + self.extractor = dspy.TypedPredictor(CombinedExtraction) + self.self_reflection = dspy.TypedPredictor(CombinedSelfReflection) + + def forward(self, input_text: str) -> dspy.Prediction: + extraction_result = self.extractor(input_text=input_text, entity_types=self.entity_types) + reflection_result = self.self_reflection( + input_text=input_text, + entity_types=self.entity_types, + entities=extraction_result.entities, + relationships=extraction_result.relationships + ) + entities = extraction_result.entities + missing_entities = reflection_result.missing_entities + relationships = extraction_result.relationships + missing_relationships = reflection_result.missing_relationships + all_entities = Entities(context=entities.context + missing_entities.context) + all_relationships = Relationships(context=relationships.context + missing_relationships.context) + logger.debug(f"Entities: {len(entities.context)} | Missed Entities: {len(missing_entities.context)} | Total Entities: {len(all_entities.context)}") + logger.debug(f"Relationships: {len(relationships.context)} | Missed Relationships: {len(missing_relationships.context)} | Total Relationships: {len(all_relationships.context)}") + + for entity in all_entities.context: + entity.entity_name = clean_str(entity.entity_name.upper()) + entity.entity_type = clean_str(entity.entity_type.upper()) + entity.description = clean_str(entity.description) + entity.importance_score = float(entity.importance_score) + + for relationship in all_relationships.context: + relationship.src_id = clean_str(relationship.src_id.upper()) + relationship.tgt_id = clean_str(relationship.tgt_id.upper()) + relationship.description = clean_str(relationship.description) + relationship.weight = float(relationship.weight) + relationship.order = int(relationship.order) + + direct_relationships = sum(1 for r in all_relationships.context if r.order == 1) + second_order_relationships = sum(1 for r in all_relationships.context if r.order == 2) + third_order_relationships = sum(1 for r in all_relationships.context if r.order == 3) + logger.debug(f"Direct Relationships: {direct_relationships} | Second-order: {second_order_relationships} | Third-order: {third_order_relationships} | Total Relationships: {len(all_relationships.context)}") + return dspy.Prediction(entities=all_entities, relationships=all_relationships) + \ No newline at end of file diff --git a/nano_graphrag/graphrag.py b/nano_graphrag/graphrag.py index a73398a..53d14fd 100644 --- a/nano_graphrag/graphrag.py +++ b/nano_graphrag/graphrag.py @@ -112,6 +112,9 @@ class GraphRAG: cheap_model_max_token_size: int = 32768 cheap_model_max_async: int = 16 + # entity extraction + entity_extraction_func: callable = extract_entities + # storage key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage @@ -293,7 +296,7 @@ async def ainsert(self, string_or_strings): # ---------- extract/summary entity and upsert to graph logger.info("[Entity Extraction]...") - maybe_new_kg = await extract_entities( + maybe_new_kg = await self.entity_extraction_func( inserting_chunks, knwoledge_graph_inst=self.chunk_entity_relation_graph, entity_vdb=self.entities_vdb, diff --git a/requirements-dev.txt b/requirements-dev.txt index 5806a58..28e1f07 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,4 +2,4 @@ flake8 pytest future pytest-asyncio -pytest-cov \ No newline at end of file +pytest-cov diff --git a/requirements.txt b/requirements.txt index 8a74d5e..11d738b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,5 @@ nano-vectordb hnswlib xxhash tenacity +dspy-ai +sentence-transformers diff --git a/tests/entity_extraction/__init__.py b/tests/entity_extraction/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/entity_extraction/test_extract.py b/tests/entity_extraction/test_extract.py new file mode 100644 index 0000000..10d3a6e --- /dev/null +++ b/tests/entity_extraction/test_extract.py @@ -0,0 +1,104 @@ +import pytest +import dspy +from unittest.mock import Mock, patch, AsyncMock +from nano_graphrag.entity_extraction.module import ( + Entities, + Relationships, +) +from nano_graphrag.entity_extraction.extract import generate_dataset, extract_entities_dspy +from nano_graphrag.base import TextChunkSchema, BaseGraphStorage, BaseVectorStorage + + +@pytest.fixture +def mock_chunks(): + return { + "chunk1": TextChunkSchema(content="Apple announced a new iPhone model."), + "chunk2": TextChunkSchema(content="Google released an update for Android.") + } + + +@pytest.fixture +def mock_entity_extractor(): + with patch('nano_graphrag.entity_extraction.extract.EntityRelationshipExtractor') as mock: + mock_instance = Mock() + mock.return_value = mock_instance + yield mock_instance + + +@pytest.fixture +def mock_graph_storage(): + return Mock(spec=BaseGraphStorage) + + +@pytest.fixture +def mock_vector_storage(): + return Mock(spec=BaseVectorStorage) + + +@pytest.fixture +def mock_global_config(): + return { + "use_compiled_dspy_entity_relationship": False, + "entity_relationship_module_path": "path/to/module" + } + + +@pytest.mark.asyncio +async def test_generate_dataset(mock_chunks, mock_entity_extractor, tmp_path): + mock_prediction = Mock( + entities=Mock(context=[{"entity_name": "APPLE", "entity_type": "ORGANIZATION"}]), + relationships=Mock(context=[{"src_id": "APPLE", "tgt_id": "IPHONE"}]) + ) + mock_entity_extractor.return_value = mock_prediction + + filepath = tmp_path / "test_dataset.pkl" + + with patch('nano_graphrag.entity_extraction.extract.pickle.dump') as mock_dump: + result = await generate_dataset(mock_chunks, str(filepath)) + + assert len(result) == 2 + assert isinstance(result[0], dspy.Example) + assert hasattr(result[0], 'input_text') + assert hasattr(result[0], 'entities') + assert hasattr(result[0], 'relationships') + assert result[0].input_text == "Apple announced a new iPhone model." + assert result[0].entities.context == [{"entity_name": "APPLE", "entity_type": "ORGANIZATION"}] + assert result[0].relationships.context == [{"src_id": "APPLE", "tgt_id": "IPHONE"}] + + +@pytest.mark.asyncio +async def test_extract_entities_dspy(mock_chunks, mock_graph_storage, mock_vector_storage, mock_global_config): + mock_entity = { + "entity_name": "APPLE", + "entity_type": "ORGANIZATION", + "description": "A tech company", + "importance_score": 0.9 + } + mock_relationship = { + "src_id": "APPLE", + "tgt_id": "IPHONE", + "description": "Produces", + "weight": 0.8, + "order": 1 + } + mock_prediction = Mock( + entities=Entities(context=[mock_entity]), + relationships=Relationships(context=[mock_relationship]) + ) + + with patch('nano_graphrag.entity_extraction.extract.EntityRelationshipExtractor') as mock_extractor_class: + mock_extractor_instance = Mock() + mock_extractor_instance.return_value = mock_prediction + mock_extractor_class.return_value = mock_extractor_instance + + with patch('nano_graphrag.entity_extraction.extract._merge_nodes_then_upsert', new_callable=AsyncMock) as mock_merge_nodes, \ + patch('nano_graphrag.entity_extraction.extract._merge_edges_then_upsert', new_callable=AsyncMock) as mock_merge_edges: + mock_merge_nodes.return_value = mock_entity + result = await extract_entities_dspy(mock_chunks, mock_graph_storage, mock_vector_storage, mock_global_config) + + assert result == mock_graph_storage + mock_extractor_class.assert_called_once() + mock_extractor_instance.assert_called() + mock_merge_nodes.assert_called() + mock_merge_edges.assert_called() + mock_vector_storage.upsert.assert_called_once() diff --git a/tests/entity_extraction/test_metric.py b/tests/entity_extraction/test_metric.py new file mode 100644 index 0000000..8c1fea4 --- /dev/null +++ b/tests/entity_extraction/test_metric.py @@ -0,0 +1,150 @@ +import pytest +import numpy as np +import dspy +from unittest.mock import Mock, patch +from nano_graphrag.entity_extraction.metric import ( + relationship_similarity_metric, + entity_recall_metric, +) + + +@pytest.fixture +def relationship(): + class Relationship: + def __init__(self, src_id, tgt_id, description=None): + self.src_id = src_id + self.tgt_id = tgt_id + self.description = description + return Relationship + + +@pytest.fixture +def entity(): + class Entity: + def __init__(self, entity_name): + self.entity_name = entity_name + return Entity + + +@pytest.fixture +def example(): + class Example: + def __init__(self, items): + self.relationships = type('obj', (object,), {'context': items}) + self.entities = type('obj', (object,), {'context': items}) + return Example + + +@pytest.fixture +def prediction(): + class Prediction: + def __init__(self, items): + self.relationships = type('obj', (object,), {'context': items}) + self.entities = type('obj', (object,), {'context': items}) + return Prediction + + +@pytest.fixture +def sample_texts(): + return ["Hello", "World", "Test"] + + +@pytest.fixture +def mock_dspy_predict(): + with patch('nano_graphrag.entity_extraction.metric.dspy.Predict') as mock_predict: + mock_instance = Mock() + mock_instance.return_value = dspy.Prediction(similarity_score="0.75") + mock_predict.return_value = mock_instance + yield mock_predict + + +@pytest.mark.asyncio +async def test_relationship_similarity_metric(relationship, example, prediction, mock_dspy_predict): + gold = example([ + relationship("1", "2", "is related to"), + relationship("2", "3", "is connected with"), + ]) + pred = prediction([ + relationship("1", "2", "is connected to"), + relationship("2", "3", "is linked with"), + ]) + + similarity = relationship_similarity_metric(gold, pred) + assert np.isclose(similarity, 0.75, atol=1e-6) + + +@pytest.mark.asyncio +async def test_entity_recall_metric(entity, example, prediction): + gold = example([ + entity("Entity1"), + entity("Entity2"), + entity("Entity3"), + ]) + pred = prediction([ + entity("Entity1"), + entity("Entity3"), + entity("Entity4"), + ]) + + recall = entity_recall_metric(gold, pred) + assert recall == 2/3 + + +@pytest.mark.asyncio +async def test_relationship_similarity_metric_no_common_keys(relationship, example, prediction, mock_dspy_predict): + gold = example([relationship("1", "2", "is related to")]) + pred = prediction([relationship("3", "4", "is connected with")]) + + similarity = relationship_similarity_metric(gold, pred) + assert similarity == 0.75 # The mocked value + + +@pytest.mark.asyncio +async def test_entity_recall_metric_no_true_positives(entity, example, prediction): + gold = example([entity("Entity1"), entity("Entity2")]) + pred = prediction([entity("Entity3"), entity("Entity4")]) + + recall = entity_recall_metric(gold, pred) + assert recall == 0 + + +@pytest.mark.asyncio +async def test_relationship_similarity_metric_identical_descriptions(relationship, example, prediction, mock_dspy_predict): + gold = example([relationship("1", "2", "is related to")]) + pred = prediction([relationship("1", "2", "is related to")]) + + similarity = relationship_similarity_metric(gold, pred) + assert np.isclose(similarity, 0.75, atol=1e-6) + + +@pytest.mark.asyncio +async def test_entity_recall_metric_perfect_recall(entity, example, prediction): + entities = [entity("Entity1"), entity("Entity2")] + gold = example(entities) + pred = prediction(entities) + + recall = entity_recall_metric(gold, pred) + assert recall == 1.0 + + +@pytest.mark.asyncio +async def test_relationship_similarity_metric_no_relationships(example, prediction, mock_dspy_predict): + gold = example([]) + pred = prediction([]) + + similarity = relationship_similarity_metric(gold, pred) + assert similarity == 0.0 + + +@pytest.mark.asyncio +async def test_relationship_similarity_metric_invalid_score(relationship, example, prediction): + with patch('nano_graphrag.entity_extraction.metric.dspy.Predict') as mock_predict: + mock_instance = Mock() + mock_instance.return_value = dspy.Prediction(similarity_score="invalid") + mock_predict.return_value = mock_instance + + gold = example([relationship("1", "2", "is related to")]) + pred = prediction([relationship("1", "2", "is connected to")]) + + similarity = relationship_similarity_metric(gold, pred) + assert similarity == 0.0 diff --git a/tests/entity_extraction/test_module.py b/tests/entity_extraction/test_module.py new file mode 100644 index 0000000..1d9ace3 --- /dev/null +++ b/tests/entity_extraction/test_module.py @@ -0,0 +1,90 @@ +from unittest.mock import Mock, patch +from nano_graphrag.entity_extraction.module import ( + EntityRelationshipExtractor, + Entities, + Relationships, + Entity, + Relationship +) + + +def test_entity_relationship_extractor(): + with patch('nano_graphrag.entity_extraction.module.dspy.TypedPredictor') as mock_typed_predictor: + input_text = "Apple announced a new iPhone model." + mock_extractor = Mock() + mock_self_reflection = Mock() + mock_typed_predictor.side_effect = [mock_extractor, mock_self_reflection] + + mock_entities = [ + Entity(entity_name="APPLE", entity_type="ORGANIZATION", description="A technology company", importance_score=1), + Entity(entity_name="IPHONE", entity_type="PRODUCT", description="A smartphone", importance_score=1) + ] + mock_relationships = [ + Relationship(src_id="APPLE", tgt_id="IPHONE", description="Apple manufactures iPhone", weight=1, order=1) + ] + mock_missing_entities = [ + Entity(entity_name="TIM_COOK", entity_type="PERSON", description="CEO of Apple", importance_score=0.8) + ] + mock_missing_relationships = [ + Relationship(src_id="TIM_COOK", tgt_id="APPLE", description="Tim Cook is the CEO of Apple", weight=0.9, order=1), + Relationship(src_id="APPLE", tgt_id="IPHONE", description="Apple announces new iPhone model", weight=1, order=1) + ] + + mock_extractor.return_value = Mock( + entities=Entities(context=mock_entities), + relationships=Relationships(context=mock_relationships) + ) + mock_self_reflection.return_value = Mock( + missing_entities=Entities(context=mock_missing_entities), + missing_relationships=Relationships(context=mock_missing_relationships) + ) + + extractor = EntityRelationshipExtractor() + result = extractor.forward(input_text=input_text) + + mock_extractor.assert_called_once_with( + input_text=input_text, + entity_types=extractor.entity_types + ) + + mock_self_reflection.assert_called_once_with( + input_text=input_text, + entity_types=extractor.entity_types, + entities=mock_extractor.return_value.entities, + relationships=mock_extractor.return_value.relationships + ) + + assert len(result.entities.context) == 3 + assert len(result.relationships.context) == 3 + + assert result.entities.context[0].entity_name == "APPLE" + assert result.entities.context[0].entity_type == "ORGANIZATION" + assert result.entities.context[0].description == "A technology company" + + assert result.entities.context[1].entity_name == "IPHONE" + assert result.entities.context[1].entity_type == "PRODUCT" + assert result.entities.context[1].description == "A smartphone" + assert result.entities.context[1].importance_score == 1 + + assert result.entities.context[2].entity_name == "TIM_COOK" + assert result.entities.context[2].entity_type == "PERSON" + assert result.entities.context[2].description == "CEO of Apple" + assert result.entities.context[2].importance_score == 0.8 + + assert result.relationships.context[0].src_id == "APPLE" + assert result.relationships.context[0].tgt_id == "IPHONE" + assert result.relationships.context[0].description == "Apple manufactures iPhone" + assert result.relationships.context[0].weight == 1 + assert result.relationships.context[0].order == 1 + + assert result.relationships.context[1].src_id == "TIM_COOK" + assert result.relationships.context[1].tgt_id == "APPLE" + assert result.relationships.context[1].description == "Tim Cook is the CEO of Apple" + assert result.relationships.context[1].weight == 0.9 + assert result.relationships.context[1].order == 1 + + assert result.relationships.context[2].src_id == "APPLE" + assert result.relationships.context[2].tgt_id == "IPHONE" + assert result.relationships.context[2].description == "Apple announces new iPhone model" + assert result.relationships.context[2].weight == 1 + assert result.relationships.context[2].order == 1