diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java index 0826d936cb4d2..21ce6269ee2bd 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java @@ -469,7 +469,7 @@ public void testInvalidateRefreshToken() throws Exception { final String clientRefreshToken = tokenFuture.get().getRefreshToken(); assertNotNull(accessToken); assertNotNull(clientRefreshToken); - mockFindTokenFromRefreshToken(newTokenBytes.v1(), newTokenBytes.v2(), tokenService, authentication, null); + mockTokenForRefreshToken(newTokenBytes.v1(), newTokenBytes.v2(), tokenService, authentication, null); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); storeTokenHeader(requestContext, accessToken); @@ -496,7 +496,7 @@ public void testInvalidateRefreshTokenThatIsAlreadyInvalidated() throws Exceptio final String clientRefreshToken = tokenFuture.get().getRefreshToken(); assertNotNull(accessToken); assertNotNull(clientRefreshToken); - mockFindTokenFromRefreshToken( + mockTokenForRefreshToken( newTokenBytes.v1(), newTokenBytes.v2(), tokenService, @@ -1035,13 +1035,13 @@ public static String tokenDocIdFromAccessTokenBytes(byte[] accessTokenBytes, Tra } } - private void mockFindTokenFromRefreshToken( + private void mockTokenForRefreshToken( byte[] accessTokenBytes, byte[] refreshTokenBytes, TokenService tokenService, Authentication authentication, @Nullable RefreshTokenStatus refreshTokenStatus - ) { + ) throws IOException { UserToken userToken = buildUserToken(tokenService, accessTokenBytes, authentication, null, Map.of()); final String storedAccessToken; final String storedRefreshToken; @@ -1055,59 +1055,74 @@ private void mockFindTokenFromRefreshToken( storedAccessToken = null; storedRefreshToken = Base64.getUrlEncoder().withoutPadding().encodeToString(refreshTokenBytes); } - doAnswer(invocationOnMock -> { - final SearchRequest request = (SearchRequest) invocationOnMock.getArguments()[0]; + final RealmRef realmRef = new RealmRef( + refreshTokenStatus == null ? randomAlphaOfLength(6) : refreshTokenStatus.getAssociatedRealm(), + "test", + randomAlphaOfLength(12) + ); + final Authentication clientAuthentication = AuthenticationTestHelper.builder() + .user(new User(refreshTokenStatus == null ? randomAlphaOfLength(8) : refreshTokenStatus.getAssociatedUser())) + .realmRef(realmRef) + .build(false); + + final SearchHit hit = new SearchHit(randomInt(), "token_" + userToken.getId()); + BytesReference source = TokenService.createTokenDocument( + userToken, + storedAccessToken, + storedRefreshToken, + clientAuthentication, + Instant.now() + ); + if (refreshTokenStatus != null) { + var sourceAsMap = XContentHelper.convertToMap(source, false, XContentType.JSON).v2(); @SuppressWarnings("unchecked") - final ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; - final SearchResponse response = mock(SearchResponse.class); - - assertThat(request.source().query(), instanceOf(BoolQueryBuilder.class)); - BoolQueryBuilder bool = (BoolQueryBuilder) request.source().query(); - assertThat(bool.filter(), hasSize(2)); - - assertThat(bool.filter().get(0), instanceOf(TermQueryBuilder.class)); - TermQueryBuilder docType = (TermQueryBuilder) bool.filter().get(0); - assertThat(docType.fieldName(), is("doc_type")); - assertThat(docType.value(), is("token")); - - assertThat(bool.filter().get(1), instanceOf(TermQueryBuilder.class)); - TermQueryBuilder refreshFilter = (TermQueryBuilder) bool.filter().get(1); - assertThat(refreshFilter.fieldName(), is("refresh_token.token")); - assertThat(refreshFilter.value(), is(storedRefreshToken)); - - final RealmRef realmRef = new RealmRef( - refreshTokenStatus == null ? randomAlphaOfLength(6) : refreshTokenStatus.getAssociatedRealm(), - "test", - randomAlphaOfLength(12) - ); - final Authentication clientAuthentication = AuthenticationTestHelper.builder() - .user(new User(refreshTokenStatus == null ? randomAlphaOfLength(8) : refreshTokenStatus.getAssociatedUser())) - .realmRef(realmRef) - .build(false); - - final SearchHit hit = new SearchHit(randomInt(), "token_" + userToken.getId()); - BytesReference source = TokenService.createTokenDocument( - userToken, - storedAccessToken, - storedRefreshToken, - clientAuthentication, - Instant.now() - ); - if (refreshTokenStatus != null) { - var sourceAsMap = XContentHelper.convertToMap(source, false, XContentType.JSON).v2(); + var refreshTokenSource = (Map) sourceAsMap.get("refresh_token"); + refreshTokenSource.put("invalidated", refreshTokenStatus.isInvalidated()); + refreshTokenSource.put("refreshed", refreshTokenStatus.isRefreshed()); + source = XContentTestUtils.convertToXContent(sourceAsMap, XContentType.JSON); + } + final BytesReference docSource = source; + if (userToken.getTransportVersion().onOrAfter(VERSION_GET_TOKEN_DOC_FOR_REFRESH)) { + doAnswer(invocationOnMock -> { + GetRequest request = (GetRequest) invocationOnMock.getArguments()[0]; @SuppressWarnings("unchecked") - var refreshTokenSource = (Map) sourceAsMap.get("refresh_token"); - refreshTokenSource.put("invalidated", refreshTokenStatus.isInvalidated()); - refreshTokenSource.put("refreshed", refreshTokenStatus.isRefreshed()); - source = XContentTestUtils.convertToXContent(sourceAsMap, XContentType.JSON); - } - hit.sourceRef(source); - - final SearchHits hits = new SearchHits(new SearchHit[] { hit }, null, 1); - when(response.getHits()).thenReturn(hits); - listener.onResponse(response); - return Void.TYPE; - }).when(client).search(any(SearchRequest.class), any()); + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; + GetResponse response = mock(GetResponse.class); + if (userToken.getId().equals(request.id().substring("token_".length()))) { + when(response.isExists()).thenReturn(true); + when(response.getSource()).thenReturn(XContentHelper.convertToMap(docSource, false, XContentType.JSON).v2()); + } + listener.onResponse(response); + return null; + }).when(client).get(any(GetRequest.class), anyActionListener()); + } else { + doAnswer(invocationOnMock -> { + final SearchRequest request = (SearchRequest) invocationOnMock.getArguments()[0]; + @SuppressWarnings("unchecked") + final ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; + final SearchResponse response = mock(SearchResponse.class); + + assertThat(request.source().query(), instanceOf(BoolQueryBuilder.class)); + BoolQueryBuilder bool = (BoolQueryBuilder) request.source().query(); + assertThat(bool.filter(), hasSize(2)); + + assertThat(bool.filter().get(0), instanceOf(TermQueryBuilder.class)); + TermQueryBuilder docType = (TermQueryBuilder) bool.filter().get(0); + assertThat(docType.fieldName(), is("doc_type")); + assertThat(docType.value(), is("token")); + + assertThat(bool.filter().get(1), instanceOf(TermQueryBuilder.class)); + TermQueryBuilder refreshFilter = (TermQueryBuilder) bool.filter().get(1); + assertThat(refreshFilter.fieldName(), is("refresh_token.token")); + assertThat(refreshFilter.value(), is(storedRefreshToken)); + hit.sourceRef(docSource); + + final SearchHits hits = new SearchHits(new SearchHit[] { hit }, null, 1); + when(response.getHits()).thenReturn(hits); + listener.onResponse(response); + return Void.TYPE; + }).when(client).search(any(SearchRequest.class), any()); + } } public static void assertAuthentication(Authentication result, Authentication expected) {