Skip to content

Commit

Permalink
Handle invalid entries in custom plots view.
Browse files Browse the repository at this point in the history
  • Loading branch information
njohner committed Apr 26, 2024
1 parent 2b7ff15 commit 47c3e27
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 27 deletions.
23 changes: 17 additions & 6 deletions webapp/chlamdb/forms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
from collections import namedtuple
from io import StringIO

from Bio import SeqIO
Expand All @@ -9,6 +10,7 @@
from django import forms
from django.core.exceptions import ValidationError
from django.core.validators import MaxLengthValidator, MinLengthValidator
from views.utils import EntryIdParser


def get_accessions(db, all=False, plasmid=False):
Expand Down Expand Up @@ -693,8 +695,11 @@ class CustomPlotsForm(forms.Form):
widget=forms.Textarea(attrs={'cols': 50, 'rows': 5}),
required=True, label="Entry IDs", help_text=help_text)

def __init__(self, *args, **kwargs):
Entry = namedtuple("Entry", "id label type")

def __init__(self, db, *args, **kwargs):
super().__init__(*args, **kwargs)
self.db = db
self.helper = FormHelper()

self.helper.form_method = 'post'
Expand All @@ -711,14 +716,20 @@ def __init__(self, *args, **kwargs):
)
)

def get_entries(self):
def clean_entries(self):
raw_entries = self.cleaned_data["entries"].split(",")
parser = EntryIdParser(self.db)
entries = []
entry2label = {}
for entry in raw_entries:
entry = entry.strip()
if ":" in entry:
entry, label = entry.split(":", 1)
entry2label[entry] = label
entries.append(entry)
return entries, entry2label
else:
label = entry
try:
object_type, entry_id = parser.id_to_object_type(entry)
except Exception:
raise ValidationError(f'Invalid identifier "{entry}".',
code="invalid")
entries.append(self.Entry(entry_id, label, object_type))
return entries
22 changes: 22 additions & 0 deletions webapp/lib/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2631,6 +2631,28 @@ def get_number_of_ko_entries(self):
query = "SELECT COUNT(*) FROM (SELECT DISTINCT ko_id FROM ko_hits)"
return self.server.adaptor.execute_and_fetchall(query)[0][0]

def check_entry_existence(self, entry_id, entry_col, table):
query = f'SELECT 1 FROM {table} WHERE {entry_col}="{entry_id}" LIMIT 1'
return bool(self.server.adaptor.execute_one(query))

def check_og_entry_id(self, entry_id):
return self.check_entry_existence(entry_id, "orthogroup", "og_hits")

def check_ko_entry_id(self, entry_id):
return self.check_entry_existence(entry_id, "ko_id", "ko_def")

def check_cog_entry_id(self, entry_id):
return self.check_entry_existence(entry_id, "cog_id", "cog_names")

def check_pfam_entry_id(self, entry_id):
return self.check_entry_existence(entry_id, "pfam_id", "pfam_table")

def check_vf_entry_id(self, entry_id):
return self.check_entry_existence(entry_id, "vf_gene_id", "vf_defs")

def check_amr_entry_id(self, entry_id):
return self.check_entry_existence(entry_id, "gene", "amr_hits")

def gen_placeholder_string(self, args):
return ",".join(self.placeholder for _ in args)

Expand Down
17 changes: 7 additions & 10 deletions webapp/views/custom_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
from lib.db_utils import DB
from lib.ete_phylo import EteTree, SimpleColorColumn

from views.errors import errors
from views.mixins import ComparisonViewMixin
from views.object_type_metadata import my_locals
from views.utils import EntryIdIdentifier, ResultTab, TabularResultTab
from views.utils import ResultTab, TabularResultTab


class CusomPlotsView(View):
Expand Down Expand Up @@ -52,27 +51,25 @@ def get_context(self, **kwargs):
return my_locals(context)

def get(self, request, *args, **kwargs):
self.form = self.form_class()
self.form = self.form_class(self.db)
return render(request, self.template, self.get_context())

def post(self, request, *args, **kwargs):
self.form = self.form_class(request.POST)
self.form = self.form_class(self.db, request.POST)
if not self.form.is_valid():
return render(request, self.template, self.get_context())

entries, entry2label = self.form.get_entries()
entries = self.form.cleaned_data["entries"]

# We make 1 query for each entry, although we could of course make
# a single query for each object type, but I don't expect any
# performance issues here, so I'd rather keep it simple (and maintain
# the order of the entries as defined by the user).
entry_id_identifier = EntryIdIdentifier()
counts = []
for entry in entries:
object_type, entry_id = entry_id_identifier.id_to_object_type(entry)
mixin = ComparisonViewMixin.type2mixin[object_type]
hits = mixin().get_hit_counts([entry_id], search_on=object_type)
hits = hits.rename({entry_id: entry2label.get(entry, entry)})
mixin = ComparisonViewMixin.type2mixin[entry.type]
hits = mixin().get_hit_counts([entry.id], search_on=entry.type)
hits = hits.rename({entry.id: entry.label})
counts.append(hits)

genome_descriptions = self.db.get_genomes_description()
Expand Down
33 changes: 22 additions & 11 deletions webapp/views/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,35 +283,46 @@ def __init__(self, tabid, title, template="chlamdb/result_table.html",
tabid, title, template, show_badge=show_badge, badge=badge, **kwargs)


class EntryIdIdentifier():
class EntryIdParser():

og_re = re.compile("group_([0-9]*)")
cog_re = re.compile("COG([0-9]{4})")
pfam_re = re.compile("PF([0-9]{4,5})")
ko_re = re.compile("K([0-9]{5})")
vf_re = re.compile("VFG[0-9]{6}")

def __init__(self, db):
self.db = db

def id_to_object_type(self, identifier):
match = self.og_re.match(identifier)
if match:
return "orthogroup", int(match.groups()[0])
parsed_id = match and int(match.groups()[0])
if parsed_id and self.db.check_orthogroup_entry_id(parsed_id):
return "orthogroup", parsed_id

match = self.cog_re.match(identifier)
if match:
return "cog", int(match.groups()[0])
parsed_id = match and int(match.groups()[0])
if parsed_id and self.db.check_cog_entry_id(parsed_id):
return "cog", parsed_id

match = self.pfam_re.match(identifier)
if match:
return "pfam", int(match.groups()[0])
parsed_id = match and int(match.groups()[0])
if parsed_id and self.db.check_pfam_entry_id(parsed_id):
return "pfam", parsed_id

match = self.ko_re.match(identifier)
if match:
return "ko", int(match.groups()[0])
parsed_id = match and int(match.groups()[0])
if parsed_id and self.db.check_ko_entry_id(parsed_id):
return "ko", parsed_id

match = self.vf_re.match(identifier)
if match:
if match and self.db.check_vf_entry_id(identifier):
return "vf", identifier
return "amr", identifier

if self.db.check_amr_entry_id(identifier):
return "amr", identifier

return None


def locusx_genomic_region(db, seqid, window):
Expand Down

0 comments on commit 47c3e27

Please sign in to comment.