Skip to content

Commit

Permalink
add async node creation
Browse files Browse the repository at this point in the history
  • Loading branch information
JR-1991 committed Mar 15, 2024
1 parent 4b085bd commit 3607c4e
Showing 1 changed file with 60 additions and 35 deletions.
95 changes: 60 additions & 35 deletions pyeed/network/network.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import List, Optional
import networkx as nx
import plotly.graph_objects as go
Expand Down Expand Up @@ -74,10 +75,13 @@ def add_target(self, target: AbstractSequence):
self.targets.append(target.source_id)

@property
def graph(self) -> nx.Graph:
def graph(self, n_parallel: int = 200) -> nx.Graph:
"""
Maps properties of alignments to a network graph.
Args:
n_parallel (int): The number of parallel tasks to run. Default is 200.
Returns:
nx.Graph: The network graph representing the sequence network.
Expand All @@ -94,40 +98,7 @@ def graph(self) -> nx.Graph:
graph = nx.Graph()

# Add nodes and assign node attributes
if all([isinstance(sequence, ProteinInfo) for sequence in self.sequences]):
for sequence in self.sequences:
graph.add_node(
sequence.source_id,
name=sequence.name,
familiy_name=sequence.family_name,
domain=sequence.organism.domain,
kingdome=sequence.organism.kingdom,
phylum=sequence.organism.phylum,
tax_class=sequence.organism.tax_class,
order=sequence.organism.order,
family=sequence.organism.family,
genus=sequence.organism.genus,
species=sequence.organism.species,
ec_number=sequence.ec_number,
mol_weight=sequence.mol_weight,
taxonomy_id=sequence.organism.taxonomy_id,
)

else:
for sequence in self.sequences:
graph.add_node(
sequence.source_id,
name=sequence.name,
domain=sequence.organism.domain,
kingdome=sequence.organism.kingdom,
phylum=sequence.organism.phylum,
tax_class=sequence.organism.tax_class,
order=sequence.organism.order,
family=sequence.organism.family,
genus=sequence.organism.genus,
species=sequence.organism.species,
taxonomy_id=sequence.organism.taxonomy_id,
)
asyncio.run(self._add_nodes(graph, n_parallel))

# Add edges and assign edge attributes
if self.threshold != None:
Expand Down Expand Up @@ -170,6 +141,60 @@ def graph(self) -> nx.Graph:

return graph

async def _add_nodes(
self,
graph: nx.Graph,
n_parallel: int = 200,
):
"""
Adds nodes to the graph.
Parameters:
- graph: The graph to add the nodes to.
- sequences: The list of sequence objects representing the nodes.
- n_parallel: The number of parallel tasks to run.
Returns:
None
"""

# Create semaphore to control the number of parallel tasks
sem = asyncio.Semaphore(n_parallel)

# Collect tasks
tasks = [self._add_node(graph, sequence, sem) for sequence in self.sequences]

# Run tasks
await asyncio.gather(*tasks)

@staticmethod
async def _add_node(
graph: nx.Graph,
sequence: ProteinInfo,
sem: asyncio.Semaphore,
):
"""
Adds a node to the graph.
Parameters:
- graph: The graph to add the node to.
- sequence: The sequence object representing the node.
- sem: The semaphore to control the number of parallel tasks.
Returns:
None
"""

async with sem:
assert isinstance(
sequence, ProteinInfo
), "Sequence must be a ProteinInfo object"

graph.add_node(
sequence.source_id,
**sequence.dict(exclude={"source_id"}),
)

def visualize(self):
"""
Visualizes the network graph.
Expand Down

0 comments on commit 3607c4e

Please sign in to comment.