diff --git a/apollo/submissions/filters.py b/apollo/submissions/filters.py index ebccea04a..0b5be0ba6 100644 --- a/apollo/submissions/filters.py +++ b/apollo/submissions/filters.py @@ -58,43 +58,40 @@ def __init__(self, *args, **kwargs): self.participant_set_id = participant_set_id self.filter_on_locations = filter_on_locations - kwargs['choices'] = _make_choices(sample_choices, _('Sample')) + kwargs["choices"] = _make_choices(sample_choices, _("Sample")) super().__init__(*args, **kwargs) def filter_by_locations(self, query, value): + # Join with Location model if not already present if utils.has_model(query, models.Location): query1 = query else: query1 = query.join(models.Submission.location) - sample_locations = models.Participant.query.filter_by( - participant_set_id=participant_set_id - ).join( - models.Participant.samples - ).filter( - models.Sample.participant_set_id == participant_set_id, - models.Sample.id == value - ).with_entities( - models.Participant.location_id + locations_subquery = ( + models.Sample.query.filter( + models.Sample.id == value, models.Sample.participant_set_id == participant_set_id + ) + .join(models.Participant.samples) + .with_entities(models.Participant.location_id.label("loc_id")) + .subquery() ) + query2 = query1.join(locations_subquery, models.Submission.location_id == locations_subquery.c.loc_id) - query2 = query1.filter( - models.Submission.location_id.in_(sample_locations) - ) return query2 def filter_by_participants(self, query, value): - participants_in_sample = models.Participant.query.join( - models.Participant.samples - ).filter( - models.Participant.participant_set_id == participant_set_id, - models.Sample.id == value + participants_subquery = ( + models.Sample.query.filter( + models.Sample.id == value, models.Sample.participant_set_id == participant_set_id + ) + .join(models.Participant.samples) + .with_entities(models.Participant.id.label("part_id")) + .subquery() + ) + query2 = query.join( + participants_subquery, models.Submission.participant_id == participants_subquery.c.part_id ) - participant_ids = list( - chain(*participants_in_sample.with_entities( - models.Participant.id).all())) - query2 = query.filter( - models.Submission.participant_id.in_(participant_ids)) return query2