From bd8b23d99efb6e4bb0523a66e7efd415236e2b44 Mon Sep 17 00:00:00 2001 From: Mathieu Leplatre Date: Tue, 18 Jul 2023 11:57:23 +0200 Subject: [PATCH] Fix #642: count total number of contacts (#759) * Fix #642: count total number of contacts * Rewrite using a count estimate * Lint * Add comment about count * Do not rely on vacuum/anaylyze in metrics tests --- ctms/crud.py | 15 +++++++++++++++ ctms/metrics.py | 15 ++++++++++++--- ctms/routers/platform.py | 10 +++++++++- tests/unit/test_crud.py | 34 ++++++++++++++++++++++++++++++++++ tests/unit/test_metrics.py | 10 ++++++++++ 5 files changed, 80 insertions(+), 4 deletions(-) diff --git a/ctms/crud.py b/ctms/crud.py index 22a27ffb..88504703 100644 --- a/ctms/crud.py +++ b/ctms/crud.py @@ -70,6 +70,21 @@ def ping(db: Session): return False +def count_total_contacts(db: Session): + """Return the total number of email records. + Since the table is huge, we rely on the PostgreSQL internal + catalog to retrieve an approximate size efficiently. + This metadata is refreshed on `VACUUM` or `ANALYSIS` which + is run regularly by default on our database instances. + """ + query = text( + "SELECT reltuples AS estimate " + "FROM pg_class " + f"where relname = '{Email.__tablename__}'" + ) + return int(db.execute(query).first()["estimate"]) + + def get_amo_by_email_id(db: Session, email_id: UUID4): return db.query(AmoAccount).filter(AmoAccount.email_id == email_id).one_or_none() diff --git a/ctms/metrics.py b/ctms/metrics.py index 109418b6..187c51e5 100644 --- a/ctms/metrics.py +++ b/ctms/metrics.py @@ -5,7 +5,7 @@ from fastapi import FastAPI from fastapi.security import HTTPBasic -from prometheus_client import CollectorRegistry, Counter, Histogram +from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram from prometheus_client.utils import INF from sqlalchemy.orm import Session from starlette.routing import Route @@ -13,7 +13,9 @@ from ctms.auth import OAuth2ClientCredentials from ctms.crud import get_active_api_client_ids -METRICS_PARAMS: dict[str, tuple[Type[Counter] | Type[Histogram], dict]] = { +METRICS_PARAMS: dict[ + str, tuple[Type[Counter] | Type[Histogram] | type[Gauge], dict] +] = { "requests": ( Counter, { @@ -63,6 +65,13 @@ "documentation": "Total count of API calls that use the legacy waitlists format", }, ), + "contacts": ( + Gauge, + { + "name": "ctms_contacts_total", + "documentation": "Total count of contacts in the database", + }, + ), } # We could use the default prometheus_client.REGISTRY, but it makes tests @@ -84,7 +93,7 @@ def set_metrics(metrics: Any) -> None: token_scheme = HTTPBasic(auto_error=False) -def init_metrics(registry: CollectorRegistry) -> dict[str, Counter | Histogram]: +def init_metrics(registry: CollectorRegistry) -> dict[str, Counter | Histogram | Gauge]: """Initialize the metrics with the registry.""" metrics = {} for name, init_bits in METRICS_PARAMS.items(): diff --git a/ctms/routers/platform.py b/ctms/routers/platform.py index ea9ea91b..7d63f047 100644 --- a/ctms/routers/platform.py +++ b/ctms/routers/platform.py @@ -18,6 +18,7 @@ ) from ctms.config import Settings, get_version from ctms.crud import ( + count_total_contacts, get_all_acoustic_fields, get_all_acoustic_newsletters_mapping, get_api_client_by_id, @@ -29,7 +30,7 @@ get_settings, get_token_settings, ) -from ctms.metrics import get_metrics_registry, token_scheme +from ctms.metrics import get_metrics, get_metrics_registry, token_scheme from ctms.schemas.api_client import ApiClientSchema from ctms.schemas.web import BadRequestResponse, TokenResponse @@ -150,6 +151,13 @@ def heartbeat( **details, } + if alive and (appmetrics := get_metrics()): + # Report number of contacts in the database. + # Sending the metric in this heartbeat endpoint is simpler than reporting + # it in every write endpoint. Plus, performance does not matter much here. + total_contacts = count_total_contacts(db) + appmetrics["contacts"].set(total_contacts) + status_code = 200 if (alive and acoustic_success) else 503 return JSONResponse(content=result, status_code=status_code) diff --git a/tests/unit/test_crud.py b/tests/unit/test_crud.py index dd2c1bc4..28c563c9 100644 --- a/tests/unit/test_crud.py +++ b/tests/unit/test_crud.py @@ -8,6 +8,7 @@ from sqlalchemy.orm import Session from ctms.crud import ( + count_total_contacts, create_acoustic_field, create_acoustic_newsletters_mapping, create_amo, @@ -31,6 +32,7 @@ retry_acoustic_record, schedule_acoustic_record, ) +from ctms.database import ScopedSessionLocal from ctms.models import ( AcousticField, AcousticNewsletterMapping, @@ -51,6 +53,38 @@ pytestmark = pytest.mark.filterwarnings("error::sqlalchemy.exc.SAWarning") +def test_email_count(connection, email_factory): + # The default `dbsession` fixture will run in a nested transaction + # that is rollback. + # In this test, we manipulate raw connections and transactions because + # we need to force a VACUUM operation outside a running transaction. + + # Insert contacts in the table. + transaction = connection.begin() + session = ScopedSessionLocal() + email_factory.create_batch(3) + session.commit() + session.close() + transaction.commit() + + # Force an analysis of the table. + old_isolation_level = connection.connection.isolation_level + connection.connection.set_isolation_level(0) + session.execute(sqlalchemy.text(f"VACUUM ANALYZE {Email.__tablename__}")) + session.close() + connection.connection.set_isolation_level(old_isolation_level) + + # Query the count result (since last analyze) + session = ScopedSessionLocal() + count = count_total_contacts(session) + assert count == 3 + + # Delete created objects (since our transaction was not rollback automatically) + session.query(Email).delete() + session.commit() + session.close() + + def test_get_email(dbsession, email_factory): email = email_factory() dbsession.commit() diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py index e74f0313..6a362ee2 100644 --- a/tests/unit/test_metrics.py +++ b/tests/unit/test_metrics.py @@ -1,4 +1,6 @@ # Test for metrics +from unittest import mock + import pytest from prometheus_client import CollectorRegistry, generate_latest from prometheus_client.parser import text_string_to_metric_families @@ -165,6 +167,14 @@ def test_homepage_request(anon_client, registry): assert_duration_metric_obs(registry, "GET", "/docs", "2xx") +def test_contacts_total(anon_client, dbsession, registry): + """Total number of contacts is reported in heartbeat.""" + with mock.patch("ctms.routers.platform.count_total_contacts", return_value=3): + anon_client.get("/__heartbeat__") + + assert registry.get_sample_value("ctms_contacts_total") == 3 + + def test_api_request(client, minimal_contact, registry): """An API request emits API metrics as well.""" email_id = minimal_contact.email.email_id