Skip to content

Commit

Permalink
Configure factories to automatically commit when called (#951)
Browse files Browse the repository at this point in the history
* Use traits + relatedfactories for fxa, mofo, and amo

* set sqlalchemy_session_persistence = "commit" in base factory

* Autouse dbsession in tests, with the option to disable

* Build newsletters in schema tests, instead of create

* Modify tests to remove dbsessions that no longer need them
  • Loading branch information
grahamalama authored Aug 1, 2024
1 parent 17ea7d6 commit e2e4380
Show file tree
Hide file tree
Showing 16 changed files with 121 additions and 185 deletions.
40 changes: 20 additions & 20 deletions tests/factories/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class BaseSQLAlchemyModelFactory(SQLAlchemyModelFactory):
class Meta:
abstract = True
sqlalchemy_session = ScopedSessionLocal
sqlalchemy_session_persistence = "commit"


class NewsletterFactory(BaseSQLAlchemyModelFactory):
Expand Down Expand Up @@ -136,26 +137,25 @@ def waitlists(self, create, extracted, **kwargs):
for _ in range(extracted):
WaitlistFactory(email=self, **kwargs)

@factory.post_generation
def fxa(self, create, extracted, **kwargs):
if not create:
return
if extracted:
FirefoxAccountFactory(email=self, **kwargs)

@factory.post_generation
def mofo(self, create, extracted, **kwargs):
if not create:
return
if extracted:
MozillaFoundationContactFactory(email=self, **kwargs)

@factory.post_generation
def amo(self, create, extracted, **kwargs):
if not create:
return
if extracted:
AmoAccountFactory(email=self, **kwargs)
class Params:
with_fxa = factory.Trait(
fxa=factory.RelatedFactory(
FirefoxAccountFactory,
factory_related_name="email",
)
)
with_amo = factory.Trait(
amo=factory.RelatedFactory(
AmoAccountFactory,
factory_related_name="email",
)
)
with_mofo = factory.Trait(
mofo=factory.RelatedFactory(
MozillaFoundationContactFactory,
factory_related_name="email",
)
)


__all__ = (
Expand Down
17 changes: 10 additions & 7 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,17 +132,20 @@ def connection(engine):
conn.close()


@pytest.fixture
def dbsession(connection):
@pytest.fixture(autouse=True)
def dbsession(request, connection):
"""Return a database session that rolls back.
Adapted from https://docs.sqlalchemy.org/en/20/orm/session_transaction.html#joining-a-session-into-an-external-transaction-such-as-for-test-suites
"""
transaction = connection.begin()
session = ScopedSessionLocal()
yield session
session.close()
transaction.rollback()
if "disable_autouse" in request.keywords:
yield
else:
transaction = connection.begin()
session = ScopedSessionLocal()
yield session
session.close()
transaction.rollback()


# Database models
Expand Down
5 changes: 1 addition & 4 deletions tests/unit/routers/contacts/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,13 @@ def test_unauthorized_api_call_fails(anon_client, method, path, params):


@pytest.mark.parametrize("method,path,params", API_TEST_CASES)
def test_authorized_api_call_succeeds(
client, dbsession, email_factory, method, path, params
):
def test_authorized_api_call_succeeds(client, email_factory, method, path, params):
"""Calling the API with credentials succeeds."""

email_factory(
email_id="332de237-cab7-4461-bcc3-48e68f42bd5c",
primary_email="contact@example.com",
)
dbsession.commit()

if method == "GET":
resp = client.request(method, path, params=params)
Expand Down
8 changes: 2 additions & 6 deletions tests/unit/routers/contacts/test_api_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,17 @@ def test_delete_contact_by_primary_email_not_found(client):
assert resp.status_code == 404


def test_delete_contact_by_primary_email(client, dbsession, email_factory):
def test_delete_contact_by_primary_email(client, email_factory):
primary_email = email_factory().primary_email
dbsession.commit()

resp = client.delete(f"/ctms/{primary_email}")
assert resp.status_code == 200
resp = client.delete(f"/ctms/{primary_email}")
assert resp.status_code == 404


def test_delete_contact_by_primary_email_with_basket_token_unset(
client, dbsession, email_factory
):
def test_delete_contact_by_primary_email_with_basket_token_unset(client, email_factory):
email = email_factory(basket_token=None)
dbsession.commit()

resp = client.delete(f"/ctms/{email.primary_email}")
assert resp.status_code == 200
4 changes: 2 additions & 2 deletions tests/unit/routers/contacts/test_api_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import pytest


def test_get_ctms_for_minimal_contact(client, dbsession, email_factory):
def test_get_ctms_for_minimal_contact(client, email_factory):
"""GET /ctms/{email_id} returns a contact with most fields unset."""
contact = email_factory(newsletters=1)
newsletter = contact.newsletters[0]
dbsession.commit()

email_id = str(contact.email_id)
resp = client.get(f"/ctms/{email_id}")
assert resp.status_code == 200
Expand Down
11 changes: 4 additions & 7 deletions tests/unit/routers/contacts/test_api_get_by_alt_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,22 @@
("mofo_email_id", "195207d2-63f2-4c9f-b149-80e9c408477a"),
],
)
def test_get_ctms_by_alt_id(
dbsession, email_factory, client, alt_id_name, alt_id_value
):
def test_get_ctms_by_alt_id(email_factory, client, alt_id_name, alt_id_value):
"""The desired contact can be fetched by alternate ID."""
email_factory(
email_id="67e52c77-950f-4f28-accb-bb3ea1a2c51a",
primary_email="mozilla-fan@example.com",
basket_token="d9ba6182-f5dd-4728-a477-2cc11bf62b69",
sfdc_id="001A000001aMozFan",
amo=True,
with_amo=True,
amo__user_id="123",
fxa=True,
with_fxa=True,
fxa__fxa_id="611b6788-2bba-42a6-98c9-9ce6eb9cbd34",
fxa__primary_email="fxa-firefox-fan@example.com",
mofo=True,
with_mofo=True,
mofo__mofo_contact_id="5e499cc0-eeb5-4f0e-aae6-a101721874b8",
mofo__mofo_email_id="195207d2-63f2-4c9f-b149-80e9c408477a",
)
dbsession.commit()

resp = client.get("/ctms", params={alt_id_name: alt_id_value})
assert resp.status_code == 200
Expand Down
Loading

0 comments on commit e2e4380

Please sign in to comment.