Skip to content

Commit

Permalink
Merge pull request #756 from wwang-chcn/suite_dev
Browse files Browse the repository at this point in the history
Integrate Bento Analysis
  • Loading branch information
jiajic authored Sep 15, 2023
2 parents 3c5a72d + d4b660e commit ed93119
Show file tree
Hide file tree
Showing 7 changed files with 312 additions and 0 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ Collate:
'interactivity.R'
'interoperability.R'
'poly_influence.R'
'python_bento.R'
'python_environment.R'
'python_hmrf.R'
'python_scrublet.R'
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ export(comparePolygonExpression)
export(convertEnsemblToGeneSymbol)
export(convertGiottoLargeImageToMG)
export(createArchRProj)
export(createBentoAdata)
export(createCellMetaObj)
export(createCrossSection)
export(createDimObj)
Expand Down
30 changes: 30 additions & 0 deletions R/python_bento.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#' @title Create bento adata object from gobject
#' @name createBentoAdata
#' @description Create bento adata object from gobject
#' @param gobject Giotto object
#' @return bento_adata bento adata object
#' @export
createBentoAdata <- function(gobject){
# Transcripts
transcripts_df <- as.data.frame(sf::st_as_sf(gobject@feat_info$rna@spatVector))
coordinates_df <- lapply(transcripts_df['geometry'], sf::st_coordinates)$geometry
t_df <- as.data.frame(cbind(coordinates_df, transcripts_df[c("feat_ID")]))
colnames(t_df) <- c('x','y','gene')

# Cell shapes
# TODO: Add batch information based on?
cell_poly <- spatVector_to_dt(gobject@spatial_info$cell@spatVector)
cell_poly <- data.frame(cell_id = cell_poly$poly_ID, x = cell_poly$x, y = cell_poly$y, batch = 0L)

# Nuclei shapes
# TODO: Add batch information based on?
nucleus_poly <- spatVector_to_dt(gobject@spatial_info$nucleus@spatVector)
nucleus_poly <- data.frame(cell_id = nucleus_poly$poly_ID, x = nucleus_poly$x, y = nucleus_poly$y, batch = 0L)

# Create AnnData object
g2bento_path <- system.file("python","g2bento.py",package="Giotto")
reticulate::source_python(g2bento_path)
bento_adata <- create_AnnData(trainscripts=t_df, cell_shape=cell_poly, nucleus_shape=nucleus_poly)

return(bento_adata)
}
4 changes: 4 additions & 0 deletions R/python_environment.R
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ install_giotto_environment_specific = function(packages_to_install = c('pandas',
python_version = python_version)
}

# reticulate don't support installation from github yet
# using system call instead
config <- reticulate::py_discover_config(use_environment='giotto_env')
system2(config$python, c("-m", "pip", "install", "git+https://github.com/wwang-chcn/bento-tools.git@giotto_install"))

}

Expand Down
92 changes: 92 additions & 0 deletions inst/python/g2bento.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import bento as bt
import geopandas as gpd
import pandas as pd
from anndata import AnnData
from shapely.geometry import MultiPolygon, Polygon

from log import debug, warning, info

def create_seg_df(vertices_df: pd.DataFrame, x: str = 'x', y: str = 'y', cell_id: str = 'cell_id') -> pd.DataFrame:
"""
Create a dataframe with cell_id and geometry columns from a dataframe with vertices
:param vertices_df: a dataframe with columns: cell_id, x, y
:param x: the column name of x coordinates
:param y: the column name of y coordinates
:param cell_id: the column name of cell id
:param bounds: the bounds of the area (minx, miny, maxx, maxy)
:return: a dataframe with cell_id and geometry columns
"""
# --- create polygons ---
polygons = vertices_df.groupby(cell_id).apply(lambda group: Polygon(zip(group[x], group[y]))) # type: ignore
seg_df = gpd.GeoDataFrame(polygons, columns=['geometry'])
# --- correct invalid polygons ---
corrected_seg_df = seg_df.copy(deep=True)
for i in range(seg_df.shape[0]):
if not seg_df.iloc[i, 0].is_valid: # type: ignore
corrected_seg_df.iloc[i, 0] = seg_df.iloc[i, 0].buffer(0) # type: ignore
if isinstance(corrected_seg_df.iloc[i, 0], MultiPolygon):
corrected_seg_df.iloc[i, 0] = max(corrected_seg_df.iloc[i, 0], key=lambda x: x.area) # type: ignore
return corrected_seg_df


def add_batch(adata: AnnData, cell_shape: pd.DataFrame):
"""
Add batch information to an AnnData object
If cell_seg have batch information, add batch information to adata, else all batch will be set to 0
:param adata: an AnnData object
:param cell_seg: the name of the cell segmentation
:return: an AnnData object with batch information
"""
if 'batch' in cell_shape.columns:
info('Batch information found in cell_shape, adding batch information to adata')
adata.obs['batch'] = [cell_shape.loc[cell_shape['cell_id']==cell,'batch'].values[0] for cell in adata.obs_names] # type: ignore
adata.uns['points']['batch'] = [cell_shape.loc[cell_shape['cell_id']==cell,'batch'].values[0] for cell in adata.uns['points']['cell']] # type: ignore
else: # Interim measures, batch information may not transfered to cell_shape
warning('Batch information not found in cell_shape, all batch will be set to 0')
adata.obs['batch'] = 0
adata.uns['points']['batch'] = 0
adata.obs['batch'] = adata.obs['batch'].astype('category')
adata.uns['points']['batch'] = adata.uns['points']['batch'].astype('category')


def create_AnnData(trainscripts, cell_shape, nucleus_shape) -> AnnData:
# --- processing input ---
trainscripts = pd.DataFrame(trainscripts)
cell_shape = pd.DataFrame(cell_shape)
cell_shape['cell_id'] = cell_shape['cell_id'].astype('category')
if 'batch' in cell_shape.columns:
cell_shape['batch'] = cell_shape['batch'].astype('category')
nucleus_shape = pd.DataFrame(nucleus_shape)
nucleus_shape['cell_id'] = nucleus_shape['cell_id'].astype('category')

# --- create shape ---
cell_seg = create_seg_df(cell_shape, x='x', y='y', cell_id='cell_id')
nucleus_seg = create_seg_df(nucleus_shape, x='x', y='y', cell_id='cell_id')
if cell_seg.shape[0] > 500:
warning('cell_seg has more than 500 cells, processing may take a long time.')

# --- filter cells ---
# Let Giotto perform the filtering
legal_cells = pd.Series([True] * cell_seg.shape[0])

# --- create AnnData ---
adata: AnnData = bt.io.prepare(molecules=trainscripts, cell_seg=cell_seg, other_seg={'nucleus': nucleus_seg}) # type: ignore
add_batch(adata, cell_shape)

# --- filter genes ---
# Interim measures
# subsetGiottoLocs don't give wanted result, so we use a workaround here
legal_genes = adata.var_names.isin(set(adata.uns['points']['gene'].values))
filtered_adata = adata[legal_cells,legal_genes] # type: ignore
filtered_adata.uns['points']['gene'] = filtered_adata.uns['points']['gene'].cat.remove_unused_categories()

# Interim measures
# subsetted adata (_is_view == True) don't have _X property, which will cause unexpect error for adata.__sizeof__
# when create object in R, reticulate will call sys.getsizeof to get the size of the object, which will call adata.__sizeof__
# https://github.com/rstudio/reticulate/issues/1332
# https://github.com/rstudio/rstudio/issues/13491
# wait for AnnData to fix this issue
filtered_adata._X = filtered_adata.X
return filtered_adata

# return adata
41 changes: 41 additions & 0 deletions inst/python/log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import sys
import time


def get_current_time() -> str:
return time.strftime('%H:%M:%S', time.localtime())


def write_direct_message(message: str):
curr_time_str = get_current_time()
sys.stdout.write(f'{curr_time_str} --- {message}\n')
sys.stdout.flush()


def debug(message: str):
write_direct_message(f'DEBUG: {message}')


def info(message: str):
write_direct_message(f'INFO: {message}')


def write_direct_message_err(message: str):
curr_time_str = get_current_time()
sys.stderr.write(f'{curr_time_str} --- {message}\n')
sys.stderr.flush()


def warning(message: str):
write_direct_message_err(f'WARNING: {message}')


def error(message: str):
write_direct_message_err(f'ERROR: {message}')


def critical(message: str):
write_direct_message_err(f'CRITICAL: {message}')


__all__ = ['debug', 'info', 'warning', 'error', 'critical']
143 changes: 143 additions & 0 deletions inst/python/python_bento_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from typing import List, Optional
import bento as bt
from anndata import AnnData
import emoji
from bento._utils import track
from bento.tools._colocation import _colocation_tensor
from bento.tools import decompose
from kneed import KneeLocator

import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

from log import warning, info


# ---------------------------------
# modified bento and dependencies functions/classes
# ---------------------------------
@track
def colocation(
data: AnnData,
ranks: List[int],
fname: str,
iterations: int = 3,
plot_error: bool = True,
copy: bool = False,
):
"""Decompose a tensor of pairwise colocalization quotients into signatures.
Parameters
----------
adata : AnnData
Spatial formatted AnnData object.
ranks : list
List of ranks to decompose the tensor.
iterations : int
Number of iterations to run the decomposition.
plot_error : bool
Whether to plot the error of the decomposition.
copy : bool
Whether to return a copy of the AnnData object. Default False.
Returns
-------
adata : AnnData
.uns['factors']: Decomposed tensor factors.
.uns['factors_error']: Decomposition error.
"""
adata = data.copy() if copy else data

print("Preparing tensor...")
_colocation_tensor(adata, copy=copy)

tensor = adata.uns["tensor"]

print(emoji.emojize(":running: Decomposing tensor..."))
factors, errors = decompose(tensor, ranks, iterations=iterations)

if plot_error and errors.shape[0] > 1:
kl = KneeLocator(errors["rank"], errors["rmse"], direction="decreasing", curve="convex")
if kl.knee is None:
warning('No knee found, please extend the ranks range.\nCurrent ranks range: [{ranks[0]},{ranks[-1]}]')
else:
info(f'Knee found at rank {kl.knee}')
sns.lineplot(data=errors, x="rank", y="rmse", ci=95, marker="o") # type: ignore
plt.axvline(kl.knee, linestyle="--")
plt.savefig(fname)
info(f"Saved to {fname}")

adata.uns["factors"] = factors
adata.uns["factors_error"] = errors

print(emoji.emojize(":heavy_check_mark: Done."))
return adata if copy else None


# ---------------------------------
# bento wrapper functions
# ---------------------------------


def analysis_shape_features(adata: AnnData, feature_names: Optional[List[str]] = None) -> None:
if feature_names is None:
feature_names = list(bt.tl.list_shape_features().keys())
bt.tl.obs_stats(adata, feature_names=feature_names)


def plot_shape_features_analysis_results(adata: AnnData, fname: str):
bt.pl.shapes(adata, fname=fname)


def analysis_points_features(adata: AnnData,
shapes_names: Optional[List[str]] = None,
feature_names: Optional[List[str]] = None) -> None:
if shapes_names is None:
shapes_names = ["cell_shape", "nucleus_shape"]
if feature_names is None:
feature_names = list(bt.tl.list_point_features().keys())
bt.tl.analyze_points(adata, shape_names=shapes_names, feature_names=feature_names, groupby='gene')


def plot_points_features_analysis_results(adata: AnnData, fname: str) -> None:
bt.pl.points(adata, fname=fname)


def analysis_rna_forest(adata: AnnData) -> None:
bt.tl.lp(adata)
bt.tl.lp_stats(adata)


def plot_rna_forest_analysis_results(adata: AnnData, fname1: str, fname2: str) -> None:
bt.pl.lp_genes(adata, fname=fname1)
bt.pl.lp_dist(adata, fname=fname2)


def analysis_colocalization(adata: AnnData, fname: str, ranks: Optional[List[int]] = None) -> None:
if ranks is None:
ranks = list(range(1, 6))

# Cytoplasm = cell - nucleus
adata.obs["cytoplasm_shape"] = bt.geo.get_shape(adata, "cell_shape") - bt.geo.get_shape(adata, "nucleus_shape")

# Create point index
adata.uns["points"]["cytoplasm"] = (adata.uns["points"]["nucleus"].astype(int) < 0).astype(int)

bt.tl.coloc_quotient(adata, shapes=["cytoplasm_shape", "nucleus_shape"])

colocation(adata, ranks=ranks, fname=fname)


def plot_colocalization_analysis_results(adata: AnnData, fname: str, rank: int) -> None:
bt.pl.colocation(adata, rank=rank, fname=fname)


def chekc_genes_number(adata: AnnData) -> None:
print(f'adata shape: {adata.shape}')
print(f'adata points genes: {len(adata.uns["points"]["gene"].unique())}')
print(f'adata cell_gene_features genes: {len(adata.uns["cell_gene_features"]["gene"].unique())}')
diff_set = set(adata.uns["points"]["gene"]) - set(adata.uns["cell_gene_features"]["gene"])
print(diff_set)
for g in diff_set:
print(f'{g}')
print(adata.uns['points'][adata.uns['points']['gene'] == g])

0 comments on commit ed93119

Please sign in to comment.