Skip to content

Commit

Permalink
Allow factor sorting by R2 and renaming by enrichment significance
Browse files Browse the repository at this point in the history
Update `muvi.pl.factors_overview`
  • Loading branch information
arberqoku committed Jun 1, 2024
1 parent cdb4fa8 commit 87cf036
Show file tree
Hide file tree
Showing 5 changed files with 324 additions and 86 deletions.
84 changes: 70 additions & 14 deletions muvi/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,19 @@ def __init__(
self.device = self._setup_device(device)
self.to(self.device)

self._old_factor_names = self._factor_names.copy()

self._model = None
self._guide = None
self._built = False
self._trained = False
self._training_log: dict[str, Any] = {}
self._cache = None

self._version = version("muvi")

self._reset_cache()
self._reset_factors()

def __repr__(self):
table = [
["n_views", self.n_views],
Expand Down Expand Up @@ -171,8 +175,48 @@ def __repr__(self):
output = header + body + "\n" + empty_line
return output

def _get_factor_signs(self):
signs = np.ones(self.n_factors, dtype=np.float32)
@property
def factor_order(self):
return self._factor_order

@factor_order.setter
def factor_order(self, value):
self._factor_order = self._factor_order[np.array(value)]
if self._cache is not None:
self._cache.reorder_factors(self._factor_order)

@property
def factor_names(self):
return self._factor_names[self.factor_order]

@factor_names.setter
def factor_names(self, value):
self._factor_names = _make_index_unique(pd.Index(value))
if self._cache is not None:
self._cache.rename_factors(self._factor_names)

@property
def factor_signs(self):
return self._factor_signs[self.factor_order]

@factor_signs.setter
def factor_signs(self, value):
self._factor_signs = value

def _reset_cache(self):
self._cache = None

def _reset_factors(self):
self._factor_order = np.arange(self.n_factors)
self._factor_names = self._old_factor_names.copy()
self._factor_signs = np.ones(self.n_factors, dtype=np.float32)

def _compute_factor_signs(self):
self.factor_signs = np.ones(self.n_factors, dtype=np.float32)

# if nmf is enabled, all factors are positive
if any(self.nmf.values()):
return self.factor_signs
# only compute signs if model is trained
# and there is no nmf constraint
if self._trained:
Expand All @@ -182,12 +226,9 @@ def _get_factor_signs(self):
w = np.array(
list(map(lambda x, y: y[x], np.argsort(-np.abs(w), axis=1)[:, :100], w))
)
signs = (w.sum(axis=1) > 0) * 2 - 1
self.factor_signs = (w.sum(axis=1) > 0) * 2 - 1

# if nmf is enabled, all factors are positive
if any(self.nmf.values()):
signs = np.ones(self.n_factors, dtype=np.float32)
return pd.Series(signs, index=self.factor_names, dtype=np.float32)
return self.factor_signs

def _setup_device(self, device):
cuda_available = torch.cuda.is_available()
Expand Down Expand Up @@ -472,7 +513,7 @@ def _setup_prior_masks(self, masks, n_factors):
self.n_factors = n_factors
self.n_dense_factors = n_factors
# TODO: duplicate line...see below
self.factor_names = pd.Index([f"factor_{k}" for k in range(n_factors)])
self._factor_names = pd.Index([f"factor_{k}" for k in range(n_factors)])
return None, None

# if list convert to dict
Expand All @@ -481,7 +522,10 @@ def _setup_prior_masks(self, masks, n_factors):

masks = self._merge(masks)

informed_views = [vn for vn in self.view_names if vn in masks]
informed_views = []
for vn in self.view_names:
if vn in masks and np.any(masks[vn]):
informed_views.append(vn)

n_prior_factors = masks[informed_views[0]].shape[0]

Expand Down Expand Up @@ -620,7 +664,7 @@ def _setup_prior_masks(self, masks, n_factors):
factor_names = list(factor_names) + [
f"dense_{k}" for k in range(n_dense_factors)
]
self.factor_names = self._validate_index(pd.Index(factor_names))
self._factor_names = self._validate_index(pd.Index(factor_names))

# keep only numpy arrays
prior_masks = {
Expand Down Expand Up @@ -655,6 +699,7 @@ def _setup_prior_masks(self, masks, n_factors):

self.n_factors = n_factors
self.n_dense_factors = n_dense_factors
self.informed_views = informed_views
return prior_masks, prior_scales

def _setup_likelihoods(self, likelihoods):
Expand Down Expand Up @@ -1087,7 +1132,7 @@ def get_factor_loadings(
self._raise_untrained_error()

ws = self._guide.get_w(as_list=True)
ws = [w * self._get_factor_signs().to_numpy()[:, np.newaxis] for w in ws]
ws = [w[self.factor_order, :] * self.factor_signs[:, np.newaxis] for w in ws]
return self._get_view_attr(
{vn: ws[m] for m, vn in enumerate(self.view_names)},
view_idx,
Expand Down Expand Up @@ -1158,8 +1203,11 @@ def get_factor_scores(
"""
self._raise_untrained_error()

z = self._guide.get_z()
z = z[:, self.factor_order] * self.factor_signs

return self._get_sample_attr(
self._guide.get_z() * self._get_factor_signs().to_numpy(),
z,
sample_idx,
other_idx=factor_idx,
other_names=self.factor_names,
Expand Down Expand Up @@ -1463,6 +1511,7 @@ def _step():
pyro.clear_param_store()

logger.info("Starting training...")
# needs to be set here otherwise the logcallback fails
self._trained = True
stop_early = False
history = []
Expand Down Expand Up @@ -1501,7 +1550,14 @@ def _step():
}
logger.info("Call `model._training_log` to inspect the training progress.")
# reset cache in case it was initialized by any of the callbacks
self._cache = None
self._post_fit()

def _post_fit(self):
"""Post fit method."""
self._trained = True
self._reset_cache()
self._reset_factors()
self._compute_factor_signs()


class MuVIModel(PyroModule):
Expand Down
8 changes: 8 additions & 0 deletions muvi/tools/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ def update_cov_metadata(self, scores):
if self.cov_adata is not None:
self.cov_adata.varm[Cache.META_KEY].update(scores.astype(np.float32))

def reorder_factors(self, order):
if self.factor_adata is not None:
self.factor_adata = self.factor_adata[:, order].copy()

def rename_factors(self, factor_names):
if self.factor_adata is not None:
self.factor_adata.var_names = factor_names

def filter_factors(self, factor_idx):
self.factor_adata.obsm[Cache.FILTERED_KEY] = (
self.factor_adata.to_df().loc[:, factor_idx].copy()
Expand Down
178 changes: 114 additions & 64 deletions muvi/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def variance_explained_grouped(

def factors_overview(
model,
view_idx=0,
view_idx="all",
one_sided=True,
alpha=0.1,
sig_only=False,
Expand All @@ -278,77 +278,127 @@ def factors_overview(
if isinstance(view_idx, int):
view_idx = model.view_names[view_idx]

view_indices = _normalize_index(view_idx, model.view_names, as_idx=False)
n_views = len(view_indices)

model_cache = _get_model_cache(model)
data = model_cache.factor_metadata.copy()

figsize = (8 * n_views, 8)
fig, axs = plt.subplots(1, n_views, figsize=figsize, squeeze=False, sharey=False)

factor_metadata = model_cache.factor_metadata.copy()

name_col = "Factor"
data[name_col] = data.index.astype(str)
factor_metadata[name_col] = factor_metadata.index.astype(str)

if prior_only:
data = data.loc[~data[name_col].str.contains("dense", case=False), :].copy()

size_col = None
if model._informed:
size_col = "Size"
data[size_col] = model.get_prior_masks(view_idx, as_df=True)[view_idx].sum(1)
data.loc[data[name_col].str.contains("dense", case=False), size_col] = 0

sign_dict = {model_cache.TEST_ALL: " (*)"}
if one_sided:
sign_dict = {model_cache.TEST_NEG: " (-)", model_cache.TEST_POS: " (+)"}
joint_p_col = "p_min"
direction_col = "direction"

p_col = "p"
if adjusted:
p_col = "p_adj"

p_df = data[[f"{p_col}_{sign}_{view_idx}" for sign in sign_dict]]
if p_df.isna().all(None) and sig_only:
raise ValueError("No test results found in model cache, rerun `muvi.tl.test`.")

p_df = p_df.fillna(1.0)
data[joint_p_col] = p_df.min(axis=1).clip(1e-10, 1.0)
data[direction_col] = p_df.idxmin(axis=1).str[len(p_col) + 1 : len(p_col) + 4]

if alpha is None:
alpha = 1.0
if alpha <= 0:
logger.warning("Negative or zero `alpha`, setting `alpha` to 0.01.")
alpha = 0.01
if alpha > 1.0:
logger.warning("`alpha` larger than 1.0, setting `alpha` to 1.0.")
alpha = 1.0
data.loc[data[joint_p_col] > alpha, direction_col] = ""
if sig_only:
data = data.loc[data[direction_col] != "", :]

data[name_col] = data[name_col] + data[direction_col].map(sign_dict).fillna("")

neg_log_col = r"$-\log_{10}(FDR)$"
data[neg_log_col] = -np.log10(data[joint_p_col])
r2_col = f"r2_{view_idx}"
if top > 0:
data = data.sort_values(r2_col, ascending=True)
factor_metadata = factor_metadata.loc[
~factor_metadata[name_col].str.contains("dense", case=False), :
].copy()

for m, view_idx in enumerate(view_indices):
factor_metadata_view = factor_metadata.copy()
size_col = None
if model._informed:
size_col = "Size"
factor_metadata_view[size_col] = model.get_prior_masks(
view_idx, as_df=True
)[view_idx].sum(1)
factor_metadata_view.loc[
factor_metadata_view[name_col].str.contains("dense", case=False),
size_col,
] = 0

sign_dict = {model_cache.TEST_ALL: " (*)"}
if one_sided:
sign_dict = {model_cache.TEST_NEG: " (-)", model_cache.TEST_POS: " (+)"}
joint_p_col = "p_min"
direction_col = "direction"

p_col = model_cache.TEST_P
if adjusted:
p_col = model_cache.TEST_P_ADJ

p_df = factor_metadata_view[
[f"{p_col}_{sign}_{view_idx}" for sign in sign_dict]
]
if p_df.isna().all(None) and sig_only:
raise ValueError(
"No test results found in model cache, rerun `muvi.tl.test`."
)

g = sns.scatterplot(
data=data.iloc[-top:],
x=r2_col,
y=name_col,
hue=neg_log_col,
palette=kwargs.pop("palette", "flare"),
size=size_col,
sizes=kwargs.pop("sizes", (50, 350)),
**kwargs,
)
g.set_title(
rf"Overview top factors in {view_idx} $\alpha = {alpha}$"
# rf"($\alpha = {alpha:.{max(1, int(-np.log10(alpha)))}f}$)" # noqa: ERA001
)
g.set(xlabel=r"$R^2$")
p_df = p_df.fillna(1.0)
factor_metadata_view[joint_p_col] = p_df.min(axis=1).clip(1e-10, 1.0)
factor_metadata_view[direction_col] = p_df.idxmin(axis=1).str[
len(p_col) + 1 : len(p_col) + 4
]

if alpha is None:
alpha = 1.0
if alpha <= 0:
logger.warning("Negative or zero `alpha`, setting `alpha` to 0.01.")
alpha = 0.01
if alpha > 1.0:
logger.warning("`alpha` larger than 1.0, setting `alpha` to 1.0.")
alpha = 1.0
factor_metadata_view.loc[
factor_metadata_view[joint_p_col] > alpha, direction_col
] = ""
if sig_only:
factor_metadata_view = factor_metadata_view.loc[
factor_metadata_view[direction_col] != "", :
]

factor_metadata_view[name_col] = factor_metadata_view[
name_col
] + factor_metadata_view[direction_col].map(sign_dict).fillna("")

neg_log_col = r"$-\log_{10}(FDR)$"
factor_metadata_view[neg_log_col] = -np.log10(factor_metadata_view[joint_p_col])

r2_col = f"{model_cache.METRIC_R2}_{view_idx}"
if top > 0:
factor_metadata_view = factor_metadata_view.sort_values(
r2_col, ascending=True
)

g = sns.scatterplot(
ax=axs[0][m],
data=factor_metadata_view.iloc[-top:],
x=r2_col,
y=name_col,
hue=neg_log_col,
palette=kwargs.pop("palette", "flare"),
size=size_col,
sizes=kwargs.pop("sizes", (50, 350)),
**kwargs,
)
g.set_title(
rf"Overview top factors in {view_idx} $\alpha = {alpha}$"
# rf"($\alpha = {alpha:.{max(1, int(-np.log10(alpha)))}f}$)" # noqa: ERA001
)
g.set(xlabel=r"$R^2$")

# Format the legend labels
new_labels = []
for text in g.legend_.get_texts():
if "size" in text.get_text().lower():
break
try:
new_label = float(text.get_text())
new_label = f"{new_label:.3f}"
except ValueError:
new_label = text.get_text()
new_labels.append(new_label)

# Set the new labels
for text, new_label in zip(g.legend_.get_texts(), new_labels):
text.set_text(new_label)

fig.tight_layout()
savefig_or_show(f"overview_view_{view_idx}", show=show, save=save)
if not show:
return g
return fig, axs


def inspect_factor(
Expand Down
Loading

0 comments on commit 87cf036

Please sign in to comment.