Skip to content

Commit

Permalink
fix: spark edge writer #42
Browse files Browse the repository at this point in the history
fix: spark edge writer
  • Loading branch information
wey-gu authored Sep 6, 2023
2 parents ad865fe + 8457c94 commit ed4d001
Show file tree
Hide file tree
Showing 14 changed files with 836 additions and 530 deletions.
142 changes: 116 additions & 26 deletions examples/ng_ai_networkx_plot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"config = NebulaGraphConfig(**config_dict)\n",
"reader = NebulaReader(engine=\"nebula\", config=config)\n",
"reader.query(edges=[\"follow\", \"serve\"], props=[[\"degree\"], []])\n",
"g = reader.read()\n"
"g = reader.read()"
]
},
{
Expand Down Expand Up @@ -151,7 +151,21 @@
"import matplotlib.pyplot as plt\n",
"import random\n",
"\n",
"def draw_graph(G, colors=[\"#1984c5\", \"#22a7f0\", \"#63bff0\", \"#a7d5ed\", \"#e2e2e2\", \"#e1a692\", \"#de6e56\", \"#e14b31\", \"#c23728\"]):\n",
"\n",
"def draw_graph(\n",
" G,\n",
" colors=[\n",
" \"#1984c5\",\n",
" \"#22a7f0\",\n",
" \"#63bff0\",\n",
" \"#a7d5ed\",\n",
" \"#e2e2e2\",\n",
" \"#e1a692\",\n",
" \"#de6e56\",\n",
" \"#e14b31\",\n",
" \"#c23728\",\n",
" ],\n",
"):\n",
" # Define positions for the nodes\n",
" pos = nx.spring_layout(G)\n",
"\n",
Expand All @@ -163,35 +177,62 @@
" # Draw the nodes and edges of the graph\n",
" node_colors = {}\n",
" for node in G.nodes():\n",
" if 'label' in G.nodes[node]:\n",
" label = G.nodes[node]['label']\n",
" if \"label\" in G.nodes[node]:\n",
" label = G.nodes[node][\"label\"]\n",
" else:\n",
" label = ''\n",
" label = \"\"\n",
" if label in node_colors:\n",
" node_color = node_colors[label]\n",
" else:\n",
" node_color = random.choice(colors)\n",
" node_colors[label] = node_color\n",
" nx.draw_networkx_nodes(G, pos=pos, ax=ax, nodelist=[node], node_color=node_color, node_size=4000)\n",
" \n",
" nx.draw_networkx_edges(G, pos=pos, ax=ax, edge_color='gray', width=2, connectionstyle='arc3, rad=0.1', arrowstyle='-|>', arrows=True)\n",
" nx.draw_networkx_nodes(\n",
" G,\n",
" pos=pos,\n",
" ax=ax,\n",
" nodelist=[node],\n",
" node_color=node_color,\n",
" node_size=4000,\n",
" )\n",
"\n",
" nx.draw_networkx_edges(\n",
" G,\n",
" pos=pos,\n",
" ax=ax,\n",
" edge_color=\"gray\",\n",
" width=2,\n",
" connectionstyle=\"arc3, rad=0.1\",\n",
" arrowstyle=\"-|>\",\n",
" arrows=True,\n",
" )\n",
"\n",
" # Extract edge labels as a dictionary\n",
" edge_labels = nx.get_edge_attributes(G, 'label')\n",
" edge_labels = nx.get_edge_attributes(G, \"label\")\n",
"\n",
" # Add edge labels to the graph\n",
" for edge, label in edge_labels.items():\n",
" ax.text((pos[edge[0]][0] + pos[edge[1]][0])/2,\n",
" (pos[edge[0]][1] + pos[edge[1]][1])/2,\n",
" label, fontsize=12, color='black', ha='center', va='center')\n",
" ax.text(\n",
" (pos[edge[0]][0] + pos[edge[1]][0]) / 2,\n",
" (pos[edge[0]][1] + pos[edge[1]][1]) / 2,\n",
" label,\n",
" fontsize=12,\n",
" color=\"black\",\n",
" ha=\"center\",\n",
" va=\"center\",\n",
" )\n",
"\n",
" # Add node labels to the graph\n",
" node_labels = {n: G.nodes[n]['label'] if 'label' in G.nodes[n] else n for n in G.nodes()}\n",
" nx.draw_networkx_labels(G, pos=pos, ax=ax, labels=node_labels, font_size=12, font_color='black')\n",
" node_labels = {\n",
" n: G.nodes[n][\"label\"] if \"label\" in G.nodes[n] else n for n in G.nodes()\n",
" }\n",
" nx.draw_networkx_labels(\n",
" G, pos=pos, ax=ax, labels=node_labels, font_size=12, font_color=\"black\"\n",
" )\n",
"\n",
" # Show the figure\n",
" plt.show()\n",
"\n",
"\n",
"draw_graph(G)"
]
},
Expand Down Expand Up @@ -243,7 +284,22 @@
"from matplotlib.colors import ListedColormap\n",
"\n",
"\n",
"def draw_graph_louvain_pr(G, pr_result, louvain_result, colors=[\"#1984c5\", \"#22a7f0\", \"#63bff0\", \"#a7d5ed\", \"#e2e2e2\", \"#e1a692\", \"#de6e56\", \"#e14b31\", \"#c23728\"]):\n",
"def draw_graph_louvain_pr(\n",
" G,\n",
" pr_result,\n",
" louvain_result,\n",
" colors=[\n",
" \"#1984c5\",\n",
" \"#22a7f0\",\n",
" \"#63bff0\",\n",
" \"#a7d5ed\",\n",
" \"#e2e2e2\",\n",
" \"#e1a692\",\n",
" \"#de6e56\",\n",
" \"#e14b31\",\n",
" \"#c23728\",\n",
" ],\n",
"):\n",
" # Define positions for the nodes\n",
" pos = nx.spring_layout(G)\n",
"\n",
Expand All @@ -258,28 +314,62 @@
" # Draw the nodes and edges of the graph\n",
" node_colors = [louvain_result[node] for node in G.nodes()]\n",
" node_sizes = [70000 * pr_result[node] for node in G.nodes()]\n",
" nx.draw_networkx_nodes(G, pos=pos, ax=ax, node_color=node_colors, node_size=node_sizes, cmap=cmap, vmin=0, vmax=max(louvain_result.values()))\n",
" nx.draw_networkx_nodes(\n",
" G,\n",
" pos=pos,\n",
" ax=ax,\n",
" node_color=node_colors,\n",
" node_size=node_sizes,\n",
" cmap=cmap,\n",
" vmin=0,\n",
" vmax=max(louvain_result.values()),\n",
" )\n",
"\n",
" nx.draw_networkx_edges(G, pos=pos, ax=ax, edge_color='gray', width=1, connectionstyle='arc3, rad=0.2', arrowstyle='-|>', arrows=True)\n",
" nx.draw_networkx_edges(\n",
" G,\n",
" pos=pos,\n",
" ax=ax,\n",
" edge_color=\"gray\",\n",
" width=1,\n",
" connectionstyle=\"arc3, rad=0.2\",\n",
" arrowstyle=\"-|>\",\n",
" arrows=True,\n",
" )\n",
"\n",
" # Extract edge labels as a dictionary\n",
" edge_labels = nx.get_edge_attributes(G, 'label')\n",
" edge_labels = nx.get_edge_attributes(G, \"label\")\n",
"\n",
" # Add edge labels to the graph\n",
" for edge, label in edge_labels.items():\n",
" ax.text((pos[edge[0]][0] + pos[edge[1]][0])/2,\n",
" (pos[edge[0]][1] + pos[edge[1]][1])/2,\n",
" label, fontsize=12, color='black', ha='center', va='center')\n",
" ax.text(\n",
" (pos[edge[0]][0] + pos[edge[1]][0]) / 2,\n",
" (pos[edge[0]][1] + pos[edge[1]][1]) / 2,\n",
" label,\n",
" fontsize=12,\n",
" color=\"black\",\n",
" ha=\"center\",\n",
" va=\"center\",\n",
" )\n",
"\n",
" # Add node labels to the graph\n",
" node_labels = {n: G.nodes[n]['label'] if 'label' in G.nodes[n] else n for n in G.nodes()}\n",
" nx.draw_networkx_labels(G, pos=pos, ax=ax, labels=node_labels, font_size=12, font_color='black')\n",
" node_labels = {\n",
" n: G.nodes[n][\"label\"] if \"label\" in G.nodes[n] else n for n in G.nodes()\n",
" }\n",
" nx.draw_networkx_labels(\n",
" G, pos=pos, ax=ax, labels=node_labels, font_size=12, font_color=\"black\"\n",
" )\n",
"\n",
" # Add colorbar for community colors\n",
" sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=max(louvain_result.values())))\n",
" sm = plt.cm.ScalarMappable(\n",
" cmap=cmap, norm=plt.Normalize(vmin=0, vmax=max(louvain_result.values()))\n",
" )\n",
" sm.set_array([])\n",
" cbar = plt.colorbar(sm, ax=ax, ticks=range(max(louvain_result.values()) + 1), shrink=0.5)\n",
" cbar.ax.set_yticklabels([f'Community {i}' for i in range(max(louvain_result.values()) + 1)])\n",
" cbar = plt.colorbar(\n",
" sm, ax=ax, ticks=range(max(louvain_result.values()) + 1), shrink=0.5\n",
" )\n",
" cbar.ax.set_yticklabels(\n",
" [f\"Community {i}\" for i in range(max(louvain_result.values()) + 1)]\n",
" )\n",
"\n",
" # Show the figure\n",
" plt.show()\n",
Expand Down
103 changes: 102 additions & 1 deletion examples/spark_engine.ipynb
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "a54fe998",
"metadata": {},
Expand Down Expand Up @@ -90,6 +91,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "3617de5f",
"metadata": {},
Expand Down Expand Up @@ -176,6 +178,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "66e70ca0",
"metadata": {},
Expand Down Expand Up @@ -212,6 +215,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "3eb228f8",
"metadata": {},
Expand Down Expand Up @@ -262,6 +266,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "49becbdb",
"metadata": {},
Expand Down Expand Up @@ -339,6 +344,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "38181d45",
"metadata": {},
Expand Down Expand Up @@ -375,11 +381,12 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "3d088006",
"metadata": {},
"source": [
"### Write back algo result to NebulaGraph\n",
"### Write back algo result to NebulaGraph as TAG\n",
"\n",
"Assume that we have a Spark DataFrame `df_result` computed with `df.algo.label_propagation()` with the following schema:\n",
"\n",
Expand Down Expand Up @@ -541,6 +548,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "9da30271",
"metadata": {},
Expand Down Expand Up @@ -576,6 +584,99 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "670bd34b",
"metadata": {},
"source": [
"### Result being written as edge\n",
"\n",
"Similar to TAG, we first need to ensure to create schema first\n",
"\n",
"```\n",
"CREATE EDGE jaccard_similarity(similarity double);\n",
"```\n",
"\n",
"Then we run a algorithm writting results to edge:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "493cc969",
"metadata": {},
"outputs": [],
"source": [
"# Run Jaccard Algorithm\n",
"df_result = df.algo.jaccard()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "b9e76ef2",
"metadata": {},
"source": [
"Then let's write the result to NebulaGraph, map the column `similarity` to `similarity`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0e91af7f",
"metadata": {},
"outputs": [],
"source": [
"writer = NebulaWriter(\n",
" data=df_result, sink=\"nebulagraph_edge\", config=config, engine=\"spark\"\n",
")\n",
"\n",
"# map column louvain into property cluster_id\n",
"properties = {\"similarity\": \"similarity\"}\n",
"\n",
"writer.set_options(\n",
" space=\"basketballplayer\",\n",
" edge_type=\"jaccard_similarity\",\n",
" src_id=\"srcId\",\n",
" dst_id=\"dstId\",\n",
" src_id_policy=\"\",\n",
" dst_id_policy=\"\",\n",
" properties=properties,\n",
" batch_size=256,\n",
" write_mode=\"insert\",\n",
")\n",
"\n",
"# write back to NebulaGraph\n",
"writer.write()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "8165fec9",
"metadata": {},
"source": [
"Check result:\n",
"\n",
"```\n",
"(root@nebula) [basketballplayer]> MATCH ()-[e:jaccard_similarity]->() RETURN e LIMIT 3\n",
"+-------------------------------------------------------------------------------------+\n",
"| e |\n",
"+-------------------------------------------------------------------------------------+\n",
"| [:jaccard_similarity \"player102\"->\"player100\" @0 {similarity: 0.07692307692307687}] |\n",
"| [:jaccard_similarity \"player102\"->\"player101\" @0 {similarity: 0.11111111111111116}] |\n",
"| [:jaccard_similarity \"player102\"->\"player104\" @0 {similarity: 0.33333333333333326}] |\n",
"+-------------------------------------------------------------------------------------+\n",
"Got 3 rows (time spent 39.984ms/44.574542ms)\n",
"\n",
"Wed, 06 Sep 2023 13:04:38 CST\n",
"\n",
"(root@nebula) [basketballplayer]>\n",
"```"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "5bcb02e2",
"metadata": {},
Expand Down
Loading

0 comments on commit ed4d001

Please sign in to comment.