diff --git a/uvdat/core/models/networks.py b/uvdat/core/models/networks.py index 235156d2..5bafd566 100644 --- a/uvdat/core/models/networks.py +++ b/uvdat/core/models/networks.py @@ -1,27 +1,82 @@ from django.contrib.gis.db import models as geo_models -from django.db import models -import networkx as nx +from django.db import connection, models from .dataset import Dataset +GCC_QUERY = """ +WITH RECURSIVE n as ( + -- starting node + SELECT id FROM ( + SELECT cnn.id + FROM core_networknode cnn + WHERE + cnn.network_id = %(network_id)s AND + NOT (cnn.id = ANY(%(excluded_nodes)s)) + ORDER BY random() + LIMIT 1 + ) nn + UNION + -- Select the *other* node in the edge + SELECT CASE + WHEN e.to_node_id = n.id + THEN e.from_node_id + ELSE e.to_node_id + END + FROM n + JOIN ( + SELECT * + FROM core_networkedge ne + WHERE + ne.network_id = %(network_id)s AND + NOT ( + ne.from_node_id = ANY(%(excluded_nodes)s) OR + ne.to_node_id = ANY(%(excluded_nodes)s) + ) + ) e + ON + e.from_node_id = n.id OR + e.to_node_id = n.id +) +SELECT id FROM n ORDER BY id +; +""" + class Network(models.Model): dataset = models.ForeignKey(Dataset, on_delete=models.CASCADE, related_name='networks') category = models.CharField(max_length=25) metadata = models.JSONField(blank=True, null=True) - def get_graph(self): - from uvdat.core.tasks.networks import get_network_graph + def get_gcc(self, excluded_nodes: list[int]): + total_nodes = NetworkNode.objects.filter(network=self).count() + + # This is used to store all the nodes we've already visited, + # starting with the explicitly excluded nodes + cur_excluded_nodes = excluded_nodes.copy() + + # Store largest network found so far + gcc = [] + + with connection.cursor() as cursor: + # If the GCC size is greater than half the network, we know that there's no way to + # find a larger one. If we've exhausted all nodes, also stop searching. + while not (len(gcc) > (total_nodes // 2) or len(cur_excluded_nodes) >= total_nodes): + cursor.execute( + GCC_QUERY, + { + 'excluded_nodes': cur_excluded_nodes, + 'network_id': self.pk, + }, + ) + nodes = [x[0] for x in cursor.fetchall()] + if not nodes: + raise Exception('Expected to find nodes but found none') - return get_network_graph(self) + cur_excluded_nodes.extend(nodes) + if len(nodes) > len(gcc): + gcc = nodes - def get_gcc(self, exclude_nodes): - graph = self.get_graph() - graph.remove_nodes_from(exclude_nodes) - if graph.number_of_nodes == 0 or nx.number_connected_components(graph) == 0: - return [] - gcc = max(nx.connected_components(graph), key=len) - return list(gcc) + return gcc class NetworkNode(models.Model): diff --git a/uvdat/core/rest/dataset.py b/uvdat/core/rest/dataset.py index 85f6f1be..dbfed9f7 100644 --- a/uvdat/core/rest/dataset.py +++ b/uvdat/core/rest/dataset.py @@ -1,6 +1,5 @@ import json -from django.db import connection from django.http import HttpResponse from rest_framework.decorators import action from rest_framework.response import Response @@ -17,76 +16,6 @@ ) from uvdat.core.tasks.chart import add_gcc_chart_datum -GCC_QUERY = """ -WITH RECURSIVE n as ( - -- starting node - SELECT id FROM ( - SELECT cnn.id - FROM core_networknode cnn - WHERE - cnn.network_id = %(network_id)s AND - NOT (cnn.id = ANY(%(excluded_nodes)s)) - ORDER BY random() - LIMIT 1 - ) nn - UNION - -- Select the *other* node in the edge - SELECT CASE - WHEN e.to_node_id = n.id - THEN e.from_node_id - ELSE e.to_node_id - END - FROM n - JOIN ( - SELECT * - FROM core_networkedge ne - WHERE - ne.network_id = %(network_id)s AND - NOT ( - ne.from_node_id = ANY(%(excluded_nodes)s) OR - ne.to_node_id = ANY(%(excluded_nodes)s) - ) - ) e - ON - e.from_node_id = n.id OR - e.to_node_id = n.id -) -SELECT id FROM n ORDER BY id -; -""" - - -def find_network_gcc(network: Network, excluded_nodes: list[int]) -> list[int]: - total_nodes = NetworkNode.objects.filter(network=network).count() - - # This is used to store all the nodes we've already visited, - # starting with the explicitly excluded nodes - cur_excluded_nodes = excluded_nodes.copy() - - # Store largest network found so far - gcc = [] - - with connection.cursor() as cursor: - # If the GCC size is greater than half the network, we know that there's no way to find a - # larger one. If we've exhausted all nodes, also stop searching - while not (len(gcc) > (total_nodes // 2) or len(cur_excluded_nodes) >= total_nodes): - cursor.execute( - GCC_QUERY, - { - 'excluded_nodes': cur_excluded_nodes, - 'network_id': network.pk, - }, - ) - nodes = [x[0] for x in cursor.fetchall()] - if not nodes: - raise Exception('Expected to find nodes but found none') - - cur_excluded_nodes.extend(nodes) - if len(nodes) > len(gcc): - gcc = nodes - - return gcc - class DatasetViewSet(ModelViewSet): queryset = Dataset.objects.all() @@ -156,7 +85,8 @@ def gcc(self, request, **kwargs): # Find the GCC for each network in the dataset network_gccs: list[list[int]] = [] for network in dataset.networks.all(): - network_gccs.append(find_network_gcc(network=network, excluded_nodes=exclude_nodes)) + network: Network + network_gccs.append(network.get_gcc(excluded_nodes=exclude_nodes)) # TODO: improve this for datasets with multiple networks. # This currently returns the gcc for the network with the most excluded nodes