Skip to content

Commit

Permalink
Specify clientRegistrationId in TokenRelayFilterFunctions (#3591)
Browse files Browse the repository at this point in the history
Closes gh-3541
  • Loading branch information
sjohnr authored Nov 19, 2024
1 parent 9050809 commit 311d9a1
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@

package org.springframework.cloud.gateway.server.mvc.filter;

import java.security.Principal;

import org.springframework.cloud.gateway.server.mvc.common.Shortcut;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
Expand All @@ -37,13 +36,21 @@ private TokenRelayFilterFunctions() {

@Shortcut
public static HandlerFilterFunction<ServerResponse, ServerResponse> tokenRelay() {
return tokenRelay(null);
}

public static HandlerFilterFunction<ServerResponse, ServerResponse> tokenRelay(String defaultClientRegistrationId) {
return (request, next) -> {
Principal principle = request.servletRequest().getUserPrincipal();
if (principle instanceof OAuth2AuthenticationToken token) {
String clientRegistrationId = token.getAuthorizedClientRegistrationId();
Authentication principal = (Authentication) request.servletRequest().getUserPrincipal();

String clientRegistrationId = defaultClientRegistrationId;
if (clientRegistrationId == null && principal instanceof OAuth2AuthenticationToken token) {
clientRegistrationId = token.getAuthorizedClientRegistrationId();
}
if (clientRegistrationId != null) {
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
.withClientRegistrationId(clientRegistrationId)
.principal(token)
.principal(principal)
.build();
OAuth2AuthorizedClientManager clientManager = getApplicationContext(request)
.getBean(OAuth2AuthorizedClientManager.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,29 @@ public void whenPrincipalExistsAuthorizationHeaderAdded() throws Exception {
});
}

@Test
public void whenDefaultClientRegistrationIdProvidedAuthorizationHeaderAdded() throws Exception {
OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class);
when(accessToken.getTokenValue()).thenReturn("mytoken");

ClientRegistration clientRegistration = ClientRegistration.withRegistrationId("myregistrationid")
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
.clientId("myclientid")
.tokenUri("mytokenuri")
.build();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, "joe", accessToken);

when(authorizedClientManager.authorize(any(OAuth2AuthorizeRequest.class))).thenReturn(authorizedClient);

request.setUserPrincipal(new TestingAuthenticationToken("my", null));

filter = TokenRelayFilterFunctions.tokenRelay("myId");
filter.filter(ServerRequest.create(request, converters), req -> {
assertThat(req.headers().firstHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer mytoken");
return null;
});
}

@Test
public void principalIsNotOAuth2AuthenticationToken() throws Exception {
request.setUserPrincipal(new TestingAuthenticationToken("my", null));
Expand Down

0 comments on commit 311d9a1

Please sign in to comment.