Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2714 custom fields #2757

Merged
merged 9 commits into from
Jul 18, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,9 @@ class to actually work with the cases collection, so that different
storage technology can be chosen.
All methods return a tuple of (response, HTTP status code)"""

def __init__(self, app, store, outbreak_date: date):
"""store is the flask app
store is an adapter to the external storage technology.
def __init__(self, store, outbreak_date: date):
"""store is an adapter to the external storage technology.
outbreak_date is the earliest date on which this instance should accept cases."""
self.app = app
self.store = store
self.outbreak_date = outbreak_date

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import dataclasses

from data_service.model.case import Case, make_custom_case_class
from data_service.model.field import Field
from data_service.util.errors import ConflictError, PreconditionUnsatisfiedError


class SchemaController:
"""Manipulate the fields on the Case class."""

def __init__(self, store):
self.store = store

def add_field(self, name: str, type_name: str, description: str):
global Case
"""Add a field of the specified type to the Case class. There cannot
already be a field of that name, either built in, as part of the
DayZeroCase schema, or added through this method previously.

Additionally dataclasses imposes other conditions (for example names
cannot be Python keywords).

The description will be used in the data dictionary."""
existing_fields = dataclasses.fields(Case)
if name in [f.name for f in existing_fields]:
raise ConflictError(f"field {name} already exists")
if type_name not in Field.acceptable_types:
raise PreconditionUnsatisfiedError(
f"cannot use {type_name} as the type of a field"
)
type = Field.model_type(type_name)
fields_list = [(f.name, f.type, f) for f in existing_fields]
fields_list.append((name, type, dataclasses.field(init=False, default=None)))
# re-invent the Case class
Case = make_custom_case_class("Case", fields_list)
# create a storable model of the field and store it
# FIXME rewrite the validation logic above to use the data model
field_model = Field(name, type_name, description)
self.store.add_field(field_model)
17 changes: 15 additions & 2 deletions data-serving/reusable-data-service/data_service/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import date
from flask import Flask, jsonify, request
from data_service.controller.case_controller import CaseController
from data_service.controller.schema_controller import SchemaController
from data_service.stores.mongo_store import MongoStore
from data_service.util.errors import (
PreconditionUnsatisfiedError,
Expand All @@ -17,6 +18,7 @@
app.json_encoder = JSONEncoder

case_controller = None # Will be set up in main()
schema_controller = None


@app.route("/api/cases/<id>", methods=["GET", "PUT", "DELETE"])
Expand Down Expand Up @@ -150,8 +152,18 @@ def excluded_case_ids():
return jsonify({"message": e.args[0]}), e.http_code


@app.route("/api/schema", methods=["POST"])
def add_field_to_case_schema():
try:
req = request.get_json()
schema_controller.add_field(req["name"], req["type"], req["description"])
return "", 201
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

except ConflictError returns 409, except PreconditionUnsatisfiedError returns 400,
and except KeyError returns 400?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so KeyError shouldn't happen: that will come from Field.model_type if you try to use a field type that isn't in the allowed list. So we turn it into PreconditionUnsatisfiedError which is the same as when we explicitly check that the type isn't found in the allowed list: i.e. both exception paths mean the same thing so they result in the same error (400 bad request: you've asked for something we aren't going to let you do).

ConflictError does indeed result in a 409 Conflict status.

except WebApplicationError as e:
return jsonify({"message": e.args[0]}), e.http_code


def set_up_controllers():
global case_controller
global case_controller, schema_controller
store_options = {"mongodb": MongoStore.setup}
if store_choice := os.environ.get("DATA_STORAGE_BACKEND"):
try:
Expand All @@ -162,7 +174,8 @@ def set_up_controllers():
outbreak_date = os.environ.get("OUTBREAK_DATE")
if outbreak_date is None:
raise ValueError("Define $OUTBREAK_DATE in the environment")
case_controller = CaseController(app, store, date.fromisoformat(outbreak_date))
case_controller = CaseController(store, date.fromisoformat(outbreak_date))
schema_controller = SchemaController(store)


def main():
Expand Down
55 changes: 54 additions & 1 deletion data-serving/reusable-data-service/data_service/model/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import json
import flask.json

from collections.abc import Callable
from typing import Any, List

from data_service.model.case_exclusion_metadata import CaseExclusionMetadata
from data_service.model.case_reference import CaseReference
from data_service.model.document import Document
from data_service.util.errors import (
DependencyFailedError,
PreconditionUnsatisfiedError,
ValidationError,
)
Expand Down Expand Up @@ -86,6 +88,57 @@ def validate(self):
self.caseReference.validate()


observers = []

# Actually we want to capture extra fields which can be specified dynamically:
# so Case is the class that you should use.
Case = dataclasses.make_dataclass("Case", fields=[], bases=(DayZeroCase,))


def make_custom_case_class(name: str, fields=[]) -> type:
"""Generate a class extending the DayZeroCase class with additional fields."""
global Case
try:
new_case_class = dataclasses.make_dataclass(name, fields, bases=(DayZeroCase,))
except TypeError as e:
raise DependencyFailedError(*(e.args))
for observer in observers:
observer(new_case_class)
# also store it locally so anyone who does import Case from here gets the new one from now on
Case = new_case_class
return new_case_class


def observe_case_class(observer: Callable[[type], None]) -> None:
"""When someone imports a class by name, they get a reference to that class object.
Unfortunately that means that when we recreate the Case class (e.g. because someone
calls make_custom_case_class) nobody finds out about that. They would if we modified
the existing Case class, but dataclasses doesn't provide for that. So provide a
mechanism for importers to discover that the class has been recreated. An implementation of
observer will probably look something like this:

def observer(new_case_class: type) -> None:
global Case
Case = new_case_class

But you could also do something more subtle (like rewrite the __class__ on instances of Case
you already have, or recreate a working set of Cases).

This function calls the observer so that clients can get the initial definition of Case without
also having to import that."""
observers.append(observer)
observer(Case)


def remove_case_class_observer(observer: Callable[[type], None]) -> None:
"""When you're done watching for changes to Case, call this."""
observers.remove(observer)


def reset_custom_case_fields() -> None:
"""When you want to get back to where you started, for example to load the field definitions from
storage or if you're writing tests that modify the Case class."""
make_custom_case_class("Case")


# let's start with a clean slate on first load
reset_custom_case_fields()
25 changes: 25 additions & 0 deletions data-serving/reusable-data-service/data_service/model/field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import dataclasses
from datetime import date

from data_service.model.document import Document
from data_service.util.errors import PreconditionUnsatisfiedError


@dataclasses.dataclass
class Field(Document):
"""Represents a custom field in a Document object."""

key: str = dataclasses.field(init=True, default=None)
type: str = dataclasses.field(init=True, default=None)
data_dictionary_text: str = dataclasses.field(init=True, default=None)
STRING = "string"
DATE = "date"
type_map = {STRING: str, DATE: date}
acceptable_types = type_map.keys()

@classmethod
def model_type(cls, name: str) -> type:
try:
return cls.type_map[name]
except KeyError:
raise PreconditionUnsatisfiedError(f"cannot use type {name} in a Field")
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ class PropertyFilter(Filter):
"""Represents a test that an object's property has a value that satisfies some constraint."""

def __init__(self, property_name: str, operation: str, value: Any):
valid_ops = [FilterOperator.LESS_THAN, FilterOperator.GREATER_THAN]
valid_ops = [
FilterOperator.LESS_THAN,
FilterOperator.GREATER_THAN,
FilterOperator.EQUAL,
]
if operation not in valid_ops:
raise ValueError(f"Unknown operation {operation}")
self.property_name = property_name
Expand All @@ -37,6 +41,7 @@ def __str__(self) -> str:
class FilterOperator:
LESS_THAN = "<"
GREATER_THAN = ">"
EQUAL = "="


class AndFilter(Filter):
Expand Down
128 changes: 128 additions & 0 deletions data-serving/reusable-data-service/data_service/stores/memory_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from functools import reduce
from operator import attrgetter, and_
from typing import List, Optional

from data_service.model.case import Case
from data_service.model.case_exclusion_metadata import CaseExclusionMetadata
from data_service.model.document_update import DocumentUpdate
from data_service.model.field import Field
from data_service.model.filter import (
Filter,
Anything,
PropertyFilter,
AndFilter,
FilterOperator,
)


class MemoryStore:
"""Simple dictionary-based store for cases."""

def __init__(self):
self.cases = dict()
self.fields = []
self.next_id = 0

def case_by_id(self, id: str):
return self.cases.get(id)

def put_case(self, id: str, case: Case):
"""This is test-only interface for populating the store."""
self.cases[id] = case

def insert_case(self, case: Case):
"""This is the external case insertion API that the case controller uses."""
self.next_id += 1
id = str(self.next_id)
case._id = id
self.put_case(id, case)

def replace_case(self, id: str, case: Case):
self.put_case(id, case)

def update_case(self, id: str, update: DocumentUpdate):
case = self.case_by_id(id)
case.apply_update(update)

def batch_update(self, updates: dict[str, DocumentUpdate]):
for id, update in iter(updates.items()):
self.update_case(id, update)
return len(updates)

def update_case_status(
self, id: str, status: str, exclusion: CaseExclusionMetadata
):
case = self.case_by_id(id)
case.caseReference.status = status
case.caseExclusion = exclusion

def fetch_cases(self, page: int, limit: int, predicate: Filter):
return list(self.cases.values())[(page - 1) * limit : page * limit]

def count_cases(self, predicate: Filter = Anything()):
return len([True for c in self.cases.values() if predicate(c)])

def batch_upsert(self, cases: List[Case]):
for case in cases:
self.insert_case(case)
return len(cases), 0

def excluded_cases(self, source_id: str, filter: Filter):
return [
c
for c in self.cases.values()
if c.caseReference.sourceId == source_id
and c.caseReference.status == "EXCLUDED"
]

def delete_case(self, case_id: str):
del self.cases[case_id]

def delete_cases(self, query: Filter):
self.cases = dict()

def matching_case_iterator(self, query: Filter):
return iter(self.cases.values())

def identified_case_iterator(self, case_ids):
ids_as_ints = [int(x) for x in case_ids]
all_cases = list(self.cases.values())
matching_cases = [all_cases[i] for i in ids_as_ints]
return iter(matching_cases)

def add_field(self, field: Field):
self.fields.append(field)

def get_case_fields(self) -> List[Field]:
return self.fields


def anything_call(self, case: Case):
return True


Anything.__call__ = anything_call


def property_call(self, case: Case):
my_value = self.value
its_value = attrgetter(self.property_name)(case)
match self.operation:
case FilterOperator.LESS_THAN:
return its_value < my_value
case FilterOperator.GREATER_THAN:
return its_value > my_value
case FilterOperator.EQUAL:
return its_value == my_value
case _:
raise ValueError(f"Unhandled operation {self.operation}")


PropertyFilter.__call__ = property_call


def and_call(self, case: Case):
return reduce(and_, [f(case) for f in self.filters])


AndFilter.__call__ = and_call
Loading