diff --git a/DESCRIPTION b/DESCRIPTION index 8fcd0beb4..83c8479a5 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -143,6 +143,7 @@ Collate: 'interactivity.R' 'interoperability.R' 'poly_influence.R' + 'python_bento.R' 'python_environment.R' 'python_hmrf.R' 'python_scrublet.R' diff --git a/NAMESPACE b/NAMESPACE index 93d1d4589..3c92cca1f 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -86,6 +86,7 @@ export(comparePolygonExpression) export(convertEnsemblToGeneSymbol) export(convertGiottoLargeImageToMG) export(createArchRProj) +export(createBentoAdata) export(createCellMetaObj) export(createCrossSection) export(createDimObj) diff --git a/R/python_bento.R b/R/python_bento.R new file mode 100644 index 000000000..773b39d91 --- /dev/null +++ b/R/python_bento.R @@ -0,0 +1,30 @@ +#' @title Create bento adata object from gobject +#' @name createBentoAdata +#' @description Create bento adata object from gobject +#' @param gobject Giotto object +#' @return bento_adata bento adata object +#' @export +createBentoAdata <- function(gobject){ + # Transcripts + transcripts_df <- as.data.frame(sf::st_as_sf(gobject@feat_info$rna@spatVector)) + coordinates_df <- lapply(transcripts_df['geometry'], sf::st_coordinates)$geometry + t_df <- as.data.frame(cbind(coordinates_df, transcripts_df[c("feat_ID")])) + colnames(t_df) <- c('x','y','gene') + + # Cell shapes + # TODO: Add batch information based on? + cell_poly <- spatVector_to_dt(gobject@spatial_info$cell@spatVector) + cell_poly <- data.frame(cell_id = cell_poly$poly_ID, x = cell_poly$x, y = cell_poly$y, batch = 0L) + + # Nuclei shapes + # TODO: Add batch information based on? + nucleus_poly <- spatVector_to_dt(gobject@spatial_info$nucleus@spatVector) + nucleus_poly <- data.frame(cell_id = nucleus_poly$poly_ID, x = nucleus_poly$x, y = nucleus_poly$y, batch = 0L) + + # Create AnnData object + g2bento_path <- system.file("python","g2bento.py",package="Giotto") + reticulate::source_python(g2bento_path) + bento_adata <- create_AnnData(trainscripts=t_df, cell_shape=cell_poly, nucleus_shape=nucleus_poly) + + return(bento_adata) +} diff --git a/R/python_environment.R b/R/python_environment.R index 84a5a8af8..044a15066 100644 --- a/R/python_environment.R +++ b/R/python_environment.R @@ -194,6 +194,10 @@ install_giotto_environment_specific = function(packages_to_install = c('pandas', python_version = python_version) } + # reticulate don't support installation from github yet + # using system call instead + config <- reticulate::py_discover_config(use_environment='giotto_env') + system2(config$python, c("-m", "pip", "install", "git+https://github.com/wwang-chcn/bento-tools.git@giotto_install")) } diff --git a/inst/python/g2bento.py b/inst/python/g2bento.py new file mode 100644 index 000000000..27b36d5f8 --- /dev/null +++ b/inst/python/g2bento.py @@ -0,0 +1,92 @@ +import bento as bt +import geopandas as gpd +import pandas as pd +from anndata import AnnData +from shapely.geometry import MultiPolygon, Polygon + +from log import debug, warning, info + +def create_seg_df(vertices_df: pd.DataFrame, x: str = 'x', y: str = 'y', cell_id: str = 'cell_id') -> pd.DataFrame: + """ + Create a dataframe with cell_id and geometry columns from a dataframe with vertices + :param vertices_df: a dataframe with columns: cell_id, x, y + :param x: the column name of x coordinates + :param y: the column name of y coordinates + :param cell_id: the column name of cell id + :param bounds: the bounds of the area (minx, miny, maxx, maxy) + :return: a dataframe with cell_id and geometry columns + """ + # --- create polygons --- + polygons = vertices_df.groupby(cell_id).apply(lambda group: Polygon(zip(group[x], group[y]))) # type: ignore + seg_df = gpd.GeoDataFrame(polygons, columns=['geometry']) + # --- correct invalid polygons --- + corrected_seg_df = seg_df.copy(deep=True) + for i in range(seg_df.shape[0]): + if not seg_df.iloc[i, 0].is_valid: # type: ignore + corrected_seg_df.iloc[i, 0] = seg_df.iloc[i, 0].buffer(0) # type: ignore + if isinstance(corrected_seg_df.iloc[i, 0], MultiPolygon): + corrected_seg_df.iloc[i, 0] = max(corrected_seg_df.iloc[i, 0], key=lambda x: x.area) # type: ignore + return corrected_seg_df + + +def add_batch(adata: AnnData, cell_shape: pd.DataFrame): + """ + Add batch information to an AnnData object + If cell_seg have batch information, add batch information to adata, else all batch will be set to 0 + :param adata: an AnnData object + :param cell_seg: the name of the cell segmentation + :return: an AnnData object with batch information + """ + if 'batch' in cell_shape.columns: + info('Batch information found in cell_shape, adding batch information to adata') + adata.obs['batch'] = [cell_shape.loc[cell_shape['cell_id']==cell,'batch'].values[0] for cell in adata.obs_names] # type: ignore + adata.uns['points']['batch'] = [cell_shape.loc[cell_shape['cell_id']==cell,'batch'].values[0] for cell in adata.uns['points']['cell']] # type: ignore + else: # Interim measures, batch information may not transfered to cell_shape + warning('Batch information not found in cell_shape, all batch will be set to 0') + adata.obs['batch'] = 0 + adata.uns['points']['batch'] = 0 + adata.obs['batch'] = adata.obs['batch'].astype('category') + adata.uns['points']['batch'] = adata.uns['points']['batch'].astype('category') + + +def create_AnnData(trainscripts, cell_shape, nucleus_shape) -> AnnData: + # --- processing input --- + trainscripts = pd.DataFrame(trainscripts) + cell_shape = pd.DataFrame(cell_shape) + cell_shape['cell_id'] = cell_shape['cell_id'].astype('category') + if 'batch' in cell_shape.columns: + cell_shape['batch'] = cell_shape['batch'].astype('category') + nucleus_shape = pd.DataFrame(nucleus_shape) + nucleus_shape['cell_id'] = nucleus_shape['cell_id'].astype('category') + + # --- create shape --- + cell_seg = create_seg_df(cell_shape, x='x', y='y', cell_id='cell_id') + nucleus_seg = create_seg_df(nucleus_shape, x='x', y='y', cell_id='cell_id') + if cell_seg.shape[0] > 500: + warning('cell_seg has more than 500 cells, processing may take a long time.') + + # --- filter cells --- + # Let Giotto perform the filtering + legal_cells = pd.Series([True] * cell_seg.shape[0]) + + # --- create AnnData --- + adata: AnnData = bt.io.prepare(molecules=trainscripts, cell_seg=cell_seg, other_seg={'nucleus': nucleus_seg}) # type: ignore + add_batch(adata, cell_shape) + + # --- filter genes --- + # Interim measures + # subsetGiottoLocs don't give wanted result, so we use a workaround here + legal_genes = adata.var_names.isin(set(adata.uns['points']['gene'].values)) + filtered_adata = adata[legal_cells,legal_genes] # type: ignore + filtered_adata.uns['points']['gene'] = filtered_adata.uns['points']['gene'].cat.remove_unused_categories() + + # Interim measures + # subsetted adata (_is_view == True) don't have _X property, which will cause unexpect error for adata.__sizeof__ + # when create object in R, reticulate will call sys.getsizeof to get the size of the object, which will call adata.__sizeof__ + # https://github.com/rstudio/reticulate/issues/1332 + # https://github.com/rstudio/rstudio/issues/13491 + # wait for AnnData to fix this issue + filtered_adata._X = filtered_adata.X + return filtered_adata + + # return adata diff --git a/inst/python/log.py b/inst/python/log.py new file mode 100644 index 000000000..e6c3f2d1c --- /dev/null +++ b/inst/python/log.py @@ -0,0 +1,41 @@ +import sys +import time + + +def get_current_time() -> str: + return time.strftime('%H:%M:%S', time.localtime()) + + +def write_direct_message(message: str): + curr_time_str = get_current_time() + sys.stdout.write(f'{curr_time_str} --- {message}\n') + sys.stdout.flush() + + +def debug(message: str): + write_direct_message(f'DEBUG: {message}') + + +def info(message: str): + write_direct_message(f'INFO: {message}') + + +def write_direct_message_err(message: str): + curr_time_str = get_current_time() + sys.stderr.write(f'{curr_time_str} --- {message}\n') + sys.stderr.flush() + + +def warning(message: str): + write_direct_message_err(f'WARNING: {message}') + + +def error(message: str): + write_direct_message_err(f'ERROR: {message}') + + +def critical(message: str): + write_direct_message_err(f'CRITICAL: {message}') + + +__all__ = ['debug', 'info', 'warning', 'error', 'critical'] diff --git a/inst/python/python_bento_analysis.py b/inst/python/python_bento_analysis.py new file mode 100644 index 000000000..45cf0e055 --- /dev/null +++ b/inst/python/python_bento_analysis.py @@ -0,0 +1,143 @@ +from typing import List, Optional +import bento as bt +from anndata import AnnData +import emoji +from bento._utils import track +from bento.tools._colocation import _colocation_tensor +from bento.tools import decompose +from kneed import KneeLocator + +import matplotlib as mpl +import matplotlib.pyplot as plt +import seaborn as sns + +from log import warning, info + + +# --------------------------------- +# modified bento and dependencies functions/classes +# --------------------------------- +@track +def colocation( + data: AnnData, + ranks: List[int], + fname: str, + iterations: int = 3, + plot_error: bool = True, + copy: bool = False, +): + """Decompose a tensor of pairwise colocalization quotients into signatures. + + Parameters + ---------- + adata : AnnData + Spatial formatted AnnData object. + ranks : list + List of ranks to decompose the tensor. + iterations : int + Number of iterations to run the decomposition. + plot_error : bool + Whether to plot the error of the decomposition. + copy : bool + Whether to return a copy of the AnnData object. Default False. + Returns + ------- + adata : AnnData + .uns['factors']: Decomposed tensor factors. + .uns['factors_error']: Decomposition error. + """ + adata = data.copy() if copy else data + + print("Preparing tensor...") + _colocation_tensor(adata, copy=copy) + + tensor = adata.uns["tensor"] + + print(emoji.emojize(":running: Decomposing tensor...")) + factors, errors = decompose(tensor, ranks, iterations=iterations) + + if plot_error and errors.shape[0] > 1: + kl = KneeLocator(errors["rank"], errors["rmse"], direction="decreasing", curve="convex") + if kl.knee is None: + warning('No knee found, please extend the ranks range.\nCurrent ranks range: [{ranks[0]},{ranks[-1]}]') + else: + info(f'Knee found at rank {kl.knee}') + sns.lineplot(data=errors, x="rank", y="rmse", ci=95, marker="o") # type: ignore + plt.axvline(kl.knee, linestyle="--") + plt.savefig(fname) + info(f"Saved to {fname}") + + adata.uns["factors"] = factors + adata.uns["factors_error"] = errors + + print(emoji.emojize(":heavy_check_mark: Done.")) + return adata if copy else None + + +# --------------------------------- +# bento wrapper functions +# --------------------------------- + + +def analysis_shape_features(adata: AnnData, feature_names: Optional[List[str]] = None) -> None: + if feature_names is None: + feature_names = list(bt.tl.list_shape_features().keys()) + bt.tl.obs_stats(adata, feature_names=feature_names) + + +def plot_shape_features_analysis_results(adata: AnnData, fname: str): + bt.pl.shapes(adata, fname=fname) + + +def analysis_points_features(adata: AnnData, + shapes_names: Optional[List[str]] = None, + feature_names: Optional[List[str]] = None) -> None: + if shapes_names is None: + shapes_names = ["cell_shape", "nucleus_shape"] + if feature_names is None: + feature_names = list(bt.tl.list_point_features().keys()) + bt.tl.analyze_points(adata, shape_names=shapes_names, feature_names=feature_names, groupby='gene') + + +def plot_points_features_analysis_results(adata: AnnData, fname: str) -> None: + bt.pl.points(adata, fname=fname) + + +def analysis_rna_forest(adata: AnnData) -> None: + bt.tl.lp(adata) + bt.tl.lp_stats(adata) + + +def plot_rna_forest_analysis_results(adata: AnnData, fname1: str, fname2: str) -> None: + bt.pl.lp_genes(adata, fname=fname1) + bt.pl.lp_dist(adata, fname=fname2) + + +def analysis_colocalization(adata: AnnData, fname: str, ranks: Optional[List[int]] = None) -> None: + if ranks is None: + ranks = list(range(1, 6)) + + # Cytoplasm = cell - nucleus + adata.obs["cytoplasm_shape"] = bt.geo.get_shape(adata, "cell_shape") - bt.geo.get_shape(adata, "nucleus_shape") + + # Create point index + adata.uns["points"]["cytoplasm"] = (adata.uns["points"]["nucleus"].astype(int) < 0).astype(int) + + bt.tl.coloc_quotient(adata, shapes=["cytoplasm_shape", "nucleus_shape"]) + + colocation(adata, ranks=ranks, fname=fname) + + +def plot_colocalization_analysis_results(adata: AnnData, fname: str, rank: int) -> None: + bt.pl.colocation(adata, rank=rank, fname=fname) + + +def chekc_genes_number(adata: AnnData) -> None: + print(f'adata shape: {adata.shape}') + print(f'adata points genes: {len(adata.uns["points"]["gene"].unique())}') + print(f'adata cell_gene_features genes: {len(adata.uns["cell_gene_features"]["gene"].unique())}') + diff_set = set(adata.uns["points"]["gene"]) - set(adata.uns["cell_gene_features"]["gene"]) + print(diff_set) + for g in diff_set: + print(f'{g}') + print(adata.uns['points'][adata.uns['points']['gene'] == g])