From ce5c351dcb4fe8914379a8a350e200ceac39484b Mon Sep 17 00:00:00 2001 From: Michiel De Smet Date: Fri, 16 Sep 2022 15:22:32 +0200 Subject: [PATCH] Implement roles support in sqlalchemy --- README.md | 4 +++- tests/unit/sqlalchemy/test_dialect.py | 10 ++++++++++ trino/sqlalchemy/dialect.py | 3 +++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 205bfa8b..cab2188a 100644 --- a/README.md +++ b/README.md @@ -107,6 +107,7 @@ engine = create_engine( "session_properties": {'query_max_run_time': '1d'}, "client_tags": ["tag1", "tag2"], "experimental_python_types": True, + "roles": {"catalog1": "role1"}, } ) @@ -115,7 +116,8 @@ engine = create_engine( 'trino://user@localhost:8080/system?' 'session_properties={"query_max_run_time": "1d"}' '&client_tags=["tag1", "tag2"]' - '&experimental_python_types=true', + '&experimental_python_types=true' + '&roles={"catalog1": "role1"}' ) ``` diff --git a/tests/unit/sqlalchemy/test_dialect.py b/tests/unit/sqlalchemy/test_dialect.py index b17f8cfe..963b27f7 100644 --- a/tests/unit/sqlalchemy/test_dialect.py +++ b/tests/unit/sqlalchemy/test_dialect.py @@ -63,6 +63,16 @@ def setup(self): experimental_python_types=True, ), ), + ( + make_url('trino://user@localhost:8080?roles={"hive":"finance","system":"analyst"}'), + list(), + dict(host="localhost", + port=8080, + catalog="system", + user="user", + roles={"hive": "finance", "system": "analyst"}, + source="trino-sqlalchemy"), + ), ], ) def test_create_connect_args(self, url: URL, expected_args: List[Any], expected_kwargs: Dict[str, Any]): diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index e967cb6b..c5bccf37 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -124,6 +124,9 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any if "experimental_python_types" in url.query: kwargs["experimental_python_types"] = json.loads(url.query["experimental_python_types"]) + if "roles" in url.query: + kwargs["roles"] = json.loads(url.query["roles"]) + return args, kwargs def get_columns(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: