Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix to return 401 instead of 500 #715

Merged
merged 2 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import com.datastax.oss.driver.api.core.AllNodesFailedException;
import com.datastax.oss.driver.api.core.DriverException;
import com.datastax.oss.driver.api.core.DriverTimeoutException;
import com.datastax.oss.driver.api.core.NoNodeAvailableException;
import com.datastax.oss.driver.api.core.NodeUnavailableException;
import com.datastax.oss.driver.api.core.auth.AuthenticationException;
import com.datastax.oss.driver.api.core.metadata.Node;
import com.datastax.oss.driver.api.core.servererrors.QueryValidationException;
import com.datastax.oss.driver.api.core.servererrors.WriteTimeoutException;
import io.grpc.Status;
Expand All @@ -15,6 +15,7 @@
import io.stargate.sgv2.jsonapi.config.DebugModeConfig;
import io.stargate.sgv2.jsonapi.exception.JsonApiException;
import jakarta.ws.rs.core.Response;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
Expand Down Expand Up @@ -70,10 +71,40 @@ public final class ThrowableToErrorMapper {
message = "Mismatched vector dimension";
}
return new CommandResult.Error(message, fieldsForMetricsTag, fields, Response.Status.OK);
} else if (throwable instanceof NodeUnavailableException
|| throwable instanceof DriverException
|| throwable instanceof AllNodesFailedException
|| throwable instanceof NoNodeAvailableException) {
} else if (throwable instanceof DriverException) {
if (throwable instanceof AllNodesFailedException) {
Map<Node, List<Throwable>> nodewiseErrors =
((AllNodesFailedException) throwable).getAllErrors();
if (!nodewiseErrors.isEmpty()) {
List<Throwable> errors = nodewiseErrors.values().iterator().next();
if (errors != null && !errors.isEmpty()) {
Throwable error =
errors.stream()
.findAny()
.filter(
t ->
t instanceof AuthenticationException
|| t instanceof IllegalArgumentException)
.orElse(null);
// connecting to oss cassandra throws AuthenticationException for invalid
// credentials connecting to AstraDB throws IllegalArgumentException for invalid
// token/credentials
if (error instanceof AuthenticationException
|| (error instanceof IllegalArgumentException
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, So IllegalArgumentException is the issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, correct.

&& (error.getMessage().contains("AUTHENTICATION ERROR")
|| error
.getMessage()
.contains(
"Provided username token and/or password are incorrect")))) {
return new CommandResult.Error(
"UNAUTHENTICATED: Invalid token",
fieldsForMetricsTag,
fields,
Response.Status.UNAUTHORIZED);
}
}
}
}
return new CommandResult.Error(
message, fieldsForMetricsTag, fields, Response.Status.INTERNAL_SERVER_ERROR);
} else if (throwable instanceof DriverTimeoutException
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.stargate.sgv2.jsonapi.service.cqldriver;

import io.quarkus.test.junit.QuarkusTestProfile;
import java.util.Map;
import org.testcontainers.shaded.com.google.common.collect.ImmutableMap;

public class InvalidCredentialsProfile implements QuarkusTestProfile {

@Override
public Map<String, String> getConfigOverrides() {
return ImmutableMap.<String, String>builder()
.put("stargate.jsonapi.operations.database-config.fixed-token", "test-token")
.put("stargate.jsonapi.operations.database-config.password", "invalid-password")
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package io.stargate.sgv2.jsonapi.service.cqldriver;

import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.assertj.core.api.AssertionsForClassTypes.catchThrowable;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import com.datastax.oss.driver.api.core.AllNodesFailedException;
import com.datastax.oss.driver.api.core.CqlSession;
import io.micrometer.core.instrument.FunctionCounter;
import io.micrometer.core.instrument.Gauge;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.simple.SimpleMeterRegistry;
import io.quarkus.test.junit.QuarkusTest;
import io.quarkus.test.junit.TestProfile;
import io.stargate.sgv2.api.common.StargateRequestInfo;
import io.stargate.sgv2.jsonapi.api.model.command.CommandResult;
import io.stargate.sgv2.jsonapi.config.OperationsConfig;
import io.stargate.sgv2.jsonapi.exception.mappers.ThrowableToErrorMapper;
import jakarta.inject.Inject;
import jakarta.ws.rs.core.Response;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

@QuarkusTest
@TestProfile(InvalidCredentialsProfile.class)
public class InvalidCredentialsTests {

private static final String TENANT_ID_FOR_TEST = "test_tenant";

@Inject OperationsConfig operationsConfig;

private MeterRegistry meterRegistry;

/**
* List of sessions created in the tests. This is used to close the sessions after each test. This
* is needed because, though the sessions evicted from the cache are closed, the sessions left
* active on the cache are not closed, so we have to close them explicitly.
*/
private List<CqlSession> sessionsCreatedInTests;

@BeforeEach
public void tearUpEachTest() {
meterRegistry = new SimpleMeterRegistry();
sessionsCreatedInTests = new ArrayList<>();
}

@AfterEach
public void tearDownEachTest() {
sessionsCreatedInTests.forEach(CqlSession::close);
}

@Test
public void testOSSCxCQLSessionCacheWithInvalidCredentials()
throws NoSuchFieldException, IllegalAccessException {
// set request info
StargateRequestInfo stargateRequestInfo = mock(StargateRequestInfo.class);
when(stargateRequestInfo.getTenantId()).thenReturn(Optional.of(TENANT_ID_FOR_TEST));
when(stargateRequestInfo.getCassandraToken())
.thenReturn(operationsConfig.databaseConfig().fixedToken());
CQLSessionCache cqlSessionCacheForTest = new CQLSessionCache(operationsConfig, meterRegistry);
Field stargateRequestInfoField =
cqlSessionCacheForTest.getClass().getDeclaredField("stargateRequestInfo");
stargateRequestInfoField.setAccessible(true);
stargateRequestInfoField.set(cqlSessionCacheForTest, stargateRequestInfo);
// set operation config
Field operationsConfigField =
cqlSessionCacheForTest.getClass().getDeclaredField("operationsConfig");
operationsConfigField.setAccessible(true);
operationsConfigField.set(cqlSessionCacheForTest, operationsConfig);
// Throwable
Throwable t = catchThrowable(cqlSessionCacheForTest::getSession);
assertThat(t).isInstanceOf(AllNodesFailedException.class);
CommandResult.Error error =
ThrowableToErrorMapper.getMapperWithMessageFunction().apply(t, t.getMessage());
assertThat(error).isNotNull();
assertThat(error.message()).contains("UNAUTHENTICATED: Invalid token");
assertThat(error.status()).isEqualTo(Response.Status.UNAUTHORIZED);
assertThat(cqlSessionCacheForTest.cacheSize()).isEqualTo(0);
// metrics test
Gauge cacheSizeMetric =
meterRegistry.find("cache.size").tag("cache", "cql_sessions_cache").gauge();
assertThat(cacheSizeMetric).isNotNull();
assertThat(cacheSizeMetric.value()).isEqualTo(0);
FunctionCounter cachePutMetric =
meterRegistry.find("cache.puts").tag("cache", "cql_sessions_cache").functionCounter();
assertThat(cachePutMetric).isNotNull();
assertThat(cachePutMetric.count()).isEqualTo(1);
FunctionCounter cacheLoadSuccessMetric =
meterRegistry
.find("cache.load")
.tag("cache", "cql_sessions_cache")
.tag("result", "success")
.functionCounter();
assertThat(cacheLoadSuccessMetric).isNotNull();
assertThat(cacheLoadSuccessMetric.count()).isEqualTo(0);
FunctionCounter cacheLoadFailureMetric =
meterRegistry
.find("cache.load")
.tag("cache", "cql_sessions_cache")
.tag("result", "failure")
.functionCounter();
assertThat(cacheLoadFailureMetric).isNotNull();
assertThat(cacheLoadFailureMetric.count()).isEqualTo(1);
}
}