Skip to content

Commit

Permalink
Implement roles support in dbapi.connect()
Browse files Browse the repository at this point in the history
  • Loading branch information
mdesmet committed Oct 3, 2022
1 parent 523283d commit de433c7
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 1 deletion.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,20 @@ cur.execute('SELECT * FROM system.runtime.nodes')
rows = cur.fetchall()
```

## Roles

Authorization roles to use for catalogs, specified as a dict with key-value pairs for the catalog and role. For example, `{"catalog1": "roleA", "catalog2": "roleB"}` sets `roleA` for `catalog1` and `roleB` for `catalog2`. See Trino docs.

```python
import trino
conn = trino.dbapi.connect(
host='localhost',
port=443,
user='the-user',
roles={"catalog1": "roleA", "catalog2": "roleB"},
)
```

## SSL

### SSL verification
Expand Down
13 changes: 13 additions & 0 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,19 @@ def test_set_role_trino_351(run_trino):
assert_role_headers(cur, "tpch=ALL")


@pytest.mark.skipif(trino_version() == '351', reason="Newer Trino versions return the system role")
def test_set_role_in_connection_trino_higher_351(run_trino):
_, host, port = run_trino

trino_connection = trino.dbapi.Connection(
host=host, port=port, user="test", catalog="tpch", roles={"system": "ALL"}
)
cur = trino_connection.cursor()
cur.execute('SHOW TABLES FROM information_schema')
cur.fetchall()
assert_role_headers(cur, "system=ALL")


def assert_role_headers(cursor, expected_header):
assert cursor._request.http_headers[constants.HEADER_ROLE] == expected_header

Expand Down
10 changes: 10 additions & 0 deletions tests/unit/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,13 @@ def test_tags_are_set_when_specified(mock_client):

_, passed_client_tags = mock_client.ClientSession.call_args
assert passed_client_tags["client_tags"] == client_tags


@patch("trino.dbapi.trino.client")
def test_role_is_set_when_specified(mock_client):
roles = {"system": "finance"}
with connect("sample_trino_cluster:443", roles=roles) as conn:
conn.cursor().execute("SOME FAKE QUERY")

_, passed_role = mock_client.ClientSession.call_args
assert passed_role["roles"] == roles
4 changes: 3 additions & 1 deletion trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(
http_session=None,
client_tags=None,
experimental_python_types=False,
roles=None,
):
self.host = host
self.port = port
Expand All @@ -127,7 +128,8 @@ def __init__(
headers=http_headers,
transaction_id=NO_TRANSACTION,
extra_credential=extra_credential,
client_tags=client_tags
client_tags=client_tags,
roles=roles,
)
# mypy cannot follow module import
if http_session is None:
Expand Down

0 comments on commit de433c7

Please sign in to comment.