Skip to content

Commit

Permalink
fix ci
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudon committed Jan 25, 2024
1 parent cf7ebee commit 2be34cd
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 335 deletions.
1 change: 1 addition & 0 deletions emodel_generalisation/bluecellulab_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Compute the threshold and holding current using bluecellulab, adapted from BluePyThresh."""
import logging
from copy import copy
from multiprocessing.context import TimeoutError # pylint: disable=redefined-builtin
from pathlib import Path

import bluecellulab
Expand Down
305 changes: 0 additions & 305 deletions emodel_generalisation/information.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from scipy.stats import pearsonr
from tqdm import tqdm

from emodel_generalisation import ALL_LABELS
from emodel_generalisation import FEATURE_LABELS_LONG
from emodel_generalisation import PARAM_LABELS
from emodel_generalisation.utils import cluster_matrix
Expand Down Expand Up @@ -685,307 +684,3 @@ def reduce_features(df, threshold=0.9):
selected_features.append(feature1)
print(f"Found {len(selected_features)} out of {len(df.index)}")
return selected_features, feature_map


def bin_data(p1, p2, f, n=20, mode="mean", _min1=-1.0, _max1=1.0, _min2=-1.0, _max2=1.0):
"""Bin data to make heatmap."""
_df = pd.DataFrame()
_df["p1"] = np.array(n * (p1 - _min1) / (_max1 - _min1 + 1e-10), dtype=int)
_df["p2"] = np.array(n * (p2 - _min2) / (_max2 - _min2 + 1e-10), dtype=int)
_df["f"] = f
_df = getattr(_df.groupby(["p1", "p2"]), mode)().reset_index()
m = np.zeros([n + 1, n + 1])
m[_df["p1"], _df["p2"]] = _df["f"]
m[m == 0] = np.nan
return m[:-1][:, :-1]


def _get_2d_data(df, feature, param1, param2, n_bins=20, perc=10):
"""Bin data to make heatmap."""
p1 = df[("normalized_parameters", param1)].to_numpy()
p2 = df[("normalized_parameters", param2)].to_numpy()

if feature is None:
m = bin_data(p1, p2, np.ones(len(p1)), n=n_bins, mode="sum")
else:
m = bin_data(p1, p2, feature, n=n_bins, mode="mean")
return np.clip(
m, np.percentile(m[~np.isnan(m)], perc), np.percentile(m[~np.isnan(m)], 100 - perc)
)


def _plot_2d_data(ax, m, vmin, vmax, rev=True, cmap="gnuplot", normalize=False):
"""Plot heatmap."""
dx = 1.0 / (len(m) - 1.0)
return ax.imshow(
m.T / np.nanmax(np.nanmax(m)) if normalize else m.T,
origin="lower",
aspect="auto",
extent=(-1 - dx, 1 + dx, -1 - dx, 1 + dx),
cmap=f"{cmap}_r" if rev else cmap,
interpolation="nearest",
vmin=0 if normalize else vmin,
vmax=1 if normalize else vmax,
)


def plot_corner(
df,
feature=None,
filename="corner.pdf",
n_bins=20,
cmap="gnuplot",
normalize=False,
highlights=None,
):
"""Make a corner plot which consists of scatter plots of all pairs.
Args:
feature (str): name of feature for coloring heatmap
filename (str): name of figure for corner plot
"""
params = np.array(sorted(df.normalized_parameters.columns.to_list()))
_params = np.array([PARAM_LABELS.get(p, p) for p in params])
params = params[np.argsort(_params)]
n_params = len(params)

# get feature data
_feature = None
if feature is not None:
_feature = df[feature].to_numpy()

if np.std(_feature) < 1e-5:
print("no data to plot")
return None

# precompute heatmaps to get a global vmin/vmax
m = []
vmin = 1e10
vmax = -1e10
for i, param1 in enumerate(params):
m.append([])
for j, param2 in enumerate(params):
if j < i:
_m = _get_2d_data(df, _feature, param2, param1, n_bins=n_bins)
m[i].append(_m)
vmin = min(vmin, min(_m[~np.isnan(_m)].flatten())) # pylint: disable=nested-min-max
vmax = max(vmax, max(_m[~np.isnan(_m)].flatten())) # pylint: disable=nested-min-max
fig = plt.figure(figsize=(5 + 0.5 * n_params, 5 + 0.5 * n_params))
gs = fig.add_gridspec(n_params, n_params, hspace=0.1, wspace=0.1)
im = None
for i, param1 in enumerate(params):
_param1 = PARAM_LABELS.get(param1, param1)
for j, param2 in enumerate(params):
_param2 = PARAM_LABELS.get(param2, param2)
ax = plt.subplot(gs[i, j])

ax.set_xticks([])
ax.set_yticks([])
ax.set_yticklabels([])
ax.set_xticklabels([])

if j >= i + 1:
ax.set_frame_on(False)
elif j < i:
ax.set_frame_on(True)
im = _plot_2d_data(
ax, m[i][j], vmin, vmax, rev=feature is not None, cmap=cmap, normalize=normalize
)
if highlights is not None:
for _i, c in zip(*highlights):
plt.scatter(
df.loc[_i, ("normalized_parameters", param2)],
df.loc[_i, ("normalized_parameters", param1)],
c=c,
s=20,
)
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
else:
if _feature is None:
ax.hist(
df[("normalized_parameters", param1)].to_numpy(),
bins=n_bins,
color="k",
histtype="step",
)
ax.set_xlim(-1, 1)
ax.set_frame_on(True)
else:
ax.set_frame_on(False)

if j == 0:
ax.set_ylabel(_param1, rotation="horizontal", horizontalalignment="right")
if i == n_params - 1:
ax.set_xlabel(_param2, rotation="vertical")
if j == 0 and i == 0:
ax.set_ylabel(_param1, rotation="horizontal", horizontalalignment="left")
if (i == j) if _feature is None else (i == j + 1):
ax.set_ylabel(_param1, rotation="horizontal", horizontalalignment="left")
ax.yaxis.set_label_position("right")

if im is not None:
if n_params > 6:
axs = [
plt.subplot(gs[i + 3, n_params - j - 1])
for i in range(n_params - 6)
for j in range(3)
]
else:
axs = [plt.subplot(gs[-1, -1])]
plt.colorbar(
im,
orientation="vertical",
ax=axs,
label=feature[1] if feature is not None else "number of models",
)
plt.tight_layout()
plt.savefig(filename)

return fig


def _top_params(cor, sd=5):
cor_n = cor.to_numpy()
if np.shape(cor_n)[0] == np.shape(cor_n)[1]:
_x = np.triu(cor, 1).flatten()
_x = _x[abs(_x) > 0]
thresh = sd * np.std(_x)
else:
thresh = sd * np.std(cor_n.flatten())
if sd > 0:
_cor = cor[cor > thresh]
else:
_cor = cor[cor < thresh]
tuples = []
for x in cor.index:
for y in cor.columns:
if not np.isnan(_cor.loc[x, y]):
if (y, x) not in tuples:
tuples += [(x, y)]
return tuples


def plot_best_corr(df, cor, x_col, y_col, filename, sd=5):
"""Plot only highest correlated tuples."""
tuples = _top_params(cor, sd)
with PdfPages(filename) as pdf:
for x, y in tqdm(tuples):
x_data = df[(x_col, x)]
y_data = df[(y_col, y)]

plt.figure(figsize=(4, 3))
plt.scatter(x_data, y_data, c=df["cost"], s=0.2, marker=".", rasterized=True)
plt.suptitle(f"MI: {cor.loc[x, y]}")
plt.colorbar()
plt.xlabel(ALL_LABELS.get(x, x))
plt.ylabel(ALL_LABELS.get(y, y))
plt.tight_layout()
pdf.savefig()
plt.close()


def plot_MI(MI, with_cluster=False):
"""Plot MI matrix."""
if with_cluster:
sorted_labels = cluster_matrix(abs(MI))
MI = MI.loc[sorted_labels, sorted_labels]
plt.figure(figsize=(15, 15))
ax = plt.gca()
_MI = MI.copy()
_MI.index = [ALL_LABELS.get(p, p) for p in MI.index]
_MI.columns = [ALL_LABELS.get(p, p) for p in MI.columns]
if MI.min().min() >= 0:
vmin = 0
vmax = None
cmap = "viridis"
else:
vmax = MI.abs().max().max()
vmin = -vmax
cmap = "bwr"

sns.heatmap(
data=_MI,
ax=ax,
vmin=vmin,
vmax=vmax,
cmap=cmap,
linewidths=0.5,
linecolor="k",
cbar_kws={"label": "MI", "shrink": 0.3},
xticklabels=True,
yticklabels=True,
square=True,
)
plt.tight_layout()


def get_2d_correlations(
df,
x_col="normalized_parameters",
y_col="normalized_parameters",
mi_max=1,
feature=None,
tpe="MI",
):
"""Get 2d correlations."""
if x_col == y_col:
tuples = itertools.combinations(df[x_col].columns, 2)
else:
tuples = itertools.product(df[x_col].columns, df[y_col].columns)

MI = pd.DataFrame(index=df[x_col].columns, columns=df[y_col].columns, dtype=float)
for x, y in tuples:
if feature is None:
if df[(x_col, x)].std() == 0 or df[(y_col, y)].std() == 0:
mi = 0
else:
if tpe == "MI":
mi = mi_gaussian(
np.vstack([df[(x_col, x)].to_numpy(), df[(y_col, y)].to_numpy()]).T
)
if tpe == "pearson":
mi = pearsonr(df[(x_col, x)].to_numpy(), df[(y_col, y)])[0]
else:
feat = df[("features", feature)].to_numpy()
if feat.std() > 0:
feat = (feat - feat.mean()) / feat.std()
else:
feat = 0 * feat
if tpe == "MI":
mi = rsi_gaussian(
np.vstack([feat, df[(x_col, x)].to_numpy(), df[(y_col, y)].to_numpy()]).T
)
if tpe == "pearson":
mi = pearsonr(feat, [df[(x_col, x)].to_numpy(), df[(y_col, y)].to_numpy()])[0]
if tpe == "MI":
mi = min(mi, mi_max)
if np.isnan(mi):
mi = 0
MI.loc[x, y] = mi
if x_col == y_col:
MI.loc[y, x] = mi
MI.loc[x, x] = 0.0
MI.loc[y, y] = 0.0
return MI


def plot_top_corner(df, out_path):
"""Plot subcorner of top correlations."""
for feat in df["features"].columns:
cor = get_2d_correlations(df, feature=feat)
plot_MI(cor)
plt.savefig(out_path / f"RSI_param_param_{feat}.pdf")
tuples = _top_params(cor, sd=-5)
params = set(np.array(tuples).flatten())
_df = df.drop(
columns=[
c for c in df.columns if (c[0] == "normalized_parameters") and (c[1] not in params)
]
)
plot_corner(
_df,
feature=("features", feat),
filename=out_path / f"RSI_top_corner_{feat}.pdf",
)
plt.close()
32 changes: 2 additions & 30 deletions emodel_generalisation/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@

from emodel_generalisation import ALL_LABELS
from emodel_generalisation import PARAM_LABELS
from emodel_generalisation.information import mi_gaussian
from emodel_generalisation.information import rsi_gaussian
from emodel_generalisation.model.access_point import AccessPoint
from emodel_generalisation.model.evaluation import get_evaluator_from_access_point
from emodel_generalisation.parallel import evaluate
Expand Down Expand Up @@ -1031,36 +1033,6 @@ def plot_feature_distributions(
plt.close()


def log(x, unit="nats"):
"""Log function."""
if unit == "nats":
return np.log(x)
if unit == "bits":
return np.log2(x)
raise Exception("Unknown unit")


def mi_gaussian(x):
"""MI with gaussian approximation."""
cov = np.cov(x.T)
mi = -log(np.linalg.det(cov)) + log(cov[0][0]) + log(cov[1][1])
return 0.5 * mi


def rsi_gaussian(x):
"""RSI calculation with gaussians (assuming first element is y)."""
cov = np.cov(x.T)
cov_X = cov[1:][:, 1:]
dim = len(cov)

rsi = (dim - 2) * log(cov[0, 0])
rsi += sum(log(np.diag(cov_X))) - log(np.linalg.det(cov_X))
for i in range(dim - 1):
rsi -= log(np.linalg.det(cov[[0, i + 1]][:, [0, i + 1]]))
rsi += log(np.linalg.det(cov))
return 0.5 * rsi


def bin_data(p1, p2, f, n=20, mode="mean", _min1=-1.0, _max1=1.0, _min2=-1.0, _max2=1.0):
"""Bin data to make heatmap."""
_df = pd.DataFrame()
Expand Down

0 comments on commit 2be34cd

Please sign in to comment.