Skip to content

Commit

Permalink
Back to pyplot
Browse files Browse the repository at this point in the history
  • Loading branch information
adryyan committed Dec 12, 2024
1 parent a50b5e5 commit b1fbd3f
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 105 deletions.
114 changes: 32 additions & 82 deletions src/iminuit/cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,7 @@ class documentation for details.
TypeVar,
Callable,
cast,
TYPE_CHECKING,
)

if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.figure import Figure
import warnings
from ._deprecated import deprecated_parameter

Expand Down Expand Up @@ -744,10 +739,7 @@ def __getitem__(self, key):
return self._items.__getitem__(key)

def visualize(
self,
args: Sequence[float],
component_kwargs: Dict[int, Dict[str, Any]] = None,
fig: Figure = None,
self, args: Sequence[float], component_kwargs: Dict[int, Dict[str, Any]] = None
):
"""
Visualize data and model agreement (requires matplotlib).
Expand All @@ -765,27 +757,16 @@ def visualize(
Dict that maps an index to dict of keyword arguments. This can be
used to pass keyword arguments to a visualize method of a component with
that index.
fig : Figure, optional
The matplotlib figure into which the visualization is drawn.
If None is passed, the current figure is used. If the passed figure
has the same number of axes as there are components with a visualize
method, the axes are reused. Otherwise, new axes are created.
**kwargs :
Other keyword arguments are forwarded to all components.
"""
if fig is None:
from matplotlib import pyplot as plt

fig = plt.gcf()
from matplotlib import pyplot as plt

n = sum(hasattr(comp, "visualize") for comp in self)

fig = plt.gcf()
fig.set_figwidth(n * fig.get_figwidth() / 1.5)
ax = fig.get_axes()
if len(ax) != n:
fig.clear()
_, ax = fig.subplots(1, n)
# For some reason fig.subplots does not return axes array but only
# a single axes even for n > 1
ax = fig.get_axes()
_, ax = plt.subplots(1, n, num=fig.number)

if component_kwargs is None:
component_kwargs = {}
Expand All @@ -795,7 +776,8 @@ def visualize(
if not hasattr(comp, "visualize"):
continue
kwargs = component_kwargs.get(k, {})
comp.visualize(cargs, ax=ax[i], **kwargs)
plt.sca(ax[i])
comp.visualize(cargs, **kwargs)
i += 1


Expand Down Expand Up @@ -949,7 +931,6 @@ def visualize(
args: Sequence[float],
model_points: Union[int, Sequence[float]] = 0,
bins: int = 50,
ax: Axes = None,
):
"""
Visualize data and model agreement (requires matplotlib).
Expand All @@ -966,14 +947,8 @@ def visualize(
it is interpreted as the point locations.
bins : int, optional
number of bins. Default is 50 bins.
ax : Axes, optional
The matplotlib axes into which the visualization is drawn.
If None is passed, the current axes is used.
"""
if ax is None:
from matplotlib import pyplot as plt

ax = plt.gca()
from matplotlib import pyplot as plt

x = np.sort(self.data)

Expand All @@ -999,8 +974,8 @@ def visualize(
cx = 0.5 * (xe[1:] + xe[:-1])
dx = xe[1] - xe[0]

ax.errorbar(cx, n, n**0.5, fmt="ok")
ax.fill_between(xm, 0, ym * dx, fc="C0")
plt.errorbar(cx, n, n**0.5, fmt="ok")
plt.fill_between(xm, 0, ym * dx, fc="C0")

def fisher_information(self, *args: float) -> NDArray:
"""
Expand Down Expand Up @@ -1400,7 +1375,7 @@ def prediction(
"""
return self._pred(args)

def visualize(self, args: Sequence[float], ax: Axes = None) -> None:
def visualize(self, args: Sequence[float]) -> None:
"""
Visualize data and model agreement (requires matplotlib).
Expand All @@ -1410,9 +1385,6 @@ def visualize(self, args: Sequence[float], ax: Axes = None) -> None:
----------
args : sequence of float
Parameter values.
ax : Axes, optional
The matplotlib axes into which the visualization is drawn.
If None is passed, the current axes is used.
Notes
-----
Expand All @@ -1422,13 +1394,10 @@ def visualize(self, args: Sequence[float], ax: Axes = None) -> None:
comparison to a model, the visualization shows all data bins as a single
sequence.
"""
return self._visualize(args, ax)

def _visualize(self, args: Sequence[float], ax: Axes) -> None:
if ax is None:
from matplotlib import pyplot as plt
return self._visualize(args)

ax = plt.gca()
def _visualize(self, args: Sequence[float]) -> None:
from matplotlib import pyplot as plt

n, ne = self._n_err()
mu = self.prediction(args)
Expand All @@ -1445,9 +1414,8 @@ def _visualize(self, args: Sequence[float], ax: Axes) -> None:
else:
xe = self.xe
cx = 0.5 * (xe[1:] + xe[:-1])

ax.errorbar(cx, n, ne, fmt="ok")
ax.stairs(mu, xe, fill=True, color="C0")
plt.errorbar(cx, n, ne, fmt="ok")
plt.stairs(mu, xe, fill=True, color="C0")

@abc.abstractmethod
def _pred(
Expand Down Expand Up @@ -1890,11 +1858,8 @@ def prediction(self, args: Sequence[float]) -> Tuple[NDArray, NDArray]:
mu, mu_var = self._pred(args)
return mu, np.sqrt(mu_var)

def _visualize(self, args: Sequence[float], ax: Axes) -> None:
if ax is None:
from matplotlib import pyplot as plt

ax = plt.gca()
def _visualize(self, args: Sequence[float]) -> None:
from matplotlib import pyplot as plt

n, ne = self._n_err()
mu, mue = self.prediction(args) # type: ignore
Expand All @@ -1911,11 +1876,11 @@ def _visualize(self, args: Sequence[float], ax: Axes) -> None:
xe = self.xe
cx = 0.5 * (xe[1:] + xe[:-1])

ax.errorbar(cx, n, ne, fmt="ok")
plt.errorbar(cx, n, ne, fmt="ok")

# need fill=True and fill=False so that bins with mue=0 show up
for fill in (False, True):
ax.stairs(mu + mue, xe, baseline=mu - mue, fill=fill, color="C0")
plt.stairs(mu + mue, xe, baseline=mu - mue, fill=fill, color="C0")

def _pulls(self, args: Sequence[float]) -> NDArray:
mu, mue = self.prediction(args)
Expand Down Expand Up @@ -2316,10 +2281,7 @@ def _ndata(self):
return len(self._masked)

def visualize(
self,
args: ArrayLike,
model_points: Union[int, Sequence[float]] = 0,
ax: Axes = None,
self, args: ArrayLike, model_points: Union[int, Sequence[float]] = 0
) -> Tuple[Tuple[NDArray, NDArray, NDArray], Tuple[NDArray, NDArray]]:
"""
Visualize data and model agreement (requires matplotlib).
Expand All @@ -2335,20 +2297,14 @@ def visualize(
How many points to use to draw the model. Default is 0, in this case
an smart sampling algorithm selects the number of points. If array-like,
it is interpreted as the point locations.
ax : Axes, optional
The matplotlib axes into which the visualization is drawn.
If None is passed, the current axes is used.
"""
if ax is None:
from matplotlib import pyplot as plt

ax = plt.gca()
from matplotlib import pyplot as plt

if self._ndim > 1:
raise ValueError("visualize is not implemented for multi-dimensional data")

x, y, ye = self._masked.T
ax.errorbar(x, y, ye, fmt="ok")
plt.errorbar(x, y, ye, fmt="ok")
if isinstance(model_points, Iterable):
xm = np.array(model_points)
ym = self.model(xm, *args)
Expand All @@ -2360,7 +2316,7 @@ def visualize(
ym = self.model(xm, *args)
else:
xm, ym = _smart_sampling(lambda x: self.model(x, *args), x[0], x[-1])
ax.plot(xm, ym)
plt.plot(xm, ym)
return (x, y, ye), (xm, ym)

def prediction(self, args: Sequence[float]) -> NDArray:
Expand Down Expand Up @@ -2530,7 +2486,7 @@ def _has_grad(self) -> bool:
def _ndata(self):
return len(self._expected)

def visualize(self, args: ArrayLike, ax: Axes = None):
def visualize(self, args: ArrayLike):
"""
Visualize data and model agreement (requires matplotlib).
Expand All @@ -2540,14 +2496,8 @@ def visualize(self, args: ArrayLike, ax: Axes = None):
----------
args : array-like
Parameter values.
ax : Axes, optional
The matplotlib axes into which the visualization is drawn.
If None is passed, the current axes is used.
"""
if ax is None:
from matplotlib import pyplot as plt

ax = plt.gca()
from matplotlib import pyplot as plt

args = np.atleast_1d(args)

Expand All @@ -2565,14 +2515,14 @@ def visualize(self, args: ArrayLike, ax: Axes = None):
for v, e, a in zip(val, err, args):
pull = (a - v) / e
max_pull = max(abs(pull), max_pull)
ax.errorbar(pull, -i, 0, 1, fmt="o", color="C0")
plt.errorbar(pull, -i, 0, 1, fmt="o", color="C0")
i += 1
ax.axvline(0, color="k")
ax.set_xlim(-max_pull - 1.1, max_pull + 1.1)
yaxis = ax.yaxis
plt.axvline(0, color="k")
plt.xlim(-max_pull - 1.1, max_pull + 1.1)
yaxis = plt.gca().yaxis
yaxis.set_ticks(-np.arange(n))
yaxis.set_ticklabels(par)
ax.set_ylim(-n + 0.5, 0.5)
plt.ylim(-n + 0.5, 0.5)


def _norm(value: ArrayLike) -> NDArray:
Expand Down
37 changes: 14 additions & 23 deletions src/iminuit/qtwidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,24 +205,14 @@ def __init__(self):
)
plot_group.setSizePolicy(size_policy)
plot_layout = QtWidgets.QVBoxLayout(plot_group)
# Use pyplot here to allow users to use pyplot in the plot
# function (not recommended / unstable)
self.fig, ax = plt.subplots()
self.canvas = FigureCanvasQTAgg(self.fig)
fig = plt.figure()
self.figsize = fig.get_size_inches()
manager = plt.get_current_fig_manager()
self.canvas = FigureCanvasQTAgg(fig)
self.canvas.manager = manager
plot_layout.addWidget(self.canvas)
plot_layout.addStretch()
interactive_layout.addWidget(plot_group, 0, 0, 2, 1)
try:
plot(minuit.values, fig=self.fig)
kwargs["fig"] = self.fig
except Exception:
pass
try:
plot(minuit.values, ax=ax)
kwargs["ax"] = ax
except Exception:
pass
self.fig_width = self.fig.get_figwidth()

button_group = QtWidgets.QGroupBox("", parent=self)
size_policy = QtWidgets.QSizePolicy(
Expand Down Expand Up @@ -285,7 +275,7 @@ def __init__(self):
self.plot_with_frame(from_fit=False, report_success=False)

def plot_with_frame(self, from_fit, report_success):
self.fig.set_figwidth(self.fig_width)
trans = plt.gca().transAxes
try:
with warnings.catch_warnings():
minuit.visualize(plot, **kwargs)
Expand All @@ -295,7 +285,7 @@ def plot_with_frame(self, from_fit, report_success):

import traceback

self.fig.text(
plt.figtext(
0,
0.5,
traceback.format_exc(limit=-1),
Expand All @@ -308,19 +298,19 @@ def plot_with_frame(self, from_fit, report_success):
return

fval = minuit.fmin.fval if from_fit else minuit._fcn(minuit.values)
self.fig.get_axes()[0].text(
plt.text(
0.05,
1.05,
f"FCN = {fval:.3f}",
transform=self.fig.get_axes()[0].transAxes,
transform=trans,
fontsize="x-large",
)
if from_fit and report_success:
self.fig.get_axes()[-1].text(
plt.text(
0.95,
1.05,
f"{'success' if minuit.valid and minuit.accurate else 'FAILURE'}",
transform=self.fig.get_axes()[-1].transAxes,
transform=trans,
fontsize="x-large",
ha="right",
)
Expand Down Expand Up @@ -353,8 +343,9 @@ def on_parameter_change(self, from_fit=False, report_success=False):
else:
self.results_text.clear()

for ax in self.fig.get_axes():
ax.clear()

plt.clf()
plt.gcf().set_size_inches(self.figsize)
self.plot_with_frame(from_fit, report_success)
self.canvas.draw_idle()

Expand Down

0 comments on commit b1fbd3f

Please sign in to comment.