diff --git a/src/main/offline/io/stargate/sgv2/jsonapi/service/cqldriver/sstablewriter/FileWriterSession.java b/src/main/offline/io/stargate/sgv2/jsonapi/service/cqldriver/sstablewriter/FileWriterSession.java index 8710882268..dcaaa45ba5 100644 --- a/src/main/offline/io/stargate/sgv2/jsonapi/service/cqldriver/sstablewriter/FileWriterSession.java +++ b/src/main/offline/io/stargate/sgv2/jsonapi/service/cqldriver/sstablewriter/FileWriterSession.java @@ -244,10 +244,12 @@ private void resetColumnValues(List boundValues) { boundValues.size() - 1; // TODO-SL: Need to find a better way to identify the vector column index CqlVector cqlVector = (CqlVector) boundValues.get(vectorColumnIndex); - ByteBuffer encodedVectorData = - TypeCodecs.vectorOf(cqlVector.size(), TypeCodecs.FLOAT) - .encode(cqlVector, ProtocolVersion.DEFAULT); - boundValues.set(vectorColumnIndex, encodedVectorData); + if (cqlVector != null) { + ByteBuffer encodedVectorData = + TypeCodecs.vectorOf(cqlVector.size(), TypeCodecs.FLOAT) + .encode(cqlVector, ProtocolVersion.DEFAULT); + boundValues.set(vectorColumnIndex, encodedVectorData); + } } } diff --git a/src/main/offline/io/stargate/sgv2/jsonapi/service/cqldriver/sstablewriter/OfflineCommandsProcessor.java b/src/main/offline/io/stargate/sgv2/jsonapi/service/cqldriver/sstablewriter/OfflineCommandsProcessor.java index 6dd2df9f04..7787b080fe 100644 --- a/src/main/offline/io/stargate/sgv2/jsonapi/service/cqldriver/sstablewriter/OfflineCommandsProcessor.java +++ b/src/main/offline/io/stargate/sgv2/jsonapi/service/cqldriver/sstablewriter/OfflineCommandsProcessor.java @@ -124,7 +124,7 @@ public boolean canEndSession( >= createNewSessionAfterDataInBytes; } - public Triple beginSession( + public Triple beginSession( CreateCollectionCommand createCollectionCommand, String namespace, String ssTablesOutputDirectory, @@ -163,7 +163,11 @@ public Triple beginSession( return new ImmutableTriple<>( beginOfflineSessionResponse, commandContext, - beginOfflineSessionCommand.getFileWriterParams().createTableCQL()); + new SchemaInfo( + beginOfflineSessionCommand.getFileWriterParams().keyspaceName(), + beginOfflineSessionCommand.getFileWriterParams().tableName(), + beginOfflineSessionCommand.getFileWriterParams().createTableCQL(), + beginOfflineSessionCommand.getFileWriterParams().indexCQLs())); } public OfflineInsertManyResponse loadData( diff --git a/src/main/offline/io/stargate/sgv2/jsonapi/service/cqldriver/sstablewriter/SchemaInfo.java b/src/main/offline/io/stargate/sgv2/jsonapi/service/cqldriver/sstablewriter/SchemaInfo.java new file mode 100644 index 0000000000..3531ada3cd --- /dev/null +++ b/src/main/offline/io/stargate/sgv2/jsonapi/service/cqldriver/sstablewriter/SchemaInfo.java @@ -0,0 +1,26 @@ +package io.stargate.sgv2.jsonapi.service.cqldriver.sstablewriter; + +import java.util.List; + +public record SchemaInfo( + /* Name of the keypsace */ + String keyspaceName, + /* Name of the table */ + String tableName, + /* CQL to create the table */ + String createTableCQL, + /* CQL to create the indices*/ + List indexCQLs) { + public SchemaInfo { + if (keyspaceName == null || keyspaceName.isBlank()) { + throw new IllegalArgumentException("keyspaceName cannot be null or empty"); + } + if (tableName == null || tableName.isBlank()) { + throw new IllegalArgumentException("tableName cannot be null or empty"); + } + if (createTableCQL == null || createTableCQL.isBlank()) { + throw new IllegalArgumentException("createTableCQL cannot be null or empty"); + } + // when no-index option is specified by the user, indexCQLs will be null or empty + } +} diff --git a/src/test/offline/io/stargate/sgv2/jsonapi/service/cqldriver/sstablewriter/OfflineCommandsProcessorIT.java b/src/test/offline/io/stargate/sgv2/jsonapi/service/cqldriver/sstablewriter/OfflineCommandsProcessorIT.java index 974d43d572..cd41e0299d 100644 --- a/src/test/offline/io/stargate/sgv2/jsonapi/service/cqldriver/sstablewriter/OfflineCommandsProcessorIT.java +++ b/src/test/offline/io/stargate/sgv2/jsonapi/service/cqldriver/sstablewriter/OfflineCommandsProcessorIT.java @@ -1,6 +1,7 @@ package io.stargate.sgv2.jsonapi.service.cqldriver.sstablewriter; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.fail; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -20,11 +21,14 @@ import java.util.Objects; import java.util.UUID; import java.util.concurrent.ExecutionException; +import java.util.stream.Stream; import org.apache.commons.lang3.tuple.Triple; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; public class OfflineCommandsProcessorIT { @@ -61,19 +65,17 @@ private static void deleteRecursively(File sstablesTestDirectory) { } } - @Test - public void testOfflineCommandsProcessor() - throws ExecutionException, InterruptedException, IOException { - testOfflineCommandsProcessor(false); - } - - @Test - public void testOfflineCommandsProcessorVector() - throws ExecutionException, InterruptedException, IOException { - testOfflineCommandsProcessor(true); + public static Stream testScenarios() { + return Stream.of( + Arguments.of(false, false), + Arguments.of(false, true), + Arguments.of(true, false), + Arguments.of(true, true)); } - private void testOfflineCommandsProcessor(boolean isVectorSearch) + @ParameterizedTest + @MethodSource("testScenarios") + public void testOfflineCommandsProcessor(boolean isVectorTable, boolean includeVectorData) throws ExecutionException, InterruptedException, IOException { String testId = UUID.randomUUID().toString(); String namespace = "test_namespace"; @@ -85,14 +87,14 @@ private void testOfflineCommandsProcessor(boolean isVectorSearch) } OfflineCommandsProcessor offlineCommandsProcessor = OfflineCommandsProcessor.getInstance(); // begin session - Triple beginSessionResponse = + Triple beginSessionResponse = beginSession( offlineCommandsProcessor, namespace, sstablesOutputDirectory, fileWriterBufferSizeInMB, embeddingProvider, - isVectorSearch); + isVectorTable); BeginOfflineSessionResponse beginOfflineSessionResponse = beginSessionResponse.getLeft(); if (beginOfflineSessionResponse.errors() != null && !beginOfflineSessionResponse.errors().isEmpty()) { @@ -100,7 +102,8 @@ private void testOfflineCommandsProcessor(boolean isVectorSearch) "Error while beginning session : " + beginOfflineSessionResponse.errors()); } CommandContext commandContext = beginSessionResponse.getMiddle(); - String createTableCQL = beginSessionResponse.getRight(); + SchemaInfo schemaInfo = beginSessionResponse.getRight(); + String createTableCQL = schemaInfo.createTableCQL(); String expectedCreateCQL = """ CREATE TABLE IF NOT EXISTS "test_namespace"."test_collection_false"( @@ -121,20 +124,63 @@ PRIMARY KEY (key)) """; // assertThat(createTableCQL).isEqualTo(expectedCreateCQL);//TODO-SL fix assertion assertThat(createTableCQL).isNotNull(); + String tableName = "test_collection" + (isVectorTable ? "_true" : "_false"); assertThat(createTableCQL) - .startsWith( - "CREATE TABLE IF NOT EXISTS \"test_namespace\".\"test_collection" - + (isVectorSearch ? "_true" : "_false")); + .startsWith("CREATE TABLE IF NOT EXISTS \"test_namespace\".\"" + tableName); + assertThat(schemaInfo.keyspaceName()).isEqualTo("test_namespace"); + assertThat(schemaInfo.tableName()) + .isEqualTo("test_collection" + (isVectorTable ? "_true" : "_false")); + List indexCQLs = new ArrayList<>(schemaInfo.indexCQLs()); + assertThat(indexCQLs.size()).isEqualTo(isVectorTable ? 9 : 8); + indexCQLs.remove( + "CREATE CUSTOM INDEX IF NOT EXISTS %s_exists_keys ON \"test_namespace\".\"%s\" (exist_keys) USING 'StorageAttachedIndex'" + .formatted(tableName, tableName)); + indexCQLs.remove( + "CREATE CUSTOM INDEX IF NOT EXISTS %s_array_size ON \"test_namespace\".\"%s\" (entries(array_size)) USING 'StorageAttachedIndex'" + .formatted(tableName, tableName)); + indexCQLs.remove( + "CREATE CUSTOM INDEX IF NOT EXISTS %s_array_contains ON \"test_namespace\".\"%s\" (array_contains) USING 'StorageAttachedIndex'" + .formatted(tableName, tableName)); + indexCQLs.remove( + "CREATE CUSTOM INDEX IF NOT EXISTS %s_query_bool_values ON \"test_namespace\".\"%s\" (entries(query_bool_values)) USING 'StorageAttachedIndex'" + .formatted(tableName, tableName)); + indexCQLs.remove( + "CREATE CUSTOM INDEX IF NOT EXISTS %s_query_dbl_values ON \"test_namespace\".\"%s\" (entries(query_dbl_values)) USING 'StorageAttachedIndex'" + .formatted(tableName, tableName)); + indexCQLs.remove( + "CREATE CUSTOM INDEX IF NOT EXISTS %s_query_text_values ON \"test_namespace\".\"%s\" (entries(query_text_values)) USING 'StorageAttachedIndex'" + .formatted(tableName, tableName)); + indexCQLs.remove( + "CREATE CUSTOM INDEX IF NOT EXISTS %s_query_timestamp_values ON \"test_namespace\".\"%s\" (entries(query_timestamp_values)) USING 'StorageAttachedIndex'" + .formatted(tableName, tableName)); + indexCQLs.remove( + "CREATE CUSTOM INDEX IF NOT EXISTS %s_query_null_values ON \"test_namespace\".\"%s\" (query_null_values) USING 'StorageAttachedIndex'" + .formatted(tableName, tableName)); + if (isVectorTable) { + indexCQLs.remove( + "CREATE CUSTOM INDEX IF NOT EXISTS test_collection_true_query_vector_value ON \"test_namespace\".\"test_collection_true\" (query_vector_value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'similarity_function': 'COSINE'}"); + } + assertThat(indexCQLs.size()).isEqualTo(0); String sessionId = beginOfflineSessionResponse.sessionId(); // load data - List jsonNodes = getRecords(isVectorSearch); + List jsonNodes = getRecords(includeVectorData); OfflineInsertManyResponse offlineInsertManyResponse = loadTestData(offlineCommandsProcessor, commandContext, sessionId, jsonNodes); + boolean verifyVectorDataForNonVectorTable = !isVectorTable && includeVectorData; if (offlineInsertManyResponse.errors() != null && !offlineInsertManyResponse.errors().isEmpty()) { + if (verifyVectorDataForNonVectorTable) { + assertThat(offlineInsertManyResponse.errors().size()).isEqualTo(1); + assertThat(offlineInsertManyResponse.errors().get(0).message()) + .contains("Vector search is not enabled for the collection %s".formatted(tableName)); + return; + } throw new RuntimeException( "Error while inserting data : " + offlineInsertManyResponse.errors()); } + if (verifyVectorDataForNonVectorTable) { + fail("Should have failed for vector data in non vector table"); + } // get statsus OfflineGetStatusResponse offlineGetStatusResponse = getStatus(offlineCommandsProcessor, commandContext, sessionId); @@ -145,7 +191,7 @@ PRIMARY KEY (key)) assertEquals(sessionId, offlineGetStatusResponse.offlineWriterSessionStatus().sessionId()); assertEquals(namespace, offlineGetStatusResponse.offlineWriterSessionStatus().keyspace()); assertEquals( - "test_collection_" + isVectorSearch, + "test_collection_" + isVectorTable, offlineGetStatusResponse.offlineWriterSessionStatus().tableName()); assertEquals( sstablesOutputDirectory, @@ -169,7 +215,7 @@ PRIMARY KEY (key)) assertEquals(sessionId, endOfflineSessionResponse.offlineWriterSessionStatus().sessionId()); assertEquals(namespace, endOfflineSessionResponse.offlineWriterSessionStatus().keyspace()); assertEquals( - "test_collection_" + isVectorSearch, + "test_collection_" + isVectorTable, endOfflineSessionResponse.offlineWriterSessionStatus().tableName()); assertEquals( sstablesOutputDirectory, @@ -268,7 +314,7 @@ private OfflineInsertManyResponse loadTestData( return offlineCommandsProcessor.loadData(sessionId, commandContext, jsonNodes); } - private Triple beginSession( + private Triple beginSession( OfflineCommandsProcessor offlineCommandsProcessor, String namespace, String ssTablesOutputDirectory, diff --git a/src/test/offline/io/stargate/sgv2/jsonapi/service/cqldriver/sstablewriter/SchemaInfoTest.java b/src/test/offline/io/stargate/sgv2/jsonapi/service/cqldriver/sstablewriter/SchemaInfoTest.java new file mode 100644 index 0000000000..6eef267593 --- /dev/null +++ b/src/test/offline/io/stargate/sgv2/jsonapi/service/cqldriver/sstablewriter/SchemaInfoTest.java @@ -0,0 +1,39 @@ +package io.stargate.sgv2.jsonapi.service.cqldriver.sstablewriter; + +import static org.junit.Assert.assertThrows; + +import java.util.List; +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +public class SchemaInfoTest { + public static Stream provideInvalidSchemaInfo() { + return Stream.of( + Arguments.of( + null, "table1", "CREATE TABLE table1 (id UUID PRIMARY KEY, name TEXT)", List.of()), + Arguments.of( + "", "table1", "CREATE TABLE table1 (id UUID PRIMARY KEY, name TEXT)", List.of()), + Arguments.of( + " ", "table1", "CREATE TABLE table1 (id UUID PRIMARY KEY, name TEXT)", List.of()), + Arguments.of( + "keyspace1", null, "CREATE TABLE table1 (id UUID PRIMARY KEY, name TEXT)", List.of()), + Arguments.of( + "keyspace1", "", "CREATE TABLE table1 (id UUID PRIMARY KEY, name TEXT)", List.of()), + Arguments.of( + "keyspace1", " ", "CREATE TABLE table1 (id UUID PRIMARY KEY, name TEXT)", List.of()), + Arguments.of("keyspace1", "table1", null, List.of()), + Arguments.of("keyspace1", "table1", "", List.of()), + Arguments.of("keyspace1", "table1", " ", List.of())); + } + + @ParameterizedTest + @MethodSource("provideInvalidSchemaInfo") + public void testInvalidSchemaInfo( + String keyspaceName, String tableName, String createTableCQL, List indexCQLs) { + assertThrows( + IllegalArgumentException.class, + () -> new SchemaInfo(keyspaceName, tableName, createTableCQL, indexCQLs)); + } +}