Skip to content

Commit

Permalink
Remove indirection in getting contacts (#661)
Browse files Browse the repository at this point in the history
* Return Contact object directly from get_contact_by_email_id

Also remove `get_acoustic_record_as_contact`, because now that is the
same as the new behavior of `get_contact_by_email_id`

* Use get_contacts_by_any_id for all queries with multiple IDs
  • Loading branch information
grahamalama authored May 4, 2023
1 parent 8f6de8e commit 7323dab
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 223 deletions.
54 changes: 6 additions & 48 deletions ctms/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
get_api_client_by_id,
get_bulk_contacts,
get_contact_by_email_id,
get_contacts_by_any_id,
get_email,
get_emails_by_any_id,
schedule_acoustic_record,
update_contact,
)
Expand Down Expand Up @@ -190,48 +190,6 @@ def all_ids(
}


def get_contacts_by_ids(
db: Session,
email_id: Optional[UUID] = None,
primary_email: Optional[str] = None,
basket_token: Optional[UUID] = None,
sfdc_id: Optional[str] = None,
mofo_contact_id: Optional[str] = None,
mofo_email_id: Optional[str] = None,
amo_user_id: Optional[str] = None,
fxa_id: Optional[str] = None,
fxa_primary_email: Optional[str] = None,
) -> List[ContactSchema]:
"""Get contacts by any ID.
Callers are expected to set just one ID, but if multiple are set, a contact
must match all IDs.
"""
rows = get_emails_by_any_id(
db,
email_id,
primary_email,
basket_token,
sfdc_id,
mofo_contact_id,
mofo_email_id,
amo_user_id,
fxa_id,
fxa_primary_email,
)
return [
ContactSchema(
amo=email.amo,
email=email,
fxa=email.fxa,
mofo=email.mofo,
newsletters=email.newsletters,
waitlists=email.waitlists,
)
for email in rows
]


def get_bulk_contacts_by_timestamp_or_4xx(
db: Session,
start_time: datetime,
Expand Down Expand Up @@ -481,7 +439,7 @@ def read_ctms_by_any_id(
f"No identifiers provided, at least one is needed: {', '.join(ids.keys())}"
)
raise HTTPException(status_code=400, detail=detail)
contacts = get_contacts_by_ids(db, **ids)
contacts = get_contacts_by_any_id(db, **ids)
traced = set()
for contact in contacts:
email = contact.email.primary_email
Expand Down Expand Up @@ -539,11 +497,11 @@ def create_ctms_contact(
email_id = contact.email.email_id
existing = get_contact_by_email_id(db, email_id)
if existing:
email = existing["email"].primary_email
email = existing.email.primary_email
if re_trace_email.match(email):
request.state.log_context["trace"] = email
request.state.log_context["trace_json"] = content_json
if ContactInSchema(**existing).idempotent_equal(contact):
if ContactInSchema(**existing.dict()).idempotent_equal(contact):
response.headers["Location"] = f"/ctms/{email_id}"
response.status_code = 200
return get_ctms_response_or_404(db=db, email_id=email_id)
Expand Down Expand Up @@ -683,7 +641,7 @@ def delete_contact_by_primary_email(
api_client: ApiClientSchema = Depends(get_enabled_api_client),
):
ids = all_ids(primary_email=primary_email.lower())
contacts = get_contacts_by_ids(db, **ids)
contacts = get_contacts_by_any_id(db, **ids)

if not contacts:
raise HTTPException(status_code=404, detail=f"email {primary_email} not found!")
Expand Down Expand Up @@ -749,7 +707,7 @@ def read_identities(
f"No identifiers provided, at least one is needed: {', '.join(ids.keys())}"
)
raise HTTPException(status_code=400, detail=detail)
contacts = get_contacts_by_ids(db, **ids)
contacts = get_contacts_by_any_id(db, **ids)
return [contact.as_identity_response() for contact in contacts]


Expand Down
96 changes: 26 additions & 70 deletions ctms/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,25 +176,24 @@ def get_email(db: Session, email_id: UUID4) -> Optional[Email]:
)


def get_contact_by_email_id(db: Session, email_id: UUID4) -> Optional[Dict]:
"""Get all the data for a contact, as a dict."""
def get_contact_by_email_id(db: Session, email_id: UUID4) -> Optional[ContactSchema]:
"""Return a Contact object for a given email id"""
email = get_email(db, email_id)
if email is None:
return None
products = []
products.extend(get_stripe_products(email))
return {
"amo": email.amo,
"email": email,
"fxa": email.fxa,
"mofo": email.mofo,
"newsletters": email.newsletters,
"products": products,
"waitlists": email.waitlists,
}
products = get_stripe_products(email)
return ContactSchema(
amo=email.amo,
email=email,
fxa=email.fxa,
mofo=email.mofo,
newsletters=email.newsletters,
products=products,
waitlists=email.waitlists,
)


def get_emails_by_any_id(
def get_contacts_by_any_id(
db: Session,
email_id: Optional[UUID4] = None,
primary_email: Optional[str] = None,
Expand All @@ -205,9 +204,9 @@ def get_emails_by_any_id(
amo_user_id: Optional[str] = None,
fxa_id: Optional[str] = None,
fxa_primary_email: Optional[str] = None,
) -> List[Email]:
) -> List[ContactSchema]:
"""
Get all the data for multiple contacts by IDs as a list of Email instances.
Get all the data for multiple contacts by ID as a list of Contacts.
Newsletters are retrieved in batches of 500 email_ids, so it will be two
queries for most calls.
Expand Down Expand Up @@ -252,52 +251,18 @@ def get_emails_by_any_id(
statement = statement.join(Email.fxa).filter_by(
fxa_primary_email_insensitive_comparator=fxa_primary_email
)
return cast(List[Email], statement.all())


def get_contacts_by_any_id(
db: Session,
email_id: Optional[UUID4] = None,
primary_email: Optional[str] = None,
basket_token: Optional[UUID4] = None,
sfdc_id: Optional[str] = None,
mofo_contact_id: Optional[str] = None,
mofo_email_id: Optional[str] = None,
amo_user_id: Optional[str] = None,
fxa_id: Optional[str] = None,
fxa_primary_email: Optional[str] = None,
) -> List[Dict]:
"""
Get all the data for multiple contacts by ID as a list of dicts.
Newsletters are retrieved in batches of 500 email_ids, so it will be two
queries for most calls.
"""
emails = get_emails_by_any_id(
db,
email_id,
primary_email,
basket_token,
sfdc_id,
mofo_contact_id,
mofo_email_id,
amo_user_id,
fxa_id,
fxa_primary_email,
)
data = []
for email in emails:
data.append(
{
"amo": email.amo,
"email": email,
"fxa": email.fxa,
"mofo": email.mofo,
"newsletters": email.newsletters,
"waitlists": email.waitlists,
}
emails = cast(List[Email], statement.all())
return [
ContactSchema(
amo=email.amo,
email=email,
fxa=email.fxa,
mofo=email.mofo,
newsletters=email.newsletters,
waitlists=email.waitlists,
)
return data
for email in emails
]


def _acoustic_sync_retry_query(db: Session):
Expand Down Expand Up @@ -359,15 +324,6 @@ def get_all_acoustic_records_count(
return pending_records_count


def get_acoustic_record_as_contact(
db: Session,
record: PendingAcousticRecord,
) -> ContactSchema:
contact = get_contact_by_email_id(db, record.email_id)
contact_schema: ContactSchema = ContactSchema.parse_obj(contact)
return contact_schema


def bulk_schedule_acoustic_records(db: Session, primary_emails: list[str]):
"""Mark a list of primary email as pending synchronization."""
statement = _contact_base_query(db).filter(Email.primary_email.in_(primary_emails))
Expand Down
7 changes: 2 additions & 5 deletions ctms/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
from ctms.background_metrics import BackgroundMetricService
from ctms.crud import (
delete_acoustic_record,
get_acoustic_record_as_contact,
get_all_acoustic_records_before,
get_all_acoustic_records_count,
get_all_acoustic_retries_count,
get_contact_by_email_id,
retry_acoustic_record,
)
from ctms.models import AcousticField, AcousticNewsletterMapping, PendingAcousticRecord
from ctms.schemas import ContactSchema


class CTMSToAcousticSync:
Expand Down Expand Up @@ -65,9 +64,7 @@ def _sync_pending_record(
try:
sync_error = None
if self.is_acoustic_enabled:
contact: ContactSchema = get_acoustic_record_as_contact(
db, pending_record
)
contact = get_contact_by_email_id(db, pending_record.email_id)
try:
self.ctms_to_acoustic.attempt_to_upload_ctms_contact(
contact, main_fields, newsletters_mapping
Expand Down
12 changes: 3 additions & 9 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,10 +525,7 @@ def _add(
assert resp.status_code == code, resp.text
if check_redirect:
assert resp.headers["location"] == f"/ctms/{sample.email.email_id}"
saved = [
ContactSchema(**c)
for c in get_contacts_by_any_id(dbsession, **query_fields)
]
saved = get_contacts_by_any_id(dbsession, **query_fields)
assert len(saved) == stored_contacts

# Now make sure that we skip writing default models
Expand Down Expand Up @@ -612,10 +609,7 @@ def _add(
sample = modifier(sample)
resp = client.put(f"/ctms/{sample.email.email_id}", sample.json())
assert resp.status_code == code, resp.text
saved = [
ContactSchema(**c)
for c in get_contacts_by_any_id(dbsession, **query_fields)
]
saved = get_contacts_by_any_id(dbsession, **query_fields)
assert len(saved) == stored_contacts

# Now make sure that we skip writing default models
Expand Down Expand Up @@ -883,7 +877,7 @@ def contact_with_stripe_subscription(
dbsession.commit()

contact = get_contact_by_email_id(dbsession, stripe_customer.email.email_id)
return ContactSchema(**contact)
return contact


@pytest.fixture
Expand Down
Loading

0 comments on commit 7323dab

Please sign in to comment.