-
Notifications
You must be signed in to change notification settings - Fork 9
/
main.py
128 lines (103 loc) · 4.13 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
from contextlib import asynccontextmanager
from alembic import command
from alembic.config import Config
from alembic.migration import MigrationContext
from alembic.script import ScriptDirectory
from fastapi import Depends, FastAPI
from scalar_fastapi import get_scalar_api_reference
from sqlalchemy import inspect
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session
from shared.constants import PROJECT_VERSION
from shared.log_config import get_logger
from trustregistry import crud
from trustregistry.database import engine
from trustregistry.db import get_db
from trustregistry.registry import registry_actors, registry_schemas
logger = get_logger(__name__)
OPENAPI_NAME = os.getenv("OPENAPI_NAME", "Trust Registry")
ROOT_PATH = os.getenv("ROOT_PATH", "")
def check_migrations(db_engine: Engine, alembic_cfg: Config) -> bool:
# Check if alembic_version table exists
with db_engine.connect() as connection:
inspector = inspect(connection)
table_names = inspector.get_table_names()
has_alembic_version = "alembic_version" in table_names
has_actors_table = "actors" in table_names
script = ScriptDirectory.from_config(alembic_cfg)
if not has_alembic_version and has_actors_table:
logger.info(
"Alembic version table not found. Stamping with initial revision..."
)
try:
initial_revision = script.get_base()
command.stamp(alembic_cfg, initial_revision)
logger.info(
"Database stamped with initial migration version: {}", initial_revision
)
except Exception: # pylint: disable=W0718
logger.exception("Error stamping database")
raise
elif not has_alembic_version:
logger.info("Alembic version table not found.")
return False
# Get current revision
with db_engine.connect() as connection:
context = MigrationContext.configure(connection)
current_rev = context.get_current_revision()
head_rev = script.get_current_head()
return current_rev == head_rev
@asynccontextmanager
async def lifespan(_: FastAPI):
alembic_cfg = Config("alembic.ini")
if not check_migrations(engine, alembic_cfg):
logger.info("Applying database migrations...")
try:
command.upgrade(alembic_cfg, "head")
logger.info("Database schema is up to date.")
except Exception: # pylint: disable=broad-except
logger.exception("Error during migration")
raise
else:
logger.info("Database is up to date. No migrations needed.")
logger.debug("TrustRegistry startup: Validate tables are created")
with engine.connect() as connection:
inspector = inspect(connection)
table_names = inspector.get_table_names()
logger.debug("TrustRegistry tables created: `{}`", table_names)
# start-up logic is before the yield
yield
# shutdown logic after
def create_app():
application = FastAPI(
root_path=ROOT_PATH,
title=OPENAPI_NAME,
version=PROJECT_VERSION,
description="Welcome to the OpenAPI interface to the Aries CloudAPI trust registry",
lifespan=lifespan,
redoc_url=None,
docs_url=None,
)
application.include_router(registry_actors.router)
application.include_router(registry_schemas.router)
return application
app = create_app()
# Use Scalar instead of Swagger
@app.get("/docs", include_in_schema=False)
async def scalar_html():
return get_scalar_api_reference(
openapi_url=app.openapi_url,
title=app.title,
)
@app.get("/")
async def root(db_session: Session = Depends(get_db)):
logger.debug("GET request received: Fetch actors and schemas from registry")
db_schemas = crud.get_schemas(db_session)
db_actors = crud.get_actors(db_session)
schemas_repr = [schema.id for schema in db_schemas]
logger.debug("Successfully fetched actors and schemas from registry.")
return {"actors": db_actors, "schemas": schemas_repr}
@app.get("/registry")
async def registry(db_session: Session = Depends(get_db)):
return await root(db_session)