Skip to content

Commit

Permalink
Merge pull request #2256 from opensafely-core/evansd/revert-in-memory
Browse files Browse the repository at this point in the history
Revert buggy in-memory engine change
  • Loading branch information
evansd authored Nov 28, 2024
2 parents e45c221 + b748589 commit 95197de
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 13 deletions.
26 changes: 14 additions & 12 deletions ehrql/query_engines/in_memory_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def sort(self, sort_index):
def pick_at_index(self, ix):
return PatientTable(
{
name: col.pick_at_index(ix, name == "patient_id")
name: col.pick_at_index(ix)
for name, col in self.name_to_col.items()
if name != "row_id"
}
Expand Down Expand Up @@ -351,14 +351,19 @@ def sort(self, sort_index):
{p: rows.sort(sort_index[p]) for p, rows in self.patient_to_rows.items()}
)

def pick_at_index(self, ix, is_patient_id=False):
if is_patient_id:
# The patient_id column is special, and should always be a mapping
# from an id to itself. Rows.pick_at_index will return None if a
# patient has no rows in the column.
return PatientColumn({p: p for p in self.patient_to_rows})
def pick_at_index(self, ix):
# It is arguable that for a patient with no rows (which would occur if
# this EventColumn was derived by filtering another EventColumn), the
# patient should be present in the new PatientColumn, with value None.
#
# However, we have decided to instead omit the patient from the new
# PatientColumn.
return PatientColumn(
{p: rows.pick_at_index(ix) for p, rows in self.patient_to_rows.items()}
{
p: rows.pick_at_index(ix)
for p, rows in self.patient_to_rows.items()
if rows
}
)


Expand Down Expand Up @@ -430,10 +435,7 @@ def sort(self, sort_index):
def pick_at_index(self, ix):
"""Return element at given position."""

try:
k = list(self)[ix]
except IndexError:
return None
k = list(self)[ix]
return self[k]


Expand Down
20 changes: 20 additions & 0 deletions tests/integration/test_query_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,23 @@ def test_population_which_uses_combine_as_set_and_no_patient_frame(engine):
assert engine.extract_qm(variables) == [
{"patient_id": 1, "v": True},
]


def test_picking_row_doesnt_cause_filtered_rows_to_reappear(engine):
# Regression test for a bug we introduced in the in-memory engine
dataset = create_dataset()
dataset.define_population(events.exists_for_patient())

rows = events.where(events.i < 0).sort_by(events.i).first_for_patient()
dataset.has_row = rows.exists_for_patient()
dataset.row_count = rows.count_for_patient()

engine.populate(
{
events: [{"patient_id": 1, "i": 2}],
}
)

assert engine.extract(dataset) == [
{"patient_id": 1, "has_row": False, "row_count": 0},
]
1 change: 0 additions & 1 deletion tests/unit/query_engines/test_in_memory_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,6 @@ def test_event_table_filter_then_pick_at_index():
--+-----+-----
1 | 102 | 112
2 | 203 | 211
3 | |
""",
)

Expand Down
20 changes: 20 additions & 0 deletions tests/unit/test_quiz.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
clinical_events,
medications,
patients,
practice_registrations,
)


Expand Down Expand Up @@ -159,6 +160,25 @@ def test_check_answer_dataset_column_has_missing_patients(engine):
assert msg == "Incorrect `age` value for patient 1: expected 49, got 50 instead."


@pytest.mark.parametrize(
"order, message",
[
([0, 1], "Missing patient(s): 7."),
([1, 0], "Found extra patient(s): 7."),
],
)
def test_check_answer_patient_series_has_missing_or_extra_patients(
engine, order, message
):
series = [
practice_registrations.for_patient_on("2013-12-01").practice_pseudo_id,
practice_registrations.for_patient_on("2014-01-01").practice_pseudo_id,
]
answer, expected = (series[i] for i in order)
msg = quiz.check_answer(engine=engine, answer=answer, expected=expected)
assert msg == message


def test_check_answer_patient_series_has_incorrect_value(engine):
msg = quiz.check_answer(
engine=engine,
Expand Down

0 comments on commit 95197de

Please sign in to comment.