diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java index 7197315371f67..3165ce63e68b8 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java @@ -476,51 +476,91 @@ public void getAuthenticationAndMetadata(String token, ActionListener listener + ) { + getTokenDocById( + tokenId, + tokenVersion, + storedAccessToken, + storedRefreshToken, + listener.map( + doc -> UserToken.fromSourceMap( + (Map) ((Map) doc.sourceAsMap().get("access_token")).get("user_token") + ) + ) + ); + } + + @SuppressWarnings("unchecked") + private void getTokenDocById( + String tokenId, + TransportVersion tokenVersion, + @Nullable String storedAccessToken, + @Nullable String storedRefreshToken, + ActionListener listener ) { final SecurityIndexManager tokensIndex = getTokensIndexForVersion(tokenVersion); final SecurityIndexManager frozenTokensIndex = tokensIndex.freeze(); if (frozenTokensIndex.isAvailable() == false) { - logger.warn("failed to get access token [{}] because index [{}] is not available", userTokenId, tokensIndex.aliasName()); + logger.warn("failed to get access token [{}] because index [{}] is not available", tokenId, tokensIndex.aliasName()); listener.onFailure(frozenTokensIndex.getUnavailableReason()); } else { - final GetRequest getRequest = client.prepareGet(tokensIndex.aliasName(), getTokenDocumentId(userTokenId)).request(); - final Consumer onFailure = ex -> listener.onFailure(traceLog("get token from id", userTokenId, ex)); + final GetRequest getRequest = client.prepareGet(tokensIndex.aliasName(), getTokenDocumentId(tokenId)).request(); + final Consumer onFailure = ex -> listener.onFailure(traceLog("get token from id", tokenId, ex)); tokensIndex.checkIndexVersionThenExecute( - ex -> listener.onFailure(traceLog("prepare tokens index [" + tokensIndex.aliasName() + "]", userTokenId, ex)), + ex -> listener.onFailure(traceLog("prepare tokens index [" + tokensIndex.aliasName() + "]", tokenId, ex)), () -> executeAsyncWithOrigin( client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest, ActionListener.wrap(response -> { if (response.isExists()) { - @SuppressWarnings("unchecked") Map accessTokenSource = (Map) response.getSource().get("access_token"); if (accessTokenSource == null) { onFailure.accept(new IllegalStateException("token document is missing the access_token field")); } else if (accessTokenSource.containsKey("user_token") == false) { onFailure.accept(new IllegalStateException("token document is missing the user_token field")); - } else if ((accessToken == null && accessTokenSource.containsKey("token")) - || (accessToken != null && accessToken.equals(accessTokenSource.get("token")) == false)) { - logger.trace("The access token [{}] is invalid", userTokenId); - listener.onResponse(null); - } else { - @SuppressWarnings("unchecked") - Map userTokenSource = (Map) accessTokenSource.get("user_token"); - listener.onResponse(UserToken.fromSourceMap(userTokenSource)); - } + } else if (tokenVersion.onOrAfter(VERSION_GET_TOKEN_DOC_FOR_REFRESH) + && accessTokenSource.containsKey("token") == false) { + onFailure.accept(new IllegalStateException("token document is missing the user_token.token field")); + } else if (tokenVersion.onOrAfter(VERSION_GET_TOKEN_DOC_FOR_REFRESH) + && response.getSource().get("refresh_token") != null + && ((Map) response.getSource().get("refresh_token")).containsKey("token") == false) { + onFailure.accept( + new IllegalStateException("token document is missing the refresh_token.token field") + ); + } else if (storedAccessToken != null + && storedAccessToken.equals(accessTokenSource.get("token")) == false) { + logger.error( + "The stored access token [{}] for token doc id [{}] could not be verified", + storedAccessToken, + tokenId + ); + listener.onResponse(null); + } else if (storedRefreshToken != null + && (response.getSource().get("refresh_token") == null + || storedRefreshToken.equals( + ((Map) response.getSource().get("refresh_token")).get("token") + ) == false)) { + logger.error( + "The stored refresh token [{}] for token doc id [{}] could not be verified", + storedRefreshToken, + tokenId + ); + listener.onResponse(null); + } else { + listener.onResponse(new Doc(response)); + } } else { // The chances of a random token string decoding to something that we can read is minimal, so // we assume that this was a token we have created but is now expired/revoked and deleted - logger.trace("The access token [{}] is expired and already deleted", userTokenId); + logger.trace("The token [{}] probably expired and has already been deleted", tokenId); listener.onResponse(null); } }, e -> { @@ -528,12 +568,12 @@ private void getUserTokenFromId( // the token is not valid if (isShardNotAvailableException(e)) { logger.warn( - "failed to get access token [{}] because index [{}] is not available", - userTokenId, + "failed to get token doc [{}] because index [{}] is not available", + tokenId, tokensIndex.aliasName() ); } else { - logger.error(() -> "failed to get access token [" + userTokenId + "]", e); + logger.error(() -> "failed to get token doc [" + tokenId + "]", e); } listener.onFailure(e); }), @@ -587,8 +627,8 @@ void decodeToken(String token, ActionListener listener) { MessageDigest userTokenIdDigest = sha256(); userTokenIdDigest.update(accessTokenBytes, RAW_TOKEN_BYTES_LENGTH, RAW_TOKEN_DOC_ID_BYTES_LENGTH); final String userTokenId = Base64.getUrlEncoder().withoutPadding().encodeToString(userTokenIdDigest.digest()); - final String accessToken = Base64.getUrlEncoder().withoutPadding().encodeToString(sha256().digest(accessTokenBytes)); - getUserTokenFromId(userTokenId, accessToken, version, listener); + final String storedAccessToken = Base64.getUrlEncoder().withoutPadding().encodeToString(sha256().digest(accessTokenBytes)); + getUserTokenById(userTokenId, version, storedAccessToken, null, listener); } else if (version.onOrAfter(VERSION_ACCESS_TOKENS_AS_UUIDS)) { // The token was created in a > VERSION_ACCESS_TOKENS_UUIDS cluster if (in.available() < MINIMUM_BYTES) { @@ -598,7 +638,7 @@ void decodeToken(String token, ActionListener listener) { } final String accessToken = in.readString(); final String userTokenId = hashTokenString(accessToken); - getUserTokenFromId(userTokenId, null, version, listener); + getUserTokenById(userTokenId, version, null, null, listener); } else { // The token was created in a < VERSION_ACCESS_TOKENS_UUIDS cluster so we need to decrypt it to get the tokenId if (in.available() < LEGACY_MINIMUM_BYTES) { @@ -619,7 +659,7 @@ void decodeToken(String token, ActionListener listener) { try { final Cipher cipher = getDecryptionCipher(iv, decodeKey, version, decodedSalt); final String tokenId = decryptTokenId(encryptedTokenId, cipher, version); - getUserTokenFromId(tokenId, null, version, listener); + getUserTokenById(tokenId, version, null, null, listener); } catch (IOException | GeneralSecurityException e) { // could happen with a token that is not ours logger.warn("invalid token", e); @@ -1021,6 +1061,7 @@ private void findTokenFromRefreshToken(String refreshToken, Iterator final String hashedRefreshToken = Base64.getUrlEncoder() .withoutPadding() .encodeToString(sha256().digest(unencodedRefreshToken)); + // TODO findTokenFromRefreshToken(hashedRefreshToken, securityTokensIndex, backoff, listener); } } else if (version.onOrAfter(VERSION_HASHED_TOKENS)) {