diff --git a/src/main/java/io/stargate/sgv2/jsonapi/exception/ErrorCode.java b/src/main/java/io/stargate/sgv2/jsonapi/exception/ErrorCode.java index 1680672ca2..1754100969 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/exception/ErrorCode.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/exception/ErrorCode.java @@ -59,6 +59,8 @@ public enum ErrorCode { INVALID_COLLECTION_NAME("Invalid collection name "), + INVALID_JSONAPI_COLLECTION_SCHEMA("Not a valid json api collection schema: "), + TOO_MANY_COLLECTIONS("Too many collections"), UNSUPPORTED_FILTER_DATA_TYPE("Unsupported filter data type"), diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/NamespaceCache.java b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/NamespaceCache.java index 202eaf1bf9..1e94eeb00a 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/NamespaceCache.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/NamespaceCache.java @@ -7,6 +7,7 @@ import io.smallrye.mutiny.Uni; import io.stargate.sgv2.jsonapi.exception.ErrorCode; import io.stargate.sgv2.jsonapi.exception.JsonApiException; +import io.stargate.sgv2.jsonapi.service.schema.model.JsonapiTableMatcher; import java.time.Duration; /** Caches the vector enabled status for the namespace */ @@ -42,10 +43,21 @@ protected Uni getCollectionProperties(String collectionName) .transformToUni( (result, error) -> { if (null != error) { + // not a valid collection schema + if (error instanceof JsonApiException + && ((JsonApiException) error).getErrorCode() + == ErrorCode.VECTORIZECONFIG_CHECK_FAIL) { + return Uni.createFrom() + .failure( + new JsonApiException( + ErrorCode.INVALID_JSONAPI_COLLECTION_SCHEMA, + ErrorCode.INVALID_JSONAPI_COLLECTION_SCHEMA + .getMessage() + .concat(collectionName))); + } // collection does not exist if (error instanceof RuntimeException rte - && rte.getMessage() - .startsWith(ErrorCode.INVALID_COLLECTION_NAME.getMessage())) { + && rte.getMessage().startsWith(ErrorCode.COLLECTION_NOT_EXIST.getMessage())) { return Uni.createFrom() .failure( new JsonApiException( @@ -79,10 +91,16 @@ private Uni getVectorProperties(String collectionName) { .transform( table -> { if (table.isPresent()) { + // check if its a valid json api table + if (!new JsonapiTableMatcher().test(table.get())) { + throw new JsonApiException( + ErrorCode.INVALID_JSONAPI_COLLECTION_SCHEMA, + ErrorCode.INVALID_JSONAPI_COLLECTION_SCHEMA.getMessage() + collectionName); + } return CollectionSettings.getCollectionSettings(table.get(), objectMapper); } else { throw new RuntimeException( - ErrorCode.INVALID_COLLECTION_NAME.getMessage() + collectionName); + ErrorCode.COLLECTION_NOT_EXIST.getMessage() + collectionName); } }); } diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/NamespaceCacheTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/NamespaceCacheTest.java new file mode 100644 index 0000000000..11a1737fd2 --- /dev/null +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/cqldriver/executor/NamespaceCacheTest.java @@ -0,0 +1,233 @@ +package io.stargate.sgv2.jsonapi.service.cqldriver.executor; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.datastax.oss.driver.api.core.CqlIdentifier; +import com.datastax.oss.driver.api.core.metadata.schema.ColumnMetadata; +import com.datastax.oss.driver.api.core.type.DataTypes; +import com.datastax.oss.driver.internal.core.metadata.schema.DefaultColumnMetadata; +import com.datastax.oss.driver.internal.core.metadata.schema.DefaultTableMetadata; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.Lists; +import io.quarkus.test.junit.QuarkusTest; +import io.quarkus.test.junit.TestProfile; +import io.smallrye.mutiny.Uni; +import io.smallrye.mutiny.helpers.test.UniAssertSubscriber; +import io.stargate.sgv2.common.testprofiles.NoGlobalResourcesTestProfile; +import io.stargate.sgv2.jsonapi.exception.ErrorCode; +import io.stargate.sgv2.jsonapi.exception.JsonApiException; +import jakarta.inject.Inject; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +@QuarkusTest +@TestProfile(NoGlobalResourcesTestProfile.Impl.class) +public class NamespaceCacheTest { + + @Inject ObjectMapper objectMapper; + + @Nested + class Execute { + + @Test + public void checkValidJsonApiTable() { + QueryExecutor queryExecutor = mock(QueryExecutor.class); + when(queryExecutor.getSchema(any(), any())) + .then( + i -> { + List partitionColumn = + Lists.newArrayList( + new DefaultColumnMetadata( + CqlIdentifier.fromInternal("ks"), + CqlIdentifier.fromInternal("table"), + CqlIdentifier.fromInternal("key"), + DataTypes.tupleOf(DataTypes.TINYINT, DataTypes.TEXT), + false)); + Map columns = new HashMap<>(); + columns.put( + CqlIdentifier.fromInternal("tx_id"), + new DefaultColumnMetadata( + CqlIdentifier.fromInternal("ks"), + CqlIdentifier.fromInternal("table"), + CqlIdentifier.fromInternal("tx_id"), + DataTypes.TIMEUUID, + false)); + columns.put( + CqlIdentifier.fromInternal("doc_json"), + new DefaultColumnMetadata( + CqlIdentifier.fromInternal("ks"), + CqlIdentifier.fromInternal("table"), + CqlIdentifier.fromInternal("doc_json"), + DataTypes.TEXT, + false)); + columns.put( + CqlIdentifier.fromInternal("exist_keys"), + new DefaultColumnMetadata( + CqlIdentifier.fromInternal("ks"), + CqlIdentifier.fromInternal("table"), + CqlIdentifier.fromInternal("exist_keys"), + DataTypes.setOf(DataTypes.TEXT), + false)); + columns.put( + CqlIdentifier.fromInternal("array_size"), + new DefaultColumnMetadata( + CqlIdentifier.fromInternal("ks"), + CqlIdentifier.fromInternal("table"), + CqlIdentifier.fromInternal("array_size"), + DataTypes.mapOf(DataTypes.TEXT, DataTypes.INT), + false)); + columns.put( + CqlIdentifier.fromInternal("array_contains"), + new DefaultColumnMetadata( + CqlIdentifier.fromInternal("ks"), + CqlIdentifier.fromInternal("table"), + CqlIdentifier.fromInternal("array_contains"), + DataTypes.setOf(DataTypes.TEXT), + false)); + columns.put( + CqlIdentifier.fromInternal("query_bool_values"), + new DefaultColumnMetadata( + CqlIdentifier.fromInternal("ks"), + CqlIdentifier.fromInternal("table"), + CqlIdentifier.fromInternal("query_bool_values"), + DataTypes.mapOf(DataTypes.TEXT, DataTypes.TINYINT), + false)); + columns.put( + CqlIdentifier.fromInternal("query_dbl_values"), + new DefaultColumnMetadata( + CqlIdentifier.fromInternal("ks"), + CqlIdentifier.fromInternal("table"), + CqlIdentifier.fromInternal("query_dbl_values"), + DataTypes.mapOf(DataTypes.TEXT, DataTypes.DECIMAL), + false)); + columns.put( + CqlIdentifier.fromInternal("query_text_values"), + new DefaultColumnMetadata( + CqlIdentifier.fromInternal("ks"), + CqlIdentifier.fromInternal("table"), + CqlIdentifier.fromInternal("query_text_values"), + DataTypes.mapOf(DataTypes.TEXT, DataTypes.TEXT), + false)); + columns.put( + CqlIdentifier.fromInternal("query_timestamp_values"), + new DefaultColumnMetadata( + CqlIdentifier.fromInternal("ks"), + CqlIdentifier.fromInternal("table"), + CqlIdentifier.fromInternal("query_timestamp_values"), + DataTypes.mapOf(DataTypes.TEXT, DataTypes.TIMESTAMP), + false)); + columns.put( + CqlIdentifier.fromInternal("query_null_values"), + new DefaultColumnMetadata( + CqlIdentifier.fromInternal("ks"), + CqlIdentifier.fromInternal("table"), + CqlIdentifier.fromInternal("query_null_values"), + DataTypes.setOf(DataTypes.TEXT), + false)); + + return Uni.createFrom() + .item( + Optional.of( + new DefaultTableMetadata( + CqlIdentifier.fromInternal("ks"), + CqlIdentifier.fromInternal("table"), + UUID.randomUUID(), + false, + false, + partitionColumn, + new HashMap<>(), + columns, + new HashMap<>(), + new HashMap<>()))); + }); + NamespaceCache namespaceCache = new NamespaceCache("ks", queryExecutor, objectMapper); + CollectionSettings collectionSettings = + namespaceCache + .getCollectionProperties("table") + .subscribe() + .withSubscriber(UniAssertSubscriber.create()) + .awaitItem() + .getItem(); + + assertThat(collectionSettings) + .satisfies( + s -> { + assertThat(s.vectorEnabled()).isFalse(); + assertThat(s.collectionName()).isEqualTo("table"); + }); + } + + @Test + public void checkInvalidJsonApiTable() { + QueryExecutor queryExecutor = mock(QueryExecutor.class); + when(queryExecutor.getSchema(any(), any())) + .then( + i -> { + List partitionColumn = + Lists.newArrayList( + new DefaultColumnMetadata( + CqlIdentifier.fromInternal("ks"), + CqlIdentifier.fromInternal("table"), + CqlIdentifier.fromInternal("key"), + DataTypes.tupleOf(DataTypes.TINYINT, DataTypes.TEXT), + false)); + Map columns = new HashMap<>(); + columns.put( + CqlIdentifier.fromInternal("tx_id"), + new DefaultColumnMetadata( + CqlIdentifier.fromInternal("ks"), + CqlIdentifier.fromInternal("table"), + CqlIdentifier.fromInternal("tx_id"), + DataTypes.TIMEUUID, + false)); + columns.put( + CqlIdentifier.fromInternal("doc"), + new DefaultColumnMetadata( + CqlIdentifier.fromInternal("ks"), + CqlIdentifier.fromInternal("table"), + CqlIdentifier.fromInternal("doc"), + DataTypes.TEXT, + false)); + return Uni.createFrom() + .item( + Optional.of( + new DefaultTableMetadata( + CqlIdentifier.fromInternal("ks"), + CqlIdentifier.fromInternal("table"), + UUID.randomUUID(), + false, + false, + partitionColumn, + new HashMap<>(), + columns, + new HashMap<>(), + new HashMap<>()))); + }); + NamespaceCache namespaceCache = new NamespaceCache("ks", queryExecutor, objectMapper); + Throwable error = + namespaceCache + .getCollectionProperties("table") + .subscribe() + .withSubscriber(UniAssertSubscriber.create()) + .awaitFailure() + .getFailure(); + + assertThat(error) + .isInstanceOfSatisfying( + JsonApiException.class, + s -> { + assertThat(s.getErrorCode()).isEqualTo(ErrorCode.INVALID_JSONAPI_COLLECTION_SCHEMA); + assertThat(s.getMessage()) + .isEqualTo(ErrorCode.INVALID_JSONAPI_COLLECTION_SCHEMA.getMessage() + "table"); + }); + } + } +}