diff --git a/ctms/crud.py b/ctms/crud.py index 8d19ba55..abe2d74e 100644 --- a/ctms/crud.py +++ b/ctms/crud.py @@ -11,6 +11,7 @@ from sqlalchemy.orm import Session, joinedload, load_only, selectinload from ctms.schemas.newsletter import NewsletterTableSchema +from ctms.schemas.waitlist import WaitlistTableSchema from .auth import hash_password from .backport_legacy_waitlists import format_legacy_vpn_relay_waitlist_input @@ -478,15 +479,19 @@ def create_or_update_newsletters( def create_waitlist( - db: Session, email_id: UUID4, waitlist: WaitlistInSchema + db: Session, email_id: UUID4, waitlist: WaitlistInSchema | WaitlistTableSchema ) -> Optional[Waitlist]: if waitlist.is_default(): return None - if not isinstance(waitlist, WaitlistInSchema): - # Sample data are used as both input (`WaitlistInSchema`) and internal (`WaitlistSchema`) - # representations. - waitlist = WaitlistInSchema(**waitlist.dict()) - db_waitlist = Waitlist(email_id=email_id, **waitlist.orm_dict()) + + # This is called from API input data with `WaitlistInSchema`, and from tests fixtures + # with `WaitlistTableSchema` + # Unlike `WaitlistTableSchema`, `WaitlistInSchema` has no `email_id`, `subscribed` fields. + attrs = {"email_id": email_id, **waitlist.dict()} + if "subscribed" in attrs: + del attrs["subscribed"] + + db_waitlist = Waitlist(**attrs) db.add(db_waitlist) return db_waitlist @@ -522,7 +527,10 @@ def create_or_update_waitlists( def create_contact( - db: Session, email_id: UUID4, contact: ContactInSchema, metrics: Optional[Dict] + db: Session, + email_id: UUID4, + contact: ContactInSchema | ContactSchema, + metrics: Optional[Dict], ): create_email(db, contact.email) if contact.amo: diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py index 600694a0..a63d55d4 100644 --- a/tests/unit/test_api.py +++ b/tests/unit/test_api.py @@ -180,7 +180,16 @@ def test_post_get_put(client, post_contact, put_contact, update_fetched): resp = client.get(f"/ctms/{email_id}") assert resp.status_code == 200 - fetched = ContactInSchema(**resp.json()) + # TODO: remove this once we remove support of `vpn_waitlist` and + # `relay_waitlist` as input. + # If we don't strip these two fields before turning the data into + # a `ContactInSchema`, they will create waitlist objects. + without_alias_fields = { + k: v + for k, v in resp.json().items() + if k not in ("vpn_waitlist", "relay_waitlist") + } + fetched = ContactInSchema(**without_alias_fields) update_fetched(fetched) new_default_fields = find_default_fields(fetched) # We set new_default_fields here because the returned response above diff --git a/tests/unit/test_api_patch.py b/tests/unit/test_api_patch.py index 0af4681a..734979b3 100644 --- a/tests/unit/test_api_patch.py +++ b/tests/unit/test_api_patch.py @@ -17,6 +17,7 @@ MozillaFoundationInSchema, MozillaFoundationSchema, ) +from ctms.schemas.waitlist import WaitlistInSchema def swap_bool(existing): @@ -187,6 +188,19 @@ def test_patch_cannot_set_timestamps(client, maximal_contact): assert actual["amo"]["update_timestamp"] != new_ts expected["amo"]["update_timestamp"] = actual["amo"]["update_timestamp"] expected["email"]["update_timestamp"] = actual["email"]["update_timestamp"] + # `actual` comes a `CTMSResponse`, and `expected` is a `ContactTableSchema` + # that has timestamps. + # Since this test compares the two instances directly, we strip the timestamps from + # `expected`. + for newsletter in expected["newsletters"]: + del newsletter["email_id"] + del newsletter["create_timestamp"] + del newsletter["update_timestamp"] + for waitlist in expected["waitlists"]: + del waitlist["email_id"] + del waitlist["create_timestamp"] + del waitlist["update_timestamp"] + # products list is not (yet) in output schema assert expected["products"] == [] assert "products" not in actual @@ -323,14 +337,9 @@ def test_patch_to_unsubscribe(client, maximal_contact): """PATCH can unsubscribe by setting a newsletter field.""" email_id = maximal_contact.email.email_id existing_news_data = maximal_contact.newsletters[1].dict() - assert existing_news_data == { - "format": "T", - "lang": "fr", - "name": "common-voice", - "source": "https://commonvoice.mozilla.org/fr", - "subscribed": True, - "unsub_reason": None, - } + assert existing_news_data["subscribed"] + assert existing_news_data["name"] == "common-voice" + assert existing_news_data["unsub_reason"] is None patch_data = { "newsletters": [ { @@ -476,7 +485,9 @@ def test_patch_does_not_add_an_unsubscribed_waitlist(client, maximal_contact): def test_patch_to_update_a_waitlist(client, maximal_contact): """PATCH can update a waitlist.""" email_id = maximal_contact.email.email_id - existing = [wl.dict() for wl in maximal_contact.waitlists] + existing = [ + WaitlistInSchema(**wl.dict()).dict() for wl in maximal_contact.waitlists + ] existing[0]["fields"]["geo"] = "ca" patch_data = {"waitlists": existing} resp = client.patch(f"/ctms/{email_id}", json=patch_data, allow_redirects=True) @@ -492,7 +503,9 @@ def test_patch_to_remove_a_waitlist(client, maximal_contact): """PATCH can remove a single waitlist.""" email_id = maximal_contact.email.email_id existing = [wl.dict() for wl in maximal_contact.waitlists] - patch_data = {"waitlists": [{**existing[-1], "subscribed": False}]} + patch_data = { + "waitlists": [WaitlistInSchema(**existing[-1], subscribed=False).dict()] + } resp = client.patch(f"/ctms/{email_id}", json=patch_data, allow_redirects=True) assert resp.status_code == 200 actual = resp.json() diff --git a/tests/unit/test_bulk.py b/tests/unit/test_bulk.py index 8460c164..784776c9 100644 --- a/tests/unit/test_bulk.py +++ b/tests/unit/test_bulk.py @@ -118,6 +118,15 @@ def test_get_ctms_bulk_by_timerange( # does not have them. del dict_contact_actual["vpn_waitlist"] del dict_contact_actual["relay_waitlist"] + # The reponse does not show `email_id` and timestamp fields. + for newsletter in dict_contact_expected["newsletters"]: + del newsletter["email_id"] + del newsletter["create_timestamp"] + del newsletter["update_timestamp"] + for waitlist in dict_contact_expected["waitlists"]: + del waitlist["email_id"] + del waitlist["create_timestamp"] + del waitlist["update_timestamp"] assert dict_contact_expected == dict_contact_actual assert results["next"] is not None diff --git a/tests/unit/test_crud.py b/tests/unit/test_crud.py index 5e1ec8c1..07bff35a 100644 --- a/tests/unit/test_crud.py +++ b/tests/unit/test_crud.py @@ -454,8 +454,9 @@ def test_get_bulk_contacts_some_after_higher_limit( after_email_id=after_id, ) assert len(bulk_contact_list) == 2 - assert last_contact in bulk_contact_list - assert sorted_list[-2] in bulk_contact_list + bulk_contact_list_ids = [c.email.email_id for c in bulk_contact_list] + assert last_contact.email.email_id in bulk_contact_list_ids + assert sorted_list[-2].email.email_id in bulk_contact_list_ids def test_get_bulk_contacts_some_after( @@ -482,7 +483,7 @@ def test_get_bulk_contacts_some_after( after_email_id=after_id, ) assert len(bulk_contact_list) == 1 - assert last_contact in bulk_contact_list + assert last_contact.email.email_id == bulk_contact_list[0].email.email_id def test_get_bulk_contacts_some( @@ -502,9 +503,10 @@ def test_get_bulk_contacts_some( limit=10, ) assert len(bulk_contact_list) >= 3 - assert example_contact in bulk_contact_list - assert maximal_contact in bulk_contact_list - assert minimal_contact in bulk_contact_list + bulk_contact_list_ids = [c.email.email_id for c in bulk_contact_list] + assert example_contact.email.email_id in bulk_contact_list_ids + assert maximal_contact.email.email_id in bulk_contact_list_ids + assert minimal_contact.email.email_id in bulk_contact_list_ids def test_get_bulk_contacts_one(dbsession, example_contact):