From 75e65560ff8afcdcff1ee9f2c864c2d52af607b3 Mon Sep 17 00:00:00 2001 From: Michiel De Smet Date: Fri, 16 Sep 2022 15:22:24 +0200 Subject: [PATCH] Implement roles support in `dbapi.connect()` --- README.md | 14 ++++++++++++ tests/integration/test_dbapi_integration.py | 25 ++++++++++++++++++--- tests/unit/sqlalchemy/test_dialect.py | 10 +++++++++ tests/unit/test_dbapi.py | 12 ++++++++++ trino/dbapi.py | 4 +++- 5 files changed, 61 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index c6502885..19dda261 100644 --- a/README.md +++ b/README.md @@ -321,6 +321,20 @@ cur.execute('SELECT * FROM system.runtime.nodes') rows = cur.fetchall() ``` +## Roles + +Authorization roles to use for catalogs, specified as a dict with key-value pairs for the catalog and role. For example, `{"catalog1": "roleA", "catalog2": "roleB"}` sets roleA for catalog1 and roleB for catalog2. + +```python +import trino +conn = trino.dbapi.connect( + host='localhost', + port=443, + user='the-user', + roles={"catalog1": "roleA", "catalog2": "roleB"}, +) +``` + ## SSL ### SSL verification diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index 33ce69fa..4575e4d9 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -19,6 +19,7 @@ import trino from tests.integration.conftest import trino_version +from trino import constants from trino.exceptions import TrinoQueryError, TrinoUserError, NotSupportedError from trino.transaction import IsolationLevel @@ -1045,11 +1046,11 @@ def test_set_role_trino_higher_351(run_trino): cur = trino_connection.cursor() cur.execute('SHOW TABLES FROM information_schema') cur.fetchall() - assert cur._request._client_session.roles is "" + assert cur._request._client_session.roles == "" cur.execute("SET ROLE ALL") cur.fetchall() - assert cur._request._client_session.roles == "system=ALL" + assert_role_headers(cur, "system=ALL") @pytest.mark.skipif(trino_version() != '351', reason="Trino 351 returns the role for the current catalog") @@ -1066,4 +1067,22 @@ def test_set_role_trino_351(run_trino): cur.execute("SET ROLE ALL") cur.fetchall() - assert cur._request._client_session.roles == "tpch=ALL" + assert_role_headers(cur, "tpch=ALL") + + +@pytest.mark.skipif(trino_version() == '351', reason="Newer Trino versions return the system role") +def test_set_role_in_connection_trino_higher_351(run_trino): + _, host, port = run_trino + + trino_connection = trino.dbapi.Connection( + host=host, port=port, user="test", catalog="tpch", roles={"system": "ALL"} + ) + cur = trino_connection.cursor() + cur.execute('SHOW TABLES FROM information_schema') + cur.fetchall() + assert_role_headers(cur, "system=ALL") + + +def assert_role_headers(cursor, expected_header): + assert cursor._request._client_session.roles == expected_header + assert cursor._request.http_headers[constants.HEADER_ROLE] == expected_header diff --git a/tests/unit/sqlalchemy/test_dialect.py b/tests/unit/sqlalchemy/test_dialect.py index b17f8cfe..963b27f7 100644 --- a/tests/unit/sqlalchemy/test_dialect.py +++ b/tests/unit/sqlalchemy/test_dialect.py @@ -63,6 +63,16 @@ def setup(self): experimental_python_types=True, ), ), + ( + make_url('trino://user@localhost:8080?roles={"hive":"finance","system":"analyst"}'), + list(), + dict(host="localhost", + port=8080, + catalog="system", + user="user", + roles={"hive": "finance", "system": "analyst"}, + source="trino-sqlalchemy"), + ), ], ) def test_create_connect_args(self, url: URL, expected_args: List[Any], expected_kwargs: Dict[str, Any]): diff --git a/tests/unit/test_dbapi.py b/tests/unit/test_dbapi.py index 6f2cc50c..4f5d5367 100644 --- a/tests/unit/test_dbapi.py +++ b/tests/unit/test_dbapi.py @@ -256,3 +256,15 @@ def test_tags_are_set_when_specified(mock_client): # THEN _, passed_client_tags = mock_client.ClientSession.call_args assert passed_client_tags["client_tags"] == client_tags + + +@patch("trino.dbapi.trino.client") +def test_role_is_set_when_specified(mock_client): + # WHEN + roles = {"system": "finance"} + with connect("sample_trino_cluster:443", roles=roles) as conn: + conn.cursor().execute("SOME FAKE QUERY") + + # THEN + _, passed_role = mock_client.ClientSession.call_args + assert passed_role["roles"] == roles diff --git a/trino/dbapi.py b/trino/dbapi.py index 70fb43bb..8c82a647 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -111,6 +111,7 @@ def __init__( http_session=None, client_tags=None, experimental_python_types=False, + roles=None, ): self.host = host self.port = port @@ -128,7 +129,8 @@ def __init__( headers=http_headers, transaction_id=NO_TRANSACTION, extra_credential=extra_credential, - client_tags=client_tags + client_tags=client_tags, + roles=roles, ) # mypy cannot follow module import if http_session is None: