diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 8168e1367f962..7796a6d8763af 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -3,6 +3,7 @@ from itertools import izip import numpy as np +from scipy import stats from pandas.util.decorators import cache_readonly import pandas.core.common as com @@ -11,8 +12,8 @@ from pandas.tseries.period import PeriodIndex from pandas.tseries.offsets import DateOffset -def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False, - **kwds): + +def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False, **kwds): """ Draw a matrix of scatter plots. @@ -36,6 +37,51 @@ def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False, for i, a in zip(range(n), df.columns): for j, b in zip(range(n), df.columns): + if i == j: + # Deal with the diagonal by drawing a histogram there. + if diagonal == 'hist': + axes[i, j].hist(df[a]) + elif diagonal == 'kde': + y = df[a] + gkde = stats.gaussian_kde(y) + ind = np.linspace(min(y), max(y), 1000) + axes[i, j].plot(ind, gkde.evaluate(ind), **kwds) + axes[i, j].yaxis.set_visible(False) + axes[i, j].xaxis.set_visible(False) + if i == 0 and j == 0: + axes[i, j].yaxis.set_ticks_position('left') + axes[i, j].yaxis.set_label_position('left') + axes[i, j].yaxis.set_visible(True) + if i == n - 1 and j == n - 1: + axes[i, j].yaxis.set_ticks_position('right') + axes[i, j].yaxis.set_label_position('right') + axes[i, j].yaxis.set_visible(True) + else: + axes[i, j].scatter(df[b], df[a], alpha=alpha, **kwds) + axes[i, j].yaxis.set_visible(False) + axes[i, j].xaxis.set_visible(False) + + # setup labels + if i == 0 and j % 2 == 1: + axes[i, j].set_xlabel(b, visible=True) + axes[i, j].xaxis.set_visible(True) + axes[i, j].xaxis.set_ticks_position('top') + axes[i, j].xaxis.set_label_position('top') + if i == n - 1 and j % 2 == 0: + axes[i, j].set_xlabel(b, visible=True) + axes[i, j].xaxis.set_visible(True) + axes[i, j].xaxis.set_ticks_position('bottom') + axes[i, j].xaxis.set_label_position('bottom') + if j == 0 and i % 2 == 0: + axes[i, j].set_ylabel(a, visible=True) + axes[i, j].yaxis.set_visible(True) + axes[i, j].yaxis.set_ticks_position('left') + axes[i, j].yaxis.set_label_position('left') + if j == n - 1 and i % 2 == 1: + axes[i, j].set_ylabel(a, visible=True) + axes[i, j].yaxis.set_visible(True) + axes[i, j].yaxis.set_ticks_position('right') + axes[i, j].yaxis.set_label_position('right') axes[i, j].scatter(df[b], df[a], alpha=alpha, **kwds) axes[i, j].set_xlabel('') axes[i, j].set_ylabel('') @@ -84,15 +130,14 @@ def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False, axes[i, j].set_yticklabels(ticks) axes[i, j].yaxis.set_ticks_position('right') axes[i, j].yaxis.set_label_position('right') - axes[i, j].grid(b=grid) # ensure {x,y}lim off diagonal are the same as diagonal - for i in range(n): - for j in range(n): - if i != j: - axes[i, j].set_xlim(axes[j, j].get_xlim()) - axes[i, j].set_ylim(axes[i, i].get_ylim()) + #for i in range(n): + # for j in range(n): + # if i != j: + # axes[i, j].set_xlim(axes[j, j].get_xlim()) + # axes[i, j].set_ylim(axes[i, i].get_ylim()) return axes @@ -326,6 +371,38 @@ def _get_xticks(self): return x +class KdePlot(MPLPlot): + def __init__(self, data, **kwargs): + MPLPlot.__init__(self, data, **kwargs) + + def _get_plot_function(self): + return self.plt.Axes.plot + + def _make_plot(self): + plotf = self._get_plot_function() + for i, (label, y) in enumerate(self._iter_data()): + if self.subplots: + ax = self.axes[i] + style = 'k' + else: + style = '' # empty string ignored + ax = self.ax + if self.style: + style = self.style + gkde = stats.gaussian_kde(y) + sample_range = max(y) - min(y) + ind = np.linspace(min(y) - 0.5 * sample_range, + max(y) + 0.5 * sample_range, 1000) + ax.set_ylabel("Density") + plotf(ax, ind, gkde.evaluate(ind), style, label=label, **self.kwds) + ax.grid(self.grid) + + def _post_plot_logic(self): + df = self.data + + if self.subplots and self.legend: + self.axes[0].legend(loc='best') + class LinePlot(MPLPlot): def __init__(self, data, **kwargs): @@ -608,6 +685,8 @@ def plot_frame(frame=None, subplots=False, sharex=True, sharey=False, klass = LinePlot elif kind in ('bar', 'barh'): klass = BarPlot + elif kind == 'kde': + klass = KdePlot else: raise ValueError('Invalid chart type given %s' % kind) @@ -670,6 +749,8 @@ def plot_series(series, label=None, kind='line', use_index=True, rot=None, klass = LinePlot elif kind in ('bar', 'barh'): klass = BarPlot + elif kind == 'kde': + klass = KdePlot if ax is None: ax = _gca()