Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(event_handler): add support for OpenAPI security schemes #4103

Merged
merged 22 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 125 additions & 24 deletions aws_lambda_powertools/event_handler/api_gateway.py

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions aws_lambda_powertools/event_handler/bedrock_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def get( # type: ignore[override]
include_in_schema: bool = True,
middlewares: Optional[List[Callable[..., Any]]] = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
security = None

return super(BedrockAgentResolver, self).get(
rule,
cors,
Expand All @@ -114,6 +116,7 @@ def get( # type: ignore[override]
tags,
operation_id,
include_in_schema,
security,
middlewares,
)

Expand All @@ -134,6 +137,8 @@ def post( # type: ignore[override]
include_in_schema: bool = True,
middlewares: Optional[List[Callable[..., Any]]] = None,
):
security = None

return super().post(
rule,
cors,
Expand All @@ -146,6 +151,7 @@ def post( # type: ignore[override]
tags,
operation_id,
include_in_schema,
security,
middlewares,
)

Expand All @@ -166,6 +172,8 @@ def put( # type: ignore[override]
include_in_schema: bool = True,
middlewares: Optional[List[Callable[..., Any]]] = None,
):
security = None

return super().put(
rule,
cors,
Expand All @@ -178,6 +186,7 @@ def put( # type: ignore[override]
tags,
operation_id,
include_in_schema,
security,
middlewares,
)

Expand All @@ -198,6 +207,8 @@ def patch( # type: ignore[override]
include_in_schema: bool = True,
middlewares: Optional[List[Callable]] = None,
):
security = None

return super().patch(
rule,
cors,
Expand All @@ -210,6 +221,7 @@ def patch( # type: ignore[override]
tags,
operation_id,
include_in_schema,
security,
middlewares,
)

Expand All @@ -230,6 +242,8 @@ def delete( # type: ignore[override]
include_in_schema: bool = True,
middlewares: Optional[List[Callable[..., Any]]] = None,
):
security = None

return super().delete(
rule,
cors,
Expand All @@ -242,6 +256,7 @@ def delete( # type: ignore[override]
tags,
operation_id,
include_in_schema,
security,
middlewares,
)

Expand Down
3 changes: 2 additions & 1 deletion aws_lambda_powertools/event_handler/openapi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,12 +441,13 @@ class SecurityBase(BaseModel):
description: Optional[str] = None

if PYDANTIC_V2:
model_config = {"extra": "allow"}
model_config = {"extra": "allow", "populate_by_name": True}

else:

class Config:
extra = "allow"
allow_population_by_field_name = True


class APIKeyIn(Enum):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from aws_lambda_powertools.event_handler.openapi.swagger_ui.html import (
generate_swagger_html,
)
from aws_lambda_powertools.event_handler.openapi.swagger_ui.oauth2 import (
OAuth2Config,
generate_oauth2_redirect_html,
)

__all__ = [
"generate_swagger_html",
"generate_oauth2_redirect_html",
"OAuth2Config",
]
39 changes: 33 additions & 6 deletions aws_lambda_powertools/event_handler/openapi/swagger_ui/html.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
def generate_swagger_html(spec: str, path: str, swagger_js: str, swagger_css: str, swagger_base_url: str) -> str:
from typing import Optional

from aws_lambda_powertools.event_handler.openapi.swagger_ui.oauth2 import OAuth2Config


def generate_swagger_html(
spec: str,
path: str,
swagger_js: str,
swagger_css: str,
swagger_base_url: str,
oauth2_config: Optional[OAuth2Config],
) -> str:
"""
Generate Swagger UI HTML page

Expand All @@ -8,10 +20,14 @@ def generate_swagger_html(spec: str, path: str, swagger_js: str, swagger_css: st
The OpenAPI spec
path: str
The path to the Swagger documentation
js_url: str
The URL to the Swagger UI JavaScript file
css_url: str
The URL to the Swagger UI CSS file
swagger_js: str
Swagger UI JavaScript source code or URL
swagger_css: str
Swagger UI CSS source code or URL
swagger_base_url: str
The base URL for Swagger UI
oauth2_config: OAuth2Config, optional
The OAuth2 configuration.
"""

# If Swagger base URL is present, generate HTML content with linked CSS and JavaScript files
Expand All @@ -23,6 +39,11 @@ def generate_swagger_html(spec: str, path: str, swagger_js: str, swagger_css: st
swagger_css_content = f"<style>{swagger_css}</style>"
swagger_js_content = f"<script>{swagger_js}</script>"

# Prepare oauth2 config
oauth2_content = (
f"ui.initOAuth({oauth2_config.json(exclude_none=True, exclude_unset=True)});" if oauth2_config else ""
)

return f"""
<!DOCTYPE html>
<html>
Expand All @@ -45,6 +66,9 @@ def generate_swagger_html(spec: str, path: str, swagger_js: str, swagger_css: st
{swagger_js_content}

<script>
var currentUrl = new URL(window.location.href);
var baseUrl = currentUrl.protocol + "//" + currentUrl.host + currentUrl.pathname;

var swaggerUIOptions = {{
dom_id: "#swagger-ui",
docExpansion: "list",
Expand All @@ -60,11 +84,14 @@ def generate_swagger_html(spec: str, path: str, swagger_js: str, swagger_css: st
],
plugins: [
SwaggerUIBundle.plugins.DownloadUrl
]
],
withCredentials: true,
oauth2RedirectUrl: baseUrl + "?format=oauth2-redirect",
}}

var ui = SwaggerUIBundle(swaggerUIOptions)
ui.specActions.updateUrl('{path}?format=json');
{oauth2_content}
</script>
</html>
""".strip()
148 changes: 148 additions & 0 deletions aws_lambda_powertools/event_handler/openapi/swagger_ui/oauth2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# ruff: noqa: E501
from typing import Dict, Optional, Sequence

from pydantic import BaseModel, Field, validator

from aws_lambda_powertools.event_handler.openapi.pydantic_loader import PYDANTIC_V2
from aws_lambda_powertools.shared.functions import powertools_dev_is_set


# Based on https://swagger.io/docs/open-source-tools/swagger-ui/usage/oauth2/
class OAuth2Config(BaseModel):
"""
OAuth2 configuration for Swagger UI
"""

# The client ID for the OAuth2 application
clientId: Optional[str] = Field(alias="client_id", default=None)

# The client secret for the OAuth2 application. This is sensitive information and requires the explicit presence
# of the POWERTOOLS_DEV environment variable.
clientSecret: Optional[str] = Field(alias="client_secret", default=None)

# The realm in which the OAuth2 application is registered. Optional.
realm: Optional[str] = Field(default=None)

# The name of the OAuth2 application
appName: str = Field(alias="app_name")

# The scopes that the OAuth2 application requires. Defaults to an empty list.
scopes: Sequence[str] = Field(default=[])

# Additional query string parameters to be included in the OAuth2 request. Defaults to an empty dictionary.
additionalQueryStringParams: Dict[str, str] = Field(alias="additional_query_string_params", default={})

# Whether to use basic authentication with the access code grant type. Defaults to False.
useBasicAuthenticationWithAccessCodeGrant: bool = Field(
alias="use_basic_authentication_with_access_code_grant",
default=False,
)

# Whether to use PKCE with the authorization code grant type. Defaults to False.
usePkceWithAuthorizationCodeGrant: bool = Field(alias="use_pkce_with_authorization_code_grant", default=False)

if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:

class Config:
extra = "allow"
allow_population_by_field_name = True

@validator("clientSecret", always=True)
def client_secret_only_on_dev(cls, v: Optional[str]) -> Optional[str]:
if v and not powertools_dev_is_set():
raise ValueError(
"cannot use client_secret without POWERTOOLS_DEV mode. See "
"https://docs.powertools.aws.dev/lambda/python/latest/#optimizing-for-non-production-environments",
)
return v


def generate_oauth2_redirect_html() -> str:
"""
Generates the HTML content for the OAuth2 redirect page.

Source: https://github.com/swagger-api/swagger-ui/blob/master/dist/oauth2-redirect.html
"""
return """
<!doctype html>
<html lang="en-US">
<head>
<title>Swagger UI: OAuth2 Redirect</title>
</head>
<body>
<script>
'use strict';
function run () {
var oauth2 = window.opener.swaggerUIRedirectOauth2;
var sentState = oauth2.state;
var redirectUrl = oauth2.redirectUrl;
var isValid, qp, arr;

if (/code|token|error/.test(window.location.hash)) {
qp = window.location.hash.substring(1).replace('?', '&');
} else {
qp = location.search.substring(1);
}

arr = qp.split("&");
arr.forEach(function (v,i,_arr) { _arr[i] = '"' + v.replace('=', '":"') + '"';});
qp = qp ? JSON.parse('{' + arr.join() + '}',
function (key, value) {
return key === "" ? value : decodeURIComponent(value);
}
) : {};

isValid = qp.state === sentState;

if ((
oauth2.auth.schema.get("flow") === "accessCode" ||
oauth2.auth.schema.get("flow") === "authorizationCode" ||
oauth2.auth.schema.get("flow") === "authorization_code"
) && !oauth2.auth.code) {
if (!isValid) {
oauth2.errCb({
authId: oauth2.auth.name,
source: "auth",
level: "warning",
message: "Authorization may be unsafe, passed state was changed in server. The passed state wasn't returned from auth server."
});
}

if (qp.code) {
delete oauth2.state;
oauth2.auth.code = qp.code;
oauth2.callback({auth: oauth2.auth, redirectUrl: redirectUrl});
} else {
let oauthErrorMsg;
if (qp.error) {
oauthErrorMsg = "["+qp.error+"]: " +
(qp.error_description ? qp.error_description+ ". " : "no accessCode received from the server. ") +
(qp.error_uri ? "More info: "+qp.error_uri : "");
}

oauth2.errCb({
authId: oauth2.auth.name,
source: "auth",
level: "error",
message: oauthErrorMsg || "[Authorization failed]: no accessCode received from the server."
});
}
} else {
oauth2.callback({auth: oauth2.auth, token: qp, isValid: isValid, redirectUrl: redirectUrl});
}
window.close();
}

if (document.readyState !== 'loading') {
run();
} else {
document.addEventListener('DOMContentLoaded', function () {
run();
});
}
</script>
</body>
</html>
""".strip()
Loading
Loading