Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid top-level import of pyplot for benefit of upstream packages #9

Merged
merged 2 commits into from
May 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions Ska/Matplotlib/core.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""Provide useful utilities for matplotlib."""

import warnings
import datetime
from matplotlib import pyplot
from matplotlib.dates import (YearLocator, MonthLocator, DayLocator,
HourLocator, MinuteLocator, SecondLocator,
DateFormatter, epoch2num)
Expand Down Expand Up @@ -56,7 +53,7 @@ def set_time_ticks(plt, ticklocs=None):
Pick nice values to show time ticks in a date plot.

Example::

x = cxctime2plotdate(np.linspace(0, 3e7, 20))
y = np.random.normal(size=len(x))

Expand Down Expand Up @@ -97,7 +94,7 @@ def remake_ticks(ax):
"""
ticklocs = set_time_ticks(ax)
ax.figure.canvas.draw()

def plot_cxctime(times, y, fmt='-b', fig=None, ax=None, yerr=None, xerr=None, tz=None,
state_codes=None, interactive=True, **kwargs):
"""Make a date plot where the X-axis values are in CXC time. If no ``fig``
Expand All @@ -123,7 +120,7 @@ def plot_cxctime(times, y, fmt='-b', fig=None, ax=None, yerr=None, xerr=None, tz
:param y: y values
:param fmt: plot format (default = '-b')
:param fig: pyplot figure object (optional)
:param yerr: error on y values, may be [ scalar | N, Nx1, or 2xN array-like ]
:param yerr: error on y values, may be [ scalar | N, Nx1, or 2xN array-like ]
:param xerr: error on x values in units of DAYS (may be [ scalar | N, Nx1, or 2xN array-like ] )
:param tz: timezone string
:param state_codes: list of (raw_count, state_code) tuples
Expand All @@ -132,6 +129,7 @@ def plot_cxctime(times, y, fmt='-b', fig=None, ax=None, yerr=None, xerr=None, tz

:rtype: ticklocs, fig, ax = tick locations, figure, and axes object.
"""
from matplotlib import pyplot

if fig is None:
fig = pyplot.gcf()
Expand Down Expand Up @@ -164,17 +162,17 @@ def cxctime2plotdate(times):
"""
Convert input CXC time (sec) to the time base required for the matplotlib
plot_date function (days since start of year 1).

:param times: iterable list of times
:rtype: plot_date times
"""

# Find the plotdate of first time and use a relative offset from there
t0 = Chandra.Time.DateTime(times[0]).unix
plotdate0 = epoch2num(t0)

return (np.asarray(times) - times[0]) / 86400. + plotdate0


def pointpair(x, y=None):
"""Interleave and then flatten two arrays ``x`` and ``y``. This is
Expand Down Expand Up @@ -226,7 +224,7 @@ def hist_outline(dataIn, *args, **kwargs):
stepSize = binsIn[1] - binsIn[0]

bins = np.zeros(len(binsIn)*2 + 2, dtype=np.float)
data = np.zeros(len(binsIn)*2 + 2, dtype=np.float)
data = np.zeros(len(binsIn)*2 + 2, dtype=np.float)
for bb in range(len(binsIn)):
bins[2*bb + 1] = binsIn[bb]
bins[2*bb + 2] = binsIn[bb] + stepSize
Expand All @@ -238,7 +236,7 @@ def hist_outline(dataIn, *args, **kwargs):
bins[-1] = bins[-2]
data[0] = 0
data[-1] = 0

return (bins, data)


4 changes: 3 additions & 1 deletion Ska/Matplotlib/lineid_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from distutils.version import LooseVersion
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from six.moves import zip

__version__ = "0.3_ska"
Expand Down Expand Up @@ -223,6 +222,8 @@ def prepare_axes(wave, flux, fig=None, ax_lower=(0.1, 0.1),
ax_dim=(0.85, 0.65)):
"""Create fig and axes if needed and layout axes in fig."""
# Axes location in figure.
from matplotlib import pyplot as plt

if not fig:
fig = plt.figure()
ax = fig.add_axes([ax_lower[0], ax_lower[1], ax_dim[0], ax_dim[1]])
Expand Down Expand Up @@ -441,6 +442,7 @@ def plot_line_ids(wave, flux, line_wave, line_label1, label1_size=None,
return fig, ax

if __name__ == "__main__":
from matplotlib import pyplot as plt
wave = 1240 + np.arange(300) * 0.1
flux = np.random.normal(size=300)
line_wave = [1242.80, 1260.42, 1264.74, 1265.00, 1265.2, 1265.3, 1265.35]
Expand Down