Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2714 batch upsert #2739

Merged
merged 13 commits into from
Jun 29, 2022
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,33 @@ def validate_case_dictionary(self, maybe_case: dict):
# PreconditionError means it's a case, but not one we can use
return pe.args[0], 422

def batch_upsert(self, body: dict):
"""Upsert a collection of cases (updating ones that already exist, inserting
new cases). This method can potentially return a 207 mixed status as each case is
handled separately. The response will report the number of cases inserted, the
number updated, and any validation errors encountered."""
if body is None:
return "", 415
cases = body.get("cases")
if cases is None:
return "", 400
if len(cases) == 0:
return "", 400
errors = {}
usable_cases = []
for i, maybe_case in enumerate(cases):
try:
case = self.create_case_if_valid(maybe_case)
usable_cases.append(case)
except Exception as e:
errors[str(i)] = e.args[0]
(created, updated) = (
self.store.batch_upsert(usable_cases) if len(usable_cases) > 0 else (0, 0)
)
status = 200 if len(errors) == 0 else 207
response = {"numCreated": created, "numUpdated": updated, "errors": errors}
return jsonify(response), status

def create_case_if_valid(self, maybe_case: dict):
"""Attempts to create a case from an input dictionary and validate it against
the application rules. Raises ValueError or PreconditionError on invalid input."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from datetime import date
from flask import Flask, request
from . import CaseController, MongoStore
from reusable_data_service.util.iso_json_encoder import ISOJSONEncoder
from reusable_data_service.util.iso_json_encoder import DataServiceJSONEncoder

import os
import logging

app = Flask(__name__)
app.json_encoder = ISOJSONEncoder
app.json_encoder = DataServiceJSONEncoder

case_controller = None # Will be set up in main()

Expand Down Expand Up @@ -35,6 +35,11 @@ def list_cases():
return case_controller.create_case(potential_case, num_cases=count)


@app.route("/api/cases/batchUpsert", methods=["POST"])
def batch_upsert_cases():
return case_controller.batch_upsert(request.get_json())


def set_up_controllers():
global case_controller
store_options = {"mongodb": MongoStore.setup}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from typing import Any

from reusable_data_service.model.case_reference import CaseReference


@dataclasses.dataclass()
class DayZeroCase:
Expand All @@ -18,7 +20,13 @@ class DayZeroCase:
to that function)."""

_: dataclasses.KW_ONLY
"""_id is treated as an opaque identifier by the model, allowing the
store to use whatever format it needs to uniquely identify a stored case.
The _id is allowed to be None, for cases that have been created but not
yet saved into a store."""
_id: str = dataclasses.field(init=False, default=None)
confirmationDate: datetime.date = dataclasses.field(init=False)
caseReference: CaseReference = dataclasses.field(init=False, default=None)

@classmethod
def from_json(cls, obj: str) -> type:
Expand Down Expand Up @@ -47,6 +55,18 @@ def from_dict(cls, dictionary: dict[str, Any]) -> type:
).date()
else:
raise ValueError(f"Cannot interpret date {maybe_date}")
elif key == "caseReference":
caseRef = dictionary[key]
value = (
CaseReference.from_dict(caseRef) if caseRef is not None else None
)
elif key == "_id":
the_id = dictionary[key]
if isinstance(the_id, dict):
# this came from a BSON objectID representation
value = the_id["$oid"]
else:
value = the_id
else:
value = dictionary[key]
setattr(case, key, value)
Expand All @@ -59,16 +79,20 @@ def validate(self):
raise ValueError("Confirmation Date is mandatory")
elif self.confirmationDate is None:
raise ValueError("Confirmation Date must have a value")
if not hasattr(self, "caseReference"):
raise ValueError("Case Reference is mandatory")
elif self.caseReference is None:
raise ValueError("Case Reference must have a value")
self.caseReference.validate()

def to_dict(self):
"""Return myself as a dictionary."""
return dataclasses.asdict(self)

@classmethod
def date_fields(cls) -> list[str]:
"""Record where dates are kept because they sometimes need special treatment.
A subclass could override this method to indicate it stores additional date fields."""
return ["confirmationDate"]
"""Record where dates are kept because they sometimes need special treatment."""
return [f.name for f in dataclasses.fields(cls) if f.type == datetime.date]


# Actually we want to capture extra fields which can be specified dynamically:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import bson
import dataclasses


@dataclasses.dataclass
class CaseReference:
"""Represents information about the source of a given case."""

_: dataclasses.KW_ONLY
sourceId: bson.ObjectId = dataclasses.field(init=False, default=None)

def validate(self):
"""Check whether I am consistent. Raise ValueError if not."""
if not hasattr(self, "sourceId"):
raise ValueError("Source ID is mandatory")
elif self.sourceId is None:
raise ValueError("Source ID must have a value")

@staticmethod
def from_dict(d: dict[str, str]):
"""Create a CaseReference from a dictionary representation."""
ref = CaseReference()
if "sourceId" in d:
theId = d["sourceId"]
if isinstance(theId, str):
ref.sourceId = bson.ObjectId(theId)
elif "$oid" in theId:
ref.sourceId = bson.ObjectId(theId["$oid"])
else:
raise ValueError(f"Cannot interpret {theId} as an ObjectId")
return ref
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from bson.errors import InvalidId
from bson.json_util import dumps
from bson.objectid import ObjectId

from typing import List, Tuple

class MongoStore:
"""A line list store backed by mongodb."""
Expand Down Expand Up @@ -51,12 +51,17 @@ def count_cases(self, filter: Filter) -> int:
return self.get_case_collection().count_documents(filter.to_mongo_query())

def insert_case(self, case: Case):
to_insert = case.to_dict()
for field in Case.date_fields():
# BSON works with datetimes, not dates
to_insert[field] = date_to_datetime(to_insert[field])
to_insert = MongoStore.case_to_bson_compatible_dict(case)
self.get_case_collection().insert_one(to_insert)

def batch_upsert(self, cases: List[Case]) -> Tuple[int, int]:
to_insert = [MongoStore.case_to_bson_compatible_dict(c) for c in cases if c._id is None]
to_replace = {c._id: MongoStore.case_to_bson_compatible_dict(c) for c in cases if c._id is not None}
inserts = [pymongo.InsertOne(d) for d in to_insert]
replacements = [pymongo.ReplaceOne({ "_id": k}, v) for (k,v) in to_replace.items()]
results = self.get_case_collection().bulk_write(inserts + replacements)
return results.inserted_count, results.modified_count

@staticmethod
def setup():
"""Configure a store instance from the environment."""
Expand All @@ -68,6 +73,21 @@ def setup():
)
return mongo_store

@staticmethod
def case_to_bson_compatible_dict(case: Case):
"""Turn a case into a representation that mongo will accept."""
bson_case = case.to_dict()
# Mongo mostly won't like having the _id left around: for inserts
# it will try to use the (None) _id and fail, and for updates it
# will complain that you're trying to rewrite the _id (to the same)
# value it already had! Therefore remove it always here. If you find
# a case where mongo wants the _id in a document, add it back for that
# operation.
del bson_case['_id']
for field in Case.date_fields():
# BSON works with datetimes, not dates
bson_case[field] = date_to_datetime(bson_case[field])
return bson_case

def date_to_datetime(dt: datetime.date) -> datetime.datetime:
"""Convert datetime.date to datetime.datetime for encoding as BSON"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import flask.json
import bson
import datetime


class ISOJSONEncoder(flask.json.JSONEncoder):
class DataServiceJSONEncoder(flask.json.JSONEncoder):
def default(self, obj):
try:
if isinstance(obj, datetime.date):
return obj.isoformat()
elif isinstance(obj, bson.ObjectId):
return str(obj)
iterable = iter(obj)
except TypeError:
pass
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
{
"confirmationDate": "2021-12-31T01:23:45.678Z"
"confirmationDate": "2021-12-31T01:23:45.678Z",
"caseReference": {
"sourceId": "fedc09876543210987654321"
}
}
74 changes: 67 additions & 7 deletions data-serving/reusable-data-service/tests/test_case_controller.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
import json
from datetime import date
from typing import List

from reusable_data_service import Case, CaseController, app

Expand All @@ -15,11 +17,11 @@ def case_by_id(self, id: str):
return self.cases.get(id)

def put_case(self, id: str, case: Case):
"""Used in the tests to populate the store."""
"""This is test-only interface for populating the store."""
self.cases[id] = case

def insert_case(self, case: Case):
"""Used by the controller to insert a new case."""
"""This is the external case insertion API that the case controller uses."""
self.next_id += 1
self.put_case(str(self.next_id), case)

Expand All @@ -29,6 +31,11 @@ def fetch_cases(self, page: int, limit: int, *args):
def count_cases(self, *args):
return len(self.cases)

def batch_upsert(self, cases: List[Case]):
for case in cases:
self.insert_case(case)
return len(cases), 0


@pytest.fixture
def case_controller():
Expand Down Expand Up @@ -112,29 +119,43 @@ def test_create_case_with_missing_properties_400_error(case_controller):

def test_create_case_with_invalid_data_422_error(case_controller):
(response, status) = case_controller.create_case(
{"confirmationDate": date(2001, 3, 17)}
{
"confirmationDate": date(2001, 3, 17),
"caseReference": {"sourceId": "123ab4567890123ef4567890"},
}
)
assert status == 422


def test_create_valid_case_adds_to_collection(case_controller):
(response, status) = case_controller.create_case(
{"confirmationDate": date(2021, 6, 3)}
{
"confirmationDate": date(2021, 6, 3),
"caseReference": {"sourceId": "123ab4567890123ef4567890"},
}
)
assert status == 201
assert case_controller.store.count_cases() == 1


def test_create_valid_case_with_negative_count_400_error(case_controller):
(response, status) = case_controller.create_case(
{"confirmationDate": date(2021, 6, 3)}, num_cases=-7
{
"confirmationDate": date(2021, 6, 3),
"caseReference": {"sourceId": "123ab4567890123ef4567890"},
},
num_cases=-7,
)
assert status == 400


def test_create_valid_case_with_positive_count_adds_to_collection(case_controller):
(response, status) = case_controller.create_case(
{"confirmationDate": date(2021, 6, 3)}, num_cases=7
{
"confirmationDate": date(2021, 6, 3),
"caseReference": {"sourceId": "123ab4567890123ef4567890"},
},
num_cases=7,
)
assert status == 201
assert case_controller.store.count_cases() == 7
Expand All @@ -149,7 +170,46 @@ def test_validate_case_with_valid_case_returns_204_and_does_not_add_case(
case_controller,
):
(response, status) = case_controller.validate_case_dictionary(
{"confirmationDate": date(2021, 6, 3)}
{
"confirmationDate": date(2021, 6, 3),
"caseReference": {"sourceId": "123ab4567890123ef4567890"},
}
)
assert status == 204
assert case_controller.store.count_cases() == 0


def test_batch_upsert_with_no_body_returns_415(case_controller):
(response, status) = case_controller.batch_upsert(None)
assert status == 415


def test_batch_upsert_with_no_case_list_returns_400(case_controller):
(response, status) = case_controller.batch_upsert({})
assert status == 400


def test_batch_upsert_with_empty_case_list_returns_400(case_controller):
(response, status) = case_controller.batch_upsert({"cases": []})
assert status == 400


def test_batch_upsert_creates_valid_case(case_controller):
with open("./tests/data/case.minimal.json", "r") as minimal_file:
minimal_case_description = json.loads(minimal_file.read())
(response, status) = case_controller.batch_upsert(
{"cases": [minimal_case_description]}
)
assert status == 200
assert case_controller.store.count_cases() == 1
assert response.json["numCreated"] == 1
assert response.json["numUpdated"] == 0
assert response.json["errors"] == {}


def test_batch_upsert_reports_errors(case_controller):
(response, status) = case_controller.batch_upsert({"cases": [{}]})
assert status == 207
assert response.json["numCreated"] == 0
assert response.json["numUpdated"] == 0
assert response.json["errors"] == {"0": "Confirmation Date is mandatory"}
Loading