From ba33eb08fb44b39d3d3159539538d42db4b9bceb Mon Sep 17 00:00:00 2001 From: Jacob Nesbitt Date: Tue, 24 Sep 2024 18:34:06 -0400 Subject: [PATCH 1/2] perf: Implement GCC algorithm in native postgres --- uvdat/core/rest/dataset.py | 95 ++++++++++++++++++++++++++++++++------ 1 file changed, 82 insertions(+), 13 deletions(-) diff --git a/uvdat/core/rest/dataset.py b/uvdat/core/rest/dataset.py index 1ef7eaf9..85f6f1be 100644 --- a/uvdat/core/rest/dataset.py +++ b/uvdat/core/rest/dataset.py @@ -1,11 +1,12 @@ import json +from django.db import connection from django.http import HttpResponse from rest_framework.decorators import action from rest_framework.response import Response from rest_framework.viewsets import ModelViewSet -from uvdat.core.models import Dataset, NetworkEdge, NetworkNode +from uvdat.core.models import Dataset, Network, NetworkEdge, NetworkNode from uvdat.core.rest.access_control import GuardianFilter, GuardianPermission from uvdat.core.rest.serializers import ( DatasetSerializer, @@ -16,6 +17,76 @@ ) from uvdat.core.tasks.chart import add_gcc_chart_datum +GCC_QUERY = """ +WITH RECURSIVE n as ( + -- starting node + SELECT id FROM ( + SELECT cnn.id + FROM core_networknode cnn + WHERE + cnn.network_id = %(network_id)s AND + NOT (cnn.id = ANY(%(excluded_nodes)s)) + ORDER BY random() + LIMIT 1 + ) nn + UNION + -- Select the *other* node in the edge + SELECT CASE + WHEN e.to_node_id = n.id + THEN e.from_node_id + ELSE e.to_node_id + END + FROM n + JOIN ( + SELECT * + FROM core_networkedge ne + WHERE + ne.network_id = %(network_id)s AND + NOT ( + ne.from_node_id = ANY(%(excluded_nodes)s) OR + ne.to_node_id = ANY(%(excluded_nodes)s) + ) + ) e + ON + e.from_node_id = n.id OR + e.to_node_id = n.id +) +SELECT id FROM n ORDER BY id +; +""" + + +def find_network_gcc(network: Network, excluded_nodes: list[int]) -> list[int]: + total_nodes = NetworkNode.objects.filter(network=network).count() + + # This is used to store all the nodes we've already visited, + # starting with the explicitly excluded nodes + cur_excluded_nodes = excluded_nodes.copy() + + # Store largest network found so far + gcc = [] + + with connection.cursor() as cursor: + # If the GCC size is greater than half the network, we know that there's no way to find a + # larger one. If we've exhausted all nodes, also stop searching + while not (len(gcc) > (total_nodes // 2) or len(cur_excluded_nodes) >= total_nodes): + cursor.execute( + GCC_QUERY, + { + 'excluded_nodes': cur_excluded_nodes, + 'network_id': network.pk, + }, + ) + nodes = [x[0] for x in cursor.fetchall()] + if not nodes: + raise Exception('Expected to find nodes but found none') + + cur_excluded_nodes.extend(nodes) + if len(nodes) > len(gcc): + gcc = nodes + + return gcc + class DatasetViewSet(ModelViewSet): queryset = Dataset.objects.all() @@ -82,16 +153,14 @@ def gcc(self, request, **kwargs): exclude_nodes = exclude_nodes.split(',') exclude_nodes = [int(n) for n in exclude_nodes if len(n)] - # TODO: improve this for datasets with multiple networks; - # this currently returns the gcc for the network with the most excluded nodes - results = [] + # Find the GCC for each network in the dataset + network_gccs: list[list[int]] = [] for network in dataset.networks.all(): - excluded_node_names = [n.name for n in network.nodes.all() if n.id in exclude_nodes] - gcc = network.get_gcc(exclude_nodes) - results.append(dict(excluded=excluded_node_names, gcc=gcc)) - if len(results): - results.sort(key=lambda r: len(r.get('excluded')), reverse=True) - gcc = results[0].get('gcc') - excluded = results[0].get('excluded') - add_gcc_chart_datum(dataset, project_id, excluded, len(gcc)) - return HttpResponse(json.dumps(gcc), status=200) + network_gccs.append(find_network_gcc(network=network, excluded_nodes=exclude_nodes)) + + # TODO: improve this for datasets with multiple networks. + # This currently returns the gcc for the network with the most excluded nodes + gcc = max(network_gccs, key=len) + + add_gcc_chart_datum(dataset, project_id, exclude_nodes, len(gcc)) + return Response(gcc, status=200) From 1eadf03c9949d6d85dce9f2f94cb6e0d32a24925 Mon Sep 17 00:00:00 2001 From: Jacob Nesbitt Date: Wed, 25 Sep 2024 17:08:18 -0400 Subject: [PATCH 2/2] refactor: Move find_network_gcc into Network.get_gcc --- uvdat/core/models/networks.py | 79 +++++++++++++++++++++++++++++------ uvdat/core/rest/dataset.py | 74 +------------------------------- 2 files changed, 69 insertions(+), 84 deletions(-) diff --git a/uvdat/core/models/networks.py b/uvdat/core/models/networks.py index 235156d2..5bafd566 100644 --- a/uvdat/core/models/networks.py +++ b/uvdat/core/models/networks.py @@ -1,27 +1,82 @@ from django.contrib.gis.db import models as geo_models -from django.db import models -import networkx as nx +from django.db import connection, models from .dataset import Dataset +GCC_QUERY = """ +WITH RECURSIVE n as ( + -- starting node + SELECT id FROM ( + SELECT cnn.id + FROM core_networknode cnn + WHERE + cnn.network_id = %(network_id)s AND + NOT (cnn.id = ANY(%(excluded_nodes)s)) + ORDER BY random() + LIMIT 1 + ) nn + UNION + -- Select the *other* node in the edge + SELECT CASE + WHEN e.to_node_id = n.id + THEN e.from_node_id + ELSE e.to_node_id + END + FROM n + JOIN ( + SELECT * + FROM core_networkedge ne + WHERE + ne.network_id = %(network_id)s AND + NOT ( + ne.from_node_id = ANY(%(excluded_nodes)s) OR + ne.to_node_id = ANY(%(excluded_nodes)s) + ) + ) e + ON + e.from_node_id = n.id OR + e.to_node_id = n.id +) +SELECT id FROM n ORDER BY id +; +""" + class Network(models.Model): dataset = models.ForeignKey(Dataset, on_delete=models.CASCADE, related_name='networks') category = models.CharField(max_length=25) metadata = models.JSONField(blank=True, null=True) - def get_graph(self): - from uvdat.core.tasks.networks import get_network_graph + def get_gcc(self, excluded_nodes: list[int]): + total_nodes = NetworkNode.objects.filter(network=self).count() + + # This is used to store all the nodes we've already visited, + # starting with the explicitly excluded nodes + cur_excluded_nodes = excluded_nodes.copy() + + # Store largest network found so far + gcc = [] + + with connection.cursor() as cursor: + # If the GCC size is greater than half the network, we know that there's no way to + # find a larger one. If we've exhausted all nodes, also stop searching. + while not (len(gcc) > (total_nodes // 2) or len(cur_excluded_nodes) >= total_nodes): + cursor.execute( + GCC_QUERY, + { + 'excluded_nodes': cur_excluded_nodes, + 'network_id': self.pk, + }, + ) + nodes = [x[0] for x in cursor.fetchall()] + if not nodes: + raise Exception('Expected to find nodes but found none') - return get_network_graph(self) + cur_excluded_nodes.extend(nodes) + if len(nodes) > len(gcc): + gcc = nodes - def get_gcc(self, exclude_nodes): - graph = self.get_graph() - graph.remove_nodes_from(exclude_nodes) - if graph.number_of_nodes == 0 or nx.number_connected_components(graph) == 0: - return [] - gcc = max(nx.connected_components(graph), key=len) - return list(gcc) + return gcc class NetworkNode(models.Model): diff --git a/uvdat/core/rest/dataset.py b/uvdat/core/rest/dataset.py index 85f6f1be..dbfed9f7 100644 --- a/uvdat/core/rest/dataset.py +++ b/uvdat/core/rest/dataset.py @@ -1,6 +1,5 @@ import json -from django.db import connection from django.http import HttpResponse from rest_framework.decorators import action from rest_framework.response import Response @@ -17,76 +16,6 @@ ) from uvdat.core.tasks.chart import add_gcc_chart_datum -GCC_QUERY = """ -WITH RECURSIVE n as ( - -- starting node - SELECT id FROM ( - SELECT cnn.id - FROM core_networknode cnn - WHERE - cnn.network_id = %(network_id)s AND - NOT (cnn.id = ANY(%(excluded_nodes)s)) - ORDER BY random() - LIMIT 1 - ) nn - UNION - -- Select the *other* node in the edge - SELECT CASE - WHEN e.to_node_id = n.id - THEN e.from_node_id - ELSE e.to_node_id - END - FROM n - JOIN ( - SELECT * - FROM core_networkedge ne - WHERE - ne.network_id = %(network_id)s AND - NOT ( - ne.from_node_id = ANY(%(excluded_nodes)s) OR - ne.to_node_id = ANY(%(excluded_nodes)s) - ) - ) e - ON - e.from_node_id = n.id OR - e.to_node_id = n.id -) -SELECT id FROM n ORDER BY id -; -""" - - -def find_network_gcc(network: Network, excluded_nodes: list[int]) -> list[int]: - total_nodes = NetworkNode.objects.filter(network=network).count() - - # This is used to store all the nodes we've already visited, - # starting with the explicitly excluded nodes - cur_excluded_nodes = excluded_nodes.copy() - - # Store largest network found so far - gcc = [] - - with connection.cursor() as cursor: - # If the GCC size is greater than half the network, we know that there's no way to find a - # larger one. If we've exhausted all nodes, also stop searching - while not (len(gcc) > (total_nodes // 2) or len(cur_excluded_nodes) >= total_nodes): - cursor.execute( - GCC_QUERY, - { - 'excluded_nodes': cur_excluded_nodes, - 'network_id': network.pk, - }, - ) - nodes = [x[0] for x in cursor.fetchall()] - if not nodes: - raise Exception('Expected to find nodes but found none') - - cur_excluded_nodes.extend(nodes) - if len(nodes) > len(gcc): - gcc = nodes - - return gcc - class DatasetViewSet(ModelViewSet): queryset = Dataset.objects.all() @@ -156,7 +85,8 @@ def gcc(self, request, **kwargs): # Find the GCC for each network in the dataset network_gccs: list[list[int]] = [] for network in dataset.networks.all(): - network_gccs.append(find_network_gcc(network=network, excluded_nodes=exclude_nodes)) + network: Network + network_gccs.append(network.get_gcc(excluded_nodes=exclude_nodes)) # TODO: improve this for datasets with multiple networks. # This currently returns the gcc for the network with the most excluded nodes