Skip to content

Commit

Permalink
Add Cursor.describe to retrieve the schema of a query
Browse files Browse the repository at this point in the history
  • Loading branch information
mdesmet authored and hashhar committed Dec 22, 2022
1 parent d96bff2 commit 8affa61
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 1 deletion.
67 changes: 67 additions & 0 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import trino
from tests.integration.conftest import trino_version
from trino import constants
from trino.dbapi import DescribeOutput
from trino.exceptions import NotSupportedError, TrinoQueryError, TrinoUserError
from trino.transaction import IsolationLevel

Expand Down Expand Up @@ -1155,3 +1156,69 @@ def test_connection_without_timezone(run_trino):
assert session_tz == localzone or \
(session_tz == "UTC" and localzone == "Etc/UTC") \
# Workaround for difference between Trino timezone and tzlocal for UTC


def test_describe(run_trino):
_, host, port = run_trino

trino_connection = trino.dbapi.Connection(
host=host, port=port, user="test", catalog="tpch",
)
cur = trino_connection.cursor()

result = cur.describe("SELECT 1, DECIMAL '1.0' as a")

assert result == [
DescribeOutput(name='_col0', catalog='', schema='', table='', type='integer', type_size=4, aliased=False),
DescribeOutput(name='a', catalog='', schema='', table='', type='decimal(2,1)', type_size=8, aliased=True)
]


def test_describe_table_query(run_trino):
_, host, port = run_trino

trino_connection = trino.dbapi.Connection(
host=host, port=port, user="test", catalog="tpch",
)
cur = trino_connection.cursor()

result = cur.describe("SELECT * from tpch.tiny.nation")

assert result == [
DescribeOutput(
name='nationkey',
catalog='tpch',
schema='tiny',
table='nation',
type='bigint',
type_size=8,
aliased=False,
),
DescribeOutput(
name='name',
catalog='tpch',
schema='tiny',
table='nation',
type='varchar(25)',
type_size=0,
aliased=False,
),
DescribeOutput(
name='regionkey',
catalog='tpch',
schema='tiny',
table='nation',
type='bigint',
type_size=8,
aliased=False,
),
DescribeOutput(
name='comment',
catalog='tpch',
schema='tiny',
table='nation',
type='varchar(152)',
type_size=0,
aliased=False,
)
]
38 changes: 37 additions & 1 deletion trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import math
import uuid
from decimal import Decimal
from typing import Any, Dict, List, Optional # NOQA for mypy types
from typing import Any, Dict, List, NamedTuple, Optional # NOQA for mypy types

import trino.client
import trino.exceptions
Expand Down Expand Up @@ -223,6 +223,20 @@ def cursor(self, experimental_python_types: bool = None):
)


class DescribeOutput(NamedTuple):
name: str
catalog: str
schema: str
table: str
type: str
type_size: int
aliased: bool

@classmethod
def from_row(cls, row: List[Any]):
return cls(*row)


class Cursor(object):
"""Database cursor.
Expand Down Expand Up @@ -523,6 +537,28 @@ def fetchmany(self, size=None) -> List[List[Any]]:

return result

def describe(self, sql: str) -> List[DescribeOutput]:
"""
List the output columns of a SQL statement, including the column name (or alias), catalog, schema, table, type,
type size in bytes, and a boolean indicating if the column is aliased.
:param sql: SQL statement
"""
statement_name = self._generate_unique_statement_name()
self._prepare_statement(sql, statement_name)
try:
sql = f"DESCRIBE OUTPUT {statement_name}"
self._query = trino.client.TrinoQuery(
self._request,
sql=sql,
experimental_python_types=self._experimental_pyton_types,
)
result = self._query.execute()
finally:
self._deallocate_prepared_statement(statement_name)

return list(map(lambda x: DescribeOutput.from_row(x), result))

def genall(self):
return self._query.result

Expand Down

0 comments on commit 8affa61

Please sign in to comment.