diff --git a/radiospectra/spectrogram2/spectrogram.py b/radiospectra/spectrogram2/spectrogram.py index 7579048..83881be 100644 --- a/radiospectra/spectrogram2/spectrogram.py +++ b/radiospectra/spectrogram2/spectrogram.py @@ -150,6 +150,7 @@ def plot(self, axes=None, **kwargs): Returns ------- + im : `matplotlib.cm.ScalarMappable` """ if axes is None: fig, axes = plt.subplots() @@ -167,7 +168,7 @@ def plot(self, axes=None, **kwargs): axes.set_title(title) axes.plot(self.times.datetime[[0, -1]], self.frequencies[[0, -1]], linestyle="None", marker="None") - axes.pcolormesh(self.times.datetime, self.frequencies.value, data[:-1, :-1], shading="auto", **kwargs) + im = axes.pcolormesh(self.times.datetime, self.frequencies.value, data[:-1, :-1], shading="auto", **kwargs) axes.set_xlim(self.times.datetime[0], self.times.datetime[-1]) locator = mdates.AutoDateLocator(minticks=4, maxticks=8) formatter = mdates.ConciseDateFormatter(locator) @@ -175,6 +176,13 @@ def plot(self, axes=None, **kwargs): axes.xaxis.set_major_formatter(formatter) fig.autofmt_xdate() + for i in plt.get_fignums(): + if axes in plt.figure(i).axes: + plt.sca(axes) + plt.sci(im) + + return im + class NonUniformImagePlotMixin: """