From 87cf036d4ee914e7f7742637cd1c1dffa3164697 Mon Sep 17 00:00:00 2001 From: Arber Qoku Date: Sat, 1 Jun 2024 23:52:27 +0200 Subject: [PATCH] Allow factor sorting by R2 and renaming by enrichment significance Update `muvi.pl.factors_overview` --- muvi/core/models.py | 84 +++++++++++++++---- muvi/tools/cache.py | 8 ++ muvi/tools/plotting.py | 178 ++++++++++++++++++++++++++--------------- muvi/tools/utils.py | 138 ++++++++++++++++++++++++++++++-- tests/test_tools.py | 2 +- 5 files changed, 324 insertions(+), 86 deletions(-) diff --git a/muvi/core/models.py b/muvi/core/models.py index a6ab271..428682b 100755 --- a/muvi/core/models.py +++ b/muvi/core/models.py @@ -126,15 +126,19 @@ def __init__( self.device = self._setup_device(device) self.to(self.device) + self._old_factor_names = self._factor_names.copy() + self._model = None self._guide = None self._built = False self._trained = False self._training_log: dict[str, Any] = {} - self._cache = None self._version = version("muvi") + self._reset_cache() + self._reset_factors() + def __repr__(self): table = [ ["n_views", self.n_views], @@ -171,8 +175,48 @@ def __repr__(self): output = header + body + "\n" + empty_line return output - def _get_factor_signs(self): - signs = np.ones(self.n_factors, dtype=np.float32) + @property + def factor_order(self): + return self._factor_order + + @factor_order.setter + def factor_order(self, value): + self._factor_order = self._factor_order[np.array(value)] + if self._cache is not None: + self._cache.reorder_factors(self._factor_order) + + @property + def factor_names(self): + return self._factor_names[self.factor_order] + + @factor_names.setter + def factor_names(self, value): + self._factor_names = _make_index_unique(pd.Index(value)) + if self._cache is not None: + self._cache.rename_factors(self._factor_names) + + @property + def factor_signs(self): + return self._factor_signs[self.factor_order] + + @factor_signs.setter + def factor_signs(self, value): + self._factor_signs = value + + def _reset_cache(self): + self._cache = None + + def _reset_factors(self): + self._factor_order = np.arange(self.n_factors) + self._factor_names = self._old_factor_names.copy() + self._factor_signs = np.ones(self.n_factors, dtype=np.float32) + + def _compute_factor_signs(self): + self.factor_signs = np.ones(self.n_factors, dtype=np.float32) + + # if nmf is enabled, all factors are positive + if any(self.nmf.values()): + return self.factor_signs # only compute signs if model is trained # and there is no nmf constraint if self._trained: @@ -182,12 +226,9 @@ def _get_factor_signs(self): w = np.array( list(map(lambda x, y: y[x], np.argsort(-np.abs(w), axis=1)[:, :100], w)) ) - signs = (w.sum(axis=1) > 0) * 2 - 1 + self.factor_signs = (w.sum(axis=1) > 0) * 2 - 1 - # if nmf is enabled, all factors are positive - if any(self.nmf.values()): - signs = np.ones(self.n_factors, dtype=np.float32) - return pd.Series(signs, index=self.factor_names, dtype=np.float32) + return self.factor_signs def _setup_device(self, device): cuda_available = torch.cuda.is_available() @@ -472,7 +513,7 @@ def _setup_prior_masks(self, masks, n_factors): self.n_factors = n_factors self.n_dense_factors = n_factors # TODO: duplicate line...see below - self.factor_names = pd.Index([f"factor_{k}" for k in range(n_factors)]) + self._factor_names = pd.Index([f"factor_{k}" for k in range(n_factors)]) return None, None # if list convert to dict @@ -481,7 +522,10 @@ def _setup_prior_masks(self, masks, n_factors): masks = self._merge(masks) - informed_views = [vn for vn in self.view_names if vn in masks] + informed_views = [] + for vn in self.view_names: + if vn in masks and np.any(masks[vn]): + informed_views.append(vn) n_prior_factors = masks[informed_views[0]].shape[0] @@ -620,7 +664,7 @@ def _setup_prior_masks(self, masks, n_factors): factor_names = list(factor_names) + [ f"dense_{k}" for k in range(n_dense_factors) ] - self.factor_names = self._validate_index(pd.Index(factor_names)) + self._factor_names = self._validate_index(pd.Index(factor_names)) # keep only numpy arrays prior_masks = { @@ -655,6 +699,7 @@ def _setup_prior_masks(self, masks, n_factors): self.n_factors = n_factors self.n_dense_factors = n_dense_factors + self.informed_views = informed_views return prior_masks, prior_scales def _setup_likelihoods(self, likelihoods): @@ -1087,7 +1132,7 @@ def get_factor_loadings( self._raise_untrained_error() ws = self._guide.get_w(as_list=True) - ws = [w * self._get_factor_signs().to_numpy()[:, np.newaxis] for w in ws] + ws = [w[self.factor_order, :] * self.factor_signs[:, np.newaxis] for w in ws] return self._get_view_attr( {vn: ws[m] for m, vn in enumerate(self.view_names)}, view_idx, @@ -1158,8 +1203,11 @@ def get_factor_scores( """ self._raise_untrained_error() + z = self._guide.get_z() + z = z[:, self.factor_order] * self.factor_signs + return self._get_sample_attr( - self._guide.get_z() * self._get_factor_signs().to_numpy(), + z, sample_idx, other_idx=factor_idx, other_names=self.factor_names, @@ -1463,6 +1511,7 @@ def _step(): pyro.clear_param_store() logger.info("Starting training...") + # needs to be set here otherwise the logcallback fails self._trained = True stop_early = False history = [] @@ -1501,7 +1550,14 @@ def _step(): } logger.info("Call `model._training_log` to inspect the training progress.") # reset cache in case it was initialized by any of the callbacks - self._cache = None + self._post_fit() + + def _post_fit(self): + """Post fit method.""" + self._trained = True + self._reset_cache() + self._reset_factors() + self._compute_factor_signs() class MuVIModel(PyroModule): diff --git a/muvi/tools/cache.py b/muvi/tools/cache.py index beeb7de..96200c5 100755 --- a/muvi/tools/cache.py +++ b/muvi/tools/cache.py @@ -93,6 +93,14 @@ def update_cov_metadata(self, scores): if self.cov_adata is not None: self.cov_adata.varm[Cache.META_KEY].update(scores.astype(np.float32)) + def reorder_factors(self, order): + if self.factor_adata is not None: + self.factor_adata = self.factor_adata[:, order].copy() + + def rename_factors(self, factor_names): + if self.factor_adata is not None: + self.factor_adata.var_names = factor_names + def filter_factors(self, factor_idx): self.factor_adata.obsm[Cache.FILTERED_KEY] = ( self.factor_adata.to_df().loc[:, factor_idx].copy() diff --git a/muvi/tools/plotting.py b/muvi/tools/plotting.py index 1077391..81a7607 100755 --- a/muvi/tools/plotting.py +++ b/muvi/tools/plotting.py @@ -262,7 +262,7 @@ def variance_explained_grouped( def factors_overview( model, - view_idx=0, + view_idx="all", one_sided=True, alpha=0.1, sig_only=False, @@ -278,77 +278,127 @@ def factors_overview( if isinstance(view_idx, int): view_idx = model.view_names[view_idx] + view_indices = _normalize_index(view_idx, model.view_names, as_idx=False) + n_views = len(view_indices) + model_cache = _get_model_cache(model) - data = model_cache.factor_metadata.copy() + + figsize = (8 * n_views, 8) + fig, axs = plt.subplots(1, n_views, figsize=figsize, squeeze=False, sharey=False) + + factor_metadata = model_cache.factor_metadata.copy() name_col = "Factor" - data[name_col] = data.index.astype(str) + factor_metadata[name_col] = factor_metadata.index.astype(str) if prior_only: - data = data.loc[~data[name_col].str.contains("dense", case=False), :].copy() - - size_col = None - if model._informed: - size_col = "Size" - data[size_col] = model.get_prior_masks(view_idx, as_df=True)[view_idx].sum(1) - data.loc[data[name_col].str.contains("dense", case=False), size_col] = 0 - - sign_dict = {model_cache.TEST_ALL: " (*)"} - if one_sided: - sign_dict = {model_cache.TEST_NEG: " (-)", model_cache.TEST_POS: " (+)"} - joint_p_col = "p_min" - direction_col = "direction" - - p_col = "p" - if adjusted: - p_col = "p_adj" - - p_df = data[[f"{p_col}_{sign}_{view_idx}" for sign in sign_dict]] - if p_df.isna().all(None) and sig_only: - raise ValueError("No test results found in model cache, rerun `muvi.tl.test`.") - - p_df = p_df.fillna(1.0) - data[joint_p_col] = p_df.min(axis=1).clip(1e-10, 1.0) - data[direction_col] = p_df.idxmin(axis=1).str[len(p_col) + 1 : len(p_col) + 4] - - if alpha is None: - alpha = 1.0 - if alpha <= 0: - logger.warning("Negative or zero `alpha`, setting `alpha` to 0.01.") - alpha = 0.01 - if alpha > 1.0: - logger.warning("`alpha` larger than 1.0, setting `alpha` to 1.0.") - alpha = 1.0 - data.loc[data[joint_p_col] > alpha, direction_col] = "" - if sig_only: - data = data.loc[data[direction_col] != "", :] - - data[name_col] = data[name_col] + data[direction_col].map(sign_dict).fillna("") - - neg_log_col = r"$-\log_{10}(FDR)$" - data[neg_log_col] = -np.log10(data[joint_p_col]) - r2_col = f"r2_{view_idx}" - if top > 0: - data = data.sort_values(r2_col, ascending=True) + factor_metadata = factor_metadata.loc[ + ~factor_metadata[name_col].str.contains("dense", case=False), : + ].copy() + + for m, view_idx in enumerate(view_indices): + factor_metadata_view = factor_metadata.copy() + size_col = None + if model._informed: + size_col = "Size" + factor_metadata_view[size_col] = model.get_prior_masks( + view_idx, as_df=True + )[view_idx].sum(1) + factor_metadata_view.loc[ + factor_metadata_view[name_col].str.contains("dense", case=False), + size_col, + ] = 0 + + sign_dict = {model_cache.TEST_ALL: " (*)"} + if one_sided: + sign_dict = {model_cache.TEST_NEG: " (-)", model_cache.TEST_POS: " (+)"} + joint_p_col = "p_min" + direction_col = "direction" + + p_col = model_cache.TEST_P + if adjusted: + p_col = model_cache.TEST_P_ADJ + + p_df = factor_metadata_view[ + [f"{p_col}_{sign}_{view_idx}" for sign in sign_dict] + ] + if p_df.isna().all(None) and sig_only: + raise ValueError( + "No test results found in model cache, rerun `muvi.tl.test`." + ) - g = sns.scatterplot( - data=data.iloc[-top:], - x=r2_col, - y=name_col, - hue=neg_log_col, - palette=kwargs.pop("palette", "flare"), - size=size_col, - sizes=kwargs.pop("sizes", (50, 350)), - **kwargs, - ) - g.set_title( - rf"Overview top factors in {view_idx} $\alpha = {alpha}$" - # rf"($\alpha = {alpha:.{max(1, int(-np.log10(alpha)))}f}$)" # noqa: ERA001 - ) - g.set(xlabel=r"$R^2$") + p_df = p_df.fillna(1.0) + factor_metadata_view[joint_p_col] = p_df.min(axis=1).clip(1e-10, 1.0) + factor_metadata_view[direction_col] = p_df.idxmin(axis=1).str[ + len(p_col) + 1 : len(p_col) + 4 + ] + + if alpha is None: + alpha = 1.0 + if alpha <= 0: + logger.warning("Negative or zero `alpha`, setting `alpha` to 0.01.") + alpha = 0.01 + if alpha > 1.0: + logger.warning("`alpha` larger than 1.0, setting `alpha` to 1.0.") + alpha = 1.0 + factor_metadata_view.loc[ + factor_metadata_view[joint_p_col] > alpha, direction_col + ] = "" + if sig_only: + factor_metadata_view = factor_metadata_view.loc[ + factor_metadata_view[direction_col] != "", : + ] + + factor_metadata_view[name_col] = factor_metadata_view[ + name_col + ] + factor_metadata_view[direction_col].map(sign_dict).fillna("") + + neg_log_col = r"$-\log_{10}(FDR)$" + factor_metadata_view[neg_log_col] = -np.log10(factor_metadata_view[joint_p_col]) + + r2_col = f"{model_cache.METRIC_R2}_{view_idx}" + if top > 0: + factor_metadata_view = factor_metadata_view.sort_values( + r2_col, ascending=True + ) + + g = sns.scatterplot( + ax=axs[0][m], + data=factor_metadata_view.iloc[-top:], + x=r2_col, + y=name_col, + hue=neg_log_col, + palette=kwargs.pop("palette", "flare"), + size=size_col, + sizes=kwargs.pop("sizes", (50, 350)), + **kwargs, + ) + g.set_title( + rf"Overview top factors in {view_idx} $\alpha = {alpha}$" + # rf"($\alpha = {alpha:.{max(1, int(-np.log10(alpha)))}f}$)" # noqa: ERA001 + ) + g.set(xlabel=r"$R^2$") + + # Format the legend labels + new_labels = [] + for text in g.legend_.get_texts(): + if "size" in text.get_text().lower(): + break + try: + new_label = float(text.get_text()) + new_label = f"{new_label:.3f}" + except ValueError: + new_label = text.get_text() + new_labels.append(new_label) + + # Set the new labels + for text, new_label in zip(g.legend_.get_texts(), new_labels): + text.set_text(new_label) + + fig.tight_layout() savefig_or_show(f"overview_view_{view_idx}", show=show, save=save) if not show: - return g + return fig, axs def inspect_factor( diff --git a/muvi/tools/utils.py b/muvi/tools/utils.py index c846839..9306881 100644 --- a/muvi/tools/utils.py +++ b/muvi/tools/utils.py @@ -139,9 +139,10 @@ def _recon_error( factor_wise, cov_wise, subsample, - cache, metric_label, metric_fn, + cache, + sort, ): if view_idx is None: raise ValueError("`view_idx` cannot be None.") @@ -254,6 +255,12 @@ def _recon_error( model_cache.update_uns(view_scores_key, view_scores) model_cache.update_factor_metadata(factor_scores) model_cache.update_cov_metadata(cov_scores) + if sort and sort in ["ascending", "descending"]: + order = ( + factor_scores.sum(1).sort_values(ascending=sort == "ascending").index + ) + order = _normalize_index(order, model.factor_names, as_idx=True) + model.factor_order = order return view_scores, factor_scores, cov_scores @@ -268,6 +275,7 @@ def rmse( cov_wise: bool = True, subsample: int = 0, cache: bool = True, + sort: bool = True, ): """Compute RMSE. @@ -293,11 +301,16 @@ def rmse( Number of samples to estimate RMSE, by default 0 (all samples) cache : bool, optional Whether to store results in the model cache, by default True + sort : bool, optional + Whether to sort factors by RMSE, by default True """ def _rmse(y_true, y_pred): return root_mean_squared_error(y_true, y_pred) + if sort: + sort = "ascending" + return _recon_error( model, view_idx, @@ -308,9 +321,10 @@ def _rmse(y_true, y_pred): factor_wise, cov_wise, subsample, - cache, metric_label=Cache.METRIC_RMSE, metric_fn=_rmse, + cache=cache, + sort=sort, ) @@ -325,6 +339,7 @@ def variance_explained( cov_wise: bool = True, subsample: int = 0, cache: bool = True, + sort: bool = True, ): """Compute R2. @@ -350,6 +365,8 @@ def variance_explained( Number of samples to estimate R2, by default 0 (all samples) cache : bool, optional Whether to store results in the model cache, by default True + sort : bool, optional + Whether to sort factors by R2, by default True """ def _r2(y_true, y_pred): @@ -357,6 +374,9 @@ def _r2(y_true, y_pred): ss_tot = np.nansum(np.square(y_true)) return 1.0 - (ss_res / ss_tot) + if sort: + sort = "descending" + return _recon_error( model, view_idx, @@ -367,9 +387,10 @@ def _r2(y_true, y_pred): factor_wise, cov_wise, subsample, - cache, metric_label=Cache.METRIC_R2, metric_fn=_r2, + cache=cache, + sort=sort, ) @@ -406,7 +427,7 @@ def variance_explained_grouped(model, groupby, factor_idx: Index = "all", **kwar return model._cache.uns[Cache.UNS_GROUPED_R2] -def test( +def _test_single_view( model, view_idx: Union[str, int] = 0, factor_idx: Index = "all", @@ -473,7 +494,8 @@ def test( if use_prior_mask: logger.warning( - "No feature sets provided, extracting feature sets from prior mask." + f"No feature sets provided for `{view_idx}`, " + "extracting feature sets from the prior mask." ) feature_sets = model.get_prior_masks( view_idx, factor_idx=factor_idx, as_df=True @@ -531,9 +553,8 @@ def test( t_stat_dict = {} prob_dict = {} - i = 0 + for feature_set in tqdm(feature_sets.index.tolist()): - i += 1 fs_features = feature_sets.loc[feature_set, :] features_in = factor_loadings.loc[:, fs_features] @@ -600,6 +621,109 @@ def test( return result +def test( + model, + view_idx: Index = "all", + factor_idx: Index = "all", + feature_sets: pd.DataFrame = None, + sign: str = "all", + corr_adjust: bool = True, + p_adj_method: str = "fdr_bh", + min_size: int = 10, + cache: bool = True, + rename: bool = True, +): + """Perform significance test of factor loadings against feature sets. + + Parameters + ---------- + model : MuVI + A MuVI model + view_idx : Index, optional + View index, by default "all" + factor_idx : Index, optional + Factor index, by default "all" + feature_sets : pd.DataFrame, optional + Boolean dataframe with feature sets in each row, by default None + sign : str, optional + Two sided ("all") or one-sided ("neg" or "pos"), by default "all" + corr_adjust : bool, optional + Whether to adjust for multiple testing, by default True + p_adj_method : str, optional + Adjustment method for multiple testing, by default "fdr_bh" + min_size : int, optional + Lower size limit for feature sets to be considered, by default 10 + cache : bool, optional + Whether to store results in the model cache, by default True + rename : bool, optional + Whether to rename overwritten factors (FDR > 0.05), by default True + + Returns + ------- + dict + Dictionary of test results with "t", "p" and "p_adj" keys + and pd.DataFrame values with factor_idx as index, + and index of feature_sets as columns + """ + + view_indices = _normalize_index(view_idx, model.view_names, as_idx=False) + use_prior_mask = feature_sets is None + if use_prior_mask: + view_indices = [vi for vi in view_indices if vi in model.informed_views] + + if len(view_indices) == 0: + if use_prior_mask: + raise ValueError( + "`feature_sets` is None, and none of the selected views are informed." + ) + raise ValueError(f"No valid views found for `view_idx={view_idx}`.") + + results = {} + for view_idx in view_indices: + try: + results[view_idx] = _test_single_view( + model, + view_idx=view_idx, + factor_idx=factor_idx, + feature_sets=feature_sets, + sign=sign, + corr_adjust=corr_adjust, + p_adj_method=p_adj_method, + min_size=min_size, + cache=cache, + ) + except ValueError as e: + logger.warning(e) + results[view_idx] = { + Cache.TEST_T: pd.DataFrame(), + Cache.TEST_P: pd.DataFrame(), + } + if p_adj_method is not None: + results[view_idx][Cache.TEST_P_ADJ] = pd.DataFrame() + continue + + if cache and rename: + dfs = [] + for view_name, view_results in results.items(): + p_adj = view_results[Cache.TEST_P_ADJ].copy() + dfs.append( + pd.DataFrame(np.diag(p_adj), index=[p_adj.index], columns=[view_name]) + ) + df = pd.concat(dfs, axis=1) + + new_factor_names = [] + overwritten_idx = 0 + for k in model.factor_names: + if (df.loc[k, :] > 0.05).all(None): + new_factor_names.append(f"factor_{overwritten_idx}") + overwritten_idx += 1 + else: + new_factor_names.append(k) + + model.factor_names = new_factor_names + return results + + # scanpy def _optional_neighbors(model, **kwargs): model_cache = setup_cache(model) diff --git a/tests/test_tools.py b/tests/test_tools.py index fe30526..291356a 100755 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -103,7 +103,7 @@ def test_test(): view_idx = "view_0" informed_factors = model.prior_masks[view_idx].any(axis=1) - result = muvi.tl.test(model, min_size=1) + result = muvi.tl._test_single_view(model, min_size=1) assert "t" in result and "p" in result and "p_adj" in result assert (model.factor_names[informed_factors] == result["t"].columns).all()