diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 7f4d591a77984..0f8d7995deab8 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -431,6 +431,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods oauth2_scope = "" oauth2_authorization_request_uri = "" # pylint: disable=invalid-name oauth2_token_request_uri = "" + oauth2_token_request_type = "" # Driver-specific exception that should be mapped to OAuth2RedirectError oauth2_exception = OAuth2RedirectError @@ -515,6 +516,9 @@ def get_oauth2_config(cls) -> OAuth2ClientConfig | None: "token_request_uri", cls.oauth2_token_request_uri, ), + "request_content_type": db_engine_spec_config.get( + "request_content_type", cls.oauth2_token_request_type + ), } return config @@ -552,18 +556,16 @@ def get_oauth2_token( """ timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds() uri = config["token_request_uri"] - response = requests.post( - uri, - json={ - "code": code, - "client_id": config["id"], - "client_secret": config["secret"], - "redirect_uri": config["redirect_uri"], - "grant_type": "authorization_code", - }, - timeout=timeout, - ) - return response.json() + req_body = { + "code": code, + "client_id": config["id"], + "client_secret": config["secret"], + "redirect_uri": config["redirect_uri"], + "grant_type": "authorization_code", + } + if config["request_content_type"] == "data": + return requests.post(uri, data=req_body, timeout=timeout).json() + return requests.post(uri, json=req_body, timeout=timeout).json() @classmethod def get_oauth2_fresh_token( @@ -576,17 +578,15 @@ def get_oauth2_fresh_token( """ timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds() uri = config["token_request_uri"] - response = requests.post( - uri, - json={ - "client_id": config["id"], - "client_secret": config["secret"], - "refresh_token": refresh_token, - "grant_type": "refresh_token", - }, - timeout=timeout, - ) - return response.json() + req_body = { + "client_id": config["id"], + "client_secret": config["secret"], + "refresh_token": refresh_token, + "grant_type": "refresh_token", + } + if config["request_content_type"] == "data": + return requests.post(uri, data=req_body, timeout=timeout).json() + return requests.post(uri, json=req_body, timeout=timeout).json() @classmethod def get_allows_alias_in_select( diff --git a/superset/superset_typing.py b/superset/superset_typing.py index 3a850e0acb672..c3c40cd31a918 100644 --- a/superset/superset_typing.py +++ b/superset/superset_typing.py @@ -149,6 +149,10 @@ class OAuth2ClientConfig(TypedDict): # expired access token. token_request_uri: str + # Not all identity providers expect json. Keycloak expects a form encoded request, + # which in the `requests` package context means using the `data` param, not `json`. + request_content_type: str + class OAuth2TokenResponse(TypedDict, total=False): """ diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index b889ef83c5e75..95db2921f6cd6 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -23,7 +23,7 @@ import backoff import jwt from flask import current_app, url_for -from marshmallow import EXCLUDE, fields, post_load, Schema +from marshmallow import EXCLUDE, fields, post_load, Schema, validate from superset import db from superset.distributed_lock import KeyValueDistributedLock @@ -192,3 +192,8 @@ class OAuth2ClientConfigSchema(Schema): ) authorization_request_uri = fields.String(required=True) token_request_uri = fields.String(required=True) + request_content_type = fields.String( + required=False, + load_default=lambda: "json", + validate=validate.OneOf(["json", "data"]), + ) diff --git a/tests/unit_tests/db_engine_specs/test_gsheets.py b/tests/unit_tests/db_engine_specs/test_gsheets.py index 5d2ddb807bbc1..4e17054db9e63 100644 --- a/tests/unit_tests/db_engine_specs/test_gsheets.py +++ b/tests/unit_tests/db_engine_specs/test_gsheets.py @@ -559,6 +559,7 @@ def oauth2_config() -> OAuth2ClientConfig: "redirect_uri": "http://localhost:8088/api/v1/oauth2/", "authorization_request_uri": "https://accounts.google.com/o/oauth2/v2/auth", "token_request_uri": "https://oauth2.googleapis.com/token", + "request_content_type": "json", } diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index 990ae891c465c..b169a484a1e03 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -45,7 +45,12 @@ SupersetDBAPIProgrammingError, ) from superset.sql_parse import Table -from superset.superset_typing import ResultSetColumnType, SQLAColumnType, SQLType +from superset.superset_typing import ( + OAuth2ClientConfig, + ResultSetColumnType, + SQLAColumnType, + SQLType, +) from superset.utils import json from superset.utils.core import GenericDataType from tests.unit_tests.db_engine_specs.utils import ( @@ -788,3 +793,57 @@ def test_where_latest_partition( ) == f"""SELECT * FROM table \nWHERE partition_key = {expected_value}""" ) + + +@pytest.fixture +def oauth2_config() -> OAuth2ClientConfig: + """ + Config for GSheets OAuth2. + """ + return { + "id": "trino", + "secret": "very-secret", + "scope": "", + "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/", + "authorization_request_uri": "https://trino.auth.server.example/realms/master/protocol/openid-connect/auth", + "token_request_uri": "https://trino.auth.server.example/master/protocol/openid-connect/token", + "request_content_type": "data", + } + + +def test_get_oauth2_token( + mocker: MockerFixture, + oauth2_config: OAuth2ClientConfig, +) -> None: + """ + Test `get_oauth2_token`. + """ + from superset.db_engine_specs.trino import TrinoEngineSpec + + requests = mocker.patch("superset.db_engine_specs.base.requests") + requests.post().json.return_value = { + "access_token": "access-token", + "expires_in": 3600, + "scope": "scope", + "token_type": "Bearer", + "refresh_token": "refresh-token", + } + + assert TrinoEngineSpec.get_oauth2_token(oauth2_config, "code") == { + "access_token": "access-token", + "expires_in": 3600, + "scope": "scope", + "token_type": "Bearer", + "refresh_token": "refresh-token", + } + requests.post.assert_called_with( + "https://trino.auth.server.example/master/protocol/openid-connect/token", + data={ + "code": "code", + "client_id": "trino", + "client_secret": "very-secret", + "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/", + "grant_type": "authorization_code", + }, + timeout=30.0, + ) diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index 6f588cde2408c..7d475cd8250a3 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -427,6 +427,7 @@ def test_get_oauth2_config(app_context: None) -> None: "token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request", "scope": "refresh_token session:role:SYSADMIN", "redirect_uri": "http://example.com/api/v1/database/oauth2/", + "request_content_type": "json", }