-
Notifications
You must be signed in to change notification settings - Fork 102
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #756 from wwang-chcn/suite_dev
Integrate Bento Analysis
- Loading branch information
Showing
7 changed files
with
312 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#' @title Create bento adata object from gobject | ||
#' @name createBentoAdata | ||
#' @description Create bento adata object from gobject | ||
#' @param gobject Giotto object | ||
#' @return bento_adata bento adata object | ||
#' @export | ||
createBentoAdata <- function(gobject){ | ||
# Transcripts | ||
transcripts_df <- as.data.frame(sf::st_as_sf(gobject@feat_info$rna@spatVector)) | ||
coordinates_df <- lapply(transcripts_df['geometry'], sf::st_coordinates)$geometry | ||
t_df <- as.data.frame(cbind(coordinates_df, transcripts_df[c("feat_ID")])) | ||
colnames(t_df) <- c('x','y','gene') | ||
|
||
# Cell shapes | ||
# TODO: Add batch information based on? | ||
cell_poly <- spatVector_to_dt(gobject@spatial_info$cell@spatVector) | ||
cell_poly <- data.frame(cell_id = cell_poly$poly_ID, x = cell_poly$x, y = cell_poly$y, batch = 0L) | ||
|
||
# Nuclei shapes | ||
# TODO: Add batch information based on? | ||
nucleus_poly <- spatVector_to_dt(gobject@spatial_info$nucleus@spatVector) | ||
nucleus_poly <- data.frame(cell_id = nucleus_poly$poly_ID, x = nucleus_poly$x, y = nucleus_poly$y, batch = 0L) | ||
|
||
# Create AnnData object | ||
g2bento_path <- system.file("python","g2bento.py",package="Giotto") | ||
reticulate::source_python(g2bento_path) | ||
bento_adata <- create_AnnData(trainscripts=t_df, cell_shape=cell_poly, nucleus_shape=nucleus_poly) | ||
|
||
return(bento_adata) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import bento as bt | ||
import geopandas as gpd | ||
import pandas as pd | ||
from anndata import AnnData | ||
from shapely.geometry import MultiPolygon, Polygon | ||
|
||
from log import debug, warning, info | ||
|
||
def create_seg_df(vertices_df: pd.DataFrame, x: str = 'x', y: str = 'y', cell_id: str = 'cell_id') -> pd.DataFrame: | ||
""" | ||
Create a dataframe with cell_id and geometry columns from a dataframe with vertices | ||
:param vertices_df: a dataframe with columns: cell_id, x, y | ||
:param x: the column name of x coordinates | ||
:param y: the column name of y coordinates | ||
:param cell_id: the column name of cell id | ||
:param bounds: the bounds of the area (minx, miny, maxx, maxy) | ||
:return: a dataframe with cell_id and geometry columns | ||
""" | ||
# --- create polygons --- | ||
polygons = vertices_df.groupby(cell_id).apply(lambda group: Polygon(zip(group[x], group[y]))) # type: ignore | ||
seg_df = gpd.GeoDataFrame(polygons, columns=['geometry']) | ||
# --- correct invalid polygons --- | ||
corrected_seg_df = seg_df.copy(deep=True) | ||
for i in range(seg_df.shape[0]): | ||
if not seg_df.iloc[i, 0].is_valid: # type: ignore | ||
corrected_seg_df.iloc[i, 0] = seg_df.iloc[i, 0].buffer(0) # type: ignore | ||
if isinstance(corrected_seg_df.iloc[i, 0], MultiPolygon): | ||
corrected_seg_df.iloc[i, 0] = max(corrected_seg_df.iloc[i, 0], key=lambda x: x.area) # type: ignore | ||
return corrected_seg_df | ||
|
||
|
||
def add_batch(adata: AnnData, cell_shape: pd.DataFrame): | ||
""" | ||
Add batch information to an AnnData object | ||
If cell_seg have batch information, add batch information to adata, else all batch will be set to 0 | ||
:param adata: an AnnData object | ||
:param cell_seg: the name of the cell segmentation | ||
:return: an AnnData object with batch information | ||
""" | ||
if 'batch' in cell_shape.columns: | ||
info('Batch information found in cell_shape, adding batch information to adata') | ||
adata.obs['batch'] = [cell_shape.loc[cell_shape['cell_id']==cell,'batch'].values[0] for cell in adata.obs_names] # type: ignore | ||
adata.uns['points']['batch'] = [cell_shape.loc[cell_shape['cell_id']==cell,'batch'].values[0] for cell in adata.uns['points']['cell']] # type: ignore | ||
else: # Interim measures, batch information may not transfered to cell_shape | ||
warning('Batch information not found in cell_shape, all batch will be set to 0') | ||
adata.obs['batch'] = 0 | ||
adata.uns['points']['batch'] = 0 | ||
adata.obs['batch'] = adata.obs['batch'].astype('category') | ||
adata.uns['points']['batch'] = adata.uns['points']['batch'].astype('category') | ||
|
||
|
||
def create_AnnData(trainscripts, cell_shape, nucleus_shape) -> AnnData: | ||
# --- processing input --- | ||
trainscripts = pd.DataFrame(trainscripts) | ||
cell_shape = pd.DataFrame(cell_shape) | ||
cell_shape['cell_id'] = cell_shape['cell_id'].astype('category') | ||
if 'batch' in cell_shape.columns: | ||
cell_shape['batch'] = cell_shape['batch'].astype('category') | ||
nucleus_shape = pd.DataFrame(nucleus_shape) | ||
nucleus_shape['cell_id'] = nucleus_shape['cell_id'].astype('category') | ||
|
||
# --- create shape --- | ||
cell_seg = create_seg_df(cell_shape, x='x', y='y', cell_id='cell_id') | ||
nucleus_seg = create_seg_df(nucleus_shape, x='x', y='y', cell_id='cell_id') | ||
if cell_seg.shape[0] > 500: | ||
warning('cell_seg has more than 500 cells, processing may take a long time.') | ||
|
||
# --- filter cells --- | ||
# Let Giotto perform the filtering | ||
legal_cells = pd.Series([True] * cell_seg.shape[0]) | ||
|
||
# --- create AnnData --- | ||
adata: AnnData = bt.io.prepare(molecules=trainscripts, cell_seg=cell_seg, other_seg={'nucleus': nucleus_seg}) # type: ignore | ||
add_batch(adata, cell_shape) | ||
|
||
# --- filter genes --- | ||
# Interim measures | ||
# subsetGiottoLocs don't give wanted result, so we use a workaround here | ||
legal_genes = adata.var_names.isin(set(adata.uns['points']['gene'].values)) | ||
filtered_adata = adata[legal_cells,legal_genes] # type: ignore | ||
filtered_adata.uns['points']['gene'] = filtered_adata.uns['points']['gene'].cat.remove_unused_categories() | ||
|
||
# Interim measures | ||
# subsetted adata (_is_view == True) don't have _X property, which will cause unexpect error for adata.__sizeof__ | ||
# when create object in R, reticulate will call sys.getsizeof to get the size of the object, which will call adata.__sizeof__ | ||
# https://github.com/rstudio/reticulate/issues/1332 | ||
# https://github.com/rstudio/rstudio/issues/13491 | ||
# wait for AnnData to fix this issue | ||
filtered_adata._X = filtered_adata.X | ||
return filtered_adata | ||
|
||
# return adata |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import sys | ||
import time | ||
|
||
|
||
def get_current_time() -> str: | ||
return time.strftime('%H:%M:%S', time.localtime()) | ||
|
||
|
||
def write_direct_message(message: str): | ||
curr_time_str = get_current_time() | ||
sys.stdout.write(f'{curr_time_str} --- {message}\n') | ||
sys.stdout.flush() | ||
|
||
|
||
def debug(message: str): | ||
write_direct_message(f'DEBUG: {message}') | ||
|
||
|
||
def info(message: str): | ||
write_direct_message(f'INFO: {message}') | ||
|
||
|
||
def write_direct_message_err(message: str): | ||
curr_time_str = get_current_time() | ||
sys.stderr.write(f'{curr_time_str} --- {message}\n') | ||
sys.stderr.flush() | ||
|
||
|
||
def warning(message: str): | ||
write_direct_message_err(f'WARNING: {message}') | ||
|
||
|
||
def error(message: str): | ||
write_direct_message_err(f'ERROR: {message}') | ||
|
||
|
||
def critical(message: str): | ||
write_direct_message_err(f'CRITICAL: {message}') | ||
|
||
|
||
__all__ = ['debug', 'info', 'warning', 'error', 'critical'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
from typing import List, Optional | ||
import bento as bt | ||
from anndata import AnnData | ||
import emoji | ||
from bento._utils import track | ||
from bento.tools._colocation import _colocation_tensor | ||
from bento.tools import decompose | ||
from kneed import KneeLocator | ||
|
||
import matplotlib as mpl | ||
import matplotlib.pyplot as plt | ||
import seaborn as sns | ||
|
||
from log import warning, info | ||
|
||
|
||
# --------------------------------- | ||
# modified bento and dependencies functions/classes | ||
# --------------------------------- | ||
@track | ||
def colocation( | ||
data: AnnData, | ||
ranks: List[int], | ||
fname: str, | ||
iterations: int = 3, | ||
plot_error: bool = True, | ||
copy: bool = False, | ||
): | ||
"""Decompose a tensor of pairwise colocalization quotients into signatures. | ||
Parameters | ||
---------- | ||
adata : AnnData | ||
Spatial formatted AnnData object. | ||
ranks : list | ||
List of ranks to decompose the tensor. | ||
iterations : int | ||
Number of iterations to run the decomposition. | ||
plot_error : bool | ||
Whether to plot the error of the decomposition. | ||
copy : bool | ||
Whether to return a copy of the AnnData object. Default False. | ||
Returns | ||
------- | ||
adata : AnnData | ||
.uns['factors']: Decomposed tensor factors. | ||
.uns['factors_error']: Decomposition error. | ||
""" | ||
adata = data.copy() if copy else data | ||
|
||
print("Preparing tensor...") | ||
_colocation_tensor(adata, copy=copy) | ||
|
||
tensor = adata.uns["tensor"] | ||
|
||
print(emoji.emojize(":running: Decomposing tensor...")) | ||
factors, errors = decompose(tensor, ranks, iterations=iterations) | ||
|
||
if plot_error and errors.shape[0] > 1: | ||
kl = KneeLocator(errors["rank"], errors["rmse"], direction="decreasing", curve="convex") | ||
if kl.knee is None: | ||
warning('No knee found, please extend the ranks range.\nCurrent ranks range: [{ranks[0]},{ranks[-1]}]') | ||
else: | ||
info(f'Knee found at rank {kl.knee}') | ||
sns.lineplot(data=errors, x="rank", y="rmse", ci=95, marker="o") # type: ignore | ||
plt.axvline(kl.knee, linestyle="--") | ||
plt.savefig(fname) | ||
info(f"Saved to {fname}") | ||
|
||
adata.uns["factors"] = factors | ||
adata.uns["factors_error"] = errors | ||
|
||
print(emoji.emojize(":heavy_check_mark: Done.")) | ||
return adata if copy else None | ||
|
||
|
||
# --------------------------------- | ||
# bento wrapper functions | ||
# --------------------------------- | ||
|
||
|
||
def analysis_shape_features(adata: AnnData, feature_names: Optional[List[str]] = None) -> None: | ||
if feature_names is None: | ||
feature_names = list(bt.tl.list_shape_features().keys()) | ||
bt.tl.obs_stats(adata, feature_names=feature_names) | ||
|
||
|
||
def plot_shape_features_analysis_results(adata: AnnData, fname: str): | ||
bt.pl.shapes(adata, fname=fname) | ||
|
||
|
||
def analysis_points_features(adata: AnnData, | ||
shapes_names: Optional[List[str]] = None, | ||
feature_names: Optional[List[str]] = None) -> None: | ||
if shapes_names is None: | ||
shapes_names = ["cell_shape", "nucleus_shape"] | ||
if feature_names is None: | ||
feature_names = list(bt.tl.list_point_features().keys()) | ||
bt.tl.analyze_points(adata, shape_names=shapes_names, feature_names=feature_names, groupby='gene') | ||
|
||
|
||
def plot_points_features_analysis_results(adata: AnnData, fname: str) -> None: | ||
bt.pl.points(adata, fname=fname) | ||
|
||
|
||
def analysis_rna_forest(adata: AnnData) -> None: | ||
bt.tl.lp(adata) | ||
bt.tl.lp_stats(adata) | ||
|
||
|
||
def plot_rna_forest_analysis_results(adata: AnnData, fname1: str, fname2: str) -> None: | ||
bt.pl.lp_genes(adata, fname=fname1) | ||
bt.pl.lp_dist(adata, fname=fname2) | ||
|
||
|
||
def analysis_colocalization(adata: AnnData, fname: str, ranks: Optional[List[int]] = None) -> None: | ||
if ranks is None: | ||
ranks = list(range(1, 6)) | ||
|
||
# Cytoplasm = cell - nucleus | ||
adata.obs["cytoplasm_shape"] = bt.geo.get_shape(adata, "cell_shape") - bt.geo.get_shape(adata, "nucleus_shape") | ||
|
||
# Create point index | ||
adata.uns["points"]["cytoplasm"] = (adata.uns["points"]["nucleus"].astype(int) < 0).astype(int) | ||
|
||
bt.tl.coloc_quotient(adata, shapes=["cytoplasm_shape", "nucleus_shape"]) | ||
|
||
colocation(adata, ranks=ranks, fname=fname) | ||
|
||
|
||
def plot_colocalization_analysis_results(adata: AnnData, fname: str, rank: int) -> None: | ||
bt.pl.colocation(adata, rank=rank, fname=fname) | ||
|
||
|
||
def chekc_genes_number(adata: AnnData) -> None: | ||
print(f'adata shape: {adata.shape}') | ||
print(f'adata points genes: {len(adata.uns["points"]["gene"].unique())}') | ||
print(f'adata cell_gene_features genes: {len(adata.uns["cell_gene_features"]["gene"].unique())}') | ||
diff_set = set(adata.uns["points"]["gene"]) - set(adata.uns["cell_gene_features"]["gene"]) | ||
print(diff_set) | ||
for g in diff_set: | ||
print(f'{g}') | ||
print(adata.uns['points'][adata.uns['points']['gene'] == g]) |