diff --git a/src/main/java/io/stargate/sgv2/jsonapi/exception/mappers/ThrowableToErrorMapper.java b/src/main/java/io/stargate/sgv2/jsonapi/exception/mappers/ThrowableToErrorMapper.java index d684477267..44911a6d2b 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/exception/mappers/ThrowableToErrorMapper.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/exception/mappers/ThrowableToErrorMapper.java @@ -3,8 +3,8 @@ import com.datastax.oss.driver.api.core.AllNodesFailedException; import com.datastax.oss.driver.api.core.DriverException; import com.datastax.oss.driver.api.core.DriverTimeoutException; -import com.datastax.oss.driver.api.core.NoNodeAvailableException; -import com.datastax.oss.driver.api.core.NodeUnavailableException; +import com.datastax.oss.driver.api.core.auth.AuthenticationException; +import com.datastax.oss.driver.api.core.metadata.Node; import com.datastax.oss.driver.api.core.servererrors.QueryValidationException; import com.datastax.oss.driver.api.core.servererrors.WriteTimeoutException; import io.grpc.Status; @@ -15,6 +15,7 @@ import io.stargate.sgv2.jsonapi.config.DebugModeConfig; import io.stargate.sgv2.jsonapi.exception.JsonApiException; import jakarta.ws.rs.core.Response; +import java.util.List; import java.util.Map; import java.util.function.BiFunction; import java.util.function.Function; @@ -70,10 +71,40 @@ public final class ThrowableToErrorMapper { message = "Mismatched vector dimension"; } return new CommandResult.Error(message, fieldsForMetricsTag, fields, Response.Status.OK); - } else if (throwable instanceof NodeUnavailableException - || throwable instanceof DriverException - || throwable instanceof AllNodesFailedException - || throwable instanceof NoNodeAvailableException) { + } else if (throwable instanceof DriverException) { + if (throwable instanceof AllNodesFailedException) { + Map> nodewiseErrors = + ((AllNodesFailedException) throwable).getAllErrors(); + if (!nodewiseErrors.isEmpty()) { + List errors = nodewiseErrors.values().iterator().next(); + if (errors != null && !errors.isEmpty()) { + Throwable error = + errors.stream() + .findAny() + .filter( + t -> + t instanceof AuthenticationException + || t instanceof IllegalArgumentException) + .orElse(null); + // connecting to oss cassandra throws AuthenticationException for invalid + // credentials connecting to AstraDB throws IllegalArgumentException for invalid + // token/credentials + if (error instanceof AuthenticationException + || (error instanceof IllegalArgumentException + && (error.getMessage().contains("AUTHENTICATION ERROR") + || error + .getMessage() + .contains( + "Provided username token and/or password are incorrect")))) { + return new CommandResult.Error( + "UNAUTHENTICATED: Invalid token", + fieldsForMetricsTag, + fields, + Response.Status.UNAUTHORIZED); + } + } + } + } return new CommandResult.Error( message, fieldsForMetricsTag, fields, Response.Status.INTERNAL_SERVER_ERROR); } else if (throwable instanceof DriverTimeoutException diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/InvalidCredentialsProfile.java b/src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/InvalidCredentialsProfile.java new file mode 100644 index 0000000000..ee04b8aba8 --- /dev/null +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/InvalidCredentialsProfile.java @@ -0,0 +1,16 @@ +package io.stargate.sgv2.jsonapi.service.cqldriver; + +import io.quarkus.test.junit.QuarkusTestProfile; +import java.util.Map; +import org.testcontainers.shaded.com.google.common.collect.ImmutableMap; + +public class InvalidCredentialsProfile implements QuarkusTestProfile { + + @Override + public Map getConfigOverrides() { + return ImmutableMap.builder() + .put("stargate.jsonapi.operations.database-config.fixed-token", "test-token") + .put("stargate.jsonapi.operations.database-config.password", "invalid-password") + .build(); + } +} diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/InvalidCredentialsTests.java b/src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/InvalidCredentialsTests.java new file mode 100644 index 0000000000..e3d94a2481 --- /dev/null +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/InvalidCredentialsTests.java @@ -0,0 +1,111 @@ +package io.stargate.sgv2.jsonapi.service.cqldriver; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.catchThrowable; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.datastax.oss.driver.api.core.AllNodesFailedException; +import com.datastax.oss.driver.api.core.CqlSession; +import io.micrometer.core.instrument.FunctionCounter; +import io.micrometer.core.instrument.Gauge; +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.simple.SimpleMeterRegistry; +import io.quarkus.test.junit.QuarkusTest; +import io.quarkus.test.junit.TestProfile; +import io.stargate.sgv2.api.common.StargateRequestInfo; +import io.stargate.sgv2.jsonapi.api.model.command.CommandResult; +import io.stargate.sgv2.jsonapi.config.OperationsConfig; +import io.stargate.sgv2.jsonapi.exception.mappers.ThrowableToErrorMapper; +import jakarta.inject.Inject; +import jakarta.ws.rs.core.Response; +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +@QuarkusTest +@TestProfile(InvalidCredentialsProfile.class) +public class InvalidCredentialsTests { + + private static final String TENANT_ID_FOR_TEST = "test_tenant"; + + @Inject OperationsConfig operationsConfig; + + private MeterRegistry meterRegistry; + + /** + * List of sessions created in the tests. This is used to close the sessions after each test. This + * is needed because, though the sessions evicted from the cache are closed, the sessions left + * active on the cache are not closed, so we have to close them explicitly. + */ + private List sessionsCreatedInTests; + + @BeforeEach + public void tearUpEachTest() { + meterRegistry = new SimpleMeterRegistry(); + sessionsCreatedInTests = new ArrayList<>(); + } + + @AfterEach + public void tearDownEachTest() { + sessionsCreatedInTests.forEach(CqlSession::close); + } + + @Test + public void testOSSCxCQLSessionCacheWithInvalidCredentials() + throws NoSuchFieldException, IllegalAccessException { + // set request info + StargateRequestInfo stargateRequestInfo = mock(StargateRequestInfo.class); + when(stargateRequestInfo.getTenantId()).thenReturn(Optional.of(TENANT_ID_FOR_TEST)); + when(stargateRequestInfo.getCassandraToken()) + .thenReturn(operationsConfig.databaseConfig().fixedToken()); + CQLSessionCache cqlSessionCacheForTest = new CQLSessionCache(operationsConfig, meterRegistry); + Field stargateRequestInfoField = + cqlSessionCacheForTest.getClass().getDeclaredField("stargateRequestInfo"); + stargateRequestInfoField.setAccessible(true); + stargateRequestInfoField.set(cqlSessionCacheForTest, stargateRequestInfo); + // set operation config + Field operationsConfigField = + cqlSessionCacheForTest.getClass().getDeclaredField("operationsConfig"); + operationsConfigField.setAccessible(true); + operationsConfigField.set(cqlSessionCacheForTest, operationsConfig); + // Throwable + Throwable t = catchThrowable(cqlSessionCacheForTest::getSession); + assertThat(t).isInstanceOf(AllNodesFailedException.class); + CommandResult.Error error = + ThrowableToErrorMapper.getMapperWithMessageFunction().apply(t, t.getMessage()); + assertThat(error).isNotNull(); + assertThat(error.message()).contains("UNAUTHENTICATED: Invalid token"); + assertThat(error.status()).isEqualTo(Response.Status.UNAUTHORIZED); + assertThat(cqlSessionCacheForTest.cacheSize()).isEqualTo(0); + // metrics test + Gauge cacheSizeMetric = + meterRegistry.find("cache.size").tag("cache", "cql_sessions_cache").gauge(); + assertThat(cacheSizeMetric).isNotNull(); + assertThat(cacheSizeMetric.value()).isEqualTo(0); + FunctionCounter cachePutMetric = + meterRegistry.find("cache.puts").tag("cache", "cql_sessions_cache").functionCounter(); + assertThat(cachePutMetric).isNotNull(); + assertThat(cachePutMetric.count()).isEqualTo(1); + FunctionCounter cacheLoadSuccessMetric = + meterRegistry + .find("cache.load") + .tag("cache", "cql_sessions_cache") + .tag("result", "success") + .functionCounter(); + assertThat(cacheLoadSuccessMetric).isNotNull(); + assertThat(cacheLoadSuccessMetric.count()).isEqualTo(0); + FunctionCounter cacheLoadFailureMetric = + meterRegistry + .find("cache.load") + .tag("cache", "cql_sessions_cache") + .tag("result", "failure") + .functionCounter(); + assertThat(cacheLoadFailureMetric).isNotNull(); + assertThat(cacheLoadFailureMetric.count()).isEqualTo(1); + } +}