Skip to content

Commit

Permalink
Finish up PUT endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
imbstack committed Mar 22, 2021
1 parent a11f635 commit c6486a6
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 19 deletions.
6 changes: 2 additions & 4 deletions ctms/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,8 @@ def create_or_update_newsletters(
db.query(Newsletter).filter(
Newsletter.email_id == email_id, Newsletter.name.notin_(names)
).delete(
synchronize_session="fetch"
) # TODO: investigate if this is the right sync_session

# TODO: figure out on_conflict here
synchronize_session=False
) # This doesn't need to be synchronized because the next query only alters the other remaining rows. They can happen in whatever order. If you plan to change what the rest of this function does, consider changing this as well!

if newsletters:
stmt = insert(Newsletter).values(
Expand Down
16 changes: 9 additions & 7 deletions ctms/schemas/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,15 @@ def as_identity_response(self) -> "IdentityResponse":
def find_default_fields(self) -> Set[str]:
"""Return names of fields that contain default values only"""
default_fields = set()
if self.amo and self.amo.is_default():
if hasattr(self, "amo") and self.amo and self.amo.is_default():
default_fields.add("amo")
if self.fxa and self.fxa.is_default():
if hasattr(self, "fxa") and self.fxa and self.fxa.is_default():
default_fields.add("fxa")
if self.vpn_waitlist and self.vpn_waitlist.is_default():
if (
hasattr(self, "vpn_waitlist")
and self.vpn_waitlist
and self.vpn_waitlist.is_default()
):
default_fields.add("vpn_waitlist")
if all(n.is_default() for n in self.newsletters):
default_fields.add("newsletters")
Expand Down Expand Up @@ -84,16 +88,14 @@ def _noneify(field):


class ContactInSchema(ContactInBase):
"""A contact as provided by callers."""
"""A contact as provided by callers when using POST. This is nearly identical to the ContactPutSchema but doesn't require an email_id."""

# TODO: Docuement these better
email: EmailInSchema


class ContactPutSchema(ContactInBase):
"""A contact as provided by callers."""
"""A contact as provided by callers when using POST. This is nearly identical to the ContactInSchema but does require an email_id."""

# TODO: Docuement these better
email: EmailPutSchema


Expand Down
4 changes: 4 additions & 0 deletions ctms/schemas/email.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class EmailSchema(EmailBase):


class EmailInSchema(EmailBase):
"""Nearly identical to EmailPutSchema but the email_id is not required."""

email_id: Optional[UUID4] = Field(
default=None,
description=EMAIL_ID_DESCRIPTION,
Expand All @@ -106,6 +108,8 @@ class EmailInSchema(EmailBase):


class EmailPutSchema(EmailBase):
"""Nearly identical to EmailInSchema but the email_id is required."""

email_id: UUID4 = Field(
description=EMAIL_ID_DESCRIPTION,
example=EMAIL_ID_EXAMPLE,
Expand Down
56 changes: 48 additions & 8 deletions tests/unit/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
get_vpn_by_email_id,
)
from ctms.sample_data import SAMPLE_CONTACTS
from ctms.schemas import ContactInSchema, ContactSchema
from ctms.schemas import ContactInSchema, ContactSchema, NewsletterInSchema


def test_get_ctms_for_minimal_contact(client, minimal_contact):
Expand Down Expand Up @@ -424,7 +424,7 @@ def _check_written(field, getter, result_list=False):
else:
written_id = resp.headers["location"].split("/")[-1]
results = getter(dbsession, written_id)
if sample.dict()[field] and code == 303:
if sample.dict().get(field) and code == 303:
if field in fields_not_written:
if result_list:
assert (
Expand Down Expand Up @@ -479,7 +479,7 @@ def _compare_written_contacts(
for f in fields_not_written:
setattr(sample, f, [] if f == "newsletters" else None)

assert saved_contact.dict() == sample.dict()
assert saved_contact.idempotent_equal(sample)


@pytest.mark.parametrize("post_contact", SAMPLE_CONTACTS.keys(), indirect=True)
Expand Down Expand Up @@ -636,7 +636,7 @@ def _check_written(field, getter):
else:
written_id = resp.headers["location"].split("/")[-1]
results = getter(dbsession, written_id)
if sample.dict()[field] and code == 303:
if sample.dict().get(field) and code == 303:
if field in fields_not_written or field in new_default_fields:
assert (
results is None
Expand Down Expand Up @@ -711,9 +711,6 @@ def test_create_or_update_identical(put_contact):
_compare_written_contacts(saved_contacts[0], sample, email_id)


# TODO: make sure to try out the newsletter logic a lot


@pytest.mark.parametrize("put_contact", SAMPLE_CONTACTS.keys(), indirect=True)
def test_create_or_update_change_primary_email(put_contact):
"""We can update a primary_email given a ctms ID"""
Expand Down Expand Up @@ -744,8 +741,50 @@ def _change_basket(contact):
_compare_written_contacts(saved_contacts[0], sample, email_id)


def _subscribe(contact):
contact.newsletters.append(NewsletterInSchema(name="new-newsletter"))


def _unsubscribe(contact):
contact.newsletters = contact.newsletters[0:-1]


def _subscribe_and_change(contact):
if contact.newsletters:
contact.newsletters[-1].subscribed = not contact.newsletters[-1].subscribed
contact.newsletters.append(
NewsletterInSchema(name="a-newsletter", subscribed=False)
)
contact.newsletters.append(
NewsletterInSchema(name="another-newsletter", subscribed=True)
)


def _un_amo(contact):
if contact.amo:
del contact.amo


def _change_email(contact):
contact.email.primary_email = "something-new@some-website.com"


_test_get_put_modifiers = [
_subscribe,
_unsubscribe,
_un_amo,
_change_email,
_subscribe_and_change,
]


@pytest.fixture(params=_test_get_put_modifiers)
def update_fetched(request):
return request.param


@pytest.mark.parametrize("post_contact", SAMPLE_CONTACTS.keys(), indirect=True)
def test_post_get_put(client, post_contact, put_contact):
def test_post_get_put(client, post_contact, put_contact, update_fetched):
"""This encompasses the entire expected flow for basket"""
saved_contacts, sample, email_id = post_contact()
_compare_written_contacts(saved_contacts[0], sample, email_id)
Expand All @@ -754,6 +793,7 @@ def test_post_get_put(client, post_contact, put_contact):
assert resp.status_code == 200

fetched = ContactSchema(**resp.json())
update_fetched(fetched)
new_default_fields = fetched.find_default_fields()
# We set new_default_fields here because the returned response above
# _includes_ defaults for many fields and we want to not write
Expand Down

0 comments on commit c6486a6

Please sign in to comment.