Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#2714 update case #2748

Merged
merged 6 commits into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from data_service.model.case_page import CasePage
from data_service.model.case_reference import CaseReference
from data_service.model.case_upsert_outcome import CaseUpsertOutcome
from data_service.model.document_update import DocumentUpdate
from data_service.model.filter import (
Anything,
Filter,
Expand Down Expand Up @@ -188,7 +189,9 @@ def update_status(id: str, status: str, note: str):
for case in case_iterator:
update_status(case._id, status, note)

def excluded_case_ids(self, source_id: str, query: Optional[str] = None) -> List[str]:
def excluded_case_ids(
self, source_id: str, query: Optional[str] = None
) -> List[str]:
"""Return the identifiers of all excluded cases for a given source."""
if source_id is None:
raise PreconditionUnsatisfiedError("No sourceId provided")
Expand All @@ -197,6 +200,24 @@ def excluded_case_ids(self, source_id: str, query: Optional[str] = None) -> List
raise ValidationError(f"cannot understand query {predicate}")
return [c._id for c in self.store.excluded_cases(source_id, predicate)]

def update_case(self, source_id: str, update: dict) -> Case:
"""Update the case document with the provided ID. Raises NotFoundError if
there is no case with that ID, or ValidationError if the case would not be
left in a valid state. If the update is successfully applied, returns the updated
form of the case."""
case = self.store.case_by_id(source_id)
if case is None:
raise NotFoundError(f"No case with ID {source_id}")
# build the updated version of the case to validate
diff = DocumentUpdate.from_dict(update)
updated_case = case.updated_document(diff)
updated_case.validate()
self.check_case_preconditions(updated_case)
# tell the store to apply the update rather than replacing the whole document:
# should be more efficient given a competent DB
self.store.update_case(source_id, diff)
return updated_case

def create_case_if_valid(self, maybe_case: dict):
"""Attempts to create a case from an input dictionary and validate it against
the application rules. Raises ValidationError or PreconditionUnsatisfiedError on invalid input."""
Expand Down
7 changes: 5 additions & 2 deletions data-serving/reusable-data-service/data_service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
case_controller = None # Will be set up in main()


@app.route("/api/cases/<id>")
@app.route("/api/cases/<id>", methods=["GET", "PUT"])
def get_case(id):
try:
return jsonify(case_controller.get_case(id)), 200
if request.method == "GET":
return jsonify(case_controller.get_case(id)), 200
else:
return jsonify(case_controller.update_case(id, request.get_json())), 200
except WebApplicationError as e:
return jsonify({"message": e.args[0]}), e.http_code

Expand Down
47 changes: 47 additions & 0 deletions data-serving/reusable-data-service/data_service/model/document.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import copy
import dataclasses
import datetime
import operator

from data_service.model.document_update import DocumentUpdate
from data_service.util.json_encoder import JSONEncoder

from typing import List
Expand Down Expand Up @@ -115,3 +117,48 @@ def to_tsv(self) -> str:
def to_csv(self) -> str:
"""Generate a row in a CSV file representing myself."""
return self.delimiter_separated_values(",")

def updated_document(self, update: DocumentUpdate):
"""A copy of myself with the updates applied."""
other = copy.deepcopy(self)
other.apply_update(update)
return other

def apply_update(self, update: DocumentUpdate):
"""Apply a document update to myself."""
for key, value in update.updates_iter():
self._internal_set_value(key, value)
for key in update.unsets_iter():
self._internal_set_value(key, None)

def _internal_set_value(self, key, value):
container, prop = self._internal_object_and_property_for_key_path(key)
# patch up the type for updates created from a JSON API
if container.field_type(prop) == datetime.date and type(value) == str:
value = datetime.date.fromisoformat(value)
setattr(container, prop, value)

def _internal_object_and_property_for_key_path(self, key):
if (dot_index := key.rfind(".")) == -1:
container = self
prop = key
else:
container_key = key[:dot_index]
prop = key[dot_index + 1 :]
container = operator.attrgetter(container_key)(self)
return container, prop

@classmethod
def field_type(cls, prop: str) -> type:
fields = dataclasses.fields(cls)
the_field = [f for f in fields if f.name == prop][0]
return the_field.type

@classmethod
def field_type_for_key_path(cls, key_path: str):
props = key_path.split(".")
a_type = cls
while props != []:
name = props.pop(0)
a_type = [f.type for f in dataclasses.fields(a_type) if f.name == name][0]
return a_type
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import collections


class DocumentUpdate:
"""Represents a collection of changes to a document."""

def __init__(self):
self.updates = dict()
self.unsets = set()

@staticmethod
def from_dict(dict):
"""A dictionary representation of an update looks like this:
{
"foo": 123, # update foo to 123
"bar": None, # unset bar
"sub_document": {
"baz": False, # set sub_document.baz to False
}
}
"""
update = DocumentUpdate()
DocumentUpdate._internal_from_dict(update, dict, "")
return update

@staticmethod
def _internal_from_dict(update, dict, prefix):
for k, v in iter(dict.items()):
if isinstance(v, collections.abc.Mapping):
DocumentUpdate._internal_from_dict(update, v, prefix + k + ".")
else:
update.update(prefix + k, v)

def update(self, key, value):
"""Record that the value at key should be changed to the supplied value."""
if value is None:
self.unsets.add(key)
else:
self.updates[key] = value

def updates_iter(self):
return iter(self.updates.items())

def unsets_iter(self):
return iter(self.unsets)

def __len__(self):
return len(self.updates) + len(self.unsets)
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class Filter:

class Anything(Filter):
"""Represents a lack of constraints."""

def __str__(self) -> str:
return "Anything()"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pymongo
from data_service.model.case import Case
from data_service.model.case_exclusion_metadata import CaseExclusionMetadata
from data_service.model.document_update import DocumentUpdate
from data_service.model.filter import (
Filter,
Anything,
Expand Down Expand Up @@ -107,12 +108,31 @@ def excluded_cases(self, source_id: str, filter: Filter) -> List[Case]:
"caseReference.sourceId": ObjectId(source_id),
"caseReference.status": "EXCLUDED",
},
query
query,
]
}
)
return [Case.from_json(dumps(c)) for c in cases]

def update_case(self, id: str, update: DocumentUpdate):
if len(update) == 0:
return # nothing to do
# TODO convert str to ObjectId
objectify_id = (
lambda k, v: ObjectId(v)
if Case.field_type_for_key_path(k) == ObjectId
else v
)
sets = {key: objectify_id(key, value) for key, value in update.updates_iter()}
unsets = {key: True for key in update.unsets_iter()}
command = dict()
if len(sets) > 0:
command["$set"] = sets
if len(unsets) > 0:
command["$unset"] = unsets

self.get_case_collection().update_one({"_id": ObjectId(id)}, command)

def matching_case_iterator(self, predicate: Filter):
"""Return an object that iterates over cases matching the predicate."""
cases = self.get_case_collection().find(predicate.to_mongo_query())
Expand Down
61 changes: 60 additions & 1 deletion data-serving/reusable-data-service/tests/test_case_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from data_service.controller.case_controller import CaseController
from data_service.model.case import Case
from data_service.model.case_exclusion_metadata import CaseExclusionMetadata
from data_service.model.document_update import DocumentUpdate
from data_service.util.errors import (
NotFoundError,
PreconditionUnsatisfiedError,
Expand Down Expand Up @@ -41,10 +42,13 @@ def insert_case(self, case: Case):
def replace_case(self, id: str, case: Case):
self.put_case(id, case)

def update_case(self, id: str, update: DocumentUpdate):
case = self.case_by_id(id)
case.apply_update(update)

def update_case_status(
self, id: str, status: str, exclusion: CaseExclusionMetadata
):
print(f"updating {id} to {status} with {exclusion}")
case = self.case_by_id(id)
case.caseReference.status = status
case.caseExclusion = exclusion
Expand Down Expand Up @@ -470,3 +474,58 @@ def test_excluded_case_ids_returns_ids_of_matching_cases(case_controller):
ids = case_controller.excluded_case_ids("123ab4567890123ef4567890")
assert len(ids) == 1
assert ids[0] == "1"


def test_updating_missing_case_should_throw_NotFoundError(case_controller):
case_controller.create_case(
{
"confirmationDate": date(2021, 6, 23),
"caseReference": {
"sourceId": "123ab4567890123ef4567890",
"status": "EXCLUDED",
},
"caseExclusion": {
"date": date(2022, 5, 17),
"note": "I told him we already have one",
},
}
)
with pytest.raises(NotFoundError):
case_controller.update_case("2", {"caseExclusion": {"note": "Duplicate"}})


def test_updating_case_to_invalid_state_should_throw_ValidationError(case_controller):
case_controller.create_case(
{
"confirmationDate": date(2021, 6, 23),
"caseReference": {
"sourceId": "123ab4567890123ef4567890",
"status": "EXCLUDED",
},
"caseExclusion": {
"date": date(2022, 5, 17),
"note": "I told him we already have one",
},
}
)
with pytest.raises(ValidationError):
case_controller.update_case("1", {"confirmationDate": None})


def test_updating_case_to_valid_state_returns_updated_case(case_controller):
case_controller.create_case(
{
"confirmationDate": date(2021, 6, 23),
"caseReference": {
"sourceId": "123ab4567890123ef4567890",
"status": "EXCLUDED",
},
"caseExclusion": {
"date": date(2022, 5, 17),
"note": "I told him we already have one",
},
}
)

new_case = case_controller.update_case("1", {"confirmationDate": date(2021, 6, 24)})
assert new_case.confirmationDate == date(2021, 6, 24)
52 changes: 49 additions & 3 deletions data-serving/reusable-data-service/tests/test_case_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,12 +402,12 @@ def test_filter_excluded_case_ids(client_with_patched_mongo):
"confirmationDate": datetime(2022, 5, i),
"caseReference": {
"sourceId": bson.ObjectId("fedc12345678901234567890"),
"status": "EXCLUDED"
"status": "EXCLUDED",
},
"caseExclusion": {
"date": datetime(2022, 6, i),
"note": f"Excluded upon this day, the {i}th of June"
}
"note": f"Excluded upon this day, the {i}th of June",
},
}
for i in range(1, 4)
]
Expand All @@ -422,3 +422,49 @@ def test_filter_excluded_case_ids(client_with_patched_mongo):
assert str(inserted_ids[0]) in ids
assert str(inserted_ids[1]) in ids
assert str(inserted_ids[2]) not in ids


def test_update_case(client_with_patched_mongo):
db = pymongo.MongoClient("mongodb://localhost:27017/outbreak")
inserted = (
db["outbreak"]["cases"]
.insert_one(
{
"confirmationDate": datetime(2022, 5, 10),
"caseReference": {
"sourceId": bson.ObjectId("fedc12345678901234567890")
},
}
)
.inserted_id
)
put_response = client_with_patched_mongo.put(
f"/api/cases/{str(inserted)}", json={"confirmationDate": "2022-05-11"}
)
assert put_response.status_code == 200
assert put_response.get_json()["confirmationDate"] == "2022-05-11"


def test_update_object_id_on_case(client_with_patched_mongo):
db = pymongo.MongoClient("mongodb://localhost:27017/outbreak")
inserted = (
db["outbreak"]["cases"]
.insert_one(
{
"confirmationDate": datetime(2022, 5, 10),
"caseReference": {
"sourceId": bson.ObjectId("fedc12345678901234567890")
},
}
)
.inserted_id
)
put_response = client_with_patched_mongo.put(
f"/api/cases/{str(inserted)}",
json={"caseReference": {"sourceId": "fedc1234567890123456789a"}},
)
assert put_response.status_code == 200
assert (
put_response.get_json()["caseReference"]["sourceId"]
== "fedc1234567890123456789a"
)
Loading