Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow control of font size and the x- and y-axis limits #68

Merged
merged 1 commit into from
Nov 9, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions cinnabar/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def _master_plot(
dpi: float = "figure",
data_labels: list = [],
axis_padding: float = 0.5,
xy_lim: list = [],
font_sizes: dict = {"title": 12, "labels": 9, "other": 12},
):
"""Handles the aesthetics of the plots in one place.

Expand Down Expand Up @@ -81,26 +83,34 @@ def _master_plot(
list of labels for each data point
axis_padding : float, default = 0.5
padding to add to maximum axis value and subtract from the minimum axis value
xy_lim : list, default []
contains the minimium and maximum values to use for the x and y axes. if specified, axis_padding is ignored
font_sizes : dict, default {"title": 12, "labels": 9, "other": 12}
font sizes to use for the title ("title"), the data labels ("labels"), and the rest of the plot ("other")

Returns
-------

"""
nsamples = len(x)
# aesthetics
plt.rcParams["xtick.labelsize"] = 12
plt.rcParams["ytick.labelsize"] = 12
plt.rcParams["font.size"] = 12
plt.rcParams["xtick.labelsize"] = font_sizes["other"]
plt.rcParams["ytick.labelsize"] = font_sizes["other"]
plt.rcParams["font.size"] = font_sizes["other"]

fig = plt.figure(figsize=(figsize, figsize))
plt.subplots_adjust(left=0.2, right=0.8, bottom=0.2, top=0.8)

plt.xlabel(f"{xlabel} {quantity} ({units})")
plt.ylabel(f"{ylabel} {quantity} ({units})")

ax_min = min(min(x), min(y)) - axis_padding
ax_max = max(max(x), max(y)) + axis_padding
scale = [ax_min, ax_max]
if xy_lim:
ax_min, ax_max = xy_lim
scale = xy_lim
else:
ax_min = min(min(x), min(y)) - axis_padding
ax_max = max(max(x), max(y)) + axis_padding
scale = [ax_min, ax_max]

plt.xlim(scale)
plt.ylim(scale)
Expand Down Expand Up @@ -151,7 +161,7 @@ def _master_plot(
# Label points
texts = []
for i, label in enumerate(data_labels):
texts.append(plt.text(x[i] + 0.03, y[i] + 0.03, label, fontsize=9))
texts.append(plt.text(x[i] + 0.03, y[i] + 0.03, label, fontsize=font_sizes["labels"]))
adjust_text(texts)

# stats and title
Expand All @@ -165,7 +175,7 @@ def _master_plot(

plt.title(
long_title,
fontsize=12,
fontsize=font_sizes["title"],
loc="right",
horizontalalignment="right",
family="monospace",
Expand Down