diff --git a/doc/source/visualization.rst b/doc/source/visualization.rst index eb15d7f77252d..c3efa97d47b75 100644 --- a/doc/source/visualization.rst +++ b/doc/source/visualization.rst @@ -341,3 +341,29 @@ confidence band. @savefig autocorrelation_plot.png width=6in autocorrelation_plot(data) + +RadViz +~~~~~~ + +RadViz is a way of visualizing multi-variate data. It is based on a simple +spring tension minimization algorithm. Basically you set up a bunch of points in +a plane. In our case they are equally spaced on a unit circle. Each point +represents a single attribute. You then pretend that each sample in the data set +is attached to each of these points by a spring, the stiffness of which is +proportional to the numerical value of that attribute (they are normalized to +unit interval). The point in the plane, where our sample settles to (where the +forces acting on our sample are at an equilibrium) is where a dot representing +our sample will be drawn. Depending on which class that sample belongs it will +be colored differently. + +.. ipython:: python + + from pandas import read_csv + from pandas.tools.plotting import radviz + + data = read_csv('data/iris.data') + + plt.figure() + + @savefig radviz.png width=6in + radviz(data, 'Name') diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index 34ef3c9ae1cc1..0585a6d2115f2 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -314,6 +314,14 @@ def test_andrews_curves(self): df = read_csv(path) _check_plot_works(andrews_curves, df, 'Name') + @slow + def test_radviz(self): + from pandas import read_csv + from pandas.tools.plotting import radviz + path = os.path.join(curpath(), 'data/iris.csv') + df = read_csv(path) + _check_plot_works(radviz, df, 'Name') + @slow def test_plot_int_columns(self): df = DataFrame(np.random.randn(100, 4)).cumsum() diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 9ab4dccbbe61a..cd98c24d4fec7 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -147,6 +147,65 @@ def _get_marker_compat(marker): return 'o' return marker +def radviz(frame, class_column, ax=None, **kwds): + """RadViz - a multivariate data visualization algorithm + + Parameters: + ----------- + frame: DataFrame object + class_column: Column name that contains information about class membership + ax: Matplotlib axis object, optional + kwds: Matplotlib scatter method keyword arguments, optional + + Returns: + -------- + ax: Matplotlib axis object + """ + import matplotlib.pyplot as plt + import matplotlib.patches as patches + import matplotlib.text as text + import random + def random_color(column): + random.seed(column) + return [random.random() for _ in range(3)] + def normalize(series): + a = min(series) + b = max(series) + return (series - a) / (b - a) + column_names = [column_name for column_name in frame.columns if column_name != class_column] + columns = [normalize(frame[column_name]) for column_name in column_names] + if ax == None: + ax = plt.gca(xlim=[-1, 1], ylim=[-1, 1]) + classes = set(frame[class_column]) + to_plot = {} + for class_ in classes: + to_plot[class_] = [[], []] + n = len(frame.columns) - 1 + s = np.array([(np.cos(t), np.sin(t)) for t in [2.0 * np.pi * (i / float(n)) for i in range(n)]]) + for i in range(len(frame)): + row = np.array([column[i] for column in columns]) + row_ = np.repeat(np.expand_dims(row, axis=1), 2, axis=1) + y = (s * row_).sum(axis=0) / row.sum() + class_name = frame[class_column][i] + to_plot[class_name][0].append(y[0]) + to_plot[class_name][1].append(y[1]) + for class_ in classes: + ax.scatter(to_plot[class_][0], to_plot[class_][1], color=random_color(class_), label=str(class_), **kwds) + ax.add_patch(patches.Circle((0.0, 0.0), radius=1.0, facecolor='none')) + for xy, name in zip(s, column_names): + ax.add_patch(patches.Circle(xy, radius=0.025, facecolor='gray')) + if xy[0] < 0.0 and xy[1] < 0.0: + ax.text(xy[0] - 0.025, xy[1] - 0.025, name, ha='right', va='top', size='small') + elif xy[0] < 0.0 and xy[1] >= 0.0: + ax.text(xy[0] - 0.025, xy[1] + 0.025, name, ha='right', va='bottom', size='small') + elif xy[0] >= 0.0 and xy[1] < 0.0: + ax.text(xy[0] + 0.025, xy[1] - 0.025, name, ha='left', va='top', size='small') + elif xy[0] >= 0.0 and xy[1] >= 0.0: + ax.text(xy[0] + 0.025, xy[1] + 0.025, name, ha='left', va='bottom', size='small') + ax.legend(loc='upper right') + ax.axis('equal') + return ax + def andrews_curves(data, class_column, ax=None, samples=200): """ Parameters: