Skip to content

Commit

Permalink
/data/ endpoint for multiple metrics and dimensions (DataJunction#486)
Browse files Browse the repository at this point in the history
  • Loading branch information
samredai authored May 3, 2023
1 parent 525a195 commit 1f11c21
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 2 deletions.
66 changes: 65 additions & 1 deletion dj/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from fastapi.responses import JSONResponse
from sqlmodel import Session

from dj.api.helpers import get_engine, get_node_by_name, get_query
from dj.api.helpers import get_engine, get_node_by_name, get_query, validate_cube
from dj.construction.build import build_metric_nodes
from dj.errors import DJException, DJInvalidInputException
from dj.models.metric import TranslatedSQL
from dj.models.node import AvailabilityState, AvailabilityStateBase, NodeType
Expand Down Expand Up @@ -146,3 +147,66 @@ def get_data( # pylint: disable=too-many-locals
if result.results.__root__: # pragma: no cover
result.results.__root__[0].columns = columns
return result


@router.get("/data/", response_model=QueryWithResults)
def get_data_for_metrics( # pylint: disable=R0914
metrics: List[str] = Query([]),
dimensions: List[str] = Query([]),
filters: List[str] = Query([]),
async_: bool = False,
*,
session: Session = Depends(get_session),
query_service_client: QueryServiceClient = Depends(get_query_service_client),
engine_name: Optional[str] = None,
engine_version: Optional[str] = None,
) -> QueryWithResults:
"""
Return data for a set of metrics with dimensions and filters
"""
leading_metric_node = get_node_by_name(session, metrics[0])
available_engines = leading_metric_node.current.catalog.engines
engine = (
get_engine(session, engine_name, engine_version) # type: ignore
if engine_name
else available_engines[0]
)
if engine not in available_engines:
raise DJInvalidInputException( # pragma: no cover
f"The selected engine is not available for the node {metrics[0]}. "
f"Available engines include: {', '.join(engine.name for engine in available_engines)}",
)

_, metric_nodes, _, _ = validate_cube(
session,
metrics,
dimensions,
)
query_ast = build_metric_nodes(
session,
metric_nodes,
filters=filters or [],
dimensions=dimensions or [],
)
columns = [
ColumnMetadata(name=col.alias_or_name.name, type=str(col.type)) # type: ignore
for col in query_ast.select.projection
]
query = TranslatedSQL(
sql=str(query_ast),
columns=columns,
dialect=engine.dialect if engine else None,
)

query_create = QueryCreate(
engine_name=engine.name,
catalog_name=leading_metric_node.current.catalog.name,
engine_version=engine.version,
submitted_query=query.sql,
async_=async_,
)
result = query_service_client.submit_query(query_create)
# Inject column info if there are results
if result.results.__root__: # pragma: no cover
result.results.__root__[0].columns = columns
return result
57 changes: 56 additions & 1 deletion tests/api/data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_get_metric_data(
client_with_query_service: TestClient,
) -> None:
"""
Trying to get transform or source data should fail
Test retrieving data for a metric
"""
response = client_with_query_service.get("/data/basic.num_comments/")
data = response.json()
Expand Down Expand Up @@ -200,6 +200,61 @@ def test_get_metric_data(
"errors": [],
}

def test_get_multiple_metrics_and_dimensions_data(
self,
client_with_query_service: TestClient,
) -> None:
"""
Test getting multiple metrics and dimensions
"""
response = client_with_query_service.get(
"/data?metrics=num_repair_orders&metrics="
"avg_repair_price&dimensions=dispatcher.company_name",
)
data = response.json()
assert response.status_code == 200
assert data == {
"engine_name": None,
"engine_version": None,
"errors": [],
"executed_query": None,
"finished": None,
"id": "bd98d6be-e2d2-413e-94c7-96d9411ddee2",
"next": None,
"output_table": None,
"previous": None,
"progress": 0.0,
"results": [
{
"columns": [
{"name": "avg_repair_price", "type": "double"},
{"name": "company_name", "type": "string"},
{"name": "num_repair_orders", "type": "bigint"},
],
"row_count": 0,
"rows": [[1.0, "Foo", 100], [2.0, "Bar", 200]],
"sql": "",
},
],
"scheduled": None,
"started": None,
"state": "FINISHED",
"submitted_query": "SELECT avg(repair_order_details.price) AS "
"avg_repair_price,\\n\\tdispatcher.company_name,\\n\\t"
"count(repair_orders.repair_order_id) "
"AS num_repair_orders \\n FROM roads.repair_order_details "
"AS repair_order_details LEFT OUTER JOIN (SELECT "
"repair_orders.dispatcher_id,\\n\\trepair_orders.hard_hat_id,"
"\\n\\trepair_orders.municipality_id,\\n\\trepair_orders.repair_order_id "
"\\n FROM roads.repair_orders AS repair_orders) AS "
"repair_order ON repair_order_details.repair_order_id = "
"repair_order.repair_order_id\\nLEFT OUTER JOIN (SELECT "
"dispatchers.company_name,\\n\\tdispatchers.dispatcher_id "
"\\n FROM roads.dispatchers AS dispatchers) AS dispatcher "
"ON repair_order.dispatcher_id = dispatcher.dispatcher_id "
"\\n GROUP BY dispatcher.company_name\\n",
}


class TestAvailabilityState: # pylint: disable=too-many-public-methods
"""
Expand Down
56 changes: 56 additions & 0 deletions tests/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1689,6 +1689,62 @@
}

QUERY_DATA_MAPPINGS = {
(
"SELECT avg(repair_order_details.price) AS "
"avg_repair_price,\n\tdispatcher.company_name,"
"\n\tcount(repair_orders.repair_order_id) AS "
"num_repair_orders \n FROM roads.repair_order_details "
"AS repair_order_details LEFT OUTER JOIN (SELECT "
"repair_orders.dispatcher_id,\n\t"
"repair_orders.hard_hat_id,\n\trepair_orders.municipality_id"
",\n\trepair_orders.repair_order_id \n FROM "
"roads.repair_orders AS repair_orders) AS repair_order "
"ON repair_order_details.repair_order_id = "
"repair_order.repair_order_id\nLEFT OUTER JOIN (SELECT "
"dispatchers.company_name,\n\tdispatchers.dispatcher_id "
"\n FROM roads.dispatchers AS dispatchers) AS dispatcher "
"ON repair_order.dispatcher_id = dispatcher.dispatcher_id "
"\n GROUP BY dispatcher.company_name"
)
.strip()
.replace('"', "")
.replace("\n", "")
.replace(" ", ""): QueryWithResults(
**{
"id": uuid.UUID("bd98d6be-e2d2-413e-94c7-96d9411ddee2"),
"submitted_query": (
"SELECT avg(repair_order_details.price) AS "
"avg_repair_price,\\n\\tdispatcher.company_name,"
"\\n\\tcount(repair_orders.repair_order_id) "
"AS num_repair_orders \\n FROM roads.repair_order_details AS "
"repair_order_details LEFT OUTER JOIN (SELECT "
"repair_orders.dispatcher_id,\\n\\trepair_orders.hard_hat_id,\\n\\t"
"repair_orders.municipality_id,\\n\\trepair_orders.repair_order_id "
"\\n FROM roads.repair_orders AS repair_orders) AS repair_order ON "
"repair_order_details.repair_order_id = repair_order.repair_order_id\\nLEFT "
"OUTER JOIN (SELECT dispatchers.company_name,\\n\\tdispatchers.dispatcher_id "
"\\n FROM roads.dispatchers AS dispatchers) AS dispatcher ON "
"repair_order.dispatcher_id = dispatcher.dispatcher_id \\n GROUP BY "
"dispatcher.company_name\\n"
),
"state": QueryState.FINISHED,
"results": [
{
"columns": [
{"name": "avg_repair_price", "type": "float"},
{"name": "company_name", "type": "str"},
{"name": "num_repair_orders", "type": "int"},
],
"rows": [
(1.0, "Foo", 100),
(2.0, "Bar", 200),
],
"sql": "",
},
],
"errors": [],
}
),
(
"SELECT payment_type_table.id,\n\tpayment_type_table.payment_type_classification,\n\t"
"payment_type_table.payment_type_name \n FROM accounting.payment_type_table AS "
Expand Down

0 comments on commit 1f11c21

Please sign in to comment.