Skip to content

Commit

Permalink
Implement GCC algorithm in native postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
jjnesbitt committed Sep 25, 2024
1 parent ad426e9 commit bcf5cc9
Showing 1 changed file with 82 additions and 12 deletions.
94 changes: 82 additions & 12 deletions uvdat/core/rest/dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,86 @@
import json

from django.db import connection
from django.http import HttpResponse
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet

from uvdat.core.models import Dataset, NetworkEdge, NetworkNode
from uvdat.core.models.networks import Network
from uvdat.core.rest import serializers as uvdat_serializers
from uvdat.core.tasks.chart import add_gcc_chart_datum

GCC_QUERY = """
WITH RECURSIVE n as (
-- starting node
SELECT id FROM (
SELECT cnn.id
FROM core_networknode cnn
WHERE
cnn.network_id = %(network_id)s AND
NOT (cnn.id = ANY(%(excluded_nodes)s))
ORDER BY random()
LIMIT 1
) nn
UNION
-- Select the *other* node in the edge
SELECT CASE
WHEN e.to_node_id = n.id
THEN e.from_node_id
ELSE e.to_node_id
END
FROM n
JOIN (
SELECT *
FROM core_networkedge ne
WHERE
ne.network_id = %(network_id)s AND
NOT (
ne.from_node_id = ANY(%(excluded_nodes)s) OR
ne.to_node_id = ANY(%(excluded_nodes)s)
)
) e
ON
e.from_node_id = n.id OR
e.to_node_id = n.id
)
SELECT id FROM n ORDER BY id
;
"""


def find_network_gcc(network: Network, excluded_nodes: list[int]) -> list[int]:
total_nodes = NetworkNode.objects.filter(network=network).count()

# This is used to store all the nodes we've already visited,
# starting with the explicitly excluded nodes
cur_excluded_nodes = excluded_nodes.copy()

# Store largest network found so far
gcc = []

with connection.cursor() as cursor:
# If the GCC size is greater than half the network, we know that there's no way to find a larger one
# If we've exhausted all nodes, also stop searching
while not (len(gcc) > (total_nodes // 2) or len(cur_excluded_nodes) >= total_nodes):
cursor.execute(
GCC_QUERY,
{
'excluded_nodes': cur_excluded_nodes,
'network_id': network.pk,
},
)
nodes = [x[0] for x in cursor.fetchall()]
if not nodes:
raise Exception('Expected to find nodes but found none')

cur_excluded_nodes.extend(nodes)
if len(nodes) > len(gcc):
gcc = nodes

return gcc


class DatasetViewSet(ModelViewSet):
serializer_class = uvdat_serializers.DatasetSerializer
Expand Down Expand Up @@ -70,16 +142,14 @@ def gcc(self, request, **kwargs):
exclude_nodes = exclude_nodes.split(',')
exclude_nodes = [int(n) for n in exclude_nodes if len(n)]

# TODO: improve this for datasets with multiple networks;
# this currently returns the gcc for the network with the most excluded nodes
results = []
# Find the GCC for each network in the dataset
network_gccs: list[list[int]] = []
for network in dataset.networks.all():
excluded_node_names = [n.name for n in network.nodes.all() if n.id in exclude_nodes]
gcc = network.get_gcc(exclude_nodes)
results.append(dict(excluded=excluded_node_names, gcc=gcc))
if len(results):
results.sort(key=lambda r: len(r.get('excluded')), reverse=True)
gcc = results[0].get('gcc')
excluded = results[0].get('excluded')
add_gcc_chart_datum(dataset, context_id, excluded, len(gcc))
return HttpResponse(json.dumps(gcc), status=200)
network_gccs.append(find_network_gcc(network=network, excluded_nodes=exclude_nodes))

# TODO: improve this for datasets with multiple networks.
# This currently returns the gcc for the network with the most excluded nodes
gcc = max(network_gccs, key=len)

add_gcc_chart_datum(dataset, context_id, exclude_nodes, len(gcc))
return Response(gcc, status=200)

0 comments on commit bcf5cc9

Please sign in to comment.