From 7ec895b5a0ff4f76d6bb3cc74b9470822a9685f1 Mon Sep 17 00:00:00 2001 From: Albert Zaharovits Date: Mon, 24 Jul 2023 17:46:37 +0300 Subject: [PATCH] Malformed tokens tests --- .../security/authc/TokenServiceTests.java | 97 ++++++++++++++++--- .../TokenBackwardsCompatibilityIT.java | 2 +- 2 files changed, 86 insertions(+), 13 deletions(-) diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java index 6937b759cc312..7f95f0c1737ee 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java @@ -61,6 +61,7 @@ import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.license.MockLicenseState; import org.elasticsearch.node.Node; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; import org.elasticsearch.test.ClusterServiceUtils; @@ -719,25 +720,97 @@ public void testBytesKeyEqualsHashCode() { }); } - public void testMalformedToken() throws Exception { - final int numBytes = randomIntBetween(1, TokenService.MINIMUM_BYTES + 32); - final byte[] randomBytes = new byte[numBytes]; - random().nextBytes(randomBytes); + public void testMalformedAccessTokens() throws Exception { TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC()); - // mock another random token so that we don't find a token in TokenService#getUserTokenFromId Authentication authentication = AuthenticationTestHelper.builder() .user(new User("joe", "admin")) .realmRef(new RealmRef("native_realm", "native", "node1")) .build(false); - mockGetTokenFromAccessTokenBytes(tokenService, tokenService.getRandomTokenBytes(randomBoolean()).v1(), authentication, false, null); + byte[] accessTokenBytes = tokenService.getRandomTokenBytes(randomBoolean()).v1(); + mockGetTokenFromAccessTokenBytes(tokenService, accessTokenBytes, authentication, false, null); + String mockedAccessToken = tokenService.prependVersionAndEncodeAccessToken( + tokenService.getTokenVersionCompatibility(), + accessTokenBytes + ); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - storeTokenHeader(requestContext, Base64.getEncoder().encodeToString(randomBytes)); + // test some random access tokens + for (int numBytes = 1; numBytes < TokenService.MINIMUM_BYTES + 32; numBytes++) { + final byte[] randomBytes = new byte[numBytes]; + random().nextBytes(randomBytes); + try (ThreadContext.StoredContext ignore = requestContext.newStoredContextPreservingResponseHeaders()) { + String testAccessToken = Base64.getEncoder().encodeToString(randomBytes); + assumeFalse("Test token must be different from mock", mockedAccessToken.equals(testAccessToken)); + storeTokenHeader(requestContext, testAccessToken); + PlainActionFuture future = new PlainActionFuture<>(); + final SecureString bearerToken = Authenticator.extractBearerTokenFromHeader(requestContext); + tokenService.tryAuthenticateToken(bearerToken, future); + assertNull(future.get()); + } + } + // test garbled mocked access tokens + for (int garbledByteIdx = 0; garbledByteIdx < accessTokenBytes.length; garbledByteIdx++) { + final byte[] garbledAccessToken = new byte[accessTokenBytes.length]; + System.arraycopy(accessTokenBytes, 0, garbledAccessToken, 0, accessTokenBytes.length); + garbledAccessToken[garbledByteIdx] = (byte) (garbledAccessToken[garbledByteIdx] ^ (byte) random().nextInt(255)); + String testAccessToken = tokenService.prependVersionAndEncodeAccessToken( + tokenService.getTokenVersionCompatibility(), + garbledAccessToken + ); + assumeFalse("Test token must be different from mock", mockedAccessToken.equals(testAccessToken)); + try (ThreadContext.StoredContext ignore = requestContext.newStoredContextPreservingResponseHeaders()) { + storeTokenHeader(requestContext, testAccessToken); + PlainActionFuture future = new PlainActionFuture<>(); + final SecureString bearerToken = Authenticator.extractBearerTokenFromHeader(requestContext); + tokenService.tryAuthenticateToken(bearerToken, future); + assertNull(future.get()); + } + } + } - try (ThreadContext.StoredContext ignore = requestContext.newStoredContextPreservingResponseHeaders()) { - PlainActionFuture future = new PlainActionFuture<>(); - final SecureString bearerToken = Authenticator.extractBearerTokenFromHeader(requestContext); - tokenService.tryAuthenticateToken(bearerToken, future); - assertNull(future.get()); + public void testMalformedRefreshTokens() throws Exception { + TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC()); + Authentication authentication = AuthenticationTestHelper.builder() + .user(new User("joe", "admin")) + .realmRef(new RealmRef("native_realm", "native", "node1")) + .build(false); + PlainActionFuture tokenFuture = new PlainActionFuture<>(); + Tuple newTokenBytes = tokenService.getRandomTokenBytes(true); + tokenService.createOAuth2Tokens(newTokenBytes.v1(), newTokenBytes.v2(), authentication, authentication, Map.of(), tokenFuture); + byte[] mockedRawRefreshToken = newTokenBytes.v2(); + String mockedClientRefreshToken = tokenFuture.get().getRefreshToken(); + assertNotNull(mockedClientRefreshToken); + mockTokenForRefreshToken(newTokenBytes.v1(), newTokenBytes.v2(), tokenService, authentication, null); + // test some random access tokens + for (int numBytes = 1; numBytes < RAW_TOKEN_BYTES_TOTAL_LENGTH + 8; numBytes++) { + final byte[] randomBytes = new byte[numBytes]; + random().nextBytes(randomBytes); + String testRefreshToken = Base64.getEncoder().encodeToString(randomBytes); + assumeFalse("Test token must be different from mock", mockedClientRefreshToken.equals(testRefreshToken)); + PlainActionFuture future = new PlainActionFuture<>(); + tokenService.invalidateRefreshToken(testRefreshToken, future); + final TokensInvalidationResult result = future.get(); + assertThat(result.getInvalidatedTokens(), hasSize(0)); + assertThat(result.getPreviouslyInvalidatedTokens(), empty()); + assertThat(result.getErrors(), empty()); + assertThat(result.getRestStatus(), is(RestStatus.NOT_FOUND)); + } + // test garbled mocked refresh tokens + for (int garbledByteIdx = 0; garbledByteIdx < mockedRawRefreshToken.length; garbledByteIdx++) { + final byte[] garbledRefreshToken = new byte[mockedRawRefreshToken.length]; + System.arraycopy(mockedRawRefreshToken, 0, garbledRefreshToken, 0, mockedRawRefreshToken.length); + garbledRefreshToken[garbledByteIdx] = (byte) (garbledRefreshToken[garbledByteIdx] ^ (byte) random().nextInt(255)); + String testRefreshToken = TokenService.prependVersionAndEncodeRefreshToken( + tokenService.getTokenVersionCompatibility(), + garbledRefreshToken + ); + assumeFalse("Test token must be different from mock", mockedClientRefreshToken.equals(testRefreshToken)); + PlainActionFuture future = new PlainActionFuture<>(); + tokenService.invalidateRefreshToken(testRefreshToken, future); + final TokensInvalidationResult result = future.get(); + assertThat(result.getInvalidatedTokens(), hasSize(0)); + assertThat(result.getPreviouslyInvalidatedTokens(), empty()); + assertThat(result.getErrors(), empty()); + assertThat(result.getRestStatus(), is(RestStatus.NOT_FOUND)); } } diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/TokenBackwardsCompatibilityIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/TokenBackwardsCompatibilityIT.java index 6d65aff668ff0..d82d6d5dd6747 100644 --- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/TokenBackwardsCompatibilityIT.java +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/TokenBackwardsCompatibilityIT.java @@ -343,7 +343,7 @@ private Map createTokens(RestClient client, String username, Str "password": "%s", "grant_type": "password" }""", username, password)); - Response response = client().performRequest(createTokenRequest); + Response response = client.performRequest(createTokenRequest); assertOK(response); return entityAsMap(response); }