Skip to content

Commit

Permalink
Fix #642: count total number of contacts (#759)
Browse files Browse the repository at this point in the history
* Fix #642: count total number of contacts

* Rewrite using a count estimate

* Lint

* Add comment about count

* Do not rely on vacuum/anaylyze in metrics tests
  • Loading branch information
leplatrem authored Jul 18, 2023
1 parent c092dd3 commit bd8b23d
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 4 deletions.
15 changes: 15 additions & 0 deletions ctms/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,21 @@ def ping(db: Session):
return False


def count_total_contacts(db: Session):
"""Return the total number of email records.
Since the table is huge, we rely on the PostgreSQL internal
catalog to retrieve an approximate size efficiently.
This metadata is refreshed on `VACUUM` or `ANALYSIS` which
is run regularly by default on our database instances.
"""
query = text(
"SELECT reltuples AS estimate "
"FROM pg_class "
f"where relname = '{Email.__tablename__}'"
)
return int(db.execute(query).first()["estimate"])


def get_amo_by_email_id(db: Session, email_id: UUID4):
return db.query(AmoAccount).filter(AmoAccount.email_id == email_id).one_or_none()

Expand Down
15 changes: 12 additions & 3 deletions ctms/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@

from fastapi import FastAPI
from fastapi.security import HTTPBasic
from prometheus_client import CollectorRegistry, Counter, Histogram
from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram
from prometheus_client.utils import INF
from sqlalchemy.orm import Session
from starlette.routing import Route

from ctms.auth import OAuth2ClientCredentials
from ctms.crud import get_active_api_client_ids

METRICS_PARAMS: dict[str, tuple[Type[Counter] | Type[Histogram], dict]] = {
METRICS_PARAMS: dict[
str, tuple[Type[Counter] | Type[Histogram] | type[Gauge], dict]
] = {
"requests": (
Counter,
{
Expand Down Expand Up @@ -63,6 +65,13 @@
"documentation": "Total count of API calls that use the legacy waitlists format",
},
),
"contacts": (
Gauge,
{
"name": "ctms_contacts_total",
"documentation": "Total count of contacts in the database",
},
),
}

# We could use the default prometheus_client.REGISTRY, but it makes tests
Expand All @@ -84,7 +93,7 @@ def set_metrics(metrics: Any) -> None:
token_scheme = HTTPBasic(auto_error=False)


def init_metrics(registry: CollectorRegistry) -> dict[str, Counter | Histogram]:
def init_metrics(registry: CollectorRegistry) -> dict[str, Counter | Histogram | Gauge]:
"""Initialize the metrics with the registry."""
metrics = {}
for name, init_bits in METRICS_PARAMS.items():
Expand Down
10 changes: 9 additions & 1 deletion ctms/routers/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from ctms.config import Settings, get_version
from ctms.crud import (
count_total_contacts,
get_all_acoustic_fields,
get_all_acoustic_newsletters_mapping,
get_api_client_by_id,
Expand All @@ -29,7 +30,7 @@
get_settings,
get_token_settings,
)
from ctms.metrics import get_metrics_registry, token_scheme
from ctms.metrics import get_metrics, get_metrics_registry, token_scheme
from ctms.schemas.api_client import ApiClientSchema
from ctms.schemas.web import BadRequestResponse, TokenResponse

Expand Down Expand Up @@ -150,6 +151,13 @@ def heartbeat(
**details,
}

if alive and (appmetrics := get_metrics()):
# Report number of contacts in the database.
# Sending the metric in this heartbeat endpoint is simpler than reporting
# it in every write endpoint. Plus, performance does not matter much here.
total_contacts = count_total_contacts(db)
appmetrics["contacts"].set(total_contacts)

status_code = 200 if (alive and acoustic_success) else 503
return JSONResponse(content=result, status_code=status_code)

Expand Down
34 changes: 34 additions & 0 deletions tests/unit/test_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sqlalchemy.orm import Session

from ctms.crud import (
count_total_contacts,
create_acoustic_field,
create_acoustic_newsletters_mapping,
create_amo,
Expand All @@ -31,6 +32,7 @@
retry_acoustic_record,
schedule_acoustic_record,
)
from ctms.database import ScopedSessionLocal
from ctms.models import (
AcousticField,
AcousticNewsletterMapping,
Expand All @@ -51,6 +53,38 @@
pytestmark = pytest.mark.filterwarnings("error::sqlalchemy.exc.SAWarning")


def test_email_count(connection, email_factory):
# The default `dbsession` fixture will run in a nested transaction
# that is rollback.
# In this test, we manipulate raw connections and transactions because
# we need to force a VACUUM operation outside a running transaction.

# Insert contacts in the table.
transaction = connection.begin()
session = ScopedSessionLocal()
email_factory.create_batch(3)
session.commit()
session.close()
transaction.commit()

# Force an analysis of the table.
old_isolation_level = connection.connection.isolation_level
connection.connection.set_isolation_level(0)
session.execute(sqlalchemy.text(f"VACUUM ANALYZE {Email.__tablename__}"))
session.close()
connection.connection.set_isolation_level(old_isolation_level)

# Query the count result (since last analyze)
session = ScopedSessionLocal()
count = count_total_contacts(session)
assert count == 3

# Delete created objects (since our transaction was not rollback automatically)
session.query(Email).delete()
session.commit()
session.close()


def test_get_email(dbsession, email_factory):
email = email_factory()
dbsession.commit()
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Test for metrics
from unittest import mock

import pytest
from prometheus_client import CollectorRegistry, generate_latest
from prometheus_client.parser import text_string_to_metric_families
Expand Down Expand Up @@ -165,6 +167,14 @@ def test_homepage_request(anon_client, registry):
assert_duration_metric_obs(registry, "GET", "/docs", "2xx")


def test_contacts_total(anon_client, dbsession, registry):
"""Total number of contacts is reported in heartbeat."""
with mock.patch("ctms.routers.platform.count_total_contacts", return_value=3):
anon_client.get("/__heartbeat__")

assert registry.get_sample_value("ctms_contacts_total") == 3


def test_api_request(client, minimal_contact, registry):
"""An API request emits API metrics as well."""
email_id = minimal_contact.email.email_id
Expand Down

0 comments on commit bd8b23d

Please sign in to comment.