Skip to content

Commit

Permalink
errors_correction6
Browse files Browse the repository at this point in the history
  • Loading branch information
sergiomarco25 committed Dec 19, 2024
1 parent fa43a29 commit 5386490
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 119 deletions.
133 changes: 82 additions & 51 deletions src/troutpy/pl/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
from matplotlib.colors import Colormap, Normalize
from pathlib import Path




def sorted_heatmap(celltype_by_feature, output_path:str='',filename:str="Heatmap_target_cells_by_gene",format='pdf',cmap='viridis',vmax=None,save=False,figsize=(10, 10)):
"""
Plots the heatmap of target cells by gene.
Expand Down Expand Up @@ -195,47 +192,47 @@ def plot_crosstab(data, xvar: str = '', yvar: str = '', normalize=True, axis=1,
-----------
data : pd.DataFrame
Input dataset containing the variables for the cross-tabulation.
xvar : str, optional (default: '')
The variable to use on the x-axis for the cross-tabulation.
yvar : str, optional (default: '')
The variable to use on the y-axis for the cross-tabulation.
normalize : bool, optional (default: True)
Whether to normalize the cross-tabulated data (percentages). If True, the data will be normalized.
axis : int, optional (default: 1)
The axis to normalize across. Use `1` for row normalization and `0` for column normalization.
kind : str, optional (default: 'barh')
The kind of plot to generate. Options include:
- 'barh': Horizontal bar plot
- 'bar': Vertical bar plot
- 'heatmap': Heatmap visualization
- 'clustermap': Clustermap visualization
save : bool, optional (default: True)
If True, the plot will be saved to a file.
figures_path : str, optional (default: '')
The directory path where the figure should be saved. If not specified, the plot will be saved in the current directory.
stacked : bool, optional (default: True)
If True, the bar plots will be stacked. Only applicable for 'barh' and 'bar' plot kinds.
figsize : tuple, optional (default: (6, 10))
The size of the figure for the plot (width, height).
cmap : str, optional (default: 'viridis')
The colormap to use for the plot, especially for heatmap and clustermap visualizations.
saving_format : str, optional (default: 'pdf')
The format to save the plot in. Options include 'png', 'pdf', etc.
sortby : str, optional (default: None)
The column or row to sort the cross-tabulated data by before plotting.
Returns:
--------
None
Expand Down Expand Up @@ -314,7 +311,7 @@ def pie_of_positive(data, groupby: str = '', figures_path: str = '', save: bool
"""

plt.figure()
y = np.array([np.sum(data[groupby] == False), np.sum(data[groupby] )])
y = np.array([np.sum(~data[groupby]), np.sum(data[groupby] )])
mylabels = [f"{groupby}=False", f"{groupby}=True"]

plt.pie(y, labels=mylabels, colors=['#a0b7e0', '#c5e493'])
Expand All @@ -326,7 +323,7 @@ def pie_of_positive(data, groupby: str = '', figures_path: str = '', save: bool

def genes_over_noise(sdata, scores_by_genes,layer='extracellular_transcripts', output_path:str='',save=True,format:str='pdf'):
"""Function that plots log fold change per gene over noise using a boxplot.
Parameters:
- data_quantified: DataFrame containing the extracellular transcript data, including feature names and codeword categories.
- scores_by_genes: DataFrame containing gene scores with feature names and log fold ratios.
Expand Down Expand Up @@ -650,36 +647,36 @@ def paired_nmf_factors(
):
"""
Plots the spatial distribution of NMF factors for extracellular transcripts and cells.
Parameters:
----------
sdata : spatial data object
The spatial data object containing both extracellular and cell data.
layer : str, optional
Layer in sdata to extract the NMF data from (default: 'nmf_data').
n_factors : int, optional
Number of NMF factors to plot (default: 5).
figsize : tuple, optional
Size of the figure for each subplot (default: (12, 6)).
spot_size_exrna : float, optional
Size of the spots for extracellular transcript scatter plot (default: 5).
spot_size_cells : float, optional
Size of the spots for cell scatter plot (default: 10).
cmap_exrna : str, optional
Colormap for the extracellular transcript NMF factors (default: 'YlGnBu').
cmap_cells : str, optional
Colormap for the cell NMF factors (default: 'Reds').
vmax_exrna : str or float, optional
Maximum value for extracellular transcript color scale (default: 'p99').
vmax_cells : str or float, optional
Maximum value for cell color scale (default: None).
"""
Expand Down Expand Up @@ -797,61 +794,61 @@ def spatial_interactions(
----------
sdata : AnnData
An AnnData object containing the spatial omics data, including transcript expression and cell positions.
layer : str, optional, default: 'extracellular_transcripts_enriched'
The layer in the AnnData object that contains the extracellular RNA transcript data.
gene : str, optional, default: 'Arc'
The gene of interest to be visualized in terms of its spatial interaction with source and target cells.
gene_key : str, optional, default: 'feature_name'
The column name in the AnnData object used to identify the gene.
cell_id_key : str, optional, default: 'cell_id'
The column name in the AnnData object used to identify individual cells.
color_target : str, optional, default: 'blue'
The color to be used for target cells in the plot.
color_source : str, optional, default: 'red'
The color to be used for source cells in the plot.
color_transcript : str, optional, default: 'green'
The color to be used for the RNA transcripts in the plot.
spatial_key : str, optional, default: 'spatial'
The key in the AnnData object that stores the spatial coordinates of the cells.
img : Optional[Union[bool, Sequence]], optional, default: None
A background image to overlay on the plot, such as a tissue section. Can be set to `None` to omit.
img_alpha : Optional[float], optional, default: None
The transparency level of the background image. Ignored if `img` is `None`.
image_cmap : Optional[Colormap], optional, default: None
The colormap to be used for the background image, if applicable.
size : Optional[Union[float, Sequence[float]]], optional, default: 8
The size of the scatter plot points for the cells and transcripts.
alpha : float, optional, default: 0.6
The transparency level for the scatter plot points.
title : Optional[Union[str, Sequence[str]]], optional, default: None
The title of the plot. If `None`, the gene name is used.
legend_loc : Optional[str], optional, default: 'best'
The location of the legend in the plot.
figsize : Tuple[float, float], optional, default: (10, 10)
The dimensions of the plot in inches.
dpi : Optional[int], optional, default: 100
The resolution (dots per inch) for the plot.
save : Optional[Union[str, Path]], optional, default: None
The path to save the plot image. If `None`, the plot is displayed but not saved.
**kwargs : Additional keyword arguments
Any additional arguments passed to the `scatter` or `imshow` functions for customizing plot appearance.
Expand Down Expand Up @@ -895,7 +892,7 @@ def interactions_with_arrows(
cell_id_key: str = 'cell_id',
color_target: str = 'blue',
color_source: str = 'red',
color_transcript:str='green',
color_transcript: str = 'green',
spatial_key: str = 'spatial',
img: Optional[Union[bool, Sequence]] = None,
img_alpha: Optional[float] = None,
Expand All @@ -909,6 +906,40 @@ def interactions_with_arrows(
save: Optional[Union[str, Path]] = None,
**kwargs
):
"""
Visualizes interactions between source and target cells using arrows, along with transcript locations.
The function plots arrows from source to target cells based on transcript proximity, color-coding source and target cells, and transcript locations. An optional image layer can be overlaid behind the plot.
Parameters:
sdata (AnnData): The AnnData object containing the spatial omics data.
layer (str, optional): The key in `sdata` for the extracellular transcript layer to analyze. Default is 'extracellular_transcripts_enriched'.
gene (str, optional): The gene of interest. Default is 'Arc'.
gene_key (str, optional): The key for gene names in the data. Default is 'feature_name'.
cell_id_key (str, optional): The key for cell IDs. Default is 'cell_id'.
color_target (str, optional): Color for the target cells. Default is 'blue'.
color_source (str, optional): Color for the source cells. Default is 'red'.
color_transcript (str, optional): Color for the transcript locations. Default is 'green'.
spatial_key (str, optional): The key for spatial coordinates in `sdata`. Default is 'spatial'.
img (Optional[Union[bool, Sequence]], optional): Optional background image (e.g., tissue section) to display behind the plot.
img_alpha (Optional[float], optional): Transparency level for the background image. Default is None (no image).
image_cmap (Optional[Colormap], optional): Colormap for the image. Default is None.
size (Optional[Union[float, Sequence[float]]], optional): Size of the plotted points (cells and transcripts). Default is 8.
alpha (float, optional): Transparency level for plotted points. Default is 0.6.
title (Optional[Union[str, Sequence[str]]], optional): Title of the plot. Default is the gene name.
legend_loc (Optional[str], optional): Location of the legend on the plot. Default is 'best'.
figsize (Tuple[float, float], optional): Size of the plot. Default is (10, 10).
dpi (Optional[int], optional): Resolution of the plot. Default is 100.
save (Optional[Union[str, Path]], optional): If provided, the path where the plot will be saved.
**kwargs: Additional arguments passed to the `scatter` and `imshow` functions for customization.
Returns:
None: The function displays or saves a plot of interactions between cells and transcripts.
Notes:
The plot will show arrows from source to target cells, with different colors for source, target, and transcript points.
"""

# Extract relevant data
transcripts = sdata.points[layer]
trans_filt = transcripts[transcripts[gene_key] == gene]
Expand All @@ -924,7 +955,7 @@ def interactions_with_arrows(
# Plot arrows between each paired source and target cell
for source, target in zip(source_cells, target_cells):
if source in cell_positions.index and target in cell_positions.index:
if source!=target:
if source != target:
x_start, y_start = cell_positions.loc[source, 'x'], cell_positions.loc[source, 'y']
x_end, y_end = cell_positions.loc[target, 'x'], cell_positions.loc[target, 'y']
plt.arrow(x_start, y_start, x_end - x_start, y_end - y_start, color='black', alpha=0.8, head_width=8, head_length=8)
Expand All @@ -933,7 +964,7 @@ def interactions_with_arrows(
plt.scatter(cell_positions['x'], cell_positions['y'], c='grey', s=0.6, alpha=alpha, **kwargs)
plt.scatter(cell_positions.loc[target_cells, 'x'], cell_positions.loc[target_cells, 'y'], c=color_target, s=size, label='Target Cells', **kwargs)
plt.scatter(cell_positions.loc[source_cells, 'x'], cell_positions.loc[source_cells, 'y'], c=color_source, s=size, label='Source Cells', **kwargs)
plt.scatter(trans_filt['x'], trans_filt['y'], c=color_transcript, s=size*0.4, label='Transcripts', **kwargs)
plt.scatter(trans_filt['x'], trans_filt['y'], c=color_transcript, s=size * 0.4, label='Transcripts', **kwargs)

# Titles and Legends
plt.title(title or gene)
Expand All @@ -944,4 +975,4 @@ def interactions_with_arrows(
# Save the plot if path provided
if save:
plt.savefig(save)
plt.show()
plt.show()
Loading

0 comments on commit 5386490

Please sign in to comment.