diff --git a/openmmdl/openmmdl_analysis/barcode_generation.py b/openmmdl/openmmdl_analysis/barcode_generation.py index e362bb83..931b3155 100644 --- a/openmmdl/openmmdl_analysis/barcode_generation.py +++ b/openmmdl/openmmdl_analysis/barcode_generation.py @@ -1,20 +1,28 @@ import os import numpy as np +import pandas as pd import matplotlib.pyplot as plt +from typing import List, Dict, Union class BarcodeGenerator: - def __init__(self, df): + def __init__(self, df: pd.DataFrame): """ Initializes the BarcodeGenerator with a dataframe. Args: - df (pandas dataframe): Dataframe containing all interactions from plip analysis (typically df_all) + df (pd.DataFrame): Dataframe containing all interactions from plip analysis (typically df_all). """ self.df = df self.interactions = self.gather_interactions() - def gather_interactions(self): + def gather_interactions(self) -> Dict[str, pd.Index]: + """ + Gathers columns related to different types of interactions. + + Returns: + Dict[str, pd.Index]: Dictionary where keys are interaction types and values are columns corresponding to those interactions. + """ hydrophobic_interactions = self.df.filter(regex="hydrophobic").columns acceptor_interactions = self.df.filter(regex="Acceptor_hbond").columns donor_interactions = self.df.filter(regex="Donor_hbond").columns @@ -39,15 +47,15 @@ def gather_interactions(self): "metal": metal_interactions, } - def generate_barcode(self, interaction): + def generate_barcode(self, interaction: str) -> np.ndarray: """ Generates barcodes for a given interaction. Args: - interaction (str): Name of the interaction to generate a barcode for + interaction (str): Name of the interaction to generate a barcode for. Returns: - np.array: Binary array with 1 representing the interaction is present in the corresponding frame + np.ndarray: Binary array with 1 representing the interaction is present in the corresponding frame. """ barcode = [] unique_frames = self.df["FRAME"].unique() @@ -61,15 +69,15 @@ def generate_barcode(self, interaction): return np.array(barcode) - def generate_waterids_barcode(self, interaction): + def generate_waterids_barcode(self, interaction: str) -> List[Union[int, int]]: """ Generates a barcode containing corresponding water ids for a given interaction. Args: - interaction (str): Name of the interaction to generate a barcode for + interaction (str): Name of the interaction to generate a barcode for. Returns: - list: List of water ids for the frames where the interaction is present, 0 if no interaction present + List[Union[int, int]]: List of water ids for the frames where the interaction is present, 0 if no interaction present. """ water_id_list = [] waterid_barcode = [] @@ -88,15 +96,15 @@ def generate_waterids_barcode(self, interaction): return waterid_barcode - def interacting_water_ids(self, waterbridge_interactions): - """Generates a list of all water ids that form water bridge interactions. + def interacting_water_ids(self, waterbridge_interactions: List[str]) -> List[int]: + """ + Generates a list of all water ids that form water bridge interactions. Args: - df_all (pandas dataframe): dataframe containing all interactions from plip analysis (typicaly df_all) - waterbridge_interactions (list): list of strings containing the names of all water bridge interactions + waterbridge_interactions (List[str]): List of strings containing the names of all water bridge interactions. Returns: - list: list of all unique water ids that form water bridge interactions + List[int]: List of all unique water ids that form water bridge interactions. """ interacting_waters = [] for waterbridge_interaction in waterbridge_interactions: @@ -108,11 +116,24 @@ def interacting_water_ids(self, waterbridge_interactions): class BarcodePlotter: - def __init__(self, df_all): + def __init__(self, df_all: pd.DataFrame): + """ + Initializes the BarcodePlotter with a dataframe and BarcodeGenerator instance. + + Args: + df_all (pd.DataFrame): Dataframe containing all interactions from PLIP analysis. + """ self.df_all = df_all self.barcode_gen = BarcodeGenerator(df_all) - def plot_barcodes(self, barcodes, save_path): + def plot_barcodes(self, barcodes: Dict[str, np.ndarray], save_path: str) -> None: + """ + Plots barcodes and saves the figure to the specified path. + + Args: + barcodes (Dict[str, np.ndarray]): Dictionary where keys are interaction names and values are binary barcode arrays. + save_path (str): Path to save the plotted figure. + """ if not barcodes: print("No barcodes to plot.") return @@ -154,8 +175,19 @@ def plot_barcodes(self, barcodes, save_path): plt.savefig(save_path, dpi=300, bbox_inches="tight") def plot_waterbridge_piechart( - self, waterbridge_barcodes, waterbridge_interactions, fig_type - ): + self, + waterbridge_barcodes: Dict[str, np.ndarray], + waterbridge_interactions: List[str], + fig_type: str, + ) -> None: + """ + Plots pie charts for waterbridge interactions and saves them to files. + + Args: + waterbridge_barcodes (Dict[str, np.ndarray]): Dictionary where keys are interaction names and values are binary barcode arrays. + waterbridge_interactions (List[str]): List of water bridge interaction names. + fig_type (str): File extension for the saved figures (e.g., 'png', 'svg'). + """ if not waterbridge_barcodes: print("No Piecharts to plot.") return @@ -201,7 +233,7 @@ def plot_waterbridge_piechart( plt.pie( values, labels=labels, - autopct=lambda pct: f"{pct:.1f}%\n({int(round(pct/100.0 * sum(values)))})", + autopct=lambda pct: f"{pct:.1f}%\n({int(round(pct / 100.0 * sum(values)))})", shadow=False, startangle=140, ) @@ -227,8 +259,18 @@ def plot_waterbridge_piechart( dpi=300, ) - def plot_barcodes_grouped(self, interactions, interaction_type, fig_type): - ligatoms_dict = {} + def plot_barcodes_grouped( + self, interactions: List[str], interaction_type: str, fig_type: str + ) -> None: + """ + Plots grouped barcodes for interactions and saves the figure to a file. + + Args: + interactions (List[str]): List of interaction names. + interaction_type (str): Type of interaction for grouping. + fig_type (str): File extension for the saved figure (e.g., 'png', 'pdf'). + """ + ligatoms_dict: Dict[str, List[str]] = {} for interaction in interactions: ligatom = interaction.split("_") ligatom.pop(0) @@ -245,7 +287,7 @@ def plot_barcodes_grouped(self, interactions, interaction_type, fig_type): ligatom.pop(-1) ligatom = "_".join(ligatom) ligatoms_dict.setdefault(ligatom, []).append(interaction) - + total_interactions = {} for ligatom in ligatoms_dict: ligatom_interaction_barcodes = {} diff --git a/openmmdl/openmmdl_analysis/find_stable_waters.py b/openmmdl/openmmdl_analysis/find_stable_waters.py index a67659a3..fb75208b 100644 --- a/openmmdl/openmmdl_analysis/find_stable_waters.py +++ b/openmmdl/openmmdl_analysis/find_stable_waters.py @@ -5,29 +5,28 @@ from sklearn.cluster import DBSCAN from tqdm import tqdm from io import StringIO -from Bio.PDB import PDBParser +from Bio.PDB import PDBParser, Structure +from typing import Tuple, Dict, List, Optional class StableWaters: - def __init__(self, trajectory, topology, water_eps): + def __init__(self, trajectory: str, topology: str, water_eps: float) -> None: self.trajectory = trajectory self.topology = topology self.u = mda.Universe(self.topology, self.trajectory) self.water_eps = water_eps - def trace_waters(self, output_directory): - """trace the water molecules in a trajectory and write all which move below one Angstrom distance. To adjust the distance alter the integer + def trace_waters(self, output_directory: str) -> Tuple[pd.DataFrame, int]: + """Trace the water molecules in a trajectory and write all which move below one Angstrom distance. To adjust the distance alter the integer Args: - topology (str): Path to the topology file. - trajectory (str): Path to the trajectory file. output_directory (str): Directory where output files will be saved. Returns: - pd.DataFrame: DataFrame containing stable water coordinates. - int: Total number of frames. + Tuple[pd.DataFrame, int]: DataFrame containing stable water coordinates and total number of frames. """ # Get the total number of frames for the progress bar total_frames = len(self.u.trajectory) + # Create an empty DataFrame to store stable water coordinates stable_waters = pd.DataFrame( columns=["Frame", "Residue", "Oxygen_X", "Oxygen_Y", "Oxygen_Z"] @@ -40,14 +39,13 @@ def trace_waters(self, output_directory): for ts in tqdm( self.u.trajectory, total=total_frames, - desc="Processing frames for the wateranalysis", + desc="Processing frames for the water analysis", ): frame_num = ts.frame frame_coords = {} # Iterate through oxygen atoms of the specified water type - # for atom in u.select_atoms(f"resname {water_type} and name O"): - for atom in self.u.select_atoms(f"resname HOH and name O"): + for atom in self.u.select_atoms("resname HOH and name O"): frame_coords[atom.index] = ( atom.position[0], atom.position[1], @@ -74,7 +72,7 @@ def trace_waters(self, output_directory): ) # Append stable water coordinates to the stable_waters DataFrame - if stable_coords: # Check if stable_coords is not empty + if stable_coords: stable_waters = pd.concat( [ stable_waters, @@ -100,17 +98,23 @@ def trace_waters(self, output_directory): return stable_waters, total_frames def perform_clustering_and_writing( - self, stable_waters, cluster_eps, total_frames, output_directory - ): + self, + stable_waters: pd.DataFrame, + cluster_eps: float, + total_frames: int, + output_directory: str, + ) -> None: """ Perform DBSCAN clustering on the stable water coordinates, and write the clusters and their representatives to PDB files. Args: stable_waters (pd.DataFrame): DataFrame containing stable water coordinates. - cluster_eps (float): DBSCAN clustering epsilon parameter. This is in Angstrom in this case, and defines which Water distances should be within one cluster + cluster_eps (float): DBSCAN clustering epsilon parameter. This is in Angstrom in this case + and defines which Water distances should be within one cluster. total_frames (int): Total number of frames. output_directory (str): Directory where output files will be saved. """ + # Feature extraction: XYZ coordinates X = stable_waters[["Oxygen_X", "Oxygen_Y", "Oxygen_Z"]] @@ -138,8 +142,11 @@ def perform_clustering_and_writing( ) def write_pdb_clusters_and_representatives( - self, clustered_waters, min_samples, output_sub_directory - ): + self, + clustered_waters: pd.DataFrame, + min_samples: int, + output_sub_directory: str, + ) -> None: """ Writes the clusters and their representatives to PDB files. @@ -147,16 +154,13 @@ def write_pdb_clusters_and_representatives( clustered_waters (pd.DataFrame): DataFrame containing clustered water coordinates. min_samples (int): Minimum number of samples for DBSCAN clustering. output_sub_directory (str): Subdirectory where output PDB files will be saved. - """ atom_counter = 1 pdb_file_counter = 1 print("minsamples:") print(min_samples) os.makedirs(output_sub_directory, exist_ok=True) - with pd.option_context( - "display.max_rows", None - ): # Temporarily set display options + with pd.option_context("display.max_rows", None): for label, cluster in clustered_waters.groupby("Cluster_Label"): pdb_lines = [] for _, row in cluster.iterrows(): @@ -188,22 +192,17 @@ def write_pdb_clusters_and_representatives( pdb_line = f"ATOM{index + 1:6} O WAT A{index + 1:4} {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 O\n" pdb_file.write(pdb_line) - # Example usage - # stable_waters_pipeline("topology_file", "trajectory_file", 0.5) - def stable_waters_pipeline(self, output_directory="./stableWaters"): + def stable_waters_pipeline(self, output_directory: str = "./stableWaters") -> None: """Function to run the pipeline to extract stable water clusters, and their representatives from a PDB & DCD file Args: - topology (str): Path to the topology file. - trajectory (str): Path to the trajectory file. - water_eps (float): DBSCAN clustering epsilon parameter. output_directory (str, optional): Directory where output files will be saved. Default is "./stableWaters". - """ # Load the PDB and DCD files output_directory += "_clusterEps_" strEps = str(self.water_eps).replace(".", "") output_directory += strEps os.makedirs(output_directory, exist_ok=True) + # Create a stable waters list by calling the process_trajectory_and_cluster function stable_waters, total_frames = self.trace_waters(output_directory) # Now call perform_clustering_and_writing with the returned values @@ -211,13 +210,13 @@ def stable_waters_pipeline(self, output_directory="./stableWaters"): stable_waters, self.water_eps, total_frames, output_directory ) - def filter_and_parse_pdb(protein_pdb): - """This function reads in a PDB and returns the structure with bioparser. + def filter_and_parse_pdb(self, protein_pdb: str) -> Structure: + """Reads in a PDB and returns the structure with bioparser. Args: protein_pdb (str): Path to a protein PDB file. Returns: - biopython.structure: PDB structure object. + Structure: Biopython PDB structure object. """ with open(protein_pdb, "r") as pdb_file: lines = [ @@ -226,9 +225,7 @@ def filter_and_parse_pdb(protein_pdb): if ( line.startswith("ATOM") and line[17:20].strip() not in ["HOH", "WAT", "T4P", "T3P"] - and line[22:26] - .strip() - .isdigit() # Exclude lines with non-numeric sequence identifiers + and line[22:26].strip().isdigit() ) ] @@ -242,22 +239,28 @@ def filter_and_parse_pdb(protein_pdb): return structure - def find_interacting_residues(structure, representative_waters, distance_threshold): - """This function maps waters (e.g. the representative waters) to interacting residues of a different PDB structure input. Use "filter_and_parse_pdb" to get the input for this function + def find_interacting_residues( + self, + structure: Structure, + representative_waters: pd.DataFrame, + distance_threshold: float, + ) -> Dict[int, List[Tuple[str, int]]]: + """Maps waters (e.g. the representative waters) to interacting residues of a different PDB structure input. + Use "filter_and_parse_pdb" to get the input for this function. Args: - structure (biopython.structure): Biopython PDB structure object. - representative_waters (pandasd.DataFrame): DataFrame containing representative water coordinates. + structure (Structure): Biopython PDB structure object. + representative_waters (pd.DataFrame): DataFrame containing representative water coordinates. distance_threshold (float): Threshold distance for identifying interacting residues. Returns: - dict: Dictionary mapping cluster numbers to interacting residues. + Dict[int, List[Tuple[str, int]]]: Dictionary mapping cluster numbers to interacting residues. """ interacting_residues = {} for model in structure: for chain in model: + # Check if the residue is a protein residue (not a heteroatom or water molecule) for residue in chain: - # Check if the residue is a protein residue (not a heteroatom or water molecule) if ( residue.id[0] == " " and residue.id[2] == " " @@ -277,7 +280,7 @@ def find_interacting_residues(structure, representative_waters, distance_thresho distance = np.linalg.norm(wat_coords - residue_coords) if distance < distance_threshold: - key = wat_index # Assuming wat_index is the number of the water cluster + key = wat_index # Assuming wat_index is the number of the water cluster if key not in interacting_residues: interacting_residues[key] = [] interacting_residues[key].append( @@ -286,19 +289,18 @@ def find_interacting_residues(structure, representative_waters, distance_thresho return interacting_residues - def read_pdb_as_dataframe(pdb_file): + def read_pdb_as_dataframe(self, pdb_file: str) -> pd.DataFrame: """Helper function reading a PDB Args: pdb_file (str): Path to the PDB file. Returns: - pandas.DataFrame: DataFrame containing PDB data. + pd.DataFrame: DataFrame containing PDB data. """ lines = [] with open(pdb_file, "r") as f: lines = f.readlines() - # Extract relevant information from PDB file lines data = [] for line in lines: if line.startswith("ATOM"): @@ -307,25 +309,23 @@ def read_pdb_as_dataframe(pdb_file): z = float(line[46:54].strip()) data.append([x, y, z]) - # Create a DataFrame columns = ["Oxygen_X", "Oxygen_Y", "Oxygen_Z"] representative_waters = pd.DataFrame(data, columns=columns) return representative_waters - # Analyse protein and water interaction, get the residues and the corresponding weater molecules that interact. def analyze_protein_and_water_interaction( self, - protein_pdb_file, - representative_waters_file, - cluster_eps, - output_directory="./stableWaters", - distance_threshold=5.0, - ): - """Analyse the interaction of residues to water molecules using a threshold that can be specified when calling the function + protein_pdb_file: str, + representative_waters_file: str, + cluster_eps: float, + output_directory: str = "./stableWaters", + distance_threshold: float = 5.0, + ) -> None: + """Analyze the interaction of residues to water molecules using a threshold that can be specified when calling the function Args: protein_pdb_file (str): Path to the protein PDB file without waters. - representative_waters_file (str): Path to the representative waters PDB file, or any PDB file containing only waters + representative_waters_file (str): Path to the representative waters PDB file, or any PDB file containing only waters. cluster_eps (float): DBSCAN clustering epsilon parameter. output_directory (str, optional): Directory where output files will be saved. Default is "./stableWaters". distance_threshold (float, optional): Threshold distance for identifying interacting residues. Default is 5.0 (Angstrom). diff --git a/openmmdl/openmmdl_analysis/markov_state_figure_generation.py b/openmmdl/openmmdl_analysis/markov_state_figure_generation.py index fd9d398d..220acff8 100644 --- a/openmmdl/openmmdl_analysis/markov_state_figure_generation.py +++ b/openmmdl/openmmdl_analysis/markov_state_figure_generation.py @@ -2,18 +2,19 @@ import matplotlib.pyplot as plt from matplotlib.patches import Patch import os +from typing import Dict, List, Tuple, Union class MarkovChainAnalysis: - def __init__(self, min_transition): + def __init__(self, min_transition: float): self.min_transition = min_transition - self.min_transitions = self.calculate_min_transitions() + self.min_transitions: List[float] = self.calculate_min_transitions() - def calculate_min_transitions(self): + def calculate_min_transitions(self) -> List[float]: """Calculates a list based on the minimum transition time provided values and returns it in factors 1, 2, 5, 10. Returns: - list: List with the minimum transition time with factors 1, 2, 5, 10. + List[float]: List with the minimum transition time with factors 1, 2, 5, 10. """ min_transitions = [ self.min_transition, @@ -24,42 +25,49 @@ def calculate_min_transitions(self): return min_transitions def generate_transition_graph( - self, total_frames, combined_dict, fig_type="png", font_size=36, size_node=200 - ): + self, + total_frames: int, + combined_dict: Dict[str, List[str]], + fig_type: str = "png", + font_size: int = 36, + size_node: int = 200, + ) -> None: """Generate Markov Chain plots based on transition probabilities. Args: total_frames (int): The number of frames in the protein-ligand MD simulation. - combined_dict (dict): A dictionary with the information of the Binding Modes and their order of appearance during the simulation for all frames. + combined_dict (Dict[str, List[str]]): A dictionary with the information of the Binding Modes and their order of appearance during the simulation for all frames. fig_type (str, optional): File type for the output figures. Default is 'png'. font_size (int, optional): The font size for the node labels. The default value is set to 36. size_node (int, optional): The size of the nodes in the Markov Chain plot. The default value is set to 200. """ # Calculate the number of elements in each part - total_length = len(combined_dict["all"]) - part_length = total_length // 3 - remaining_length = total_length % 3 + total_length: int = len(combined_dict["all"]) + part_length: int = total_length // 3 + remaining_length: int = total_length % 3 # Divide the 'all_data' into three parts - part1_length = part_length + remaining_length - part2_length = part_length - part1_data = combined_dict["all"][:part1_length] - part2_data = combined_dict["all"][part1_length : part1_length + part2_length] - part3_data = combined_dict["all"][part1_length + part2_length :] + part1_length: int = part_length + remaining_length + part2_length: int = part_length + part1_data: List[str] = combined_dict["all"][:part1_length] + part2_data: List[str] = combined_dict["all"][ + part1_length : part1_length + part2_length + ] + part3_data: List[str] = combined_dict["all"][part1_length + part2_length :] # Count the occurrences of each node in each part - part1_node_occurrences = { + part1_node_occurrences: Dict[str, int] = { node: part1_data.count(node) for node in set(part1_data) } - part2_node_occurrences = { + part2_node_occurrences: Dict[str, int] = { node: part2_data.count(node) for node in set(part2_data) } - part3_node_occurrences = { + part3_node_occurrences: Dict[str, int] = { node: part3_data.count(node) for node in set(part3_data) } # Create the legend - legend_labels = { + legend_labels: Dict[str, str] = { "Blue": "Binding mode not in top 10 occurrence", "Green": "Binding Mode occurrence mostly in first third of frames", "Orange": "Binding Mode occurrence mostly in second third of frames", @@ -67,57 +75,59 @@ def generate_transition_graph( "Yellow": "Binding Mode occurs throughout all trajectory equally", } - legend_colors = ["skyblue", "green", "orange", "red", "yellow"] + legend_colors: List[str] = ["skyblue", "green", "orange", "red", "yellow"] - legend_handles = [ + legend_handles: List[Patch] = [ Patch(color=color, label=label) for color, label in zip(legend_colors, legend_labels.values()) ] # Get the top 10 nodes with the most occurrences - node_occurrences = { + node_occurrences: Dict[str, int] = { node: combined_dict["all"].count(node) for node in set(combined_dict["all"]) } - top_10_nodes = sorted(node_occurrences, key=node_occurrences.get, reverse=True)[ - :10 - ] + top_10_nodes: List[str] = sorted( + node_occurrences, key=node_occurrences.get, reverse=True + )[:10] for min_transition_percent in self.min_transitions: - min_prob = min_transition_percent / 100 # Convert percentage to probability + min_prob: float = ( + min_transition_percent / 100 + ) # Convert percentage to probability # Create a directed graph - G = nx.DiGraph() + G: nx.DiGraph = nx.DiGraph() # Count the occurrences of each transition and self-loop - transitions = {} - self_loops = {} + transitions: Dict[Tuple[str, str], int] = {} + self_loops: Dict[Tuple[str, str], int] = {} for i in range(len(combined_dict["all"]) - 1): - current_state = combined_dict["all"][i] - next_state = combined_dict["all"][i + 1] + current_state: str = combined_dict["all"][i] + next_state: str = combined_dict["all"][i + 1] if current_state == next_state: # Check for self-loop - self_loop = (current_state, next_state) + self_loop: Tuple[str, str] = (current_state, next_state) self_loops[self_loop] = self_loops.get(self_loop, 0) + 1 else: - transition = (current_state, next_state) + transition: Tuple[str, str] = (current_state, next_state) transitions[transition] = transitions.get(transition, 0) + 1 # Add edges to the graph with their probabilities for transition, count in transitions.items(): current_state, next_state = transition - probability = ( - count / len(combined_dict["all"]) * 100 - ) # Convert probability to percentage + probability: float = ( + count / len(combined_dict["all"]) + ) * 100 # Convert probability to percentage if probability >= min_transition_percent: G.add_edge(current_state, next_state, weight=probability) # Include the reverse transition with a different color - reverse_transition = (next_state, current_state) - reverse_count = transitions.get( + reverse_transition: Tuple[str, str] = (next_state, current_state) + reverse_count: int = transitions.get( reverse_transition, 0 ) # Use the correct count for the reverse transition - reverse_probability = ( - reverse_count / len(combined_dict["all"]) * 100 - ) + reverse_probability: float = ( + reverse_count / len(combined_dict["all"]) + ) * 100 G.add_edge( next_state, current_state, @@ -127,49 +137,47 @@ def generate_transition_graph( # Add self-loops to the graph with their probabilities for self_loop, count in self_loops.items(): - state = self_loop[0] - probability = ( - count / len(combined_dict["all"]) * 100 - ) # Convert probability to percentage + state: str = self_loop[0] + probability: float = ( + count / len(combined_dict["all"]) + ) * 100 # Convert probability to percentage if probability >= min_transition_percent: G.add_edge(state, state, weight=probability) # Calculate transition probabilities for each direction (excluding self-loops) - transition_probabilities_forward = {} - transition_probabilities_backward = {} - transition_occurrences_forward = {} - transition_occurrences_backward = {} + transition_probabilities_forward: Dict[Tuple[str, str], float] = {} + transition_probabilities_backward: Dict[Tuple[str, str], float] = {} + transition_occurrences_forward: Dict[Tuple[str, str], float] = {} + transition_occurrences_backward: Dict[Tuple[str, str], float] = {} for transition, count in transitions.items(): start_state, end_state = transition - forward_transition = (start_state, end_state) - backward_transition = (end_state, start_state) + forward_transition: Tuple[str, str] = (start_state, end_state) + backward_transition: Tuple[str, str] = (end_state, start_state) # Separate counts for forward and backward transitions - forward_count = transitions.get(forward_transition, 0) - backward_count = transitions.get(backward_transition, 0) + forward_count: int = transitions.get(forward_transition, 0) + backward_count: int = transitions.get(backward_transition, 0) transition_probabilities_forward[forward_transition] = ( - forward_count / node_occurrences[start_state] * 100 - ) - + forward_count / node_occurrences[start_state] + ) * 100 transition_occurrences_forward[forward_transition] = ( - forward_count / len(combined_dict["all"]) * 100 - ) + forward_count / len(combined_dict["all"]) + ) * 100 transition_probabilities_backward[backward_transition] = ( - backward_count / node_occurrences[end_state] * 100 - ) - + backward_count / node_occurrences[end_state] + ) * 100 transition_occurrences_backward[backward_transition] = ( - backward_count / len(combined_dict["all"]) * 100 - ) + backward_count / len(combined_dict["all"]) + ) * 100 # Calculate self-loop probabilities - self_loop_probabilities = {} - self_loop_occurences = {} + self_loop_probabilities: Dict[str, float] = {} + self_loop_occurences: Dict[str, float] = {} for self_loop, count in self_loops.items(): - state = self_loop[0] + state: str = self_loop[0] self_loop_probabilities[state] = count / node_occurrences[state] self_loop_occurences[state] = count / len(combined_dict["all"]) * 100 @@ -181,18 +189,18 @@ def generate_transition_graph( ) # Draw nodes and edges - pos = nx.spring_layout( + pos: Dict[str, Tuple[float, float]] = nx.spring_layout( G, k=2, seed=42 ) # Increased distance between nodes (k=2) - edge_colors = [] + edge_colors: List[str] = [] for u, v, data in G.edges(data=True): - weight = data["weight"] + weight: float = data["weight"] if u == v: # Check if it is a self-loop edge_colors.append("green") # Set green color for self-loop arrows - width = 0.1 # Make self-loop arrows smaller - connection_style = ( + width: float = 0.1 # Make self-loop arrows smaller + connection_style: str = ( "arc3,rad=-0.1" # Make the self-loops more curved and compact ) nx.draw_networkx_edges( @@ -207,16 +215,20 @@ def generate_transition_graph( elif weight >= min_transition_percent: edge_colors.append( "black" - ) # Highlight significant transitions in red + ) # Highlight significant transitions in black # Check if both nodes are present before adding labels if G.has_node(u) and G.has_node(v): - width = 4.0 - forward_label = f"{transition_occurrences_forward.get((v, u), 0):.2f}% of Frames →, {transition_probabilities_forward.get((v, u), 0):.2f}% probability" - backward_label = f"{transition_occurrences_backward.get((u, v), 0):.2f}% of Frames ←, {transition_probabilities_backward.get((u, v), 0):.2f}% probability" - edge_label = f"{forward_label}\n{backward_label}" + width: float = 4.0 + forward_label: str = ( + f"{transition_occurrences_forward.get((v, u), 0):.2f}% of Frames →, {transition_probabilities_forward.get((v, u), 0):.2f}% probability" + ) + backward_label: str = ( + f"{transition_occurrences_backward.get((u, v), 0):.2f}% of Frames ←, {transition_probabilities_backward.get((u, v), 0):.2f}% probability" + ) + edge_label: str = f"{forward_label}\n{backward_label}" - connection_style = "arc3,rad=-0.1" + connection_style: str = "arc3,rad=-0.1" nx.draw_networkx_edges( G, pos, @@ -232,16 +244,20 @@ def generate_transition_graph( else: edge_colors.append( "grey" - ) # Use black for non-significant transitions - width = 0.5 + ) # Use grey for non-significant transitions + width: float = 0.5 # Check if both nodes are present before adding labels if G.has_node(u) and G.has_node(v): - forward_label = f"{transition_occurrences_forward.get((v, u), 0):.2f}% of Frames →, {transition_probabilities_forward.get((v, u), 0):.2f}% probability" - backward_label = f"{transition_occurrences_backward.get((u, v), 0):.2f}% of Frames ←, {transition_probabilities_backward.get((u, v), 0):.2f}% probability" - edge_label = f"{forward_label}\n{backward_label}" + forward_label: str = ( + f"{transition_occurrences_forward.get((v, u), 0):.2f}% of Frames →, {transition_probabilities_forward.get((v, u), 0):.2f}% probability" + ) + backward_label: str = ( + f"{transition_occurrences_backward.get((u, v), 0):.2f}% of Frames ←, {transition_probabilities_backward.get((u, v), 0):.2f}% probability" + ) + edge_label: str = f"{forward_label}\n{backward_label}" - connection_style = "arc3,rad=-0.1" + connection_style: str = "arc3,rad=-0.1" nx.draw_networkx_edges( G, pos, @@ -256,16 +272,16 @@ def generate_transition_graph( ) # Update the node colors based on their appearance percentages in each part - node_colors = [] + node_colors: List[str] = [] for node in G.nodes(): if node in top_10_nodes: - part1_percentage = ( + part1_percentage: float = ( part1_node_occurrences.get(node, 0) / node_occurrences[node] ) - part2_percentage = ( + part2_percentage: float = ( part2_node_occurrences.get(node, 0) / node_occurrences[node] ) - part3_percentage = ( + part3_percentage: float = ( part3_node_occurrences.get(node, 0) / node_occurrences[node] ) @@ -281,22 +297,24 @@ def generate_transition_graph( node_colors.append("skyblue") # Draw nodes with sizes correlated to occurrences and color top 10 nodes - node_size = [size_node * node_occurrences[node] for node in G.nodes()] + node_size: List[int] = [ + size_node * node_occurrences[node] for node in G.nodes() + ] nx.draw_networkx_nodes( G, pos, node_size=node_size, node_color=node_colors, alpha=0.8 ) # Draw node labels with occurrence percentage and self-loop probability for nodes with edges - node_labels = {} + node_labels: Dict[str, str] = {} for node in G.nodes(): if G.degree(node) > 0: # Check if the node has at least one edge - edges_with_node = [ + edges_with_node: List[Tuple[str, str, float]] = [ (u, v, data["weight"]) for u, v, data in G.edges(data=True) if u == node or v == node ] - relevant_edges = [ + relevant_edges: List[Tuple[str, str, float]] = [ edge for edge in edges_with_node if edge[2] >= min_transition_percent @@ -304,14 +322,18 @@ def generate_transition_graph( if relevant_edges: if node in top_10_nodes: - node_occurrence_percentage = ( - node_occurrences[node] / len(combined_dict["all"]) * 100 - ) - self_loop_probability = ( + node_occurrence_percentage: float = ( + node_occurrences[node] / len(combined_dict["all"]) + ) * 100 + self_loop_probability: float = ( self_loop_probabilities.get(node, 0) * 100 ) - self_loop_occurence = self_loop_occurences.get(node, 0) - node_label = f"{node}\nOccurrences: {node_occurrence_percentage:.2f}%\nSelf-Loop Probability: {self_loop_probability:.2f}% \nSelf-Loop Occurrence: {self_loop_occurence:.2f}%" + self_loop_occurence: float = self_loop_occurences.get( + node, 0 + ) + node_label: str = ( + f"{node}\nOccurrences: {node_occurrence_percentage:.2f}%\nSelf-Loop Probability: {self_loop_probability:.2f}% \nSelf-Loop Occurrence: {self_loop_occurence:.2f}%" + ) node_labels[node] = node_label else: node_labels[node] = node @@ -332,8 +354,10 @@ def generate_transition_graph( plt.tight_layout() # Save the plot - plot_filename = f"markov_chain_plot_{min_transition_percent}.{fig_type}" - plot_path = os.path.join("Binding_Modes_Markov_States", plot_filename) + plot_filename: str = ( + f"markov_chain_plot_{min_transition_percent}.{fig_type}" + ) + plot_path: str = os.path.join("Binding_Modes_Markov_States", plot_filename) os.makedirs( "Binding_Modes_Markov_States", exist_ok=True ) # Create the folder if it doesn't exist diff --git a/openmmdl/openmmdl_analysis/openmmdlanalysis.py b/openmmdl/openmmdl_analysis/openmmdlanalysis.py index 0b5c3940..70286a4b 100644 --- a/openmmdl/openmmdl_analysis/openmmdlanalysis.py +++ b/openmmdl/openmmdl_analysis/openmmdlanalysis.py @@ -373,7 +373,7 @@ def main(): interaction_analysis = InteractionAnalyzer( pdb_md, dataframe, cpu_count, ligand, special_ligand, peptide ) - interaction_list = interaction_analysis.ineraction_list + interaction_list = interaction_analysis.interaction_list interaction_list.to_csv("missing_frames_filled.csv") interaction_list = interaction_list.reset_index(drop=True) diff --git a/openmmdl/openmmdl_analysis/pharmacophore.py b/openmmdl/openmmdl_analysis/pharmacophore.py index f06a950c..d0f684ef 100644 --- a/openmmdl/openmmdl_analysis/pharmacophore.py +++ b/openmmdl/openmmdl_analysis/pharmacophore.py @@ -2,17 +2,20 @@ import pandas as pd import xml.etree.ElementTree as ET import numpy as np +from typing import Dict, List, Optional, Union class PharmacophoreGenerator: - def __init__(self, df_all, ligand_name): + def __init__(self, df_all: pd.DataFrame, ligand_name: str): self.df_all = df_all self.ligand_name = ligand_name - self.comlex_name = f"{ligand_name}_complex" + self.complex_name = f"{ligand_name}_complex" self.coord_pattern = re.compile(r"\(([\d.-]+), ([\d.-]+), ([\d.-]+)\)") self.clouds = self._generate_clouds() - def _generate_clouds(self): + def _generate_clouds( + self, + ) -> Dict[str, Dict[str, Union[List[List[float]], List[float], float]]]: interaction_coords = { "hydrophobic": [], "acceptor": [], @@ -57,7 +60,9 @@ def _generate_clouds(self): return self._format_clouds(interaction_coords) - def _format_clouds(self, interaction_coords): + def _format_clouds( + self, interaction_coords: Dict[str, List[List[float]]] + ) -> Dict[str, Dict[str, Union[List[List[float]], List[float], float]]]: color_mapping = { "hydrophobic": [1.0, 1.0, 0.0], "acceptor": [1.0, 0.0, 0.0], @@ -79,18 +84,21 @@ def _format_clouds(self, interaction_coords): for interaction, coords in interaction_coords.items() } - def to_dict(self): + def to_dict( + self, + ) -> Dict[str, Dict[str, Union[List[List[float]], List[float], float]]]: return self.clouds - def generate_pharmacophore_centers(self, interactions): - """Generates pharmacophore points for interactions that are points such as hydrophobic and ionic interactions + def generate_pharmacophore_centers( + self, interactions: List[str] + ) -> Dict[str, List[float]]: + """Generates pharmacophore points for interactions that are points such as hydrophobic and ionic interactions. Args: - df (pandas dataframe): dataframe generated by analisis using plip - interactions (list): list of interactions to generate pharmacophore from + interactions (List[str]): List of interactions to generate pharmacophore from. Returns: - dict: Dict of interactions from which pharmacophore is generated as key and list of coordinates as value + Dict[str, List[float]]: Dict of interactions from which pharmacophore is generated as key and list of coordinates as value. """ coord_pattern = re.compile(r"\(([\d.-]+), ([\d.-]+), ([\d.-]+)\)") pharmacophore = {} @@ -107,21 +115,23 @@ def generate_pharmacophore_centers(self, interactions): sum_z += z counter += 1 - center_x = round((sum_x / counter), 3) - center_y = round((sum_y / counter), 3) - center_z = round((sum_z / counter), 3) - pharmacophore[interaction] = [center_x, center_y, center_z] + if counter > 0: + center_x = round((sum_x / counter), 3) + center_y = round((sum_y / counter), 3) + center_z = round((sum_z / counter), 3) + pharmacophore[interaction] = [center_x, center_y, center_z] return pharmacophore - def generate_pharmacophore_vectors(self, interactions): - """Generates pharmacophore points for interactions that are vectors such as hydrogen bond donors or acceptors + def generate_pharmacophore_vectors( + self, interactions: List[str] + ) -> Dict[str, List[List[float]]]: + """Generates pharmacophore points for interactions that are vectors such as hydrogen bond donors or acceptors. Args: - df (pandas dataframe): dataframe generated by analisis using plip - interactions (list): list of interactions to generate pharmacophore from + interactions (List[str]): List of interactions to generate pharmacophore from. Returns: - dict: Dict of interactions from which pharmacophore is generated as key and list of coordinates as value (first coords are ligand side, second are protein side) + Dict[str, List[List[float]]]: Dict of interactions from which pharmacophore is generated as key and list of coordinates as value (first coords are ligand side, second are protein side). """ coord_pattern = re.compile(r"\(([\d.-]+), ([\d.-]+), ([\d.-]+)\)") pharmacophore = {} @@ -145,27 +155,28 @@ def generate_pharmacophore_vectors(self, interactions): sum_c += c counter += 1 - center_x = round((sum_x / counter), 3) - center_y = round((sum_y / counter), 3) - center_z = round((sum_z / counter), 3) - center_a = round((sum_a / counter), 3) - center_b = round((sum_b / counter), 3) - center_c = round((sum_c / counter), 3) - pharmacophore[interaction] = [ - [center_x, center_y, center_z], - [center_a, center_b, center_c], - ] + if counter > 0: + center_x = round((sum_x / counter), 3) + center_y = round((sum_y / counter), 3) + center_z = round((sum_z / counter), 3) + center_a = round((sum_a / counter), 3) + center_b = round((sum_b / counter), 3) + center_c = round((sum_c / counter), 3) + pharmacophore[interaction] = [ + [center_x, center_y, center_z], + [center_a, center_b, center_c], + ] return pharmacophore - def generate_md_pharmacophore_cloudcenters(self, output_filename, id_num=0): + def generate_md_pharmacophore_cloudcenters( + self, output_filename: str, id_num: int = 0 + ) -> None: """Generates pharmacophore from all interactions formed in the MD simulation. - A feature is generated for each interaction at the center of all its ocurrences. + A feature is generated for each interaction at the center of all its occurrences. Args: - df (pandas dataframe): dataframe generated by analysis using plip (generally df_all) - output_filename (str): name the of the output .pml file - sysname (str): name of thesystem simulated - id_num (int, optional): id number as an identifier in the PML file. Defaults to 0. + output_filename (str): Name of the output .pml file. + id_num (int, optional): ID number as an identifier in the PML file. Defaults to 0. """ feature_id_counter = 0 @@ -173,12 +184,12 @@ def generate_md_pharmacophore_cloudcenters(self, output_filename, id_num=0): "MolecularEnvironment", version="0.0", id=f"OpennMMDL_Analysis{id_num}", - name=self.comlex_name, + name=self.complex_name, ) pharmacophore = ET.SubElement( root, "pharmacophore", - name=self.comlex_name, + name=self.complex_name, id=f"pharmacophore{id_num}", pharmacophoreType="LIGAND_SCOUT", ) @@ -330,18 +341,16 @@ def generate_md_pharmacophore_cloudcenters(self, output_filename, id_num=0): def generate_bindingmode_pharmacophore( self, - dict_bindingmode, - outname, - id_num=0, - ): - """Generates pharmacophore from a binding mode and writes it to a .pml file + dict_bindingmode: Dict[str, Dict[str, List[List[float]]]], + outname: str, + id_num: int = 0, + ) -> None: + """Generates pharmacophore from a binding mode and writes it to a .pml file. Args: - dict_bindingmode (dict): dictionary containing all interactions of the bindingmode and thei coresponding ligand and protein coordinates - core_compound (str): name of the ligand - sysname (str): name of the analysed system - outname (str): name of the output .pml file - id_num (int, optional): if multiple id number can enumerate the diferent bindingmodes. Defaults to 0. + dict_bindingmode (Dict[str, Dict[str, List[List[float]]]]): Dictionary containing all interactions of the binding mode and their corresponding ligand and protein coordinates. + outname (str): Name of the output .pml file. + id_num (int, optional): ID number for enumerating different binding modes. Defaults to 0. """ feature_types = { "Acceptor_hbond": "HBA", @@ -356,12 +365,12 @@ def generate_bindingmode_pharmacophore( "MolecularEnvironment", version="0.0", id=f"OpennMMDL_Analysis{id_num}", - name=self.comlex_name, + name=self.complex_name, ) pharmacophore = ET.SubElement( root, "pharmacophore", - name=self.comlex_name, + name=self.complex_name, id=f"pharmacophore{id_num}", pharmacophoreType="LIGAND_SCOUT", ) @@ -372,7 +381,8 @@ def generate_bindingmode_pharmacophore( if interactiontype in interaction: feature_type = feature_types[interactiontype] break - # generate vector features + + # Generate vector features if feature_type in ["HBA", "HBD"]: if feature_type == "HBA": orig_loc = dict_bindingmode[interaction]["PROTCOO"][0] @@ -393,7 +403,7 @@ def generate_bindingmode_pharmacophore( optional="false", disabled="false", weight="1.0", - coreCompound=self.comlex_name, + coreCompound=self.complex_name, id=f"feature{str(feature_id_counter)}", ) origin = ET.SubElement( @@ -412,7 +422,7 @@ def generate_bindingmode_pharmacophore( z3=str(targ_loc[2]), tolerance="1.5", ) - # generate point features + # Generate point features elif feature_type in ["H", "PI", "NI"]: position = dict_bindingmode[interaction]["LIGCOO"][0] feature_id_counter += 1 @@ -424,7 +434,7 @@ def generate_bindingmode_pharmacophore( optional="false", disabled="false", weight="1.0", - coreCompound=self.comlex_name, + coreCompound=self.complex_name, id=f"feature{str(feature_id_counter)}", ) position = ET.SubElement( @@ -435,13 +445,13 @@ def generate_bindingmode_pharmacophore( z3=str(position[2]), tolerance="1.5", ) - # generate plane features + # Generate plane features elif feature_type == "AR": feature_id_counter += 1 lig_loc = dict_bindingmode[interaction]["LIGCOO"][0] prot_loc = dict_bindingmode[interaction]["PROTCOO"][0] - # normalize vector of plane + # Normalize vector of plane vector = np.array(lig_loc) - np.array(prot_loc) normal_vector = vector / np.linalg.norm(vector) x, y, z = normal_vector @@ -454,7 +464,7 @@ def generate_bindingmode_pharmacophore( optional="false", disabled="false", weight="1.0", - coreCompound=self.comlex_name, + coreCompound=self.complex_name, id=f"feature{str(feature_id_counter)}", ) position = ET.SubElement( @@ -481,15 +491,16 @@ def generate_bindingmode_pharmacophore( xml_declaration=True, ) - def generate_pharmacophore_centers_all_points(self, interactions): - """Generates pharmacophore points for all interactions to generate point cloud + def generate_pharmacophore_centers_all_points( + self, interactions: List[str] + ) -> Dict[str, List[List[float]]]: + """Generates pharmacophore points for all interactions to generate point cloud. Args: - df (pandas dataframe): dataframe generated by analysis using plip - interactions (list): list of interactions to generate pharmacophore from + interactions (List[str]): List of interactions to generate pharmacophore from. Returns: - dict: Dict of interactions from which pharmacophore is generated as key and list of coordinates as value + Dict[str, List[List[float]]]: Dict of interactions with pharmacophore points as values. """ coord_pattern = re.compile(r"\(([\d.-]+), ([\d.-]+), ([\d.-]+)\)") pharmacophore = {} @@ -506,13 +517,11 @@ def generate_pharmacophore_centers_all_points(self, interactions): pharmacophore[interaction] = pharmacophore_points return pharmacophore - def generate_point_cloud_pml(self, outname): - """Generates pharmacophore point cloud and writes it to a .pml file + def generate_point_cloud_pml(self, outname: str) -> None: + """Generates pharmacophore point cloud and writes it to a .pml file. Args: - cloud_dict (dict): dictionary containing all interactions of the trajectory and their corresponding ligand coordinates - sysname (str): name of the simulated system - outname (str): name of the output .pml file + outname (str): Name of the output .pml file. """ cloud_dict = {} cloud_dict["H"] = self.generate_pharmacophore_centers_all_points( @@ -539,7 +548,7 @@ def generate_point_cloud_pml(self, outname): pharmacophore = ET.Element( "pharmacophore", - name=f"{self.comlex_name}_pointcloud", + name=f"{self.complex_name}_pointcloud", id=f"pharmacophore0", pharmacophoreType="LIGAND_SCOUT", ) @@ -573,6 +582,7 @@ def generate_point_cloud_pml(self, outname): z3=str(round(additional_point[2], 2)), weight="1.0", ) + feature_id_counter += 1 tree = ET.ElementTree(pharmacophore) tree.write(f"{outname}.pml", encoding="UTF-8", xml_declaration=True) diff --git a/openmmdl/openmmdl_analysis/preprocessing.py b/openmmdl/openmmdl_analysis/preprocessing.py index 15928c92..db79d43d 100644 --- a/openmmdl/openmmdl_analysis/preprocessing.py +++ b/openmmdl/openmmdl_analysis/preprocessing.py @@ -1,28 +1,31 @@ import MDAnalysis as mda -import subprocess +import mdtraj as md import os import re +from typing import List, Optional import rdkit -import mdtraj as md from rdkit import Chem -from rdkit.Chem import Draw -from rdkit.Chem import AllChem -from rdkit.Chem.Draw import rdMolDraw2D +from rdkit.Chem import AllChem, Mol from openbabel import pybel + class Preprocessing: def __init__(self): pass - def renumber_protein_residues(self, input_pdb, reference_pdb, output_pdb): - """Renumber protein residues in a molecular dynamics trajectory based on a reference structure. + def renumber_protein_residues( + self, input_pdb: str, reference_pdb: str, output_pdb: str + ) -> None: + """ + Renumber protein residues in a molecular dynamics trajectory based on a reference structure. Args: input_pdb (str): Path to the input PDB file representing the molecular dynamics trajectory to be renumbered. reference_pdb (str): Path to the reference PDB file representing the molecular dynamics trajectory used as a reference. - output_pdb (str): Path to the output PDB file where the renumbered trajectory will be saved.. + output_pdb (str): Path to the output PDB file where the renumbered trajectory will be saved. """ + # Load trajectories traj_input = md.load(input_pdb) traj_reference = md.load(reference_pdb) @@ -87,24 +90,26 @@ def renumber_protein_residues(self, input_pdb, reference_pdb, output_pdb): # Save the renumbered trajectory to a new PDB file new_traj.save(output_pdb) - def increase_ring_indices(self, ring, lig_index): - """Increases the atom indices in a ring of the ligand obtained from the ligand to fit the atom indices present in the protein-ligand complex. + def increase_ring_indices(self, ring: List[int], lig_index: int) -> List[int]: + """ + Increases the atom indices in a ring of the ligand obtained from the ligand to fit the atom indices present in the protein-ligand complex. Args: - ring (str): A list of atom indices belonging to a ring that need to be modified. - lig_index (int): An integer that is the first number of the ligand atom indices obtained from the protein-ligand, which is used to modify the ring indices + ring (List[int]): A list of atom indices belonging to a ring that need to be modified. + lig_index (int): An integer that is the first number of the ligand atom indices obtained from the protein-ligand, which is used to modify the ring indices. Returns: - list: A new list with modified atom indicies. + List[int]: A new list with modified atom indices. """ return [atom_idx + lig_index for atom_idx in ring] - def convert_ligand_to_smiles(self, input_sdf, output_smi): - """Converts ligand structures from an SDF file to SMILES :) format + def convert_ligand_to_smiles(self, input_sdf: str, output_smi: str) -> None: + """ + Converts ligand structures from an SDF file to SMILES :) format. Args: - input_sdf (str): Path to the SDF file with the ligand that wll be converted. - output_smi (str): Path to the output SMILES files. + input_sdf (str): Path to the SDF file with the ligand that will be converted. + output_smi (str): Path to the output SMILES file. """ # Create a molecule supplier from an SDF file mol_supplier = Chem.SDMolSupplier(input_sdf) @@ -113,43 +118,42 @@ def convert_ligand_to_smiles(self, input_sdf, output_smi): with open(output_smi, "w") as output_file: # Iterate through molecules and convert each to SMILES for mol in mol_supplier: - if mol is not None: # Make sure the molecule was successfully read + if mol is not None: smiles = Chem.MolToSmiles(mol) output_file.write(smiles + "\n") else: - print("nono") + print("Molecule could not be read.") - def process_pdb_file(self, input_pdb_filename): - """Process a PDB file to make it compatible with the openmmdl_analysis package. + def process_pdb_file(self, input_pdb_filename: str) -> None: + """ + Process a PDB file to make it compatible with OpenMMDL Analysis. Args: - input_pdb_filename (str): path to the input PDB file + input_pdb_filename (str): Path to the input PDB file. """ # Load the PDB file u = mda.Universe(input_pdb_filename) - # Iterate through the topology to modify residue names for atom in u.atoms: resname = atom.resname - # Check if the residue name is one of the specified water names if resname in ["SPC", "TIP3", "TIP4", "WAT", "T3P", "T4P", "T5P"]: atom.residue.resname = "HOH" elif resname == "*": atom.residue.resname = "UNK" - # Save the modified topology to a new PDB file u.atoms.write(input_pdb_filename) def extract_and_save_ligand_as_sdf( - self, input_pdb_filename, output_filename, target_resname - ): - """Extract and save the ligand from the receptor ligand complex PDB file into a new PDB file by itself. + self, input_pdb_filename: str, output_filename: str, target_resname: str + ) -> None: + """ + Extract and save the ligand from the receptor-ligand complex PDB file into a new SDF file. Args: - input_pdb_filename (str): name of the input PDB file - output_pdb_filename (str): name of the output SDF file - target_resname (str): resname of the ligand in the original PDB file + input_pdb_filename (str): Name of the input PDB file. + output_filename (str): Name of the output SDF file. + target_resname (str): Residue name of the ligand in the original PDB file. """ # Load the PDB file using MDAnalysis u = mda.Universe(input_pdb_filename) @@ -165,8 +169,6 @@ def extract_and_save_ligand_as_sdf( # Create a new Universe with only the ligand ligand_universe = mda.Merge(ligand_atoms) - - # Save the ligand Universe as a PDB file ligand_universe.atoms.write("lig.pdb") # use openbabel to convert pdb to sdf @@ -175,8 +177,11 @@ def extract_and_save_ligand_as_sdf( # remove the temporary pdb file os.remove("lig.pdb") - def renumber_atoms_in_residues(self, input_pdb_file, output_pdb_file, lig_name): - """Renumer the atoms of the ligand in the topology PDB file. + def renumber_atoms_in_residues( + self, input_pdb_file: str, output_pdb_file: str, lig_name: str + ) -> None: + """ + Renumber the atoms of the ligand in the topology PDB file. Args: input_pdb_file (str): Path to the initial PDB file. @@ -199,13 +204,11 @@ def renumber_atoms_in_residues(self, input_pdb_file, output_pdb_file, lig_name): # Check if the residue is residue_name if residue_name == lig_name: + # Extract the element from the atom name (assuming it starts with a capital letter) element_match = re.match(r"([A-Z]+)", atom_name) - if element_match: - element = element_match.group(1) - else: - element = atom_name - + element = element_match.group(1) if element_match else atom_name + # Increment the count for the current element in the current residue_name residue lig_residue_elements[element] = ( lig_residue_elements.get(element, 0) + 1 @@ -224,8 +227,9 @@ def renumber_atoms_in_residues(self, input_pdb_file, output_pdb_file, lig_name): with open(output_pdb_file, "w") as f: f.writelines(new_pdb_lines) - def replace_atom_type(self, data): - """Replace wrong ligand atom types in the topology PDB file. + def replace_atom_type(self, data: str) -> str: + """ + Replace incorrect ligand atom types in the topology PDB file. Args: data (str): Text of the initial PDB file. @@ -238,16 +242,18 @@ def replace_atom_type(self, data): if " LIG X" in line: # Extract the last column which contains the atom type (O/N/H) atom_type = line[12:13].strip() + # Replace 'X' with the correct atom type lines[i] = line.replace(" LIG X", f" LIG {atom_type}") return "\n".join(lines) - def process_pdb(self, input_file, output_file): - """Wrapper function to process a PDB file. + def process_pdb(self, input_file: str, output_file: str) -> None: + """ + Wrapper function to process a PDB file. Args: input_file (str): Path to the input PDB file. - output_file (str): Path of the output PDB file. + output_file (str): Path to the output PDB file. """ with open(input_file, "r") as f: pdb_data = f.read() diff --git a/openmmdl/openmmdl_analysis/rdkit_figure_generation.py b/openmmdl/openmmdl_analysis/rdkit_figure_generation.py index e435f9a2..94860543 100644 --- a/openmmdl/openmmdl_analysis/rdkit_figure_generation.py +++ b/openmmdl/openmmdl_analysis/rdkit_figure_generation.py @@ -7,18 +7,19 @@ import pylab import os import MDAnalysis as mda - +from typing import List, Dict, Tuple +from typing import List, Optional class LigandImageGenerator: def __init__( self, - ligand_name, - complex_pdb_file, - ligand_no_h_pdb_file, - smiles_file, - output_svg_filename, - fig_type="svg", - ): + ligand_name: str, + complex_pdb_file: str, + ligand_no_h_pdb_file: str, + smiles_file: str, + output_svg_filename: str, + fig_type: str = "svg", + ) -> None: """ Initialize the LigandImageGenerator class. @@ -30,14 +31,14 @@ def __init__( output_svg_filename (str): Name of the output SVG file. fig_type (str): Type of the output figure. Can be "svg" or "png". """ - self.ligand_name = ligand_name - self.complex_pdb_file = complex_pdb_file - self.ligand_no_h_pdb_file = ligand_no_h_pdb_file - self.smiles_file = smiles_file - self.output_svg_filename = output_svg_filename - self.fig_type = fig_type - - def generate_image(self): + self.ligand_name: str = ligand_name + self.complex_pdb_file: str = complex_pdb_file + self.ligand_no_h_pdb_file: str = ligand_no_h_pdb_file + self.smiles_file: str = smiles_file + self.output_svg_filename: str = output_svg_filename + self.fig_type: str = fig_type + + def generate_image(self) -> None: """Generates an SVG image (or PNG) of the ligand.""" try: # Load complex and ligand structures @@ -47,31 +48,31 @@ def generate_image(self): complex_lig = complex.select_atoms(f"resname {self.ligand_name}") # Load ligand from PDB file - mol = Chem.MolFromPDBFile(self.ligand_no_h_pdb_file) - lig_rd = mol + mol: Chem.Mol = Chem.MolFromPDBFile(self.ligand_no_h_pdb_file) + lig_rd: Chem.Mol = mol # Load reference SMILES with open(self.smiles_file, "r") as file: - reference_smiles = file.read().strip() - reference_mol = Chem.MolFromSmiles(reference_smiles) + reference_smiles: str = file.read().strip() + reference_mol: Chem.Mol = Chem.MolFromSmiles(reference_smiles) # Prepare ligand - prepared_ligand = AllChem.AssignBondOrdersFromTemplate( + prepared_ligand: Chem.Mol = AllChem.AssignBondOrdersFromTemplate( reference_mol, lig_rd ) AllChem.Compute2DCoords(prepared_ligand) # Map atom indices between ligand_no_h and complex for atom in prepared_ligand.GetAtoms(): - atom_index = atom.GetIdx() + atom_index: int = atom.GetIdx() for lig_atom in lig_noh: - lig_index = lig_atom.index + lig_index: int = lig_atom.index if atom_index == lig_index: - lig_atom_name = lig_atom.name + lig_atom_name: str = lig_atom.name for comp_lig in complex_lig: - comp_lig_name = comp_lig.name + comp_lig_name: str = comp_lig.name if lig_atom_name == comp_lig_name: - num = int(comp_lig.id) + num: int = int(comp_lig.id) atom.SetAtomMapNum(num) # Generate an SVG image of the ligand @@ -82,13 +83,13 @@ def generate_image(self): drawer.DrawMolecule(prepared_ligand) # Adjust font size in the SVG output using the FontSize method - font_size = drawer.FontSize() + font_size: float = drawer.FontSize() drawer.SetFontSize( font_size * 0.5 ) # You can adjust the multiplier as needed drawer.FinishDrawing() - svg = drawer.GetDrawingText().replace("svg:", "") + svg: str = drawer.GetDrawingText().replace("svg:", "") # Save the SVG image to the specified output file with open(self.output_svg_filename, "w") as f: @@ -96,7 +97,7 @@ def generate_image(self): # Convert to PNG if requested if self.fig_type == "png": - png_filename = self.output_svg_filename.replace(".svg", ".png") + png_filename: str = self.output_svg_filename.replace(".svg", ".png") cairosvg.svg2png(url=self.output_svg_filename, write_to=png_filename) print(f"PNG image saved as: {png_filename}") @@ -104,8 +105,11 @@ def generate_image(self): print(f"Error: {e}") + + + class InteractionProcessor: - def __init__(self, complex_pdb_file, ligand_no_h_pdb_file): + def __init__(self, complex_pdb_file: str, ligand_no_h_pdb_file: str): """ Initialize the InteractionProcessor class. @@ -119,14 +123,14 @@ def __init__(self, complex_pdb_file, ligand_no_h_pdb_file): self.ligand_no_h = mda.Universe(ligand_no_h_pdb_file) self.lig_noh = self.ligand_no_h.select_atoms("all") - def split_interaction_data(self, data): + def split_interaction_data(self, data: List[str]) -> List[str]: """Splits the input data into multiple parts. Args: - data (list): A list of ResNr and ResType, Atom indices, interaction type that needs to be split. + data (List[str]): A list of ResNr and ResType, Atom indices, interaction type that needs to be split. Returns: - list: A new list of the interaction data that consists of three parts. + List[str]: A new list of the interaction data that consists of three parts. """ split_data = [] for item in data: @@ -138,28 +142,43 @@ def split_interaction_data(self, data): split_data.append(split_value) return split_data - def highlight_numbers(self, split_data, starting_idx): + def highlight_numbers( + self, split_data: List[str], starting_idx: List[int] + ) -> Tuple[ + List[int], + List[int], + List[int], + List[int], + List[int], + List[int], + List[int], + List[int], + List[int], + List[int], + List[int], + ]: """Extracts the data from the split_data output of the interactions and categorizes it to its respective list. Args: - split_data (list): A list of interaction data items, where each item contains information about protein partner name, + split_data (List[str]): A list of interaction data items, where each item contains information about protein partner name, numeric codes and interaction type. - starting_idx (list): Starting index of the ligand atom indices used for identifying the correct atom to highlight. + starting_idx (List[int]): Starting index of the ligand atom indices used for identifying the correct atom to highlight. Returns: - tuple: A tuple that contains list of all of the highlighted atoms of all of the interactions. + Tuple[List[int], List[int], List[int], List[int], List[int], List[int], List[int], List[int], List[int], List[int], List[int]]: + A tuple that contains list of all of the highlighted atoms of all of the interactions. """ - highlighted_hbond_acceptor = [] - highlighted_hbond_donor = [] - highlighted_hydrophobic = [] - highlighted_hbond_both = [] - highlighted_waterbridge = [] - highlighted_pistacking = [] - highlighted_halogen = [] - highlighted_ni = [] - highlighted_pi = [] - highlighted_pication = [] - highlighted_metal = [] + highlighted_hbond_acceptor: List[int] = [] + highlighted_hbond_donor: List[int] = [] + highlighted_hydrophobic: List[int] = [] + highlighted_hbond_both: List[int] = [] + highlighted_waterbridge: List[int] = [] + highlighted_pistacking: List[int] = [] + highlighted_halogen: List[int] = [] + highlighted_ni: List[int] = [] + highlighted_pi: List[int] = [] + highlighted_pication: List[int] = [] + highlighted_metal: List[int] = [] for item in split_data: parts = item.split() @@ -300,15 +319,17 @@ def highlight_numbers(self, split_data, starting_idx): highlighted_metal, ) - def generate_interaction_dict(self, interaction_type, keys): + def generate_interaction_dict( + self, interaction_type: str, keys: List[int] + ) -> Dict[int, Tuple[float, float, float]]: """Generates a dictionary of interaction RGB color model based on the provided interaction type. Args: interaction_type (str): The type of the interaction, for example 'hydrophobic'. - keys (list): List of the highlighted atoms that display an interaction. + keys (List[int]): List of the highlighted atoms that display an interaction. Returns: - dict: A dictionary with the interaction types are associated with their respective RGB color codes. + Dict[int, Tuple[float, float, float]]: A dictionary with the interaction types associated with their respective RGB color codes. """ interaction_dict = { "hbond_acceptor": (1.0, 0.6, 0.6), @@ -324,17 +345,21 @@ def generate_interaction_dict(self, interaction_type, keys): "metal": (1.0, 0.6, 0.0), } - interaction_dict = { - int(key): interaction_dict[interaction_type] for key in keys - } - return interaction_dict + if interaction_type not in interaction_dict: + raise ValueError(f"Unknown interaction type: {interaction_type}") - def update_dict(self, target_dict, *source_dicts): + return {int(key): interaction_dict[interaction_type] for key in keys} + + def update_dict( + self, + target_dict: Dict[int, Tuple[float, float, float]], + *source_dicts: Dict[int, Tuple[float, float, float]], + ): """Updates the dictionary with the keys and values from other dictionaries. Args: - target_dict (dict): The dictionary that needs to be updated with new keys and values. - source_dicts (dict): One or multiple dictionaries that are used to update the target dictionary with new keys and values. + target_dict (Dict[int, Tuple[float, float, float]]): The dictionary that needs to be updated with new keys and values. + source_dicts (Dict[int, Tuple[float, float, float]]): One or more dictionaries with keys and values to be merged into the target dictionary. """ for source_dict in source_dicts: for key, value in source_dict.items(): @@ -345,14 +370,18 @@ def update_dict(self, target_dict, *source_dicts): class ImageMerger: def __init__( - self, binding_mode, occurrence_percent, split_data, merged_image_paths - ): + self, + binding_mode: str, + occurrence_percent: float, + split_data: List[str], + merged_image_paths: List[str], + ) -> None: self.binding_mode = binding_mode self.occurrence_percent = occurrence_percent self.split_data = split_data self.merged_image_paths = merged_image_paths - def create_and_merge_images(self): + def create_and_merge_images(self) -> List[str]: """Create and merge images to generate a legend for binding modes. Returns: @@ -374,17 +403,19 @@ def create_and_merge_images(self): for i, data in enumerate(filtered_split_data): y = data_points[i] label = data.split()[-1] - type = data.split()[-2] + type_ = data.split()[ + -2 + ] if label == "hydrophobic": (line,) = ax.plot( x, y, label=data, color=(1.0, 1.0, 0.0), linewidth=5.0 ) elif label == "hbond": - if type == "Acceptor": + if type_ == "Acceptor": (line,) = ax.plot( x, y, label=data, color=(1.0, 0.6, 0.6), linewidth=5.0 ) - elif type == "Donor": + elif type_ == "Donor": (line,) = ax.plot( x, y, label=data, color=(0.3, 0.5, 1.0), linewidth=5.0 ) @@ -409,11 +440,11 @@ def create_and_merge_images(self): x, y, label=data, color=(1.0, 0.6, 0.0), linewidth=5.0 ) elif label == "saltbridge": - if type == "NI": + if type_ == "NI": (line,) = ax.plot( x, y, label=data, color=(0.0, 0.0, 1.0), linewidth=5.0 ) - elif type == "PI": + elif type_ == "PI": (line,) = ax.plot( x, y, label=data, color=(1.0, 0.0, 0.0), linewidth=5.0 ) @@ -476,11 +507,11 @@ def create_and_merge_images(self): class FigureArranger: - def __init__(self, merged_image_paths, output_path): + def __init__(self, merged_image_paths: List[str], output_path: str) -> None: self.merged_image_paths = merged_image_paths self.output_path = output_path - def arranged_figure_generation(self): + def arranged_figure_generation(self) -> None: """Generate an arranged figure by arranging merged images in rows and columns.""" # Open the list of images merged_images = [Image.open(path) for path in self.merged_image_paths] diff --git a/openmmdl/openmmdl_analysis/rmsd_calculation.py b/openmmdl/openmmdl_analysis/rmsd_calculation.py index 60a650da..65d64f76 100644 --- a/openmmdl/openmmdl_analysis/rmsd_calculation.py +++ b/openmmdl/openmmdl_analysis/rmsd_calculation.py @@ -7,10 +7,11 @@ import MDAnalysis as mda from MDAnalysis.analysis import rms, diffusionmap from MDAnalysis.analysis.distances import dist +from typing import Optional, List, Tuple, Union @jit(nopython=True, parallel=True, nogil=True) -def calc_rmsd_2frames_jit(ref, frame): +def calc_rmsd_2frames_jit(ref: np.ndarray, frame: np.ndarray) -> float: dist = np.zeros(len(frame)) for atom in range(len(frame)): dist[atom] = ( @@ -23,12 +24,16 @@ def calc_rmsd_2frames_jit(ref, frame): class RMSDAnalyzer: - def __init__(self, prot_lig_top_file, prot_lig_traj_file): - self.prot_lig_top_file = prot_lig_top_file - self.prot_lig_traj_file = prot_lig_traj_file - self.universe = mda.Universe(prot_lig_top_file, prot_lig_traj_file) + def __init__(self, prot_lig_top_file: str, prot_lig_traj_file: str) -> None: + self.prot_lig_top_file: str = prot_lig_top_file + self.prot_lig_traj_file: str = prot_lig_traj_file + self.universe: mda.Universe = mda.Universe( + prot_lig_top_file, prot_lig_traj_file + ) - def rmsd_for_atomgroups(self, fig_type, selection1, selection2=None): + def rmsd_for_atomgroups( + self, fig_type: str, selection1: str, selection2: Optional[List[str]] = None + ) -> pd.DataFrame: """Calculate the RMSD for selected atom groups, and save the csv file and plot. Args: @@ -54,16 +59,18 @@ def rmsd_for_atomgroups(self, fig_type, selection1, selection2=None): os.makedirs(output_directory, exist_ok=True) # Save the RMSD values to a CSV file in the created directory - rmsd_df.to_csv("./RMSD/RMSD_over_time.csv", sep=" ") + rmsd_df.to_csv(f"{output_directory}/RMSD_over_time.csv", sep=" ") # Plot and save the RMSD over time as a PNG file rmsd_df.plot(title="RMSD of protein and ligand") plt.ylabel("RMSD (Å)") - plt.savefig(f"./RMSD/RMSD_over_time.{fig_type}") + plt.savefig(f"{output_directory}/RMSD_over_time.{fig_type}") return rmsd_df - def rmsd_dist_frames(self, fig_type, lig, nucleic=False): + def rmsd_dist_frames( + self, fig_type: str, lig: str, nucleic: bool = False + ) -> Tuple[np.ndarray, np.ndarray]: """Calculate the RMSD between all frames in a matrix. Args: @@ -76,24 +83,24 @@ def rmsd_dist_frames(self, fig_type, lig, nucleic=False): np.array: pairwise_rmsd_lig. Numpy array of RMSD values for ligand structures. """ if nucleic: - pairwise_rmsd_prot = ( + pairwise_rmsd_prot: np.ndarray = ( diffusionmap.DistanceMatrix(self.universe, select="nucleic") .run() .dist_matrix ) else: - pairwise_rmsd_prot = ( + pairwise_rmsd_prot: np.ndarray = ( diffusionmap.DistanceMatrix(self.universe, select="protein") .run() .dist_matrix ) - pairwise_rmsd_lig = ( + pairwise_rmsd_lig: np.ndarray = ( diffusionmap.DistanceMatrix(self.universe, f"resname {lig}") .run() .dist_matrix ) - max_dist = max(np.amax(pairwise_rmsd_lig), np.amax(pairwise_rmsd_prot)) + max_dist: float = max(np.amax(pairwise_rmsd_lig), np.amax(pairwise_rmsd_prot)) fig, ax = plt.subplots(1, 2) fig.suptitle("RMSD between the frames") @@ -116,17 +123,17 @@ def rmsd_dist_frames(self, fig_type, lig, nucleic=False): img1, ax=ax, orientation="horizontal", fraction=0.1, label="RMSD (Å)" ) - plt.savefig(f"./RMSD/RMSD_between_the_frames.{fig_type}") + plt.savefig(f"{output_directory}/RMSD_between_the_frames.{fig_type}") return pairwise_rmsd_prot, pairwise_rmsd_lig - def calc_rmsd_2frames(self, ref, frame): + def calc_rmsd_2frames(self, ref: np.ndarray, frame: np.ndarray) -> float: """ RMSD calculation between a reference and a frame. """ return calc_rmsd_2frames_jit(ref, frame) - def calculate_distance_matrix(self, selection): - distances = np.zeros( + def calculate_distance_matrix(self, selection: str) -> np.ndarray: + distances: np.ndarray = np.zeros( (len(self.universe.trajectory), len(self.universe.trajectory)) ) # calculate distance matrix @@ -135,44 +142,35 @@ def calculate_distance_matrix(self, selection): desc="\033[1mCalculating distance matrix:\033[0m", ): self.universe.trajectory[i] - frame_i = self.universe.select_atoms(selection).positions - # distances[i] = md.rmsd(traj_aligned, traj_aligned, frame=i) + frame_i: np.ndarray = self.universe.select_atoms(selection).positions for j in range(i + 1, len(self.universe.trajectory)): self.universe.trajectory[j] - frame_j = self.universe.select_atoms(selection).positions - rmsd = self.calc_rmsd_2frames(frame_i, frame_j) + frame_j: np.ndarray = self.universe.select_atoms(selection).positions + rmsd: float = self.calc_rmsd_2frames(frame_i, frame_j) distances[i][j] = rmsd distances[j][i] = rmsd return distances - def calculate_representative_frame(self, bmode_frames, DM): - """Calculates the most representative frame for a bindingmode. This is based uppon the averagwe RMSD of a frame to all other frames in the binding mode. + def calculate_representative_frame( + self, binding_mode_frames: List[int], distance_matrix: np.ndarray + ) -> int: + """Calculates the most representative frame for a binding mode. This is based upon the average RMSD of a frame to all other frames in the binding mode. Args: - bmode_frame_list (list): List of frames belonging to a binding mode. - DM (np.array): Distance matrix of trajectory. + binding_mode_frames (list): List of frames belonging to a binding mode. + distance_matrix (np.array): Distance matrix of trajectory. Returns: int: Number of the most representative frame. """ - frames = bmode_frames - mean_rmsd_per_frame = {} - # first loop : first frame - for frame_i in frames: + mean_rmsd_per_frame: dict = {} + for frame_i in binding_mode_frames: mean_rmsd_per_frame[frame_i] = 0 - # we will add the rmsd between theses 2 frames and then calcul the - # mean - for frame_j in frames: - # We don't want to calcul the same frame. - if not frame_j == frame_i: - # we add to the corresponding value in the list of all rmsd - # the RMSD betwween frame_i and frame_j - mean_rmsd_per_frame[frame_i] += DM[frame_i - 1, frame_j - 1] - # mean calculation - mean_rmsd_per_frame[frame_i] /= len(frames) - - # Representative frame = frame with lower RMSD between all other - # frame of the cluster - repre = min(mean_rmsd_per_frame, key=mean_rmsd_per_frame.get) + for frame_j in binding_mode_frames: + if frame_j != frame_i: + mean_rmsd_per_frame[frame_i] += distance_matrix[frame_i - 1, frame_j - 1] + mean_rmsd_per_frame[frame_i] /= len(binding_mode_frames) + + repre: int = min(mean_rmsd_per_frame, key=mean_rmsd_per_frame.get) return repre diff --git a/openmmdl/openmmdl_analysis/visualization.ipynb b/openmmdl/openmmdl_analysis/visualization.ipynb deleted file mode 100644 index 8040c014..00000000 --- a/openmmdl/openmmdl_analysis/visualization.ipynb +++ /dev/null @@ -1,94 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "from openmmdl.openmmdl_analysis.visualization_functions import Visualizer\n", - "import MDAnalysis as mda" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "ligname = 'UNK' # add the name of the ligand from the original PDB File\n", - "special = None # if special ligand like HEM is present add the name of the special ligand from the original PDB File\n", - "md = mda.Universe('PATH_TO_interacting_waters.pdb', 'PATH_TO_interacting_waters.dcd')\n", - "cloud = 'PATH_TO_clouds.json'\n", - "\n", - "visualizer = Visualizer(md, cloud, ligname, special)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "09acfdb0e59f46d481905cef43c231c0", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "ThemeManager()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "4b33c5fd6d6541ed8b9766bdbfe27e42", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "NGLWidget(gui_style='ngl', layout=Layout(height='1000px', width='1000px'), max_frame=9999)" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# Predefined variables: receptor_type=\"protein or nucleic\", height=\"1000px\", width=\"1000px\"\n", - "view = visualizer.visualize()\n", - "view.display(gui=True, style='ngl')" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.10" - }, - "vscode": { - "interpreter": { - "hash": "e021570d17280ebc37ed8e9fc470ea25281789feeed8bd826cec564b6552ec2d" - } - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/openmmdl/openmmdl_analysis/visualization_functions.py b/openmmdl/openmmdl_analysis/visualization_functions.py index f7981118..8f935f7b 100644 --- a/openmmdl/openmmdl_analysis/visualization_functions.py +++ b/openmmdl/openmmdl_analysis/visualization_functions.py @@ -6,30 +6,35 @@ import subprocess import os import shutil +from typing import List, Dict, Any, Optional, Union from openmmdl.openmmdl_analysis.barcode_generation import BarcodeGenerator class TrajectorySaver: - def __init__(self, pdb_md, ligname, special, nucleic): - """Initializes the TrajectorySaver with a mda.Universe object, ligand name, special residue name and if receptor is nucleic. + def __init__( + self, pdb_md: mda.Universe, ligname: str, special: str, nucleic: bool + ) -> None: + """Initializes the TrajectorySaver with an mda.Universe object, ligand name, special residue name and receptor type. Args: - pdb_md (mda.Universe): MDAnalysis Universe object containing the trajectory - ligname (str): name of the ligand in the pdb file - special (str): name of the special residue/ligand in the pdb file (e.g. HEM) - nucleic (bool): True if receptor is nucleic, False otherwise + pdb_md (mda.Universe): MDAnalysis Universe object containing the trajectory. + ligname (str): Name of the ligand in the pdb file. + special (str): Name of the special residue/ligand in the pdb file (e.g., HEM). + nucleic (bool): True if the receptor is nucleic, False otherwise. """ self.pdb_md = pdb_md self.ligname = ligname self.special = special self.nucleic = nucleic - def save_interacting_waters_trajectory(self, interacting_waters, outputpath): + def save_interacting_waters_trajectory( + self, interacting_waters: List[int], outputpath: str = "./Visualization/" + ) -> None: """Saves .pdb and .dcd files of the trajectory containing ligand, receptor and all interacting waters. Args: - interacting_waters (list): list of all interacting water ids - outputpath (str, optional): filepath to output new pdb and dcd files. Defaults to './Visualization/'. + interacting_waters (List[int]): List of all interacting water IDs. + outputpath (str, optional): Filepath to output new pdb and dcd files. Defaults to './Visualization/'. """ water_atoms = self.pdb_md.select_atoms( f"protein or nucleic or resname {self.ligname} or resname {self.special}" @@ -47,13 +52,15 @@ def save_interacting_waters_trajectory(self, interacting_waters, outputpath): for ts in self.pdb_md.trajectory: W.write(water_atoms) - def save_frame(self, frame, outpath, selection=False): + def save_frame( + self, frame: int, outpath: str, selection: Optional[str] = None + ) -> None: """Saves a single frame of the trajectory. Args: - frame (int): Number of the frame to save - outpath (str): Path to save the frame to - selection (str, optional): MDAnalysis selection string. Defaults to False. + frame (int): Number of the frame to save. + outpath (str): Path to save the frame to. + selection (Optional[str], optional): MDAnalysis selection string. Defaults to None. """ self.pdb_md.trajectory[frame] if selection: @@ -64,30 +71,48 @@ def save_frame(self, frame, outpath, selection=False): class Visualizer: - def __init__(self, md, cloud_path, ligname, special): + def __init__( + self, + md: mda.Universe, + cloud_path: str, + ligname: str, + special: Optional[str] = None, + ) -> None: self.md = md self.cloud = self.load_cloud(cloud_path) self.ligname = ligname self.special = special - def load_cloud(self, cloud_path): + def load_cloud( + self, cloud_path: str + ) -> Dict[str, Dict[str, Union[List[float], List[int]]]]: + """Loads interaction cloud data from a JSON file. + + Args: + cloud_path (str): Path to the cloud data JSON file. + + Returns: + Dict[str, Dict[str, Union[List[float], List[int]]]]: The loaded cloud data. + """ with open(cloud_path, "r") as f: data = json.load(f) return data def visualize( - self, receptor_type="protein or nucleic", height="1000px", width="1000px" + self, + receptor_type: str = "protein or nucleic", + height: str = "1000px", + width: str = "1000px", ): """Generates visualization of the trajectory with the interacting waters and interaction clouds. Args: - ligname (str): name of the ligand in the pdb file - receptor_type (str, optional): type of receptor. Defaults to 'protein or nucleic'. - height (str, optional): height of the visualization. Defaults to '1000px'. - width (str, optional): width of the visualization. Defaults to '1000px'. + receptor_type (str, optional): Type of receptor. Defaults to 'protein or nucleic'. + height (str, optional): Height of the visualization. Defaults to '1000px'. + width (str, optional): Width of the visualization. Defaults to '1000px'. Returns: - nglview widget: returns an nglview.widget object containing the visualization + nglview widget: Returns an nglview.widget object containing the visualization. """ sphere_buffers = [] @@ -141,8 +166,8 @@ def visualize( return view -def run_visualization(): - """Runs the visualization notebook in the current directory. The visualization notebook is copied from the package directory to the current directory and automaticaly started.""" +def run_visualization() -> None: + """Runs the visualization notebook in the current directory. The visualization notebook is copied from the package directory to the current directory and automatically started.""" package_dir = os.path.dirname(__file__) notebook_path = os.path.join(package_dir, "visualization.ipynb") current_dir = os.getcwd()