Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2714 required fields #2776

Merged
merged 4 commits into from
Jul 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
import dataclasses

from data_service.model.case import Case, make_custom_case_class
from datetime import date
from typing import Optional, Union

from data_service.model.case import add_field_to_case_class, observe_case_class
from data_service.model.field import Field
from data_service.util.errors import ConflictError, PreconditionUnsatisfiedError
from data_service.util.errors import PreconditionUnsatisfiedError

Case = None


def case_class_observer(cls: type):
global Case
Case = cls


class SchemaController:
Expand All @@ -11,14 +21,29 @@ class SchemaController:
def __init__(self, store):
self.store = store
self.restore_saved_fields()
observe_case_class(case_class_observer)

def restore_saved_fields(self) -> None:
"""Find previously-created fields in the store and apply them."""
for field in self.store.get_case_fields():
self.add_field(field.key, field.type, field.data_dictionary_text, False)
fields = self.store.get_case_fields()
for field in fields:
self.add_field(
field.key,
field.type,
field.data_dictionary_text,
field.required,
field.default,
False,
)

def add_field(
self, name: str, type_name: str, description: str, store_field: bool = True
self,
name: str,
type_name: str,
description: str,
required: bool = False,
default: Optional[Union[bool, str, int, date]] = None,
store_field: bool = True,
):
global Case
"""Add a field of the specified type to the Case class. There cannot
Expand All @@ -31,21 +56,12 @@ def add_field(
The description will be used in the data dictionary.

Fields will by default be added to the store. Set store_field to False if this is not
necessary, for example if the field to be added is coming from the store."""
existing_fields = dataclasses.fields(Case)
if name in [f.name for f in existing_fields]:
raise ConflictError(f"field {name} already exists")
if type_name not in Field.acceptable_types:
raise PreconditionUnsatisfiedError(
f"cannot use {type_name} as the type of a field"
)
type = Field.model_type(type_name)
fields_list = [(f.name, f.type, f) for f in existing_fields]
fields_list.append((name, type, dataclasses.field(init=False, default=None)))
# re-invent the Case class
Case = make_custom_case_class("Case", fields_list)
necessary, for example if the field to be added is coming from the store.

If a field is required, set required = True. You must also set a default value so that
existing cases have an initial setting for the field."""
required = required if required is not None else False
field_model = Field(name, type_name, description, required, default)
add_field_to_case_class(field_model)
if store_field:
# create a storable model of the field and store it
# FIXME rewrite the validation logic above to use the data model
field_model = Field(name, type_name, description)
self.store.add_field(field_model)
8 changes: 7 additions & 1 deletion data-serving/reusable-data-service/data_service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,13 @@ def excluded_case_ids():
def add_field_to_case_schema():
try:
req = request.get_json()
schema_controller.add_field(req["name"], req["type"], req["description"])
schema_controller.add_field(
req["name"],
req["type"],
req["description"],
req.get("required"),
req.get("default"),
)
return "", 201
except WebApplicationError as e:
return jsonify({"message": e.args[0]}), e.http_code
Expand Down
40 changes: 37 additions & 3 deletions data-serving/reusable-data-service/data_service/model/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
import flask.json

from collections.abc import Callable
from operator import attrgetter
from typing import Any, List

from data_service.model.case_exclusion_metadata import CaseExclusionMetadata
from data_service.model.case_reference import CaseReference
from data_service.model.document import Document
from data_service.model.field import Field
from data_service.util.errors import (
ConflictError,
DependencyFailedError,
PreconditionUnsatisfiedError,
ValidationError,
Expand Down Expand Up @@ -38,6 +41,8 @@ class DayZeroCase(Document):
caseReference: CaseReference = dataclasses.field(init=False, default=None)
caseExclusion: CaseExclusionMetadata = dataclasses.field(init=False, default=None)

custom_fields = []

@classmethod
def from_json(cls, obj: str) -> type:
"""Create an instance of this class from a JSON representation."""
Expand Down Expand Up @@ -86,6 +91,10 @@ def validate(self):
elif self.caseReference is None:
raise ValidationError("Case Reference must have a value")
self.caseReference.validate()
print(f"validating custom fields {self.custom_fields}")
for field in self.custom_fields:
if field.required is True and attrgetter(field.key)(self) is None:
raise ValidationError(f"{field.key} must have a value")


observers = []
Expand All @@ -94,13 +103,18 @@ def validate(self):
# so Case is the class that you should use.


def make_custom_case_class(name: str, fields=[]) -> type:
"""Generate a class extending the DayZeroCase class with additional fields."""
def make_custom_case_class(name: str, fields=[], field_models=[]) -> type:
"""Generate a class extending the DayZeroCase class with additional fields.
fields is a list of dataclass fields that should be added to the generated class.
field_models is a list of model objects describing the fields for the data dictionary
and for validation."""
# FIXME generate the fields list from the field_models
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my plan for this is when the day zero schema is committed, to create a list of Field objects that contain the day zero fields and get played into the Case class; then the logic for make_custom_case_class is simply to take a list of all known Fields and make data class fields for them. At that point I think the DayZeroCase intermediate class goes away.

global Case
try:
new_case_class = dataclasses.make_dataclass(name, fields, bases=(DayZeroCase,))
except TypeError as e:
raise DependencyFailedError(*(e.args))
new_case_class.custom_fields = field_models
for observer in observers:
observer(new_case_class)
# also store it locally so anyone who does import Case from here gets the new one from now on
Expand Down Expand Up @@ -137,7 +151,27 @@ def remove_case_class_observer(observer: Callable[[type], None]) -> None:
def reset_custom_case_fields() -> None:
"""When you want to get back to where you started, for example to load the field definitions from
storage or if you're writing tests that modify the Case class."""
make_custom_case_class("Case")
make_custom_case_class("Case", [], [])


def add_field_to_case_class(field_model: Field) -> None:
existing_fields = dataclasses.fields(Case)
field_models = Case.custom_fields
if field_model.key in [f.name for f in existing_fields]:
raise ConflictError(f"field {field_model.key} already exists")
if field_model.type not in Field.acceptable_types:
raise PreconditionUnsatisfiedError(
f"cannot use {field_model.type} as the type of a field"
)
if field_model.required is True and field_model.default is None:
raise PreconditionUnsatisfiedError(
f"field {field_model.key} is required so it must have a default value"
)
fields_list = [(f.name, f.type, f) for f in existing_fields]
fields_list.append(field_model.dataclasses_tuple())
field_models.append(field_model)
# re-invent the Case class
make_custom_case_class("Case", fields_list, field_models)


# let's start with a clean slate on first load
Expand Down
23 changes: 22 additions & 1 deletion data-serving/reusable-data-service/data_service/model/field.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
from datetime import date
from typing import Optional, Union

from data_service.model.document import Document
from data_service.util.errors import PreconditionUnsatisfiedError
Expand All @@ -12,9 +13,14 @@ class Field(Document):
key: str = dataclasses.field(init=True, default=None)
type: str = dataclasses.field(init=True, default=None)
data_dictionary_text: str = dataclasses.field(init=True, default=None)
required: bool = dataclasses.field(init=True, default=False)
default: Optional[Union[bool, str, int, date]] = dataclasses.field(
init=True, default=None
)
STRING = "string"
DATE = "date"
type_map = {STRING: str, DATE: date}
INTEGER = "integer"
type_map = {STRING: str, DATE: date, INTEGER: int}
acceptable_types = type_map.keys()

@classmethod
Expand All @@ -23,3 +29,18 @@ def model_type(cls, name: str) -> type:
return cls.type_map[name]
except KeyError:
raise PreconditionUnsatisfiedError(f"cannot use type {name} in a Field")

def python_type(self) -> type:
return self.model_type(self.type)

def dataclasses_tuple(self) -> (str, type, dataclasses.Field):
# Note that the default value here is always None, even if I have a default value!
# That's because the meaning of "required" in a field model is "a user _is required_ to
# supply a value" and the meaning of "default" is "for cases that don't already have this
# key, use the default value"; if I give every Case the default value then there's no sense
# in which a user is required to define it themselves.
return (
self.key,
self.python_type(),
dataclasses.field(init=False, default=None),
)
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ def update_case(self, id: str, update: DocumentUpdate):
command = self.mongodb_update_command(update)
self.get_case_collection().update_one({"_id": ObjectId(id)}, command)

def update_cases(self, filter: Filter, update: DocumentUpdate):
if len(update) == 0:
return
command = self.mongodb_update_command(update)
self.get_case_collection().update_many(filter.to_mongo_query(), command)

def batch_update(self, updates: dict[str, DocumentUpdate]):
mongo_commands = {
ObjectId(k): self.mongodb_update_command(v)
Expand Down Expand Up @@ -236,10 +242,19 @@ def case_exclusion_to_bson_compatible_dict(exclusion: CaseExclusionMetadata):

def add_field(self, field: Field):
self.get_schema_collection().insert_one(field.to_dict())
if field.required is True:
update = DocumentUpdate.from_dict({field.key: field.default})
self.update_cases(Anything(), update)

def get_case_fields(self):
return [
Field(doc["key"], doc["type"], doc["data_dictionary_text"])
Field(
doc["key"],
doc["type"],
doc["data_dictionary_text"],
doc["required"],
doc["default"],
)
for doc in self.get_schema_collection().find({})
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

from data_service import app
from data_service.main import set_up_controllers
from data_service.model.case import reset_custom_case_fields


@pytest.fixture
def client_with_patched_mongo(monkeypatch):
reset_custom_case_fields()
# configure controllers
monkeypatch.setenv("DATA_STORAGE_BACKEND", "mongodb")
monkeypatch.setenv(
Expand All @@ -26,3 +28,4 @@ def fake_mongo(connection_string):
app.config["TESTING"] = True
client = app.test_client()
yield client
reset_custom_case_fields()
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,61 @@ def test_adding_field_then_downloading_csv(client_with_patched_mongo):
assert len(cases) == 1
case = cases[0]
assert case["someField"] == "well, what have we here"


def test_required_field_default_value_spread_to_existing_cases(
client_with_patched_mongo,
):
response = client_with_patched_mongo.post(
"/api/cases",
json={
"confirmationDate": "2022-06-01T00:00:00.000Z",
"caseReference": {
"status": "UNVERIFIED",
"sourceId": "24680135792468013579fedc",
},
},
)
assert response.status_code == 201
response = client_with_patched_mongo.post(
"/api/schema",
json={
"name": "requiredField",
"type": "string",
"description": "You must supply a value for this",
"default": "PENDING",
"required": True,
},
)
assert response.status_code == 201
response = client_with_patched_mongo.get("/api/cases")
assert response.status_code == 200
case_list = response.get_json()
assert case_list["total"] == 1
assert len(case_list["cases"]) == 1
assert case_list["cases"][0]["requiredField"] == "PENDING"


def test_required_field_becomes_required_in_validation(client_with_patched_mongo):
response = client_with_patched_mongo.post(
"/api/schema",
json={
"name": "importantInformation",
"type": "string",
"description": "You must supply a value for this",
"default": "PENDING",
"required": True,
},
)
assert response.status_code == 201
response = client_with_patched_mongo.post(
"/api/cases",
json={
"confirmationDate": "2022-06-01T00:00:00.000Z",
"caseReference": {
"status": "UNVERIFIED",
"sourceId": "24680135792468013579fedc",
},
},
)
assert response.status_code == 422
17 changes: 16 additions & 1 deletion data-serving/reusable-data-service/tests/test_mongo_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,27 @@
from bson import ObjectId
from datetime import date

from data_service.model.case import Case
from data_service.model.case import (
observe_case_class,
remove_case_class_observer,
reset_custom_case_fields,
)
from data_service.model.case_reference import CaseReference
from data_service.model.filter import Anything
from data_service.stores.mongo_store import MongoStore

Case = None


def case_observer(cls):
global Case
Case = cls


@pytest.fixture
def mongo_store(monkeypatch):
reset_custom_case_fields()
observe_case_class(case_observer)
db = mongomock.MongoClient()

def fake_mongo(connection_string):
Expand All @@ -22,6 +35,8 @@ def fake_mongo(connection_string):
"mongodb://localhost:27017/outbreak", "outbreak", "cases", "schema"
)
yield store
remove_case_class_observer(case_observer)
reset_custom_case_fields()


"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,10 @@ def test_stored_field_gets_added(observing_case_changes):
controller = SchemaController(store_with_preexisting_field)
assert "outcome_date" in Case.field_names()
assert Case.field_type("outcome_date") is datetime.date


def test_required_field_must_have_default_value(schema_controller):
with pytest.raises(PreconditionUnsatisfiedError):
schema_controller.add_field(
"countSomething", Field.INTEGER, "Some Number", required=True
)