Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type hinting and checking, run by black #102

Open
wants to merge 23 commits into
base: Code_Class_movement
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 64 additions & 22 deletions openmmdl/openmmdl_analysis/barcode_generation.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import List, Dict, Union


class BarcodeGenerator:
def __init__(self, df):
def __init__(self, df: pd.DataFrame):
"""
Initializes the BarcodeGenerator with a dataframe.

Args:
df (pandas dataframe): Dataframe containing all interactions from plip analysis (typically df_all)
df (pd.DataFrame): Dataframe containing all interactions from plip analysis (typically df_all).
"""
self.df = df
self.interactions = self.gather_interactions()

def gather_interactions(self):
def gather_interactions(self) -> Dict[str, pd.Index]:
"""
Gathers columns related to different types of interactions.

Returns:
Dict[str, pd.Index]: Dictionary where keys are interaction types and values are columns corresponding to those interactions.
"""
hydrophobic_interactions = self.df.filter(regex="hydrophobic").columns
acceptor_interactions = self.df.filter(regex="Acceptor_hbond").columns
donor_interactions = self.df.filter(regex="Donor_hbond").columns
Expand All @@ -39,15 +47,15 @@ def gather_interactions(self):
"metal": metal_interactions,
}

def generate_barcode(self, interaction):
def generate_barcode(self, interaction: str) -> np.ndarray:
"""
Generates barcodes for a given interaction.

Args:
interaction (str): Name of the interaction to generate a barcode for
interaction (str): Name of the interaction to generate a barcode for.

Returns:
np.array: Binary array with 1 representing the interaction is present in the corresponding frame
np.ndarray: Binary array with 1 representing the interaction is present in the corresponding frame.
"""
barcode = []
unique_frames = self.df["FRAME"].unique()
Expand All @@ -61,15 +69,15 @@ def generate_barcode(self, interaction):

return np.array(barcode)

def generate_waterids_barcode(self, interaction):
def generate_waterids_barcode(self, interaction: str) -> List[Union[int, int]]:
"""
Generates a barcode containing corresponding water ids for a given interaction.

Args:
interaction (str): Name of the interaction to generate a barcode for
interaction (str): Name of the interaction to generate a barcode for.

Returns:
list: List of water ids for the frames where the interaction is present, 0 if no interaction present
List[Union[int, int]]: List of water ids for the frames where the interaction is present, 0 if no interaction present.
"""
water_id_list = []
waterid_barcode = []
Expand All @@ -88,15 +96,15 @@ def generate_waterids_barcode(self, interaction):

return waterid_barcode

def interacting_water_ids(self, waterbridge_interactions):
"""Generates a list of all water ids that form water bridge interactions.
def interacting_water_ids(self, waterbridge_interactions: List[str]) -> List[int]:
"""
Generates a list of all water ids that form water bridge interactions.

Args:
df_all (pandas dataframe): dataframe containing all interactions from plip analysis (typicaly df_all)
waterbridge_interactions (list): list of strings containing the names of all water bridge interactions
waterbridge_interactions (List[str]): List of strings containing the names of all water bridge interactions.

Returns:
list: list of all unique water ids that form water bridge interactions
List[int]: List of all unique water ids that form water bridge interactions.
"""
interacting_waters = []
for waterbridge_interaction in waterbridge_interactions:
Expand All @@ -108,11 +116,24 @@ def interacting_water_ids(self, waterbridge_interactions):


class BarcodePlotter:
def __init__(self, df_all):
def __init__(self, df_all: pd.DataFrame):
"""
Initializes the BarcodePlotter with a dataframe and BarcodeGenerator instance.

Args:
df_all (pd.DataFrame): Dataframe containing all interactions from PLIP analysis.
"""
self.df_all = df_all
self.barcode_gen = BarcodeGenerator(df_all)

def plot_barcodes(self, barcodes, save_path):
def plot_barcodes(self, barcodes: Dict[str, np.ndarray], save_path: str) -> None:
"""
Plots barcodes and saves the figure to the specified path.

Args:
barcodes (Dict[str, np.ndarray]): Dictionary where keys are interaction names and values are binary barcode arrays.
save_path (str): Path to save the plotted figure.
"""
if not barcodes:
print("No barcodes to plot.")
return
Expand Down Expand Up @@ -154,8 +175,19 @@ def plot_barcodes(self, barcodes, save_path):
plt.savefig(save_path, dpi=300, bbox_inches="tight")

def plot_waterbridge_piechart(
self, waterbridge_barcodes, waterbridge_interactions, fig_type
):
self,
waterbridge_barcodes: Dict[str, np.ndarray],
waterbridge_interactions: List[str],
fig_type: str,
) -> None:
"""
Plots pie charts for waterbridge interactions and saves them to files.

Args:
waterbridge_barcodes (Dict[str, np.ndarray]): Dictionary where keys are interaction names and values are binary barcode arrays.
waterbridge_interactions (List[str]): List of water bridge interaction names.
fig_type (str): File extension for the saved figures (e.g., 'png', 'svg').
"""
if not waterbridge_barcodes:
print("No Piecharts to plot.")
return
Expand Down Expand Up @@ -201,7 +233,7 @@ def plot_waterbridge_piechart(
plt.pie(
values,
labels=labels,
autopct=lambda pct: f"{pct:.1f}%\n({int(round(pct/100.0 * sum(values)))})",
autopct=lambda pct: f"{pct:.1f}%\n({int(round(pct / 100.0 * sum(values)))})",
shadow=False,
startangle=140,
)
Expand All @@ -227,8 +259,18 @@ def plot_waterbridge_piechart(
dpi=300,
)

def plot_barcodes_grouped(self, interactions, interaction_type, fig_type):
ligatoms_dict = {}
def plot_barcodes_grouped(
self, interactions: List[str], interaction_type: str, fig_type: str
) -> None:
"""
Plots grouped barcodes for interactions and saves the figure to a file.

Args:
interactions (List[str]): List of interaction names.
interaction_type (str): Type of interaction for grouping.
fig_type (str): File extension for the saved figure (e.g., 'png', 'pdf').
"""
ligatoms_dict: Dict[str, List[str]] = {}
for interaction in interactions:
ligatom = interaction.split("_")
ligatom.pop(0)
Expand All @@ -245,7 +287,7 @@ def plot_barcodes_grouped(self, interactions, interaction_type, fig_type):
ligatom.pop(-1)
ligatom = "_".join(ligatom)
ligatoms_dict.setdefault(ligatom, []).append(interaction)

total_interactions = {}
for ligatom in ligatoms_dict:
ligatom_interaction_barcodes = {}
Expand Down
Loading