Skip to content

Commit

Permalink
refactor: Move find_network_gcc into Network.get_gcc
Browse files Browse the repository at this point in the history
  • Loading branch information
jjnesbitt committed Oct 1, 2024
1 parent ba33eb0 commit 1eadf03
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 84 deletions.
79 changes: 67 additions & 12 deletions uvdat/core/models/networks.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,82 @@
from django.contrib.gis.db import models as geo_models
from django.db import models
import networkx as nx
from django.db import connection, models

from .dataset import Dataset

GCC_QUERY = """
WITH RECURSIVE n as (
-- starting node
SELECT id FROM (
SELECT cnn.id
FROM core_networknode cnn
WHERE
cnn.network_id = %(network_id)s AND
NOT (cnn.id = ANY(%(excluded_nodes)s))
ORDER BY random()
LIMIT 1
) nn
UNION
-- Select the *other* node in the edge
SELECT CASE
WHEN e.to_node_id = n.id
THEN e.from_node_id
ELSE e.to_node_id
END
FROM n
JOIN (
SELECT *
FROM core_networkedge ne
WHERE
ne.network_id = %(network_id)s AND
NOT (
ne.from_node_id = ANY(%(excluded_nodes)s) OR
ne.to_node_id = ANY(%(excluded_nodes)s)
)
) e
ON
e.from_node_id = n.id OR
e.to_node_id = n.id
)
SELECT id FROM n ORDER BY id
;
"""


class Network(models.Model):
dataset = models.ForeignKey(Dataset, on_delete=models.CASCADE, related_name='networks')
category = models.CharField(max_length=25)
metadata = models.JSONField(blank=True, null=True)

def get_graph(self):
from uvdat.core.tasks.networks import get_network_graph
def get_gcc(self, excluded_nodes: list[int]):
total_nodes = NetworkNode.objects.filter(network=self).count()

# This is used to store all the nodes we've already visited,
# starting with the explicitly excluded nodes
cur_excluded_nodes = excluded_nodes.copy()

# Store largest network found so far
gcc = []

with connection.cursor() as cursor:
# If the GCC size is greater than half the network, we know that there's no way to
# find a larger one. If we've exhausted all nodes, also stop searching.
while not (len(gcc) > (total_nodes // 2) or len(cur_excluded_nodes) >= total_nodes):
cursor.execute(
GCC_QUERY,
{
'excluded_nodes': cur_excluded_nodes,
'network_id': self.pk,
},
)
nodes = [x[0] for x in cursor.fetchall()]
if not nodes:
raise Exception('Expected to find nodes but found none')

return get_network_graph(self)
cur_excluded_nodes.extend(nodes)
if len(nodes) > len(gcc):
gcc = nodes

def get_gcc(self, exclude_nodes):
graph = self.get_graph()
graph.remove_nodes_from(exclude_nodes)
if graph.number_of_nodes == 0 or nx.number_connected_components(graph) == 0:
return []
gcc = max(nx.connected_components(graph), key=len)
return list(gcc)
return gcc


class NetworkNode(models.Model):
Expand Down
74 changes: 2 additions & 72 deletions uvdat/core/rest/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json

from django.db import connection
from django.http import HttpResponse
from rest_framework.decorators import action
from rest_framework.response import Response
Expand All @@ -17,76 +16,6 @@
)
from uvdat.core.tasks.chart import add_gcc_chart_datum

GCC_QUERY = """
WITH RECURSIVE n as (
-- starting node
SELECT id FROM (
SELECT cnn.id
FROM core_networknode cnn
WHERE
cnn.network_id = %(network_id)s AND
NOT (cnn.id = ANY(%(excluded_nodes)s))
ORDER BY random()
LIMIT 1
) nn
UNION
-- Select the *other* node in the edge
SELECT CASE
WHEN e.to_node_id = n.id
THEN e.from_node_id
ELSE e.to_node_id
END
FROM n
JOIN (
SELECT *
FROM core_networkedge ne
WHERE
ne.network_id = %(network_id)s AND
NOT (
ne.from_node_id = ANY(%(excluded_nodes)s) OR
ne.to_node_id = ANY(%(excluded_nodes)s)
)
) e
ON
e.from_node_id = n.id OR
e.to_node_id = n.id
)
SELECT id FROM n ORDER BY id
;
"""


def find_network_gcc(network: Network, excluded_nodes: list[int]) -> list[int]:
total_nodes = NetworkNode.objects.filter(network=network).count()

# This is used to store all the nodes we've already visited,
# starting with the explicitly excluded nodes
cur_excluded_nodes = excluded_nodes.copy()

# Store largest network found so far
gcc = []

with connection.cursor() as cursor:
# If the GCC size is greater than half the network, we know that there's no way to find a
# larger one. If we've exhausted all nodes, also stop searching
while not (len(gcc) > (total_nodes // 2) or len(cur_excluded_nodes) >= total_nodes):
cursor.execute(
GCC_QUERY,
{
'excluded_nodes': cur_excluded_nodes,
'network_id': network.pk,
},
)
nodes = [x[0] for x in cursor.fetchall()]
if not nodes:
raise Exception('Expected to find nodes but found none')

cur_excluded_nodes.extend(nodes)
if len(nodes) > len(gcc):
gcc = nodes

return gcc


class DatasetViewSet(ModelViewSet):
queryset = Dataset.objects.all()
Expand Down Expand Up @@ -156,7 +85,8 @@ def gcc(self, request, **kwargs):
# Find the GCC for each network in the dataset
network_gccs: list[list[int]] = []
for network in dataset.networks.all():
network_gccs.append(find_network_gcc(network=network, excluded_nodes=exclude_nodes))
network: Network
network_gccs.append(network.get_gcc(excluded_nodes=exclude_nodes))

# TODO: improve this for datasets with multiple networks.
# This currently returns the gcc for the network with the most excluded nodes
Expand Down

0 comments on commit 1eadf03

Please sign in to comment.