Skip to content

Commit

Permalink
Improve distribution of how we generate dummy event tables
Browse files Browse the repository at this point in the history
  • Loading branch information
DRMacIver committed Nov 28, 2024
1 parent 808f2d3 commit 715cb41
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 3 deletions.
37 changes: 35 additions & 2 deletions ehrql/dummy_data_nextgen/generator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import functools
import itertools
import logging
import math
import random
import string
import time
from bisect import bisect_left
from collections import defaultdict
from datetime import date, timedelta

from ehrql.dummy_data_nextgen.query_info import QueryInfo, filter_values
Expand Down Expand Up @@ -86,6 +88,7 @@ def get_data(self):
generated += len(patient_batch)
database.populate(merge_table_data(*patient_batch.values()))
results = engine.get_results(population_query)
valid_patient_ids = set()
# Accumulate all data from matching patients, returning once we have enough
for row in results:
# Because of the existence of InlinePatientTables it's possible to get
Expand All @@ -94,6 +97,8 @@ def get_data(self):
if row.patient_id not in patient_batch:
continue

valid_patient_ids.add(row.patient_id)

extend_table_data(
data,
patient_batch[row.patient_id],
Expand All @@ -105,6 +110,25 @@ def get_data(self):
if found >= self.population_size:
break

if generator.required_tables is None and valid_patient_ids:
forbidden_tables = set(database.tables)
assert generator.forbidden_tables is None
tables_by_id = defaultdict(set)
for table, rows in database.tables.items():
for row in rows.to_records():
tables_by_id[row["patient_id"]].add(table)
required_tables = None
for patient_id in valid_patient_ids:
tables = tables_by_id[patient_id]
forbidden_tables -= tables
if required_tables is None:
required_tables = set(tables)
else:
required_tables &= tables
assert required_tables is not None
generator.required_tables = frozenset(required_tables)
generator.forbidden_tables = frozenset(forbidden_tables)

if found >= self.population_size:
return data

Expand Down Expand Up @@ -147,6 +171,8 @@ def __init__(self, variable_definitions, random_seed, today, population_size):

self.__column_values = {}
self.__reset_event_range()
self.required_tables = None
self.forbidden_tables = None

def get_patient_data_for_population_condition(self, patient_id):
# Generate data for just those tables needed for determining whether the patient
Expand Down Expand Up @@ -265,8 +291,15 @@ def rows_for_practice_registrations(self, table_info):

def empty_rows(self, table_info):
# Generate a small handful of events for event-level tables
max_rows = 1 if table_info.has_one_row_per_patient else 16
row_count = self.rnd.randrange(max_rows + 1)
if self.forbidden_tables and table_info.name in self.forbidden_tables:
return []
if table_info.has_one_row_per_patient:
row_count = self.rnd.randint(0, 1)
else:
# Geometric distribution with parameter 0.25. Will 3 events per patient.
row_count = math.floor(math.log(self.rnd.random()) / math.log(1 - 0.25))
if self.required_tables and table_info.name in self.required_tables:
row_count += 1
return [{} for _ in range(row_count)]

def populate_row(self, table_info, row):
Expand Down
63 changes: 62 additions & 1 deletion tests/unit/dummy_data_nextgen/test_specific_datasets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import operator
from collections import Counter
from datetime import date
from unittest import mock

Expand All @@ -18,7 +19,11 @@
table,
table_from_rows,
)
from ehrql.tables.core import clinical_events, medications, patients
from ehrql.tables.core import (
clinical_events,
medications,
patients,
)


index_date = date(2022, 3, 1)
Expand Down Expand Up @@ -81,6 +86,38 @@ def test_queries_with_exact_one_shot_generation(patched_time, query):
assert len(patient_ids) == target_size


@mock.patch("ehrql.dummy_data_nextgen.generator.time")
@pytest.mark.parametrize(
"query",
[
clinical_events.exists_for_patient(),
~clinical_events.exists_for_patient(),
],
ids=pretty,
)
def test_queries_with_exact_two_shot_generation(patched_time, query):
"""For queries which we can't guarantee correct from the start
but we can reliably figure out enough in the first batch of results
that the second one is complete."""
dataset = create_dataset()

dataset.define_population(patients.exists_for_patient() & query)

target_size = 1000

variable_definitions = compile(dataset)
generator = DummyDataGenerator(variable_definitions, population_size=target_size)
generator.batch_size = target_size
generator.timeout = 10

# Configure `time.time()` so we timeout after one loop pass, as we
# should be able to generate these correctly in the first pass.
patched_time.time.side_effect = [0.0, 1.0, 20.0]
patient_ids = {row.patient_id for row in generator.get_results()}

assert len(patient_ids) == target_size


@st.composite
def birthday_range_query(draw):
# We generate a single date that we require to be valid for
Expand Down Expand Up @@ -315,3 +352,27 @@ def test_generates_events_starting_from_birthdate():

for row in generator.get_results():
assert row.after_dob


def test_distribution_of_booleans():
"""For queries which we can't guarantee correct from the start
but we can reliably figure out enough in the first batch of results
that the second one is complete."""
dataset = create_dataset()

dataset.has_the_thing = clinical_events.where(
clinical_events.snomedct_code == "123456789"
).exists_for_patient()

dataset.define_population(patients.exists_for_patient())

target_size = 1000

variable_definitions = compile(dataset)
generator = DummyDataGenerator(variable_definitions, population_size=target_size)
generator.batch_size = target_size

property_counts = Counter(row.has_the_thing for row in generator.get_results())

assert property_counts[False] + property_counts[True] == target_size
assert 0.2 < property_counts[True] / target_size < 0.8

0 comments on commit 715cb41

Please sign in to comment.