diff --git a/data-serving/reusable-data-service/data_service/__init__.py b/data-serving/reusable-data-service/data_service/__init__.py index 96cd370ac..53afb94fb 100644 --- a/data-serving/reusable-data-service/data_service/__init__.py +++ b/data-serving/reusable-data-service/data_service/__init__.py @@ -1,7 +1,3 @@ __version__ = "0.1.0" -from .model.case import Case -from .model.case_reference import CaseReference -from .controller.case_controller import CaseController -from .stores.mongo_store import MongoStore -from .main import app, main, set_up_controllers +from .main import app, main diff --git a/data-serving/reusable-data-service/data_service/controller/case_controller.py b/data-serving/reusable-data-service/data_service/controller/case_controller.py index 692b6cbc2..494119282 100644 --- a/data-serving/reusable-data-service/data_service/controller/case_controller.py +++ b/data-serving/reusable-data-service/data_service/controller/case_controller.py @@ -1,9 +1,11 @@ from flask import jsonify from datetime import date -from typing import List +from typing import List, Optional from data_service.model.case import Case +from data_service.model.case_exclusion_metadata import CaseExclusionMetadata from data_service.model.case_page import CasePage +from data_service.model.case_reference import CaseReference from data_service.model.case_upsert_outcome import CaseUpsertOutcome from data_service.model.filter import ( Anything, @@ -148,6 +150,44 @@ def generate_output(): return generate_output + def batch_status_change( + self, + status: str, + note: Optional[str] = None, + case_ids: Optional[List[str]] = None, + filter: Optional[str] = None, + ): + """Update all of the cases identified in case_ids to have the supplied curation status. + Raises PreconditionUnsatisfiedError or ValidationError on invalid input.""" + statuses = CaseReference.valid_statuses() + if not status in statuses: + raise PreconditionUnsatisfiedError(f"status {status} not one of {statuses}") + if filter is not None and case_ids is not None: + raise PreconditionUnsatisfiedError( + "Do not supply both a filter and a list of IDs" + ) + if status == "EXCLUDED" and note is None: + raise ValidationError(f"Excluding cases must be documented in a note") + + def update_status(id: str, status: str, note: str): + if status == "EXCLUDED": + caseExclusion = CaseExclusionMetadata() + caseExclusion.note = note + else: + caseExclusion = None + self.store.update_case_status(id, status, caseExclusion) + + if case_ids is not None: + for anId in case_ids: + update_status(anId, status, note) + else: + predicate = CaseController.parse_filter(filter) + if predicate is None: + raise ValidationError(f"cannot understand query {filter}") + case_iterator = self.store.matching_case_iterator(predicate) + for case in case_iterator: + update_status(case._id, status, note) + def create_case_if_valid(self, maybe_case: dict): """Attempts to create a case from an input dictionary and validate it against the application rules. Raises ValidationError or PreconditionUnsatisfiedError on invalid input.""" diff --git a/data-serving/reusable-data-service/data_service/main.py b/data-serving/reusable-data-service/data_service/main.py index 4ed44e610..5d0307b04 100644 --- a/data-serving/reusable-data-service/data_service/main.py +++ b/data-serving/reusable-data-service/data_service/main.py @@ -1,6 +1,7 @@ from datetime import date from flask import Flask, jsonify, request -from . import CaseController, MongoStore +from data_service.controller.case_controller import CaseController +from data_service.stores.mongo_store import MongoStore from data_service.util.errors import ( PreconditionUnsatisfiedError, UnsupportedTypeError, @@ -88,6 +89,21 @@ def download_cases(): return jsonify({"message": e.args[0]}), e.http_code +@app.route("/api/cases/batchStatusChange", methods=["POST"]) +def batch_status_change(): + try: + req = request.get_json() + case_controller.batch_status_change( + status=req.get("status"), + note=req.get("note"), + case_ids=req.get("caseIds"), + filter=req.get("query"), + ) + return "", 204 + except WebApplicationError as e: + return jsonify({"message": e.args[0]}), e.http_code + + def set_up_controllers(): global case_controller store_options = {"mongodb": MongoStore.setup} diff --git a/data-serving/reusable-data-service/data_service/model/case.py b/data-serving/reusable-data-service/data_service/model/case.py index ccd6e3801..dbf658669 100644 --- a/data-serving/reusable-data-service/data_service/model/case.py +++ b/data-serving/reusable-data-service/data_service/model/case.py @@ -5,16 +5,17 @@ from typing import Any, List +from data_service.model.case_exclusion_metadata import CaseExclusionMetadata from data_service.model.case_reference import CaseReference +from data_service.model.document import Document from data_service.util.errors import ( PreconditionUnsatisfiedError, ValidationError, ) -from data_service.util.json_encoder import JSONEncoder @dataclasses.dataclass() -class DayZeroCase: +class DayZeroCase(Document): """This class implements the "day-zero" data schema for Global.health. At the beginning of an outbreak, we want to collect at least this much information about an individual case for the line list. @@ -33,6 +34,7 @@ class DayZeroCase: _id: str = dataclasses.field(init=False, default=None) confirmationDate: datetime.date = dataclasses.field(init=False) caseReference: CaseReference = dataclasses.field(init=False, default=None) + caseExclusion: CaseExclusionMetadata = dataclasses.field(init=False, default=None) @classmethod def from_json(cls, obj: str) -> type: @@ -45,27 +47,19 @@ def from_dict(cls, dictionary: dict[str, Any]) -> type: case = cls() for key in dictionary: if key in cls.date_fields(): - # handle a few different ways dates get represented in dictionaries - maybe_date = dictionary[key] - if isinstance(maybe_date, datetime.datetime): - value = maybe_date.date() - elif isinstance(maybe_date, datetime.date): - value = maybe_date - elif isinstance(maybe_date, str): - value = datetime.datetime.strptime( - maybe_date, "%Y-%m-%dT%H:%M:%S.%fZ" - ).date() - elif isinstance(maybe_date, dict) and "$date" in maybe_date: - value = datetime.datetime.strptime( - maybe_date["$date"], "%Y-%m-%dT%H:%M:%SZ" - ).date() - else: - raise ValueError(f"Cannot interpret date {maybe_date}") + value = cls.interpret_date(dictionary[key]) elif key == "caseReference": caseRef = dictionary[key] value = ( CaseReference.from_dict(caseRef) if caseRef is not None else None ) + elif key == "caseExclusion": + exclusion = dictionary[key] + value = ( + CaseExclusionMetadata.from_dict(exclusion) + if exclusion is not None + else None + ) elif key == "_id": the_id = dictionary[key] if isinstance(the_id, dict): @@ -91,83 +85,6 @@ def validate(self): raise ValidationError("Case Reference must have a value") self.caseReference.validate() - def to_dict(self): - """Return myself as a dictionary.""" - return dataclasses.asdict(self) - - def to_json(self): - """Return myself as JSON""" - return JSONEncoder().encode(self.to_dict()) - - @classmethod - def date_fields(cls) -> list[str]: - """Record where dates are kept because they sometimes need special treatment.""" - return [f.name for f in dataclasses.fields(cls) if f.type == datetime.date] - - @classmethod - def field_names(cls) -> List[str]: - """The list of names of fields in this class and member dataclasses.""" - fields = [] - for f in dataclasses.fields(cls): - if dataclasses.is_dataclass(f.type): - fields += [f"{f.name}.{g.name}" for g in dataclasses.fields(f.type)] - else: - fields.append(f.name) - return fields - - @classmethod - def delimiter_separated_header(cls, sep: str) -> str: - """Create a line naming all of the fields in this class and member dataclasses.""" - return sep.join(cls.field_names()) + "\n" - - @classmethod - def tsv_header(cls) -> str: - """Generate the header row for a TSV file containing members of this class.""" - return cls.delimiter_separated_header("\t") - - @classmethod - def csv_header(cls) -> str: - """Generate the header row for a CSV file containing members of this class.""" - return cls.delimiter_separated_header(",") - - @classmethod - def json_header(cls) -> str: - """The start of a JSON array.""" - return "[" - - @classmethod - def json_footer(cls) -> str: - """The end of a JSON array.""" - return "]" - - @classmethod - def json_separator(cls) -> str: - """The string between values in a JSON array.""" - return "," - - def field_values(self) -> List[str]: - """The list of values of fields on this object and member dataclasses.""" - fields = [] - for f in dataclasses.fields(self): - value = getattr(self, f.name) - if dataclasses.is_dataclass(f.type): - fields.append(value.to_csv()) - else: - fields.append(str(value) if value is not None else "") - return fields - - def delimiter_separated_values(self, sep: str) -> str: - """Create a line listing all of the fields in me and my member dataclasses.""" - return sep.join(self.field_values()) + "\n" - - def to_tsv(self) -> str: - """Generate a row in a CSV file representing myself.""" - return self.delimiter_separated_values("\t") - - def to_csv(self) -> str: - """Generate a row in a CSV file representing myself.""" - return self.delimiter_separated_values(",") - # Actually we want to capture extra fields which can be specified dynamically: # so Case is the class that you should use. diff --git a/data-serving/reusable-data-service/data_service/model/case_exclusion_metadata.py b/data-serving/reusable-data-service/data_service/model/case_exclusion_metadata.py new file mode 100644 index 000000000..97c7d8a5a --- /dev/null +++ b/data-serving/reusable-data-service/data_service/model/case_exclusion_metadata.py @@ -0,0 +1,34 @@ +import dataclasses +import datetime + +from typing import Any + +from data_service.model.document import Document + + +@dataclasses.dataclass +class CaseExclusionMetadata(Document): + """If a case is excluded, record when and why.""" + + _: dataclasses.KW_ONLY + note: str = dataclasses.field(init=False, default=None) + date: datetime.date = dataclasses.field( + init=False, default=None + ) # Populate at initialisation time, not class load time + + def __post_init__(self): + self.date = datetime.datetime.now().date() + + @classmethod + def exclude_from_download(cls): + return True + + @classmethod + def from_dict(cls, dictionary: dict[str, Any]) -> type: + """Create a CaseExclusionMetadata from a dictionary representation.""" + exclusion = CaseExclusionMetadata() + exclusion.note = dictionary.get("note") + exclusion.date = cls.interpret_date(dictionary.get("date")) + if exclusion.date is None: + raise ValueError(f"date missing in CaseExclusion document {dict}") + return exclusion diff --git a/data-serving/reusable-data-service/data_service/model/case_reference.py b/data-serving/reusable-data-service/data_service/model/case_reference.py index edbc3274e..5786c43f4 100644 --- a/data-serving/reusable-data-service/data_service/model/case_reference.py +++ b/data-serving/reusable-data-service/data_service/model/case_reference.py @@ -1,13 +1,16 @@ import bson import dataclasses +from data_service.model.document import Document + @dataclasses.dataclass -class CaseReference: +class CaseReference(Document): """Represents information about the source of a given case.""" _: dataclasses.KW_ONLY sourceId: bson.ObjectId = dataclasses.field(init=False, default=None) + status: str = dataclasses.field(init=False, default="UNVERIFIED") def validate(self): """Check whether I am consistent. Raise ValueError if not.""" @@ -15,6 +18,13 @@ def validate(self): raise ValueError("Source ID is mandatory") elif self.sourceId is None: raise ValueError("Source ID must have a value") + if self.status not in self.valid_statuses(): + raise ValueError(f"Status {self.status} is not acceptable") + + @staticmethod + def valid_statuses(): + """A case reference must have one of these statuses.""" + return ["EXCLUDED", "UNVERIFIED", "VERIFIED"] @staticmethod def from_dict(d: dict[str, str]): @@ -28,14 +38,5 @@ def from_dict(d: dict[str, str]): ref.sourceId = bson.ObjectId(theId["$oid"]) else: raise ValueError(f"Cannot interpret {theId} as an ObjectId") + ref.status = d["status"] if "status" in d else "UNVERIFIED" return ref - - def to_csv(self) -> str: - """Generate a row in a CSV file representing myself.""" - fields = [] - for f in dataclasses.fields(self): - if dataclasses.is_dataclass(f.type): - fields.append(getattr(self, f.name).to_csv()) - else: - fields.append(str(getattr(self, f.name))) - return ",".join(fields) diff --git a/data-serving/reusable-data-service/data_service/model/document.py b/data-serving/reusable-data-service/data_service/model/document.py new file mode 100644 index 000000000..fb756d9fb --- /dev/null +++ b/data-serving/reusable-data-service/data_service/model/document.py @@ -0,0 +1,116 @@ +import dataclasses +import datetime + +from data_service.util.json_encoder import JSONEncoder + +from typing import List + + +@dataclasses.dataclass +class Document: + """The base class for anything that's going into the database.""" + + def to_dict(self): + """Me, as a dictionary.""" + return dataclasses.asdict(self) + + def to_json(self): + """Return myself as JSON""" + return JSONEncoder().encode(self.to_dict()) + + @classmethod + def date_fields(cls) -> list[str]: + """Record where dates are kept because they sometimes need special treatment.""" + return [f.name for f in dataclasses.fields(cls) if f.type == datetime.date] + + @staticmethod + def interpret_date(maybe_date) -> datetime.date: + value = None + if maybe_date is None: + value = None + if isinstance(maybe_date, datetime.datetime): + value = maybe_date.date() + elif isinstance(maybe_date, datetime.date): + value = maybe_date + elif isinstance(maybe_date, str): + value = datetime.datetime.strptime( + maybe_date, "%Y-%m-%dT%H:%M:%S.%fZ" + ).date() + elif isinstance(maybe_date, dict) and "$date" in maybe_date: + value = datetime.datetime.strptime( + maybe_date["$date"], "%Y-%m-%dT%H:%M:%SZ" + ).date() + else: + raise ValueError(f"Cannot interpret date {maybe_date}") + return value + + @classmethod + def field_names(cls) -> List[str]: + """The list of names of fields in this class and member dataclasses.""" + fields = [] + for f in dataclasses.fields(cls): + if dataclasses.is_dataclass(f.type): + if cls.include_dataclass_fields(f.type): + fields += [f"{f.name}.{g.name}" for g in dataclasses.fields(f.type)] + else: + fields.append(f.name) + return fields + + @classmethod + def delimiter_separated_header(cls, sep: str) -> str: + """Create a line naming all of the fields in this class and member dataclasses.""" + return sep.join(cls.field_names()) + "\n" + + @classmethod + def tsv_header(cls) -> str: + """Generate the header row for a TSV file containing members of this class.""" + return cls.delimiter_separated_header("\t") + + @classmethod + def csv_header(cls) -> str: + """Generate the header row for a CSV file containing members of this class.""" + return cls.delimiter_separated_header(",") + + @classmethod + def json_header(cls) -> str: + """The start of a JSON array.""" + return "[" + + @classmethod + def json_footer(cls) -> str: + """The end of a JSON array.""" + return "]" + + @classmethod + def json_separator(cls) -> str: + """The string between values in a JSON array.""" + return "," + + def field_values(self) -> List[str]: + """The list of values of fields on this object and member dataclasses.""" + fields = [] + for f in dataclasses.fields(self): + value = getattr(self, f.name) + if issubclass(f.type, Document): + if self.include_dataclass_fields(f.type): + fields += value.field_values() + else: + fields.append(str(value) if value is not None else "") + return fields + + @staticmethod + def include_dataclass_fields(aType: type): + test_exclusion = getattr(aType, "exclude_from_download", None) + return test_exclusion is None or test_exclusion() is False + + def delimiter_separated_values(self, sep: str) -> str: + """Create a line listing all of the fields in me and my member dataclasses.""" + return sep.join(self.field_values()) + "\n" + + def to_tsv(self) -> str: + """Generate a row in a CSV file representing myself.""" + return self.delimiter_separated_values("\t") + + def to_csv(self) -> str: + """Generate a row in a CSV file representing myself.""" + return self.delimiter_separated_values(",") diff --git a/data-serving/reusable-data-service/data_service/stores/mongo_store.py b/data-serving/reusable-data-service/data_service/stores/mongo_store.py index 820964c14..556dd8c7b 100644 --- a/data-serving/reusable-data-service/data_service/stores/mongo_store.py +++ b/data-serving/reusable-data-service/data_service/stores/mongo_store.py @@ -2,6 +2,7 @@ import os import pymongo from data_service.model.case import Case +from data_service.model.case_exclusion_metadata import CaseExclusionMetadata from data_service.model.filter import ( Filter, Anything, @@ -61,6 +62,25 @@ def insert_case(self, case: Case): to_insert = MongoStore.case_to_bson_compatible_dict(case) self.get_case_collection().insert_one(to_insert) + def replace_case(self, id: str, case: Case): + to_replace = MongoStore.case_to_bson_compatible_dict(case) + oid = ObjectId(id) + result = self.get_case_collection().replace_one({"_id": oid}, to_replace) + if result.modified_count != 1: + raise ValueError("Did not update any documents!") + + def update_case_status( + self, id: str, status: str, exclusion: CaseExclusionMetadata + ): + update = {"$set": {"caseReference.status": status}} + if exclusion: + update["$set"][ + "caseExclusion" + ] = self.case_exclusion_to_bson_compatible_dict(exclusion) + else: + update["$unset"] = {"caseExclusion": True} + self.get_case_collection().update_one({"_id": ObjectId(id)}, update) + def batch_upsert(self, cases: List[Case]) -> Tuple[int, int]: to_insert = [ MongoStore.case_to_bson_compatible_dict(c) for c in cases if c._id is None @@ -110,16 +130,29 @@ def case_to_bson_compatible_dict(case: Case): bson_case = case.to_dict() # Mongo mostly won't like having the _id left around: for inserts # it will try to use the (None) _id and fail, and for updates it - # will complain that you're trying to rewrite the _id (to the same) - # value it already had! Therefore remove it always here. If you find + # will complain that you're trying to rewrite the _id (to the same + # value it already had, although because it treats the string value as + # different from the ObjectId value)! Therefore remove it always here. If you find # a case where mongo wants the _id in a document, add it back for that # operation. del bson_case["_id"] + # BSON works with datetimes, not dates for field in Case.date_fields(): - # BSON works with datetimes, not dates bson_case[field] = date_to_datetime(bson_case[field]) + if case.caseExclusion is not None: + bson_case[ + "caseExclusion" + ] = MongoStore.case_exclusion_to_bson_compatible_dict(case.caseExclusion) return bson_case + @staticmethod + def case_exclusion_to_bson_compatible_dict(exclusion: CaseExclusionMetadata): + """Turn a case exclusion document into a representation that mongo will accept.""" + bson_exclusion = exclusion.to_dict() + for field in CaseExclusionMetadata.date_fields(): + bson_exclusion[field] = date_to_datetime(bson_exclusion[field]) + return bson_exclusion + def date_to_datetime(dt: datetime.date) -> datetime.datetime: """Convert datetime.date to datetime.datetime for encoding as BSON""" diff --git a/data-serving/reusable-data-service/data_service/util/errors.py b/data-serving/reusable-data-service/data_service/util/errors.py index 7316b744a..74ba6ff31 100644 --- a/data-serving/reusable-data-service/data_service/util/errors.py +++ b/data-serving/reusable-data-service/data_service/util/errors.py @@ -1,23 +1,28 @@ class WebApplicationError(Exception): """Represents something going wrong on a web service.""" + http_code = 500 class PreconditionUnsatisfiedError(WebApplicationError): """Represents a bad request.""" + http_code = 400 class NotFoundError(WebApplicationError): """Represents a missing resource.""" + http_code = 404 class UnsupportedTypeError(WebApplicationError): """Something received a type it couldn't work with.""" + http_code = 415 class ValidationError(WebApplicationError): """Represents invalid data""" + http_code = 422 diff --git a/data-serving/reusable-data-service/poetry.lock b/data-serving/reusable-data-service/poetry.lock index 7beeb9cc3..ae507032f 100644 --- a/data-serving/reusable-data-service/poetry.lock +++ b/data-serving/reusable-data-service/poetry.lock @@ -94,6 +94,17 @@ Werkzeug = ">=2.0" async = ["asgiref (>=3.2)"] dotenv = ["python-dotenv"] +[[package]] +name = "freezegun" +version = "1.2.1" +description = "Let your Python tests travel through time" +category = "dev" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +python-dateutil = ">=2.7" + [[package]] name = "iniconfig" version = "1.1.1" @@ -256,6 +267,17 @@ tomli = ">=1.0.0" [package.extras] testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"] +[[package]] +name = "python-dateutil" +version = "2.8.2" +description = "Extensions to the standard Python datetime module" +category = "dev" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" + +[package.dependencies] +six = ">=1.5" + [[package]] name = "sentinels" version = "1.0.0" @@ -294,7 +316,7 @@ watchdog = ["watchdog"] [metadata] lock-version = "1.1" python-versions = "^3.10" -content-hash = "833bf0d9ec25246698e3559bd27c2d277ce097e1af9bb49db50af6c0106fe0f7" +content-hash = "95f4e76e65ed782dae0f6f9c95060577d666e2dbe38f009a3976063409e629b8" [metadata.files] atomicwrites = [ @@ -346,6 +368,10 @@ flask = [ {file = "Flask-2.1.2-py3-none-any.whl", hash = "sha256:fad5b446feb0d6db6aec0c3184d16a8c1f6c3e464b511649c8918a9be100b4fe"}, {file = "Flask-2.1.2.tar.gz", hash = "sha256:315ded2ddf8a6281567edb27393010fe3406188bafbfe65a3339d5787d89e477"}, ] +freezegun = [ + {file = "freezegun-1.2.1-py3-none-any.whl", hash = "sha256:15103a67dfa868ad809a8f508146e396be2995172d25f927e48ce51c0bf5cb09"}, + {file = "freezegun-1.2.1.tar.gz", hash = "sha256:b4c64efb275e6bc68dc6e771b17ffe0ff0f90b81a2a5189043550b6519926ba4"}, +] iniconfig = [ {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"}, {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, @@ -521,6 +547,10 @@ pytest = [ {file = "pytest-7.1.2-py3-none-any.whl", hash = "sha256:13d0e3ccfc2b6e26be000cb6568c832ba67ba32e719443bfe725814d3c42433c"}, {file = "pytest-7.1.2.tar.gz", hash = "sha256:a06a0425453864a270bc45e71f783330a7428defb4230fb5e6a731fde06ecd45"}, ] +python-dateutil = [ + {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, + {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, +] sentinels = [ {file = "sentinels-1.0.0.tar.gz", hash = "sha256:7be0704d7fe1925e397e92d18669ace2f619c92b5d4eb21a89f31e026f9ff4b1"}, ] diff --git a/data-serving/reusable-data-service/pyproject.toml b/data-serving/reusable-data-service/pyproject.toml index 183ca9216..b71719d5c 100644 --- a/data-serving/reusable-data-service/pyproject.toml +++ b/data-serving/reusable-data-service/pyproject.toml @@ -13,6 +13,7 @@ pymongo = {extras = ["srv"], version = "^4.1.1"} pytest = "^7.1.2" mongomock = "^4.0.0" black = "^22.6.0" +freezegun = "^1.2.1" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/data-serving/reusable-data-service/tests/test_case_controller.py b/data-serving/reusable-data-service/tests/test_case_controller.py index ddcc7b861..968d471b5 100644 --- a/data-serving/reusable-data-service/tests/test_case_controller.py +++ b/data-serving/reusable-data-service/tests/test_case_controller.py @@ -1,9 +1,14 @@ +import freezegun import pytest import json + from datetime import date from typing import List -from data_service import Case, CaseController, app +from data_service import app +from data_service.controller.case_controller import CaseController +from data_service.model.case import Case +from data_service.model.case_exclusion_metadata import CaseExclusionMetadata from data_service.util.errors import ( NotFoundError, PreconditionUnsatisfiedError, @@ -33,6 +38,17 @@ def insert_case(self, case: Case): case._id = id self.put_case(id, case) + def replace_case(self, id: str, case: Case): + self.put_case(id, case) + + def update_case_status( + self, id: str, status: str, exclusion: CaseExclusionMetadata + ): + print(f"updating {id} to {status} with {exclusion}") + case = self.case_by_id(id) + case.caseReference.status = status + case.caseExclusion = exclusion + def fetch_cases(self, page: int, limit: int, *args): return list(self.cases.values())[(page - 1) * limit : page * limit] @@ -317,3 +333,94 @@ def test_download_supports_json(case_controller): assert len(result) == 2 assert result[0]["confirmationDate"] == "2021-06-03" assert result[1]["caseReference"]["sourceId"] == "123ab4567890123ef4567890" + + +def test_batch_status_change_rejects_invalid_status(case_controller): + with pytest.raises(PreconditionUnsatisfiedError): + case_controller.batch_status_change("xxx", case_ids=[]) + + +def test_batch_status_change_rejects_exclusion_with_no_note(case_controller): + with pytest.raises(ValidationError): + case_controller.batch_status_change("EXCLUDED", case_ids=[]) + + +def test_batch_status_change_excludes_cases_with_note(case_controller): + for i in range(4): + _ = case_controller.create_case( + { + "confirmationDate": date(2021, 6, i + 1), + "caseReference": {"sourceId": "123ab4567890123ef4567890"}, + }, + ) + case_controller.batch_status_change( + "EXCLUDED", "I dislike this case", case_ids=["1", "2"] + ) + an_excluded_case = case_controller.store.case_by_id("1") + assert an_excluded_case.caseReference.status == "EXCLUDED" + assert an_excluded_case.caseExclusion.note == "I dislike this case" + another_case = case_controller.store.case_by_id("3") + assert another_case.caseReference.status == "UNVERIFIED" + assert another_case.caseExclusion is None + + +@freezegun.freeze_time("Aug 13th, 2021") +def test_batch_status_change_records_date_of_exclusion(case_controller): + case_controller.create_case( + { + "confirmationDate": date(2021, 6, 23), + "caseReference": { + "sourceId": "123ab4567890123ef4567890", + }, + } + ) + + case_controller.batch_status_change( + "EXCLUDED", "Mistakes have been made", case_ids=["1"] + ) + + case = case_controller.store.case_by_id("1") + assert case.caseReference.status == "EXCLUDED" + assert case.caseExclusion.note == "Mistakes have been made" + assert case.caseExclusion.date == date(2021, 8, 13) + + +def test_batch_status_change_removes_exclusion_data_on_unexcluding_case( + case_controller, +): + case_controller.create_case( + { + "confirmationDate": date(2021, 6, 23), + "caseReference": { + "sourceId": "123ab4567890123ef4567890", + }, + } + ) + + case_controller.batch_status_change( + "EXCLUDED", "Mistakes have been made", case_ids=["1"] + ) + case_controller.batch_status_change("UNVERIFIED", case_ids=["1"]) + + case = case_controller.store.case_by_id("1") + assert case.caseReference.status == "UNVERIFIED" + assert case.caseExclusion is None + + +def test_batch_status_change_by_query(case_controller): + case_controller.create_case( + { + "confirmationDate": date(2021, 6, 23), + "caseReference": { + "sourceId": "123ab4567890123ef4567890", + }, + } + ) + + case_controller.batch_status_change( + "EXCLUDED", "Mistakes have been made", filter="dateconfirmedafter:2021-06-01" + ) + + case = case_controller.store.case_by_id("1") + assert case.caseReference.status == "EXCLUDED" + assert case.caseExclusion is not None diff --git a/data-serving/reusable-data-service/tests/test_case_end_to_end.py b/data-serving/reusable-data-service/tests/test_case_end_to_end.py index 02575a207..7960eca18 100644 --- a/data-serving/reusable-data-service/tests/test_case_end_to_end.py +++ b/data-serving/reusable-data-service/tests/test_case_end_to_end.py @@ -1,11 +1,13 @@ import pytest import bson +import freezegun import mongomock import pymongo from datetime import datetime -from data_service import app, set_up_controllers +from data_service import app +from data_service.main import set_up_controllers @pytest.fixture @@ -336,3 +338,29 @@ def test_download_selected_cases_tsv(client_with_patched_mongo): assert len(cases) == 2 assert cases[0]["confirmationDate"] == "2022-05-01" assert cases[1]["confirmationDate"] == "2022-05-03" + + +def test_exclude_selected_cases(client_with_patched_mongo): + db = pymongo.MongoClient("mongodb://localhost:27017/outbreak") + inserted = ( + db["outbreak"]["cases"] + .insert_one( + { + "confirmationDate": datetime(2022, 5, 10), + "caseReference": { + "sourceId": bson.ObjectId("fedc12345678901234567890") + }, + } + ) + .inserted_id + ) + post_response = client_with_patched_mongo.post( + "/api/cases/batchStatusChange", + json={"status": "EXCLUDED", "caseIds": [str(inserted)], "note": "Duplicate"}, + ) + assert post_response.status_code == 204 + get_response = client_with_patched_mongo.get(f"/api/cases/{str(inserted)}") + assert get_response.status_code == 200 + document = get_response.get_json() + assert document["caseReference"]["status"] == "EXCLUDED" + assert document["caseExclusion"]["note"] == "Duplicate" diff --git a/data-serving/reusable-data-service/tests/test_case_model.py b/data-serving/reusable-data-service/tests/test_case_model.py index 9753f983f..94c19cf06 100644 --- a/data-serving/reusable-data-service/tests/test_case_model.py +++ b/data-serving/reusable-data-service/tests/test_case_model.py @@ -1,7 +1,8 @@ import pytest import bson from datetime import date -from data_service import Case, CaseReference +from data_service.model.case import Case +from data_service.model.case_reference import CaseReference from data_service.util.errors import ValidationError @@ -18,7 +19,10 @@ def test_case_from_minimal_json_is_valid(): def test_csv_header(): header_line = Case.csv_header() - assert header_line == "_id,confirmationDate,caseReference.sourceId\n" + assert ( + header_line + == "_id,confirmationDate,caseReference.sourceId,caseReference.status\n" + ) def test_csv_row_with_no_id(): @@ -30,7 +34,7 @@ def test_csv_row_with_no_id(): case.confirmationDate = date(2022, 6, 13) case.caseReference = ref csv = case.to_csv() - assert csv == ",2022-06-13,abcd12903478565647382910\n" + assert csv == ",2022-06-13,abcd12903478565647382910,UNVERIFIED\n" def test_csv_row_with_id(): @@ -44,4 +48,4 @@ def test_csv_row_with_id(): case.confirmationDate = date(2022, 6, 13) case.caseReference = ref csv = case.to_csv() - assert csv == f"{id1},2022-06-13,{id2}\n" + assert csv == f"{id1},2022-06-13,{id2},UNVERIFIED\n" diff --git a/data-serving/reusable-data-service/tests/test_case_reference.py b/data-serving/reusable-data-service/tests/test_case_reference.py index 70e417d5a..ee6377756 100644 --- a/data-serving/reusable-data-service/tests/test_case_reference.py +++ b/data-serving/reusable-data-service/tests/test_case_reference.py @@ -4,10 +4,30 @@ from data_service.model.case_reference import CaseReference -def test_csv_row(): +def test_csv_row_unexcluded(): identifier = "abcd12903478565647382910" oid = bson.ObjectId(identifier) ref = CaseReference() ref.sourceId = oid csv = ref.to_csv() - assert csv == identifier + assert csv == "abcd12903478565647382910,UNVERIFIED\n" + + +def test_csv_row_excluded(): + identifier = "abcd12903478565647382910" + oid = bson.ObjectId(identifier) + ref = CaseReference() + ref.sourceId = oid + ref.status = "EXCLUDED" + csv = ref.to_csv() + assert csv == "abcd12903478565647382910,EXCLUDED\n" + + +def test_reference_must_have_valid_status(): + identifier = "abcd12903478565647382910" + oid = bson.ObjectId(identifier) + ref = CaseReference() + ref.sourceId = oid + ref.status = "BANANA" + with pytest.raises(ValueError): + ref.validate()