-
Notifications
You must be signed in to change notification settings - Fork 10
/
test_db.py
127 lines (104 loc) · 3.43 KB
/
test_db.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from async_asgi_testclient import TestClient
from fastapi import Depends
from fastapi import FastAPI
from fastapi.exceptions import HTTPException
from fastapi_asyncpg import configure_asyncpg
from fastapi_asyncpg import sql
from pytest_docker_fixtures import images
from typing import Optional
import pydantic as pd
import pytest
import asyncio
images.configure(
"postgresql", "postgres", "11.1", env={"POSTGRES_DB": "test_db"}
)
pytestmark = pytest.mark.asyncio
class KeyVal(pd.BaseModel):
key: str
value: str
SCHEMA = """
DROP TABLE IF EXISTS keyval;
CREATE TABLE keyval (
key varchar,
value varchar,
UNIQUE(key)
);
"""
@pytest.fixture(scope="function")
async def asgiapp(pg):
host, port = pg
url = f"postgresql://postgres@{host}:{port}/test_db"
app = FastAPI()
bdd = configure_asyncpg(app, url, min_size=1, max_size=2)
@bdd.on_init
async def on_init(conn):
await conn.execute(SCHEMA)
@app.post("/", response_model=KeyVal)
async def add_resource(data: KeyVal, db=Depends(bdd.connection)):
result = await db.fetchrow(
"""
INSERT into keyval values ($1, $2) returning *
""",
data.key,
data.value,
)
return dict(result)
@app.get("/transaction")
async def with_transaction(
q: Optional[int] = 0, db=Depends(bdd.transaction)
):
for i in range(10):
await db.execute(
"INSERT INTO keyval values ($1, $2)", f"t{i}", f"t{i}"
)
if q == 1:
raise HTTPException(412)
return dict(result="ok")
@app.get("/{key:str}", response_model=KeyVal)
async def get_resouce(key: str, db=Depends(bdd.connection)):
result = await db.fetchrow(
"""
SELECT * from keyval where key=$1
""",
key,
)
if result:
return dict(result)
yield app, bdd
async def test_dependency(asgiapp):
app, db = asgiapp
async with TestClient(app) as client:
res = await client.post("/", json={"key": "test", "value": "val1"})
assert res.status_code == 200
res = await client.get("/test")
assert res.status_code == 200
assert res.json()["key"] == "test"
assert res.json()["value"] == "val1"
async def test_transaction(asgiapp):
app, _ = asgiapp
async with TestClient(app) as client:
res = await client.get("/transaction")
assert res.status_code == 200
async with app.state.pool.acquire() as db:
await sql.count(db, "keyval") == 10
async def test_transaction_fails(asgiapp):
app, _ = asgiapp
async with TestClient(app) as client:
res = await client.get("/transaction?q=1")
assert res.status_code == 412
async with app.state.pool.acquire() as db:
await sql.count(db, "keyval") == 0
async def test_pool_releases_connections(asgiapp):
app, db = asgiapp
async with TestClient(app) as client:
res = await client.post("/", json={"key": "test", "value": "val1"})
assert res.status_code == 200
tasks = []
for i in range(20):
tasks.append(client.get("/test"))
await asyncio.gather(*tasks)
async with app.state.pool.acquire() as db:
result = await db.fetchval(
"SELECT sum(numbackends) FROM pg_stat_database;"
)
assert result == 2