diff --git a/data-serving/reusable-data-service/data_service/controller/schema_controller.py b/data-serving/reusable-data-service/data_service/controller/schema_controller.py index d46e3a40c..77e308865 100644 --- a/data-serving/reusable-data-service/data_service/controller/schema_controller.py +++ b/data-serving/reusable-data-service/data_service/controller/schema_controller.py @@ -1,8 +1,18 @@ import dataclasses -from data_service.model.case import Case, make_custom_case_class +from datetime import date +from typing import Optional, Union + +from data_service.model.case import add_field_to_case_class, observe_case_class from data_service.model.field import Field -from data_service.util.errors import ConflictError, PreconditionUnsatisfiedError +from data_service.util.errors import PreconditionUnsatisfiedError + +Case = None + + +def case_class_observer(cls: type): + global Case + Case = cls class SchemaController: @@ -11,14 +21,29 @@ class SchemaController: def __init__(self, store): self.store = store self.restore_saved_fields() + observe_case_class(case_class_observer) def restore_saved_fields(self) -> None: """Find previously-created fields in the store and apply them.""" - for field in self.store.get_case_fields(): - self.add_field(field.key, field.type, field.data_dictionary_text, False) + fields = self.store.get_case_fields() + for field in fields: + self.add_field( + field.key, + field.type, + field.data_dictionary_text, + field.required, + field.default, + False, + ) def add_field( - self, name: str, type_name: str, description: str, store_field: bool = True + self, + name: str, + type_name: str, + description: str, + required: bool = False, + default: Optional[Union[bool, str, int, date]] = None, + store_field: bool = True, ): global Case """Add a field of the specified type to the Case class. There cannot @@ -31,21 +56,12 @@ def add_field( The description will be used in the data dictionary. Fields will by default be added to the store. Set store_field to False if this is not - necessary, for example if the field to be added is coming from the store.""" - existing_fields = dataclasses.fields(Case) - if name in [f.name for f in existing_fields]: - raise ConflictError(f"field {name} already exists") - if type_name not in Field.acceptable_types: - raise PreconditionUnsatisfiedError( - f"cannot use {type_name} as the type of a field" - ) - type = Field.model_type(type_name) - fields_list = [(f.name, f.type, f) for f in existing_fields] - fields_list.append((name, type, dataclasses.field(init=False, default=None))) - # re-invent the Case class - Case = make_custom_case_class("Case", fields_list) + necessary, for example if the field to be added is coming from the store. + + If a field is required, set required = True. You must also set a default value so that + existing cases have an initial setting for the field.""" + required = required if required is not None else False + field_model = Field(name, type_name, description, required, default) + add_field_to_case_class(field_model) if store_field: - # create a storable model of the field and store it - # FIXME rewrite the validation logic above to use the data model - field_model = Field(name, type_name, description) self.store.add_field(field_model) diff --git a/data-serving/reusable-data-service/data_service/main.py b/data-serving/reusable-data-service/data_service/main.py index 3b1443285..d94070581 100644 --- a/data-serving/reusable-data-service/data_service/main.py +++ b/data-serving/reusable-data-service/data_service/main.py @@ -156,7 +156,13 @@ def excluded_case_ids(): def add_field_to_case_schema(): try: req = request.get_json() - schema_controller.add_field(req["name"], req["type"], req["description"]) + schema_controller.add_field( + req["name"], + req["type"], + req["description"], + req.get("required"), + req.get("default"), + ) return "", 201 except WebApplicationError as e: return jsonify({"message": e.args[0]}), e.http_code diff --git a/data-serving/reusable-data-service/data_service/model/case.py b/data-serving/reusable-data-service/data_service/model/case.py index 74f9d984e..4a9952cd0 100644 --- a/data-serving/reusable-data-service/data_service/model/case.py +++ b/data-serving/reusable-data-service/data_service/model/case.py @@ -4,12 +4,15 @@ import flask.json from collections.abc import Callable +from operator import attrgetter from typing import Any, List from data_service.model.case_exclusion_metadata import CaseExclusionMetadata from data_service.model.case_reference import CaseReference from data_service.model.document import Document +from data_service.model.field import Field from data_service.util.errors import ( + ConflictError, DependencyFailedError, PreconditionUnsatisfiedError, ValidationError, @@ -38,6 +41,8 @@ class DayZeroCase(Document): caseReference: CaseReference = dataclasses.field(init=False, default=None) caseExclusion: CaseExclusionMetadata = dataclasses.field(init=False, default=None) + custom_fields = [] + @classmethod def from_json(cls, obj: str) -> type: """Create an instance of this class from a JSON representation.""" @@ -86,6 +91,10 @@ def validate(self): elif self.caseReference is None: raise ValidationError("Case Reference must have a value") self.caseReference.validate() + print(f"validating custom fields {self.custom_fields}") + for field in self.custom_fields: + if field.required is True and attrgetter(field.key)(self) is None: + raise ValidationError(f"{field.key} must have a value") observers = [] @@ -94,13 +103,18 @@ def validate(self): # so Case is the class that you should use. -def make_custom_case_class(name: str, fields=[]) -> type: - """Generate a class extending the DayZeroCase class with additional fields.""" +def make_custom_case_class(name: str, fields=[], field_models=[]) -> type: + """Generate a class extending the DayZeroCase class with additional fields. + fields is a list of dataclass fields that should be added to the generated class. + field_models is a list of model objects describing the fields for the data dictionary + and for validation.""" + # FIXME generate the fields list from the field_models global Case try: new_case_class = dataclasses.make_dataclass(name, fields, bases=(DayZeroCase,)) except TypeError as e: raise DependencyFailedError(*(e.args)) + new_case_class.custom_fields = field_models for observer in observers: observer(new_case_class) # also store it locally so anyone who does import Case from here gets the new one from now on @@ -137,7 +151,27 @@ def remove_case_class_observer(observer: Callable[[type], None]) -> None: def reset_custom_case_fields() -> None: """When you want to get back to where you started, for example to load the field definitions from storage or if you're writing tests that modify the Case class.""" - make_custom_case_class("Case") + make_custom_case_class("Case", [], []) + + +def add_field_to_case_class(field_model: Field) -> None: + existing_fields = dataclasses.fields(Case) + field_models = Case.custom_fields + if field_model.key in [f.name for f in existing_fields]: + raise ConflictError(f"field {field_model.key} already exists") + if field_model.type not in Field.acceptable_types: + raise PreconditionUnsatisfiedError( + f"cannot use {field_model.type} as the type of a field" + ) + if field_model.required is True and field_model.default is None: + raise PreconditionUnsatisfiedError( + f"field {field_model.key} is required so it must have a default value" + ) + fields_list = [(f.name, f.type, f) for f in existing_fields] + fields_list.append(field_model.dataclasses_tuple()) + field_models.append(field_model) + # re-invent the Case class + make_custom_case_class("Case", fields_list, field_models) # let's start with a clean slate on first load diff --git a/data-serving/reusable-data-service/data_service/model/field.py b/data-serving/reusable-data-service/data_service/model/field.py index 86a457584..d5599cea8 100644 --- a/data-serving/reusable-data-service/data_service/model/field.py +++ b/data-serving/reusable-data-service/data_service/model/field.py @@ -1,5 +1,6 @@ import dataclasses from datetime import date +from typing import Optional, Union from data_service.model.document import Document from data_service.util.errors import PreconditionUnsatisfiedError @@ -12,9 +13,14 @@ class Field(Document): key: str = dataclasses.field(init=True, default=None) type: str = dataclasses.field(init=True, default=None) data_dictionary_text: str = dataclasses.field(init=True, default=None) + required: bool = dataclasses.field(init=True, default=False) + default: Optional[Union[bool, str, int, date]] = dataclasses.field( + init=True, default=None + ) STRING = "string" DATE = "date" - type_map = {STRING: str, DATE: date} + INTEGER = "integer" + type_map = {STRING: str, DATE: date, INTEGER: int} acceptable_types = type_map.keys() @classmethod @@ -23,3 +29,18 @@ def model_type(cls, name: str) -> type: return cls.type_map[name] except KeyError: raise PreconditionUnsatisfiedError(f"cannot use type {name} in a Field") + + def python_type(self) -> type: + return self.model_type(self.type) + + def dataclasses_tuple(self) -> (str, type, dataclasses.Field): + # Note that the default value here is always None, even if I have a default value! + # That's because the meaning of "required" in a field model is "a user _is required_ to + # supply a value" and the meaning of "default" is "for cases that don't already have this + # key, use the default value"; if I give every Case the default value then there's no sense + # in which a user is required to define it themselves. + return ( + self.key, + self.python_type(), + dataclasses.field(init=False, default=None), + ) diff --git a/data-serving/reusable-data-service/data_service/stores/mongo_store.py b/data-serving/reusable-data-service/data_service/stores/mongo_store.py index 7370021f3..70c65a50d 100644 --- a/data-serving/reusable-data-service/data_service/stores/mongo_store.py +++ b/data-serving/reusable-data-service/data_service/stores/mongo_store.py @@ -138,6 +138,12 @@ def update_case(self, id: str, update: DocumentUpdate): command = self.mongodb_update_command(update) self.get_case_collection().update_one({"_id": ObjectId(id)}, command) + def update_cases(self, filter: Filter, update: DocumentUpdate): + if len(update) == 0: + return + command = self.mongodb_update_command(update) + self.get_case_collection().update_many(filter.to_mongo_query(), command) + def batch_update(self, updates: dict[str, DocumentUpdate]): mongo_commands = { ObjectId(k): self.mongodb_update_command(v) @@ -236,10 +242,19 @@ def case_exclusion_to_bson_compatible_dict(exclusion: CaseExclusionMetadata): def add_field(self, field: Field): self.get_schema_collection().insert_one(field.to_dict()) + if field.required is True: + update = DocumentUpdate.from_dict({field.key: field.default}) + self.update_cases(Anything(), update) def get_case_fields(self): return [ - Field(doc["key"], doc["type"], doc["data_dictionary_text"]) + Field( + doc["key"], + doc["type"], + doc["data_dictionary_text"], + doc["required"], + doc["default"], + ) for doc in self.get_schema_collection().find({}) ] diff --git a/data-serving/reusable-data-service/tests/end_to_end_fixture.py b/data-serving/reusable-data-service/tests/end_to_end_fixture.py index 1d5a1d2ea..86c474f85 100644 --- a/data-serving/reusable-data-service/tests/end_to_end_fixture.py +++ b/data-serving/reusable-data-service/tests/end_to_end_fixture.py @@ -3,10 +3,12 @@ from data_service import app from data_service.main import set_up_controllers +from data_service.model.case import reset_custom_case_fields @pytest.fixture def client_with_patched_mongo(monkeypatch): + reset_custom_case_fields() # configure controllers monkeypatch.setenv("DATA_STORAGE_BACKEND", "mongodb") monkeypatch.setenv( @@ -26,3 +28,4 @@ def fake_mongo(connection_string): app.config["TESTING"] = True client = app.test_client() yield client + reset_custom_case_fields() diff --git a/data-serving/reusable-data-service/tests/test_case_schema_integration.py b/data-serving/reusable-data-service/tests/test_case_schema_integration.py index 6c60e525a..c8e25d89f 100644 --- a/data-serving/reusable-data-service/tests/test_case_schema_integration.py +++ b/data-serving/reusable-data-service/tests/test_case_schema_integration.py @@ -66,3 +66,61 @@ def test_adding_field_then_downloading_csv(client_with_patched_mongo): assert len(cases) == 1 case = cases[0] assert case["someField"] == "well, what have we here" + + +def test_required_field_default_value_spread_to_existing_cases( + client_with_patched_mongo, +): + response = client_with_patched_mongo.post( + "/api/cases", + json={ + "confirmationDate": "2022-06-01T00:00:00.000Z", + "caseReference": { + "status": "UNVERIFIED", + "sourceId": "24680135792468013579fedc", + }, + }, + ) + assert response.status_code == 201 + response = client_with_patched_mongo.post( + "/api/schema", + json={ + "name": "requiredField", + "type": "string", + "description": "You must supply a value for this", + "default": "PENDING", + "required": True, + }, + ) + assert response.status_code == 201 + response = client_with_patched_mongo.get("/api/cases") + assert response.status_code == 200 + case_list = response.get_json() + assert case_list["total"] == 1 + assert len(case_list["cases"]) == 1 + assert case_list["cases"][0]["requiredField"] == "PENDING" + + +def test_required_field_becomes_required_in_validation(client_with_patched_mongo): + response = client_with_patched_mongo.post( + "/api/schema", + json={ + "name": "importantInformation", + "type": "string", + "description": "You must supply a value for this", + "default": "PENDING", + "required": True, + }, + ) + assert response.status_code == 201 + response = client_with_patched_mongo.post( + "/api/cases", + json={ + "confirmationDate": "2022-06-01T00:00:00.000Z", + "caseReference": { + "status": "UNVERIFIED", + "sourceId": "24680135792468013579fedc", + }, + }, + ) + assert response.status_code == 422 diff --git a/data-serving/reusable-data-service/tests/test_mongo_store.py b/data-serving/reusable-data-service/tests/test_mongo_store.py index 45dad6ecf..b8e7ac5d3 100644 --- a/data-serving/reusable-data-service/tests/test_mongo_store.py +++ b/data-serving/reusable-data-service/tests/test_mongo_store.py @@ -4,14 +4,27 @@ from bson import ObjectId from datetime import date -from data_service.model.case import Case +from data_service.model.case import ( + observe_case_class, + remove_case_class_observer, + reset_custom_case_fields, +) from data_service.model.case_reference import CaseReference from data_service.model.filter import Anything from data_service.stores.mongo_store import MongoStore +Case = None + + +def case_observer(cls): + global Case + Case = cls + @pytest.fixture def mongo_store(monkeypatch): + reset_custom_case_fields() + observe_case_class(case_observer) db = mongomock.MongoClient() def fake_mongo(connection_string): @@ -22,6 +35,8 @@ def fake_mongo(connection_string): "mongodb://localhost:27017/outbreak", "outbreak", "cases", "schema" ) yield store + remove_case_class_observer(case_observer) + reset_custom_case_fields() """ diff --git a/data-serving/reusable-data-service/tests/test_schema_controller_field_changes.py b/data-serving/reusable-data-service/tests/test_schema_controller_field_changes.py index 9b8f24177..d17089c3d 100644 --- a/data-serving/reusable-data-service/tests/test_schema_controller_field_changes.py +++ b/data-serving/reusable-data-service/tests/test_schema_controller_field_changes.py @@ -87,3 +87,10 @@ def test_stored_field_gets_added(observing_case_changes): controller = SchemaController(store_with_preexisting_field) assert "outcome_date" in Case.field_names() assert Case.field_type("outcome_date") is datetime.date + + +def test_required_field_must_have_default_value(schema_controller): + with pytest.raises(PreconditionUnsatisfiedError): + schema_controller.add_field( + "countSomething", Field.INTEGER, "Some Number", required=True + )