Skip to content

Commit

Permalink
Implement roles support in dbapi.connect()
Browse files Browse the repository at this point in the history
  • Loading branch information
mdesmet committed Sep 23, 2022
1 parent 4ff22e8 commit 75e6556
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 4 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,20 @@ cur.execute('SELECT * FROM system.runtime.nodes')
rows = cur.fetchall()
```

## Roles

Authorization roles to use for catalogs, specified as a dict with key-value pairs for the catalog and role. For example, `{"catalog1": "roleA", "catalog2": "roleB"}` sets roleA for catalog1 and roleB for catalog2.

```python
import trino
conn = trino.dbapi.connect(
host='localhost',
port=443,
user='the-user',
roles={"catalog1": "roleA", "catalog2": "roleB"},
)
```

## SSL

### SSL verification
Expand Down
25 changes: 22 additions & 3 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import trino
from tests.integration.conftest import trino_version
from trino import constants
from trino.exceptions import TrinoQueryError, TrinoUserError, NotSupportedError
from trino.transaction import IsolationLevel

Expand Down Expand Up @@ -1045,11 +1046,11 @@ def test_set_role_trino_higher_351(run_trino):
cur = trino_connection.cursor()
cur.execute('SHOW TABLES FROM information_schema')
cur.fetchall()
assert cur._request._client_session.roles is ""
assert cur._request._client_session.roles == ""

cur.execute("SET ROLE ALL")
cur.fetchall()
assert cur._request._client_session.roles == "system=ALL"
assert_role_headers(cur, "system=ALL")


@pytest.mark.skipif(trino_version() != '351', reason="Trino 351 returns the role for the current catalog")
Expand All @@ -1066,4 +1067,22 @@ def test_set_role_trino_351(run_trino):

cur.execute("SET ROLE ALL")
cur.fetchall()
assert cur._request._client_session.roles == "tpch=ALL"
assert_role_headers(cur, "tpch=ALL")


@pytest.mark.skipif(trino_version() == '351', reason="Newer Trino versions return the system role")
def test_set_role_in_connection_trino_higher_351(run_trino):
_, host, port = run_trino

trino_connection = trino.dbapi.Connection(
host=host, port=port, user="test", catalog="tpch", roles={"system": "ALL"}
)
cur = trino_connection.cursor()
cur.execute('SHOW TABLES FROM information_schema')
cur.fetchall()
assert_role_headers(cur, "system=ALL")


def assert_role_headers(cursor, expected_header):
assert cursor._request._client_session.roles == expected_header
assert cursor._request.http_headers[constants.HEADER_ROLE] == expected_header
10 changes: 10 additions & 0 deletions tests/unit/sqlalchemy/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ def setup(self):
experimental_python_types=True,
),
),
(
make_url('trino://user@localhost:8080?roles={"hive":"finance","system":"analyst"}'),
list(),
dict(host="localhost",
port=8080,
catalog="system",
user="user",
roles={"hive": "finance", "system": "analyst"},
source="trino-sqlalchemy"),
),
],
)
def test_create_connect_args(self, url: URL, expected_args: List[Any], expected_kwargs: Dict[str, Any]):
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,15 @@ def test_tags_are_set_when_specified(mock_client):
# THEN
_, passed_client_tags = mock_client.ClientSession.call_args
assert passed_client_tags["client_tags"] == client_tags


@patch("trino.dbapi.trino.client")
def test_role_is_set_when_specified(mock_client):
# WHEN
roles = {"system": "finance"}
with connect("sample_trino_cluster:443", roles=roles) as conn:
conn.cursor().execute("SOME FAKE QUERY")

# THEN
_, passed_role = mock_client.ClientSession.call_args
assert passed_role["roles"] == roles
4 changes: 3 additions & 1 deletion trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
http_session=None,
client_tags=None,
experimental_python_types=False,
roles=None,
):
self.host = host
self.port = port
Expand All @@ -128,7 +129,8 @@ def __init__(
headers=http_headers,
transaction_id=NO_TRANSACTION,
extra_credential=extra_credential,
client_tags=client_tags
client_tags=client_tags,
roles=roles,
)
# mypy cannot follow module import
if http_session is None:
Expand Down

0 comments on commit 75e6556

Please sign in to comment.