Skip to content

Commit

Permalink
api-server: fix postgres label sorting (#957)
Browse files Browse the repository at this point in the history
* fix postgres label sorting

Signed-off-by: Teo Koon Peng <teokoonpeng@gmail.com>

* succesful test with both postgres and sqlite

Signed-off-by: Teo Koon Peng <teokoonpeng@gmail.com>

* fix lint errors

Signed-off-by: Teo Koon Peng <teokoonpeng@gmail.com>

---------

Signed-off-by: Teo Koon Peng <teokoonpeng@gmail.com>
  • Loading branch information
koonpeng committed Jun 24, 2024
1 parent 49a866a commit bb88e1c
Show file tree
Hide file tree
Showing 21 changed files with 394 additions and 326 deletions.
2 changes: 1 addition & 1 deletion Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ isort = "==5.13.2"
pylint = "==3.1.0"
coverage = "~=5.5"
# api-server
api-server = {editable = true, path = "./packages/api-server"}
api-server = {editable = true, path = "./packages/api-server", extras = ["postgres"]}
httpx = "~=0.26.0"
datamodel-code-generator = "==0.25.4"
requests = "~=2.25"
Expand Down
420 changes: 240 additions & 180 deletions Pipfile.lock

Large diffs are not rendered by default.

12 changes: 9 additions & 3 deletions packages/api-server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,18 +252,24 @@ Restart the `api-server` and the changes to the databse should be reflected.
### Running unit tests

```bash
npm test
pnpm test
```

By default in-memory sqlite database is used for testing, to test on another database, set the `RMF_API_SERVER_TEST_DB_URL` environment variable.

```bash
RMF_API_SERVER_TEST_DB_URL=<db_url> pnpm test
```

### Collecting code coverage

```bash
npm run test:cov
pnpm run test:cov
```

Generate coverage report
```bash
npm run test:report
pnpm run test:report
```

## Live reload
Expand Down
6 changes: 5 additions & 1 deletion packages/api-server/api_server/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ def pagination_query(
) -> Pagination:
limit = limit or 100
offset = offset or 0
return Pagination(limit=limit, offset=offset, order_by=order_by)
return Pagination(
limit=limit,
offset=offset,
order_by=order_by.split(",") if order_by else [],
)


# hacky way to get the sio user
Expand Down
4 changes: 1 addition & 3 deletions packages/api-server/api_server/models/pagination.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from typing import Optional

from pydantic import BaseModel


class Pagination(BaseModel):
limit: int
offset: int
order_by: Optional[str]
order_by: list[str]
51 changes: 6 additions & 45 deletions packages/api-server/api_server/query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import tortoise.functions as tfuncs
from tortoise.expressions import Q
from tortoise.queryset import MODEL, QuerySet

from api_server.models.pagination import Pagination
Expand All @@ -8,47 +6,10 @@
def add_pagination(
query: QuerySet[MODEL],
pagination: Pagination,
field_mappings: dict[str, str] | None = None,
group_by: str | None = None,
) -> QuerySet[MODEL]:
"""
Adds pagination and ordering to a query. If the order field starts with `label=`, it is
assumed to be a label and label sorting will used. In this case, the model must have
a reverse relation named "labels" and the `group_by` param is required.
:param field_mapping: A dict mapping the order fields to the fields used to build the
query. e.g. a url of `?order_by=order_field` and a field mapping of `{"order_field": "db_field"}`
will order the query result according to `db_field`.
:param group_by: Required when sorting by labels, must be the foreign key column of the label table.
"""
field_mappings = field_mappings or {}
annotations = {}
query = query.limit(pagination.limit).offset(pagination.offset)
if pagination.order_by is not None:
order_fields = []
order_values = pagination.order_by.split(",")
for v in order_values:
# perform the mapping after stripping the order prefix
order_prefix = ""
order_field = v
if v[0] in ["-", "+"]:
order_prefix = v[0]
order_field = v[1:]
order_field = field_mappings.get(order_field, order_field)

# add annotations required for sorting by labels
if order_field.startswith("label="):
f = order_field[6:]
annotations[f"label_sort_{f}"] = tfuncs.Max(
"labels__label_value",
_filter=Q(labels__label_name=f),
)
order_field = f"label_sort_{f}"

order_fields.append(order_prefix + order_field)

query = query.annotate(**annotations)
if group_by is not None:
query = query.group_by(group_by)
query = query.order_by(*order_fields)
return query
"""Adds pagination and ordering to a query"""
return (
query.limit(pagination.limit)
.offset(pagination.offset)
.order_by(*pagination.order_by)
)
85 changes: 80 additions & 5 deletions packages/api-server/api_server/repositories/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from datetime import datetime
from typing import Dict, List, Optional, Sequence, Tuple

import tortoise.functions as tfuncs
from fastapi import Depends, HTTPException
from tortoise.exceptions import FieldError, IntegrityError
from tortoise.expressions import Expression, Q
from tortoise.query_utils import Prefetch
from tortoise.queryset import QuerySet
from tortoise.transactions import in_transaction

from api_server.authenticator import user_dep
Expand All @@ -18,14 +19,14 @@
TaskEventLog,
TaskRequest,
TaskState,
TaskStatus,
User,
)
from api_server.models import tortoise_models as ttm
from api_server.models.rmf_api.log_entry import Tier
from api_server.models.rmf_api.task_state import Category, Id, Phase
from api_server.models.tortoise_models import TaskRequest as DbTaskRequest
from api_server.models.tortoise_models import TaskState as DbTaskState
from api_server.query import add_pagination
from api_server.rmf_io import task_events


Expand Down Expand Up @@ -96,11 +97,85 @@ async def save_task_state(self, task_state: TaskState) -> None:
await self.save_task_labels(db_task_state, labels)

async def query_task_states(
self, query: QuerySet[DbTaskState], pagination: Optional[Pagination] = None
self,
task_id: list[str] | None = None,
category: list[str] | None = None,
assigned_to: list[str] | None = None,
start_time_between: tuple[datetime, datetime] | None = None,
finish_time_between: tuple[datetime, datetime] | None = None,
status: list[str] | None = None,
label: Labels | None = None,
pagination: Optional[Pagination] = None,
) -> List[TaskState]:
filters = {}
if task_id is not None:
filters["id___in"] = task_id
if category is not None:
filters["category__in"] = category
if assigned_to is not None:
filters["assigned_to__in"] = assigned_to
if start_time_between is not None:
filters["unix_millis_start_time__gte"] = start_time_between[0]
filters["unix_millis_start_time__lte"] = start_time_between[1]
if finish_time_between is not None:
filters["unix_millis_finish_time__gte"] = finish_time_between[0]
filters["unix_millis_finish_time__lte"] = finish_time_between[1]
if status is not None:
valid_values = [member.value for member in TaskStatus]
filters["status__in"] = []
for status_string in status:
if status_string not in valid_values:
continue
filters["status__in"].append(TaskStatus(status_string))
query = DbTaskState.filter(**filters)

need_group_by = False
label_filters = {}
if label is not None:
label_filters.update(
{
f"label_filter_{k}": tfuncs.Count(
"id_",
_filter=Q(labels__label_name=k, labels__label_value=v),
)
for k, v in label.root.items()
}
)

if len(label_filters) > 0:
filter_gt = {f"{f}__gt": 0 for f in label_filters}
query = query.annotate(**label_filters).filter(**filter_gt)
need_group_by = True

if pagination:
order_fields: list[str] = []
annotations: dict[str, Expression] = {}
# add annotations required for sorting by labels
for f in pagination.order_by:
order_prefix = f[0] if f[0] == "-" else ""
order_field = f[1:] if order_prefix == "-" else f
if order_field.startswith("label="):
f = order_field[6:]
annotations[f"label_sort_{f}"] = tfuncs.Max(
"labels__label_value",
_filter=Q(labels__label_name=f),
)
order_field = f"label_sort_{f}"

order_fields.append(order_prefix + order_field)

query = (
query.annotate(**annotations)
.limit(pagination.limit)
.offset(pagination.offset)
.order_by(*order_fields)
)
need_group_by = True

if need_group_by:
query = query.group_by("id_", "labels__state_id")

try:
if pagination:
query = add_pagination(query, pagination, group_by="labels__state_id")
# TODO: enforce with authz
results = await query.values_list("data")
return [TaskState(**r[0]) for r in results]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ async def get_scheduled_tasks(
.offset(pagination.offset)
)
if pagination.order_by:
q.order_by(*pagination.order_by.split(","))
q.order_by(*pagination.order_by)
results = await q
await ttm.ScheduledTask.fetch_for_list(results)
return [ScheduledTask.model_validate(x) for x in results]
Expand Down
58 changes: 10 additions & 48 deletions packages/api-server/api_server/routes/tasks/tasks.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from datetime import datetime
from typing import List, Optional, Tuple, cast

import tortoise.functions as tfuncs
from fastapi import Body, Depends, HTTPException, Path, Query
from reactivex import operators as rxops
from tortoise.expressions import Q

from api_server import models as mdl
from api_server.dependencies import (
Expand All @@ -15,7 +13,6 @@
start_time_between_query,
)
from api_server.fast_io import FastIORouter, SubscriptionRequest
from api_server.models.tortoise_models import TaskState as DbTaskState
from api_server.repositories import TaskRepository, task_repo_dep
from api_server.response import RawJSONResponse
from api_server.rmf_io import task_events, tasks_service
Expand Down Expand Up @@ -60,51 +57,16 @@ async def query_task_states(
),
pagination: mdl.Pagination = Depends(pagination_query),
):
filters = {}
if task_id is not None:
filters["id___in"] = task_id.split(",")
if category is not None:
filters["category__in"] = category.split(",")
if assigned_to is not None:
filters["assigned_to__in"] = assigned_to.split(",")
if start_time_between is not None:
filters["unix_millis_start_time__gte"] = start_time_between[0]
filters["unix_millis_start_time__lte"] = start_time_between[1]
if finish_time_between is not None:
filters["unix_millis_finish_time__gte"] = finish_time_between[0]
filters["unix_millis_finish_time__lte"] = finish_time_between[1]
if status is not None:
valid_values = [member.value for member in mdl.TaskStatus]
filters["status__in"] = []
for status_string in status.split(","):
if status_string not in valid_values:
continue
filters["status__in"].append(mdl.TaskStatus(status_string))
query = DbTaskState.filter(**filters)

label_filters = {}
if label is not None:
labels = mdl.Labels.from_strings(label.split(","))
label_filters.update(
{
f"label_filter_{k}": tfuncs.Count(
"id_", _filter=Q(labels__label_name=k, labels__label_value=v)
)
for k, v in labels.root.items()
}
)

if len(label_filters) > 0:
filter_gt = {f"{f}__gt": 0 for f in label_filters}
query = (
query.annotate(**label_filters)
.group_by(
"labels__state_id"
) # need to group by a related field to make tortoise-orm generate joins
.filter(**filter_gt)
)

return await task_repo.query_task_states(query, pagination)
return await task_repo.query_task_states(
task_id=task_id.split(",") if task_id else None,
category=category.split(",") if category else None,
assigned_to=assigned_to.split(",") if assigned_to else None,
start_time_between=start_time_between,
finish_time_between=finish_time_between,
status=status.split(",") if status else None,
label=mdl.Labels.from_strings(label.split(",")) if label else None,
pagination=pagination,
)


@router.get("/{task_id}/state", response_model=mdl.TaskState)
Expand Down
9 changes: 3 additions & 6 deletions packages/api-server/api_server/routes/tasks/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,12 @@ def setUpClass(cls):
cls.task_logs = [make_task_log(task_id=f"test_{x}") for x in task_ids]
cls.clsSetupErr: str | None = None

if cls.client.portal is None:
cls.clsSetupErr = "missing client portal, is the client context entered?"
return

portal = cls.get_portal()
repo = TaskRepository(cls.admin_user)
for x in cls.task_states:
cls.client.portal.call(repo.save_task_state, x)
portal.call(repo.save_task_state, x)
for x in cls.task_logs:
cls.client.portal.call(repo.save_task_log, x)
portal.call(repo.save_task_log, x)

def setUp(self):
super().setUp()
Expand Down
4 changes: 2 additions & 2 deletions packages/api-server/api_server/routes/test_building_map.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from api_server.rmf_io import rmf_events
from api_server.test import AppFixture, make_building_map, try_until


class TestBuildingMapRoute(AppFixture):
def test_get_building_map(self):
building_map = make_building_map()
rmf_events.building_map.on_next(building_map)
portal = self.get_portal()
portal.call(building_map.save)

resp = try_until(
lambda: self.client.get("/building_map"), lambda x: x.status_code == 200
Expand Down
4 changes: 2 additions & 2 deletions packages/api-server/api_server/routes/test_dispensers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
from typing import List
from uuid import uuid4

Expand All @@ -12,8 +11,9 @@ def setUpClass(cls):
super().setUpClass()
cls.dispenser_states = [make_dispenser_state(f"test_{uuid4()}")]

portal = cls.get_portal()
for x in cls.dispenser_states:
asyncio.run(x.save())
portal.call(x.save)

def test_get_dispensers(self):
resp = self.client.get("/dispensers")
Expand Down
6 changes: 3 additions & 3 deletions packages/api-server/api_server/routes/test_doors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
from uuid import uuid4

from rmf_door_msgs.msg import DoorMode as RmfDoorMode
Expand All @@ -12,12 +11,13 @@ class TestDoorsRoute(AppFixture):
def setUpClass(cls):
super().setUpClass()
cls.building_map = make_building_map()
asyncio.run(cls.building_map.save())
portal = cls.get_portal()
portal.call(cls.building_map.save)

cls.door_states = [make_door_state(f"test_{uuid4()}")]

for x in cls.door_states:
asyncio.run(x.save())
portal.call(x.save)

def test_get_doors(self):
resp = self.client.get("/doors")
Expand Down
Loading

0 comments on commit bb88e1c

Please sign in to comment.