Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix:JSON serialization problem #96

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nano_graphrag/_op.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
import json
import json5
import asyncio
import tiktoken
from typing import Union
Expand Down Expand Up @@ -680,7 +680,7 @@ async def _find_most_related_community_from_entities(
for node_d in node_datas:
if "clusters" not in node_d:
continue
related_communities.extend(json.loads(node_d["clusters"]))
related_communities.extend(json5.loads(node_d["clusters"]))
related_community_dup_keys = [
str(dp["cluster"])
for dp in related_communities
Expand Down
6 changes: 3 additions & 3 deletions nano_graphrag/_storage/gdb_networkx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import html
import json
import json5
import os
from collections import defaultdict
from dataclasses import dataclass
Expand Down Expand Up @@ -154,7 +154,7 @@ async def community_schema(self) -> dict[str, SingleCommunitySchema]:
for node_id, node_data in self._graph.nodes(data=True):
if "clusters" not in node_data:
continue
clusters = json.loads(node_data["clusters"])
clusters = json5.loads(node_data["clusters"])
this_node_edges = self._graph.edges(node_id)

for cluster in clusters:
Expand Down Expand Up @@ -195,7 +195,7 @@ async def community_schema(self) -> dict[str, SingleCommunitySchema]:

def _cluster_data_to_subgraphs(self, cluster_data: dict[str, list[dict[str, str]]]):
for node_id, clusters in cluster_data.items():
self._graph.nodes[node_id]["clusters"] = json.dumps(clusters)
self._graph.nodes[node_id]["clusters"] = json5.dumps(clusters)

async def _leiden_clustering(self):
from graspologic.partition import hierarchical_leiden
Expand Down
5 changes: 3 additions & 2 deletions nano_graphrag/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import html
import json5
import json
import logging
import os
Expand Down Expand Up @@ -151,14 +152,14 @@ def compute_mdhash_id(content, prefix: str = ""):

def write_json(json_obj, file_name):
with open(file_name, "w", encoding="utf-8") as f:
json.dump(json_obj, f, indent=2, ensure_ascii=False)
json5.dump(json_obj, f, indent=2, ensure_ascii=False)


def load_json(file_name):
if not os.path.exists(file_name):
return None
with open(file_name, encoding="utf-8") as f:
return json.load(f)
return json5.load(f)


# it's dirty to type, so it's a good way to have fun
Expand Down