Skip to content

Commit

Permalink
Graph.from_networkx method now extract node and edge attributes (#2714)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr authored and jlstevens committed May 22, 2018
1 parent 234a3e4 commit 8ce2f52
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ install:
- conda update -q conda
# Useful for debugging any issues with conda
- conda info -a
- conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION flake8 scipy=1.0.0 numpy freetype nose pandas=0.22.0 jupyter ipython=5.4.1 param matplotlib=2.1.2 xarray
- conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION flake8 scipy=1.0.0 numpy freetype nose pandas=0.22.0 jupyter ipython=5.4.1 param matplotlib=2.1.2 xarray networkx
- source activate test-environment
- conda install -c conda-forge iris plotly flexx ffmpeg netcdf4=1.3.1 --quiet
- conda install -c bokeh datashader dask bokeh=0.12.15 selenium
Expand Down
20 changes: 16 additions & 4 deletions holoviews/element/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,10 +385,17 @@ def from_networkx(cls, G, layout_function, nodes=None, **kwargs):
"""
Generate a HoloViews Graph from a networkx.Graph object and
networkx layout function. Any keyword arguments will be passed
to the layout function.
to the layout function. By default it will extract all node
and edge attributes from the networkx.Graph but explicit node
information may also be supplied.
"""
positions = layout_function(G, **kwargs)
edges = G.edges()
edges = []
for start, end in G.edges():
attrs = sorted(G.adj[start][end].items())
edges.append((start, end)+tuple(v for k, v in attrs))
edge_vdims = [k for k, v in attrs] if edges else []

if nodes:
idx_dim = nodes.kdims[-1].name
xs, ys = zip(*[v for k, v in sorted(positions.items())])
Expand All @@ -398,8 +405,13 @@ def from_networkx(cls, G, layout_function, nodes=None, **kwargs):
nodes = nodes.add_dimension('x', 0, xs)
nodes = nodes.add_dimension('y', 1, ys).clone(new_type=cls.node_type)
else:
nodes = cls.node_type([tuple(pos)+(idx,) for idx, pos in sorted(positions.items())])
return cls((edges, nodes))
nodes = []
for idx, pos in sorted(positions.items()):
attrs = sorted(G.nodes[idx].items())
nodes.append(tuple(pos)+(idx,)+tuple(v for k, v in attrs))
vdims = [k for k, v in attrs] if nodes else []
nodes = cls.node_type(nodes, vdims=vdims)
return cls((edges, nodes), vdims=edge_vdims)



Expand Down
42 changes: 40 additions & 2 deletions tests/element/testgraphelement.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Unit tests of Graph Element.
"""
from unittest import SkipTest
from nose.plugins.attrib import attr

import numpy as np
from holoviews.core.data import Dataset
Expand Down Expand Up @@ -128,8 +129,45 @@ def test_graph_redim_nodes(self):
self.assertEqual(redimmed.nodes, graph.nodes.redim(x='x2', y='y2'))
self.assertEqual(redimmed.edgepaths, graph.edgepaths.redim(x='x2', y='y2'))



@attr(optional=1)
def test_from_networkx_with_node_attrs(self):
try:
import networkx as nx
except:
raise SkipTest('Test requires networkx to be installed')
G = nx.karate_club_graph()
graph = Graph.from_networkx(G, nx.circular_layout)
clubs = np.array([
'Mr. Hi', 'Mr. Hi', 'Mr. Hi', 'Mr. Hi', 'Mr. Hi', 'Mr. Hi',
'Mr. Hi', 'Mr. Hi', 'Mr. Hi', 'Officer', 'Mr. Hi', 'Mr. Hi',
'Mr. Hi', 'Mr. Hi', 'Officer', 'Officer', 'Mr. Hi', 'Mr. Hi',
'Officer', 'Mr. Hi', 'Officer', 'Mr. Hi', 'Officer', 'Officer',
'Officer', 'Officer', 'Officer', 'Officer', 'Officer', 'Officer',
'Officer', 'Officer', 'Officer', 'Officer'])
self.assertEqual(graph.nodes.dimension_values('club'), clubs)

@attr(optional=1)
def test_from_networkx_with_edge_attrs(self):
try:
import networkx as nx
except:
raise SkipTest('Test requires networkx to be installed')
FG = nx.Graph()
FG.add_weighted_edges_from([(1,2,0.125), (1,3,0.75), (2,4,1.2), (3,4,0.375)])
graph = Graph.from_networkx(FG, nx.circular_layout)
self.assertEqual(graph.dimension_values('weight'), np.array([0.125, 0.75, 1.2, 0.375]))

@attr(optional=1)
def test_from_networkx_only_nodes(self):
try:
import networkx as nx
except:
raise SkipTest('Test requires networkx to be installed')
G = nx.Graph()
G.add_nodes_from([1, 2, 3])
graph = Graph.from_networkx(G, nx.circular_layout)
self.assertEqual(graph.nodes.dimension_values(2), np.array([1, 2, 3]))

class ChordTests(ComparisonTestCase):

def setUp(self):
Expand Down

0 comments on commit 8ce2f52

Please sign in to comment.