Skip to content

Commit

Permalink
feat: Add dummy_kernels table for testing sql_json_merge
Browse files Browse the repository at this point in the history
  • Loading branch information
jopemachine committed Dec 9, 2024
1 parent 18906ff commit ee4f67c
Showing 1 changed file with 248 additions and 2 deletions.
250 changes: 248 additions & 2 deletions tests/manager/models/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from datetime import datetime

import pytest
import sqlalchemy as sa
from dateutil.tz import tzutc
from sqlalchemy.dialects import postgresql as pgsql

from ai.backend.manager.models import KernelRow, SessionRow, kernels
from ai.backend.manager.models.utils import agg_to_array, agg_to_str
from ai.backend.manager.models import KernelRow, SessionRow, kernels, metadata
from ai.backend.manager.models.utils import (
ExtendedAsyncSAEngine,
agg_to_array,
agg_to_str,
sql_json_merge,
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -133,3 +142,240 @@ async def test_agg_to_array(session_info) -> None:
(kernels.c.tag == test_data2) & (kernels.c.session_id == session_id)
)
)


@pytest.fixture
async def dummy_kernels(database_engine: ExtendedAsyncSAEngine):
# dummy_kernels, designed solely for testing sql_json_merge, only includes the status_history column, unlike legacy kernels table.
dummy_kernels = sa.Table(
"dummy_kernels",
metadata,
sa.Column("id", sa.Integer(), primary_key=True, default=1),
sa.Column(
"status_history", pgsql.JSONB(), nullable=True, default=sa.null()
), # JSONB column for testing
extend_existing=True,
)

async with database_engine.begin() as db_sess:
await db_sess.run_sync(metadata.create_all)
await db_sess.execute(dummy_kernels.insert()) # insert fixture data for testing
await db_sess.commit()

yield dummy_kernels

async with database_engine.begin() as db_sess:
await db_sess.run_sync(dummy_kernels.drop)
await db_sess.commit()


@pytest.mark.asyncio
async def test_sql_json_merge__deeper_object(
dummy_kernels: sa.Table, database_engine: ExtendedAsyncSAEngine
):
async with database_engine.begin() as db_sess:
timestamp = datetime.now(tzutc()).isoformat()
expected = {
"kernel": {
"session": {
"PENDING": timestamp,
"PREPARING": timestamp,
},
},
}
query = (
dummy_kernels.update()
.values({
"status_history": sql_json_merge(
dummy_kernels.c.status_history,
("kernel", "session"),
{
"PENDING": timestamp,
"PREPARING": timestamp,
},
),
})
.where(dummy_kernels.c.id == 1)
)
await db_sess.execute(query)
result = (await db_sess.execute(sa.select(dummy_kernels.c.status_history))).scalar()
assert result == expected


@pytest.mark.asyncio
async def test_sql_json_merge__append_values(
dummy_kernels: sa.Table, database_engine: ExtendedAsyncSAEngine
):
async with database_engine.begin() as db_sess:
timestamp = datetime.now(tzutc()).isoformat()
expected = {
"kernel": {
"session": {
"PENDING": timestamp,
"PREPARING": timestamp,
"TERMINATED": timestamp,
"TERMINATING": timestamp,
},
},
}
query = (
dummy_kernels.update()
.values({
"status_history": sql_json_merge(
dummy_kernels.c.status_history,
("kernel", "session"),
{
"PENDING": timestamp,
"PREPARING": timestamp,
},
),
})
.where(dummy_kernels.c.id == 1)
)
await db_sess.execute(query)
query = (
dummy_kernels.update()
.values({
"status_history": sql_json_merge(
dummy_kernels.c.status_history,
("kernel", "session"),
{
"TERMINATING": timestamp,
"TERMINATED": timestamp,
},
),
})
.where(dummy_kernels.c.id == 1)
)
await db_sess.execute(query)

result = (await db_sess.execute(sa.select(dummy_kernels.c.status_history))).scalar()
assert result == expected


@pytest.mark.asyncio
async def test_sql_json_merge__kernel_status_history(
dummy_kernels: sa.Table, database_engine: ExtendedAsyncSAEngine
):
async with database_engine.begin() as db_sess:
timestamp = datetime.now(tzutc()).isoformat()
expected = {
"PENDING": timestamp,
"PREPARING": timestamp,
"TERMINATING": timestamp,
"TERMINATED": timestamp,
}
query = (
dummy_kernels.update()
.values({
"status_history": sql_json_merge(
dummy_kernels.c.status_history,
(),
{
"PENDING": timestamp,
"PREPARING": timestamp,
},
),
})
.where(dummy_kernels.c.id == 1)
)
await db_sess.execute(query)
query = (
dummy_kernels.update()
.values({
"status_history": sql_json_merge(
dummy_kernels.c.status_history,
(),
{
"TERMINATING": timestamp,
"TERMINATED": timestamp,
},
),
})
.where(dummy_kernels.c.id == 1)
)
await db_sess.execute(query)

result = (await db_sess.execute(sa.select(dummy_kernels.c.status_history))).scalar()
assert result == expected


@pytest.mark.asyncio
async def test_sql_json_merge__mixed_formats(
dummy_kernels: sa.Table, database_engine: ExtendedAsyncSAEngine
):
async with database_engine.begin() as db_sess:
timestamp = datetime.now(tzutc()).isoformat()
expected = {
"PENDING": timestamp,
"kernel": {
"PREPARING": timestamp,
},
}
query = (
dummy_kernels.update()
.values({
"status_history": sql_json_merge(
dummy_kernels.c.status_history,
(),
{
"PENDING": timestamp,
},
),
})
.where(dummy_kernels.c.id == 1)
)
await db_sess.execute(query)

query = (
dummy_kernels.update()
.values({
"status_history": sql_json_merge(
dummy_kernels.c.status_history,
("kernel",),
{
"PREPARING": timestamp,
},
),
})
.where(dummy_kernels.c.id == 1)
)
await db_sess.execute(query)

result = (await db_sess.execute(sa.select(dummy_kernels.c.status_history))).scalar()
assert result == expected


@pytest.mark.asyncio
async def test_sql_json_merge__json_serializable_types(
dummy_kernels: sa.Table, database_engine: ExtendedAsyncSAEngine
):
async with database_engine.begin() as db_sess:
expected = {
"boolean": True,
"integer": 10101010,
"float": 1010.1010,
"string": "10101010",
# "bytes": b"10101010",
"list": [
10101010,
"10101010",
],
"dict": {
"10101010": 10101010,
},
}
query = (
dummy_kernels.update()
.values({
"status_history": sql_json_merge(
dummy_kernels.c.status_history,
(),
expected,
),
})
.where(dummy_kernels.c.id == 1)
)
await db_sess.execute(query)
result = (await db_sess.execute(sa.select(dummy_kernels.c.status_history))).scalar()
assert result == expected

0 comments on commit ee4f67c

Please sign in to comment.