Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add method to generate a Contact from an Email #663

Merged
merged 4 commits into from
May 15, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions ctms/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,7 @@ def get_email_or_404(db: Session, email_id) -> Email:
def get_contact_or_404(db: Session, email_id) -> ContactSchema:
"""Get a contact by email_ID, or raise a 404 exception."""
email = get_email_or_404(db, email_id)
return ContactSchema(
amo=email.amo,
email=email,
fxa=email.fxa,
mofo=email.mofo,
newsletters=email.newsletters,
waitlists=email.waitlists,
)
return ContactSchema.from_email(email)


def all_ids(
Expand Down
129 changes: 3 additions & 126 deletions ctms/crud.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import uuid
from collections import defaultdict
from datetime import datetime, timezone
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, cast
Expand Down Expand Up @@ -44,7 +43,6 @@
FirefoxAccountsInSchema,
MozillaFoundationInSchema,
NewsletterInSchema,
ProductBaseSchema,
StripeCustomerCreateSchema,
StripeInvoiceCreateSchema,
StripeInvoiceLineItemCreateSchema,
Expand Down Expand Up @@ -153,19 +151,7 @@ def get_bulk_contacts(
.all()
)

return [
ContactSchema.parse_obj(
{
"amo": email.amo,
"email": email,
"fxa": email.fxa,
"mofo": email.mofo,
"newsletters": email.newsletters,
"waitlists": email.waitlists,
}
)
for email in bulk_contacts
]
return [ContactSchema.from_email(email) for email in bulk_contacts]


def get_email(db: Session, email_id: UUID4) -> Optional[Email]:
Expand All @@ -181,16 +167,7 @@ def get_contact_by_email_id(db: Session, email_id: UUID4) -> Optional[ContactSch
email = get_email(db, email_id)
if email is None:
return None
products = get_stripe_products(email)
return ContactSchema(
amo=email.amo,
email=email,
fxa=email.fxa,
mofo=email.mofo,
newsletters=email.newsletters,
products=products,
waitlists=email.waitlists,
)
return ContactSchema.from_email(email)


def get_contacts_by_any_id(
Expand Down Expand Up @@ -252,17 +229,7 @@ def get_contacts_by_any_id(
fxa_primary_email_insensitive_comparator=fxa_primary_email
)
emails = cast(List[Email], statement.all())
return [
ContactSchema(
amo=email.amo,
email=email,
fxa=email.fxa,
mofo=email.mofo,
newsletters=email.newsletters,
waitlists=email.waitlists,
)
for email in emails
]
return [ContactSchema.from_email(email) for email in emails]


def _acoustic_sync_retry_query(db: Session):
Expand Down Expand Up @@ -784,96 +751,6 @@ def get_stripe_customer_by_fxa_id(
return cast(Optional[StripeCustomer], obj)


def get_stripe_products(email: Email) -> List[ProductBaseSchema]:
"""Return a list of Stripe products for the contact, if any."""
if not email.stripe_customer:
return []

base_data: Dict[str, Any] = {
"payment_service": "stripe",
# These come from the Payment Method, not imported from Stripe.
"payment_type": None,
"card_brand": None,
"card_last4": None,
"billing_country": None,
}
by_product = defaultdict(list)

for subscription in email.stripe_customer.subscriptions:
subscription_data = base_data.copy()
subscription_data.update(
{
"status": subscription.status,
"created": subscription.stripe_created,
"start": subscription.start_date,
"current_period_start": subscription.current_period_start,
"current_period_end": subscription.current_period_end,
"canceled_at": subscription.canceled_at,
"cancel_at_period_end": subscription.cancel_at_period_end,
"ended_at": subscription.ended_at,
}
)
for item in subscription.subscription_items:
product_data = subscription_data.copy()
price = item.price
product_data.update(
{
"product_id": price.stripe_product_id,
"product_name": None, # Products are not imported
"price_id": price.stripe_id,
"currency": price.currency,
"amount": price.unit_amount,
"interval_count": price.recurring_interval_count,
"interval": price.recurring_interval,
}
)
by_product[price.stripe_product_id].append(product_data)

products = []
for subscriptions in by_product.values():
# Sort to find the latest subscription
def get_current_period(sub: Dict) -> datetime:
return cast(datetime, sub["current_period_end"])

subscriptions.sort(key=get_current_period, reverse=True)
latest = subscriptions[0]
data = latest.copy()
if len(subscriptions) == 1:
segment_prefix = ""
else:
segment_prefix = "re-"
if latest["status"] == "active":
if latest["canceled_at"]:
segment = "cancelling"
changed = latest["canceled_at"]
else:
segment = "active"
changed = latest["start"]
elif latest["status"] == "canceled":
segment = "canceled"
changed = latest["ended_at"]
else:
segment_prefix = ""
segment = "other"
changed = latest["created"]

assert changed
data.update(
{
"sub_count": len(subscriptions),
"segment": f"{segment_prefix}{segment}",
"changed": changed,
}
)
products.append(ProductBaseSchema(**data))

def get_product_id(prod: ProductBaseSchema) -> str:
return prod.product_id or ""

products.sort(key=get_product_id)
return products


def get_all_acoustic_fields(dbsession: Session, tablename: Optional[str] = None):
query = dbsession.query(AcousticField).order_by(
asc(AcousticField.tablename), asc(AcousticField.field)
Expand Down
103 changes: 102 additions & 1 deletion ctms/schemas/contact.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import defaultdict
from datetime import datetime
from typing import List, Literal, Optional, Set, Union
from typing import TYPE_CHECKING, List, Literal, Optional, Set, Union, cast
from uuid import UUID

from pydantic import AnyUrl, BaseModel, Field, root_validator, validator
Expand Down Expand Up @@ -27,6 +28,94 @@
validate_waitlist_newsletters,
)

if TYPE_CHECKING:
from models import Email


def _subscriptions_by_product(subscriptions):
by_product = defaultdict(list)

for subscription in subscriptions:
for item in subscription.subscription_items:
price = item.price
product_data = {
"payment_service": "stripe",
###
# These come from the Payment Method, not imported from Stripe.
"payment_type": None,
"card_brand": None,
"card_last4": None,
"billing_country": None,
###
"status": subscription.status,
"created": subscription.stripe_created,
"start": subscription.start_date,
"current_period_start": subscription.current_period_start,
"current_period_end": subscription.current_period_end,
"canceled_at": subscription.canceled_at,
"cancel_at_period_end": subscription.cancel_at_period_end,
"ended_at": subscription.ended_at,
"product_id": price.stripe_product_id,
"product_name": None, # Products are not imported
"price_id": price.stripe_id,
"currency": price.currency,
"amount": price.unit_amount,
"interval_count": price.recurring_interval_count,
"interval": price.recurring_interval,
}
by_product[price.stripe_product_id].append(product_data)
return by_product


def _product_metadata(subscriptions_by_product):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a comment may help here, at least explaining that we pick the latest product from the subscription dict values

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@leplatrem ended up going in a pretty different direction for these functions, so retagging you for review.

latest = max(
subscriptions_by_product,
key=lambda sub: cast(datetime, sub["current_period_end"]),
)
if len(subscriptions_by_product) == 1:
segment_prefix = ""
else:
segment_prefix = "re-"
if latest["status"] == "active":
if latest["canceled_at"]:
segment = "cancelling"
changed = latest["canceled_at"]
else:
segment = "active"
changed = latest["start"]
elif latest["status"] == "canceled":
segment = "canceled"
changed = latest["ended_at"]
else:
segment_prefix = ""
segment = "other"
changed = latest["created"]

assert changed
latest.update(
{
"sub_count": len(subscriptions_by_product),
"segment": f"{segment_prefix}{segment}",
"changed": changed,
}
)
return ProductBaseSchema(**latest)


def get_stripe_products(email: "Email") -> List[ProductBaseSchema]:
"""Return a list of Stripe products for the contact, if any."""
if not email.stripe_customer:
return []
subscription_metadata_by_product = _subscriptions_by_product(
email.stripe_customer.subscriptions
)
products = [
_product_metadata(subscriptions)
for subscriptions in subscription_metadata_by_product.values()
]
products.sort(key=lambda prod: prod.product_id or "")
return products


class ContactSchema(ComparableBase):
"""A complete contact."""
Expand All @@ -39,6 +128,18 @@ class ContactSchema(ComparableBase):
waitlists: List[WaitlistSchema] = []
products: List[ProductBaseSchema] = []

@classmethod
def from_email(cls, email: "Email") -> "ContactSchema":
return cls(
amo=email.amo,
email=email,
fxa=email.fxa,
mofo=email.mofo,
newsletters=email.newsletters,
waitlists=email.waitlists,
products=get_stripe_products(email),
)

class Config:
fields = {
"newsletters": {
Expand Down