diff --git a/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2Service.java b/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2Service.java index 28428e4641c8..20b8e3f6aa9a 100644 --- a/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2Service.java +++ b/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2Service.java @@ -136,15 +136,24 @@ public OAuth2Service(OAuth2Client client, @ForOAuth2 SigningKeyResolver signingK public Response startOAuth2Challenge(UriInfo uriInfo) { - return startOAuth2Challenge(uriInfo, Optional.empty()); + return startOAuth2Challenge( + uriInfo.getBaseUri().resolve(CALLBACK_ENDPOINT), + Optional.empty()); } public Response startOAuth2Challenge(UriInfo uriInfo, String handlerState) { - return startOAuth2Challenge(uriInfo, Optional.of(handlerState)); + return startOAuth2Challenge( + uriInfo.getBaseUri().resolve(CALLBACK_ENDPOINT), + Optional.of(handlerState)); } - private Response startOAuth2Challenge(UriInfo uriInfo, Optional handlerState) + public Response startOAuth2Challenge(URI callbackUri, String handlerState) + { + return startOAuth2Challenge(callbackUri, Optional.of(handlerState)); + } + + private Response startOAuth2Challenge(URI callbackUri, Optional handlerState) { Instant challengeExpiration = now().plus(challengeTimeout); String state = Jwts.builder() @@ -166,7 +175,7 @@ private Response startOAuth2Challenge(UriInfo uriInfo, Optional handlerS Response.ResponseBuilder response = Response.seeOther( client.getAuthorizationUri( state, - uriInfo.getBaseUri().resolve(CALLBACK_ENDPOINT), + callbackUri, nonce.map(OAuth2Service::hashNonce))); nonce.ifPresent(nce -> response.cookie(NonceCookie.create(nce, challengeExpiration))); return response.build();