diff --git a/tests/test_schema.py b/tests/test_schema.py index ccdf31b..27fc192 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -4,30 +4,62 @@ import pyarrow as pa import pytest -from meds import patient_schema, label, dataset_metadata +from meds import ( + data, label, dataset_metadata, patient_split, code_metadata, train_split, tuning_split, held_out_split +) - -def test_patient_schema(): +def test_data_schema(): """ - Test that mock patient data follows the patient_schema schema. + Test that mock data follows the data schema. """ # Each element in the list is a row in the table - patient_data = [ + data = [ { "patient_id": 123, - "events": [{ # Nested list for events - "time": datetime.datetime(2020, 1, 1, 12, 0, 0), - "code": "some_code", - "text_value": "Example", - "numeric_value": 10.0, - "datetime_value": datetime.datetime(2020, 1, 1, 12, 0, 0), - "properties": None - }] + "time": datetime.datetime(2020, 1, 1, 12, 0, 0), + "code": "some_code", + "text_value": "Example", + "numeric_value": 10.0, } ] - patient_table = pa.Table.from_pylist(patient_data, schema=patient_schema()) - assert patient_table.schema.equals(patient_schema()), "Patient schema does not match" + schema = data([("text_value", pa.string())]) + + table = pa.Table.from_pylist(data, schema=schema) + assert table.schema.equals(schema), "Patient schema does not match" + +def test_code_metadata_schema(): + """ + Test that mock code metadata follows the schema. + """ + # Each element in the list is a row in the table + data = [ + { + "code": "some_code", + "description": "foo", + "parent_code": ["parent_code"], + } + ] + + schema = code_metadata() + + table = pa.Table.from_pylist(data, schema=schema) + assert table.schema.equals(schema), "Code metadata schema does not match" + +def test_patient_split_schema(): + """ + Test that mock data follows the data schema. + """ + # Each element in the list is a row in the table + data = [ + {"patient_id": 123, "split": train_split}, + {"patient_id": 123, "split": tuning_split}, + {"patient_id": 123, "split": held_out_split}, + {"patient_id": 123, "split": "special"}, + ] + + table = pa.Table.from_pylist(data, schema=patient_split) + assert table.schema.equals(patient_split), "Patient split schema does not match" def test_label_schema(): """ @@ -83,12 +115,6 @@ def test_dataset_metadata_schema(): "dataset_version": "1.0", "etl_name": "Test ETL", "etl_version": "1.0", - "code_metadata": { - "test_code": { - "description": "A test code", - "standard_ontology_codes": ["12345"], - } - }, } jsonschema.validate(instance=metadata, schema=dataset_metadata)