diff --git a/uvdat/core/rest/dataset.py b/uvdat/core/rest/dataset.py index 1ef7eaf9..85f6f1be 100644 --- a/uvdat/core/rest/dataset.py +++ b/uvdat/core/rest/dataset.py @@ -1,11 +1,12 @@ import json +from django.db import connection from django.http import HttpResponse from rest_framework.decorators import action from rest_framework.response import Response from rest_framework.viewsets import ModelViewSet -from uvdat.core.models import Dataset, NetworkEdge, NetworkNode +from uvdat.core.models import Dataset, Network, NetworkEdge, NetworkNode from uvdat.core.rest.access_control import GuardianFilter, GuardianPermission from uvdat.core.rest.serializers import ( DatasetSerializer, @@ -16,6 +17,76 @@ ) from uvdat.core.tasks.chart import add_gcc_chart_datum +GCC_QUERY = """ +WITH RECURSIVE n as ( + -- starting node + SELECT id FROM ( + SELECT cnn.id + FROM core_networknode cnn + WHERE + cnn.network_id = %(network_id)s AND + NOT (cnn.id = ANY(%(excluded_nodes)s)) + ORDER BY random() + LIMIT 1 + ) nn + UNION + -- Select the *other* node in the edge + SELECT CASE + WHEN e.to_node_id = n.id + THEN e.from_node_id + ELSE e.to_node_id + END + FROM n + JOIN ( + SELECT * + FROM core_networkedge ne + WHERE + ne.network_id = %(network_id)s AND + NOT ( + ne.from_node_id = ANY(%(excluded_nodes)s) OR + ne.to_node_id = ANY(%(excluded_nodes)s) + ) + ) e + ON + e.from_node_id = n.id OR + e.to_node_id = n.id +) +SELECT id FROM n ORDER BY id +; +""" + + +def find_network_gcc(network: Network, excluded_nodes: list[int]) -> list[int]: + total_nodes = NetworkNode.objects.filter(network=network).count() + + # This is used to store all the nodes we've already visited, + # starting with the explicitly excluded nodes + cur_excluded_nodes = excluded_nodes.copy() + + # Store largest network found so far + gcc = [] + + with connection.cursor() as cursor: + # If the GCC size is greater than half the network, we know that there's no way to find a + # larger one. If we've exhausted all nodes, also stop searching + while not (len(gcc) > (total_nodes // 2) or len(cur_excluded_nodes) >= total_nodes): + cursor.execute( + GCC_QUERY, + { + 'excluded_nodes': cur_excluded_nodes, + 'network_id': network.pk, + }, + ) + nodes = [x[0] for x in cursor.fetchall()] + if not nodes: + raise Exception('Expected to find nodes but found none') + + cur_excluded_nodes.extend(nodes) + if len(nodes) > len(gcc): + gcc = nodes + + return gcc + class DatasetViewSet(ModelViewSet): queryset = Dataset.objects.all() @@ -82,16 +153,14 @@ def gcc(self, request, **kwargs): exclude_nodes = exclude_nodes.split(',') exclude_nodes = [int(n) for n in exclude_nodes if len(n)] - # TODO: improve this for datasets with multiple networks; - # this currently returns the gcc for the network with the most excluded nodes - results = [] + # Find the GCC for each network in the dataset + network_gccs: list[list[int]] = [] for network in dataset.networks.all(): - excluded_node_names = [n.name for n in network.nodes.all() if n.id in exclude_nodes] - gcc = network.get_gcc(exclude_nodes) - results.append(dict(excluded=excluded_node_names, gcc=gcc)) - if len(results): - results.sort(key=lambda r: len(r.get('excluded')), reverse=True) - gcc = results[0].get('gcc') - excluded = results[0].get('excluded') - add_gcc_chart_datum(dataset, project_id, excluded, len(gcc)) - return HttpResponse(json.dumps(gcc), status=200) + network_gccs.append(find_network_gcc(network=network, excluded_nodes=exclude_nodes)) + + # TODO: improve this for datasets with multiple networks. + # This currently returns the gcc for the network with the most excluded nodes + gcc = max(network_gccs, key=len) + + add_gcc_chart_datum(dataset, project_id, exclude_nodes, len(gcc)) + return Response(gcc, status=200)