Skip to content

Commit

Permalink
Update OidcWireMock to include the client_id in the ID token audience…
Browse files Browse the repository at this point in the history
… dynamically
  • Loading branch information
douglas444 committed Oct 21, 2024
1 parent 744bc75 commit 64446f7
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public String access() {
return "access token verified: " + (routingContext.get("code_flow_access_token_result") != null)
+ ", id_token issuer: " + idToken.getIssuer()
+ ", access_token issuer: " + accessToken.getIssuer()
+ ", id_token audience: " + idToken.getAudience().iterator().next()
+ ", id_token audience: " + String.join(";", idToken.getAudience().stream().sorted().toList())
+ ", access_token audience: " + accessToken.getAudience().iterator().next()
+ ", cache size: " + tokenCache.getCacheSize();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ public void testCodeFlowVerifyIdAndAccessToken() throws IOException {
assertEquals("access token verified: true,"
+ " id_token issuer: https://server.example.com,"
+ " access_token issuer: https://server.example.com,"
+ " id_token audience: https://id.server.example.com,"
+ " id_token audience: https://id.server.example.com;quarkus-web-app,"
+ " access_token audience: https://server.example.com,"
+ " cache size: 0", textPage.getContent());
assertNotNull(getSessionCookie(webClient, "code-flow-verify-id-and-access-tokens"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@

import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.Base64;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;

import jakarta.json.Json;
import jakarta.json.JsonObject;
Expand All @@ -25,11 +28,15 @@
import org.jose4j.keys.X509Util;

import com.github.tomakehurst.wiremock.WireMockServer;
import com.github.tomakehurst.wiremock.common.ListOrSingle;
import com.github.tomakehurst.wiremock.extension.TemplateHelperProviderExtension;
import com.google.common.collect.ImmutableSet;

import io.quarkus.test.common.QuarkusTestResourceLifecycleManager;
import io.smallrye.jwt.build.Jwt;
import io.smallrye.jwt.build.JwtClaimsBuilder;
import wiremock.com.github.jknack.handlebars.Helper;
import wiremock.com.github.jknack.handlebars.Options;

/**
* Provides a mock OIDC server to tests.
Expand Down Expand Up @@ -58,7 +65,20 @@ public class OidcWiremockTestResource implements QuarkusTestResourceLifecycleMan
@Override
public Map<String, String> start() {

server = new WireMockServer(wireMockConfig().dynamicPort());
server = new WireMockServer(wireMockConfig().dynamicPort().extensions(
new TemplateHelperProviderExtension() {
@Override
public String getName() {
return "custom-helpers";
}

@Override
public Map<String, Helper<?>> provideTemplateHelpers() {
Helper<String> idTokenHelper = OidcWiremockTestResource.this::buildBasicSchemeIdToken;
return Map.ofEntries(Map.entry("basic-scheme-id-token", idTokenHelper));
}
}));

server.start();

server.stubFor(
Expand Down Expand Up @@ -298,9 +318,9 @@ private void defineCodeFlowAuthorizationMockTokenStub() {
" \"access_token\": \""
+ getAccessToken("alice", getAdminRoles()) + "\",\n" +
" \"refresh_token\": \"07e08903-1263-4dd1-9fd1-4a59b0db5283\",\n" +
" \"id_token\": \"" + getIdToken("alice", getAdminRoles())
+ "\"\n" +
"}")));
" \"id_token\": \"{{basic-scheme-id-token 'alice'}}\"\n" +
"}")
.withTransformers("response-template")));
}

private void definePasswordGrantTokenStub() {
Expand Down Expand Up @@ -378,6 +398,10 @@ public static String getIdToken(String userName, Set<String> groups) {
return generateJwtToken(userName, groups, TOKEN_SUBJECT, ID_TOKEN_TYPE);
}

public static String getIdToken(String userName, Set<String> groups, String clientId) {
return generateJwtToken(userName, groups, TOKEN_SUBJECT, ID_TOKEN_TYPE, Set.of(clientId, ID_TOKEN_AUDIENCE));
}

public static String generateJwtToken(String userName, Set<String> groups) {
return generateJwtToken(userName, groups, TOKEN_SUBJECT);
}
Expand All @@ -387,11 +411,14 @@ public static String generateJwtToken(String userName, Set<String> groups, Strin
}

public static String generateJwtToken(String userName, Set<String> groups, String sub, String type) {
final String audience = ID_TOKEN_TYPE.equals(type) ? ID_TOKEN_AUDIENCE : TOKEN_AUDIENCE;
return generateJwtToken(userName, groups, sub, type, Set.of(TOKEN_AUDIENCE));
}

public static String generateJwtToken(String userName, Set<String> groups, String sub, String type, Set<String> aud) {
JwtClaimsBuilder builder = Jwt.preferredUserName(userName)
.groups(groups)
.issuer(TOKEN_ISSUER)
.audience(audience)
.audience(aud)
.claim("sid", "session-id")
.subject(sub);
if (type != null) {
Expand Down Expand Up @@ -438,4 +465,45 @@ public synchronized void stop() {
server = null;
}
}

private String buildBasicSchemeIdToken(String context, Options options) {

String clientId = getHeader("Authorization", options)
.map(OidcWiremockTestResource::removerBasicPrefix)
.map(OidcWiremockTestResource::decodeBase64)
.map(OidcWiremockTestResource::getClientIdFromCredentials)
.orElseThrow(() -> new RuntimeException("Invalid Authorization header"));

return getIdToken(context, getAdminRoles(), clientId);
}

private static Optional<String> getHeader(String header, Options options) {

TreeMap<String, ListOrSingle<String>> map = options.get("request.headers");
if (map == null || !map.containsKey(header) || map.get(header).isEmpty()) {
return Optional.empty();
}

return Optional.of(map.get(header).getFirst());
}

private static String removerBasicPrefix(String value) {
if (value.startsWith("Basic ")) {
return value.substring("Basic ".length());
}
return value;
}

private static String decodeBase64(String base64String) {
return new String(Base64.getDecoder().decode(base64String));
}

private static String getClientIdFromCredentials(String credentials) {
String[] tokens = credentials.split(":");
if (tokens.length >= 1) {
return tokens[0];
}
return credentials;
}

}

0 comments on commit 64446f7

Please sign in to comment.