Skip to content

Commit

Permalink
feat(oauth2): allow override encoding of token http request
Browse files Browse the repository at this point in the history
  • Loading branch information
joaoferrao committed Aug 31, 2024
1 parent 0fadd94 commit e08135d
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 25 deletions.
46 changes: 23 additions & 23 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
oauth2_scope = ""
oauth2_authorization_request_uri = "" # pylint: disable=invalid-name
oauth2_token_request_uri = ""
oauth2_token_request_type = ""

# Driver-specific exception that should be mapped to OAuth2RedirectError
oauth2_exception = OAuth2RedirectError
Expand Down Expand Up @@ -515,6 +516,9 @@ def get_oauth2_config(cls) -> OAuth2ClientConfig | None:
"token_request_uri",
cls.oauth2_token_request_uri,
),
"request_content_type": db_engine_spec_config.get(
"request_content_type", cls.oauth2_token_request_type
),
}

return config
Expand Down Expand Up @@ -552,18 +556,16 @@ def get_oauth2_token(
"""
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"]
response = requests.post(
uri,
json={
"code": code,
"client_id": config["id"],
"client_secret": config["secret"],
"redirect_uri": config["redirect_uri"],
"grant_type": "authorization_code",
},
timeout=timeout,
)
return response.json()
req_body = {
"code": code,
"client_id": config["id"],
"client_secret": config["secret"],
"redirect_uri": config["redirect_uri"],
"grant_type": "authorization_code",
}
if config["request_content_type"] == "data":
return requests.post(uri, data=req_body, timeout=timeout).json()
return requests.post(uri, json=req_body, timeout=timeout).json()

@classmethod
def get_oauth2_fresh_token(
Expand All @@ -576,17 +578,15 @@ def get_oauth2_fresh_token(
"""
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"]
response = requests.post(
uri,
json={
"client_id": config["id"],
"client_secret": config["secret"],
"refresh_token": refresh_token,
"grant_type": "refresh_token",
},
timeout=timeout,
)
return response.json()
req_body = {
"client_id": config["id"],
"client_secret": config["secret"],
"refresh_token": refresh_token,
"grant_type": "refresh_token",
}
if config["request_content_type"] == "data":
return requests.post(uri, data=req_body, timeout=timeout).json()
return requests.post(uri, json=req_body, timeout=timeout).json()

@classmethod
def get_allows_alias_in_select(
Expand Down
4 changes: 4 additions & 0 deletions superset/superset_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ class OAuth2ClientConfig(TypedDict):
# expired access token.
token_request_uri: str

# Not all identity providers expect json. Keycloak expects a form encoded request,
# which in the `requests` package context means using the `data` param, not `json`.
request_content_type: str


class OAuth2TokenResponse(TypedDict, total=False):
"""
Expand Down
7 changes: 6 additions & 1 deletion superset/utils/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import backoff
import jwt
from flask import current_app, url_for
from marshmallow import EXCLUDE, fields, post_load, Schema
from marshmallow import EXCLUDE, fields, post_load, Schema, validate

from superset import db
from superset.distributed_lock import KeyValueDistributedLock
Expand Down Expand Up @@ -192,3 +192,8 @@ class OAuth2ClientConfigSchema(Schema):
)
authorization_request_uri = fields.String(required=True)
token_request_uri = fields.String(required=True)
request_content_type = fields.String(
required=False,
load_default=lambda: "json",
validate=validate.OneOf(["json", "data"]),
)
1 change: 1 addition & 0 deletions tests/unit_tests/db_engine_specs/test_gsheets.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ def oauth2_config() -> OAuth2ClientConfig:
"redirect_uri": "http://localhost:8088/api/v1/oauth2/",
"authorization_request_uri": "https://accounts.google.com/o/oauth2/v2/auth",
"token_request_uri": "https://oauth2.googleapis.com/token",
"request_content_type": "json",
}


Expand Down
61 changes: 60 additions & 1 deletion tests/unit_tests/db_engine_specs/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@
SupersetDBAPIProgrammingError,
)
from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType, SQLAColumnType, SQLType
from superset.superset_typing import (
OAuth2ClientConfig,
ResultSetColumnType,
SQLAColumnType,
SQLType,
)
from superset.utils import json
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
Expand Down Expand Up @@ -788,3 +793,57 @@ def test_where_latest_partition(
)
== f"""SELECT * FROM table \nWHERE partition_key = {expected_value}"""
)


@pytest.fixture
def oauth2_config() -> OAuth2ClientConfig:
"""
Config for GSheets OAuth2.
"""
return {
"id": "trino",
"secret": "very-secret",
"scope": "",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "https://trino.auth.server.example/realms/master/protocol/openid-connect/auth",
"token_request_uri": "https://trino.auth.server.example/master/protocol/openid-connect/token",
"request_content_type": "data",
}


def test_get_oauth2_token(
mocker: MockerFixture,
oauth2_config: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_token`.
"""
from superset.db_engine_specs.trino import TrinoEngineSpec

requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().json.return_value = {
"access_token": "access-token",
"expires_in": 3600,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}

assert TrinoEngineSpec.get_oauth2_token(oauth2_config, "code") == {
"access_token": "access-token",
"expires_in": 3600,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
requests.post.assert_called_with(
"https://trino.auth.server.example/master/protocol/openid-connect/token",
data={
"code": "code",
"client_id": "trino",
"client_secret": "very-secret",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"grant_type": "authorization_code",
},
timeout=30.0,
)
1 change: 1 addition & 0 deletions tests/unit_tests/models/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ def test_get_oauth2_config(app_context: None) -> None:
"token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request",
"scope": "refresh_token session:role:SYSADMIN",
"redirect_uri": "http://example.com/api/v1/database/oauth2/",
"request_content_type": "json",
}


Expand Down

0 comments on commit e08135d

Please sign in to comment.