Skip to content

Commit

Permalink
Malformed tokens tests
Browse files Browse the repository at this point in the history
  • Loading branch information
albertzaharovits committed Jul 24, 2023
1 parent bd3defd commit 7ec895b
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.license.MockLicenseState;
import org.elasticsearch.node.Node;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.test.ClusterServiceUtils;
Expand Down Expand Up @@ -719,25 +720,97 @@ public void testBytesKeyEqualsHashCode() {
});
}

public void testMalformedToken() throws Exception {
final int numBytes = randomIntBetween(1, TokenService.MINIMUM_BYTES + 32);
final byte[] randomBytes = new byte[numBytes];
random().nextBytes(randomBytes);
public void testMalformedAccessTokens() throws Exception {
TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC());
// mock another random token so that we don't find a token in TokenService#getUserTokenFromId
Authentication authentication = AuthenticationTestHelper.builder()
.user(new User("joe", "admin"))
.realmRef(new RealmRef("native_realm", "native", "node1"))
.build(false);
mockGetTokenFromAccessTokenBytes(tokenService, tokenService.getRandomTokenBytes(randomBoolean()).v1(), authentication, false, null);
byte[] accessTokenBytes = tokenService.getRandomTokenBytes(randomBoolean()).v1();
mockGetTokenFromAccessTokenBytes(tokenService, accessTokenBytes, authentication, false, null);
String mockedAccessToken = tokenService.prependVersionAndEncodeAccessToken(
tokenService.getTokenVersionCompatibility(),
accessTokenBytes
);
ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
storeTokenHeader(requestContext, Base64.getEncoder().encodeToString(randomBytes));
// test some random access tokens
for (int numBytes = 1; numBytes < TokenService.MINIMUM_BYTES + 32; numBytes++) {
final byte[] randomBytes = new byte[numBytes];
random().nextBytes(randomBytes);
try (ThreadContext.StoredContext ignore = requestContext.newStoredContextPreservingResponseHeaders()) {
String testAccessToken = Base64.getEncoder().encodeToString(randomBytes);
assumeFalse("Test token must be different from mock", mockedAccessToken.equals(testAccessToken));
storeTokenHeader(requestContext, testAccessToken);
PlainActionFuture<UserToken> future = new PlainActionFuture<>();
final SecureString bearerToken = Authenticator.extractBearerTokenFromHeader(requestContext);
tokenService.tryAuthenticateToken(bearerToken, future);
assertNull(future.get());
}
}
// test garbled mocked access tokens
for (int garbledByteIdx = 0; garbledByteIdx < accessTokenBytes.length; garbledByteIdx++) {
final byte[] garbledAccessToken = new byte[accessTokenBytes.length];
System.arraycopy(accessTokenBytes, 0, garbledAccessToken, 0, accessTokenBytes.length);
garbledAccessToken[garbledByteIdx] = (byte) (garbledAccessToken[garbledByteIdx] ^ (byte) random().nextInt(255));
String testAccessToken = tokenService.prependVersionAndEncodeAccessToken(
tokenService.getTokenVersionCompatibility(),
garbledAccessToken
);
assumeFalse("Test token must be different from mock", mockedAccessToken.equals(testAccessToken));
try (ThreadContext.StoredContext ignore = requestContext.newStoredContextPreservingResponseHeaders()) {
storeTokenHeader(requestContext, testAccessToken);
PlainActionFuture<UserToken> future = new PlainActionFuture<>();
final SecureString bearerToken = Authenticator.extractBearerTokenFromHeader(requestContext);
tokenService.tryAuthenticateToken(bearerToken, future);
assertNull(future.get());
}
}
}

try (ThreadContext.StoredContext ignore = requestContext.newStoredContextPreservingResponseHeaders()) {
PlainActionFuture<UserToken> future = new PlainActionFuture<>();
final SecureString bearerToken = Authenticator.extractBearerTokenFromHeader(requestContext);
tokenService.tryAuthenticateToken(bearerToken, future);
assertNull(future.get());
public void testMalformedRefreshTokens() throws Exception {
TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC());
Authentication authentication = AuthenticationTestHelper.builder()
.user(new User("joe", "admin"))
.realmRef(new RealmRef("native_realm", "native", "node1"))
.build(false);
PlainActionFuture<TokenService.CreateTokenResult> tokenFuture = new PlainActionFuture<>();
Tuple<byte[], byte[]> newTokenBytes = tokenService.getRandomTokenBytes(true);
tokenService.createOAuth2Tokens(newTokenBytes.v1(), newTokenBytes.v2(), authentication, authentication, Map.of(), tokenFuture);
byte[] mockedRawRefreshToken = newTokenBytes.v2();
String mockedClientRefreshToken = tokenFuture.get().getRefreshToken();
assertNotNull(mockedClientRefreshToken);
mockTokenForRefreshToken(newTokenBytes.v1(), newTokenBytes.v2(), tokenService, authentication, null);
// test some random access tokens
for (int numBytes = 1; numBytes < RAW_TOKEN_BYTES_TOTAL_LENGTH + 8; numBytes++) {
final byte[] randomBytes = new byte[numBytes];
random().nextBytes(randomBytes);
String testRefreshToken = Base64.getEncoder().encodeToString(randomBytes);
assumeFalse("Test token must be different from mock", mockedClientRefreshToken.equals(testRefreshToken));
PlainActionFuture<TokensInvalidationResult> future = new PlainActionFuture<>();
tokenService.invalidateRefreshToken(testRefreshToken, future);
final TokensInvalidationResult result = future.get();
assertThat(result.getInvalidatedTokens(), hasSize(0));
assertThat(result.getPreviouslyInvalidatedTokens(), empty());
assertThat(result.getErrors(), empty());
assertThat(result.getRestStatus(), is(RestStatus.NOT_FOUND));
}
// test garbled mocked refresh tokens
for (int garbledByteIdx = 0; garbledByteIdx < mockedRawRefreshToken.length; garbledByteIdx++) {
final byte[] garbledRefreshToken = new byte[mockedRawRefreshToken.length];
System.arraycopy(mockedRawRefreshToken, 0, garbledRefreshToken, 0, mockedRawRefreshToken.length);
garbledRefreshToken[garbledByteIdx] = (byte) (garbledRefreshToken[garbledByteIdx] ^ (byte) random().nextInt(255));
String testRefreshToken = TokenService.prependVersionAndEncodeRefreshToken(
tokenService.getTokenVersionCompatibility(),
garbledRefreshToken
);
assumeFalse("Test token must be different from mock", mockedClientRefreshToken.equals(testRefreshToken));
PlainActionFuture<TokensInvalidationResult> future = new PlainActionFuture<>();
tokenService.invalidateRefreshToken(testRefreshToken, future);
final TokensInvalidationResult result = future.get();
assertThat(result.getInvalidatedTokens(), hasSize(0));
assertThat(result.getPreviouslyInvalidatedTokens(), empty());
assertThat(result.getErrors(), empty());
assertThat(result.getRestStatus(), is(RestStatus.NOT_FOUND));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ private Map<String, Object> createTokens(RestClient client, String username, Str
"password": "%s",
"grant_type": "password"
}""", username, password));
Response response = client().performRequest(createTokenRequest);
Response response = client.performRequest(createTokenRequest);
assertOK(response);
return entityAsMap(response);
}
Expand Down

0 comments on commit 7ec895b

Please sign in to comment.