Skip to content

Commit

Permalink
Fix TokenServiceTests
Browse files Browse the repository at this point in the history
  • Loading branch information
albertzaharovits committed Jul 19, 2023
1 parent 217e8de commit f23d01e
Showing 1 changed file with 70 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ public void testInvalidateRefreshToken() throws Exception {
final String clientRefreshToken = tokenFuture.get().getRefreshToken();
assertNotNull(accessToken);
assertNotNull(clientRefreshToken);
mockFindTokenFromRefreshToken(newTokenBytes.v1(), newTokenBytes.v2(), tokenService, authentication, null);
mockTokenForRefreshToken(newTokenBytes.v1(), newTokenBytes.v2(), tokenService, authentication, null);

ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
storeTokenHeader(requestContext, accessToken);
Expand All @@ -496,7 +496,7 @@ public void testInvalidateRefreshTokenThatIsAlreadyInvalidated() throws Exceptio
final String clientRefreshToken = tokenFuture.get().getRefreshToken();
assertNotNull(accessToken);
assertNotNull(clientRefreshToken);
mockFindTokenFromRefreshToken(
mockTokenForRefreshToken(
newTokenBytes.v1(),
newTokenBytes.v2(),
tokenService,
Expand Down Expand Up @@ -1035,13 +1035,13 @@ public static String tokenDocIdFromAccessTokenBytes(byte[] accessTokenBytes, Tra
}
}

private void mockFindTokenFromRefreshToken(
private void mockTokenForRefreshToken(
byte[] accessTokenBytes,
byte[] refreshTokenBytes,
TokenService tokenService,
Authentication authentication,
@Nullable RefreshTokenStatus refreshTokenStatus
) {
) throws IOException {
UserToken userToken = buildUserToken(tokenService, accessTokenBytes, authentication, null, Map.of());
final String storedAccessToken;
final String storedRefreshToken;
Expand All @@ -1055,59 +1055,74 @@ private void mockFindTokenFromRefreshToken(
storedAccessToken = null;
storedRefreshToken = Base64.getUrlEncoder().withoutPadding().encodeToString(refreshTokenBytes);
}
doAnswer(invocationOnMock -> {
final SearchRequest request = (SearchRequest) invocationOnMock.getArguments()[0];
final RealmRef realmRef = new RealmRef(
refreshTokenStatus == null ? randomAlphaOfLength(6) : refreshTokenStatus.getAssociatedRealm(),
"test",
randomAlphaOfLength(12)
);
final Authentication clientAuthentication = AuthenticationTestHelper.builder()
.user(new User(refreshTokenStatus == null ? randomAlphaOfLength(8) : refreshTokenStatus.getAssociatedUser()))
.realmRef(realmRef)
.build(false);

final SearchHit hit = new SearchHit(randomInt(), "token_" + userToken.getId());
BytesReference source = TokenService.createTokenDocument(
userToken,
storedAccessToken,
storedRefreshToken,
clientAuthentication,
Instant.now()
);
if (refreshTokenStatus != null) {
var sourceAsMap = XContentHelper.convertToMap(source, false, XContentType.JSON).v2();
@SuppressWarnings("unchecked")
final ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[1];
final SearchResponse response = mock(SearchResponse.class);

assertThat(request.source().query(), instanceOf(BoolQueryBuilder.class));
BoolQueryBuilder bool = (BoolQueryBuilder) request.source().query();
assertThat(bool.filter(), hasSize(2));

assertThat(bool.filter().get(0), instanceOf(TermQueryBuilder.class));
TermQueryBuilder docType = (TermQueryBuilder) bool.filter().get(0);
assertThat(docType.fieldName(), is("doc_type"));
assertThat(docType.value(), is("token"));

assertThat(bool.filter().get(1), instanceOf(TermQueryBuilder.class));
TermQueryBuilder refreshFilter = (TermQueryBuilder) bool.filter().get(1);
assertThat(refreshFilter.fieldName(), is("refresh_token.token"));
assertThat(refreshFilter.value(), is(storedRefreshToken));

final RealmRef realmRef = new RealmRef(
refreshTokenStatus == null ? randomAlphaOfLength(6) : refreshTokenStatus.getAssociatedRealm(),
"test",
randomAlphaOfLength(12)
);
final Authentication clientAuthentication = AuthenticationTestHelper.builder()
.user(new User(refreshTokenStatus == null ? randomAlphaOfLength(8) : refreshTokenStatus.getAssociatedUser()))
.realmRef(realmRef)
.build(false);

final SearchHit hit = new SearchHit(randomInt(), "token_" + userToken.getId());
BytesReference source = TokenService.createTokenDocument(
userToken,
storedAccessToken,
storedRefreshToken,
clientAuthentication,
Instant.now()
);
if (refreshTokenStatus != null) {
var sourceAsMap = XContentHelper.convertToMap(source, false, XContentType.JSON).v2();
var refreshTokenSource = (Map<String, Object>) sourceAsMap.get("refresh_token");
refreshTokenSource.put("invalidated", refreshTokenStatus.isInvalidated());
refreshTokenSource.put("refreshed", refreshTokenStatus.isRefreshed());
source = XContentTestUtils.convertToXContent(sourceAsMap, XContentType.JSON);
}
final BytesReference docSource = source;
if (userToken.getTransportVersion().onOrAfter(VERSION_GET_TOKEN_DOC_FOR_REFRESH)) {
doAnswer(invocationOnMock -> {
GetRequest request = (GetRequest) invocationOnMock.getArguments()[0];
@SuppressWarnings("unchecked")
var refreshTokenSource = (Map<String, Object>) sourceAsMap.get("refresh_token");
refreshTokenSource.put("invalidated", refreshTokenStatus.isInvalidated());
refreshTokenSource.put("refreshed", refreshTokenStatus.isRefreshed());
source = XContentTestUtils.convertToXContent(sourceAsMap, XContentType.JSON);
}
hit.sourceRef(source);

final SearchHits hits = new SearchHits(new SearchHit[] { hit }, null, 1);
when(response.getHits()).thenReturn(hits);
listener.onResponse(response);
return Void.TYPE;
}).when(client).search(any(SearchRequest.class), any());
ActionListener<GetResponse> listener = (ActionListener<GetResponse>) invocationOnMock.getArguments()[1];
GetResponse response = mock(GetResponse.class);
if (userToken.getId().equals(request.id().substring("token_".length()))) {
when(response.isExists()).thenReturn(true);
when(response.getSource()).thenReturn(XContentHelper.convertToMap(docSource, false, XContentType.JSON).v2());
}
listener.onResponse(response);
return null;
}).when(client).get(any(GetRequest.class), anyActionListener());
} else {
doAnswer(invocationOnMock -> {
final SearchRequest request = (SearchRequest) invocationOnMock.getArguments()[0];
@SuppressWarnings("unchecked")
final ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[1];
final SearchResponse response = mock(SearchResponse.class);

assertThat(request.source().query(), instanceOf(BoolQueryBuilder.class));
BoolQueryBuilder bool = (BoolQueryBuilder) request.source().query();
assertThat(bool.filter(), hasSize(2));

assertThat(bool.filter().get(0), instanceOf(TermQueryBuilder.class));
TermQueryBuilder docType = (TermQueryBuilder) bool.filter().get(0);
assertThat(docType.fieldName(), is("doc_type"));
assertThat(docType.value(), is("token"));

assertThat(bool.filter().get(1), instanceOf(TermQueryBuilder.class));
TermQueryBuilder refreshFilter = (TermQueryBuilder) bool.filter().get(1);
assertThat(refreshFilter.fieldName(), is("refresh_token.token"));
assertThat(refreshFilter.value(), is(storedRefreshToken));
hit.sourceRef(docSource);

final SearchHits hits = new SearchHits(new SearchHit[] { hit }, null, 1);
when(response.getHits()).thenReturn(hits);
listener.onResponse(response);
return Void.TYPE;
}).when(client).search(any(SearchRequest.class), any());
}
}

public static void assertAuthentication(Authentication result, Authentication expected) {
Expand Down

0 comments on commit f23d01e

Please sign in to comment.