Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Offline Mode - Index cql & Vector fixes #1115

Merged
merged 4 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,12 @@ private void resetColumnValues(List<Object> boundValues) {
boundValues.size()
- 1; // TODO-SL: Need to find a better way to identify the vector column index
CqlVector<Float> cqlVector = (CqlVector<Float>) boundValues.get(vectorColumnIndex);
ByteBuffer encodedVectorData =
TypeCodecs.vectorOf(cqlVector.size(), TypeCodecs.FLOAT)
.encode(cqlVector, ProtocolVersion.DEFAULT);
boundValues.set(vectorColumnIndex, encodedVectorData);
if (cqlVector != null) {
ByteBuffer encodedVectorData =
TypeCodecs.vectorOf(cqlVector.size(), TypeCodecs.FLOAT)
.encode(cqlVector, ProtocolVersion.DEFAULT);
boundValues.set(vectorColumnIndex, encodedVectorData);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ public boolean canEndSession(
>= createNewSessionAfterDataInBytes;
}

public Triple<BeginOfflineSessionResponse, CommandContext, String> beginSession(
public Triple<BeginOfflineSessionResponse, CommandContext, SchemaInfo> beginSession(
CreateCollectionCommand createCollectionCommand,
String namespace,
String ssTablesOutputDirectory,
Expand Down Expand Up @@ -163,7 +163,11 @@ public Triple<BeginOfflineSessionResponse, CommandContext, String> beginSession(
return new ImmutableTriple<>(
beginOfflineSessionResponse,
commandContext,
beginOfflineSessionCommand.getFileWriterParams().createTableCQL());
new SchemaInfo(
kathirsvn marked this conversation as resolved.
Show resolved Hide resolved
beginOfflineSessionCommand.getFileWriterParams().keyspaceName(),
beginOfflineSessionCommand.getFileWriterParams().tableName(),
beginOfflineSessionCommand.getFileWriterParams().createTableCQL(),
beginOfflineSessionCommand.getFileWriterParams().indexCQLs()));
}

public OfflineInsertManyResponse loadData(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.stargate.sgv2.jsonapi.service.cqldriver.sstablewriter;

import java.util.List;

public record SchemaInfo(
/* Name of the keypsace */
String keyspaceName,
/* Name of the table */
String tableName,
/* CQL to create the table */
String createTableCQL,
/* CQL to create the indices*/
List<String> indexCQLs) {
public SchemaInfo {
if (keyspaceName == null || keyspaceName.isBlank()) {
throw new IllegalArgumentException("keyspaceName cannot be null or empty");
}
if (tableName == null || tableName.isBlank()) {
throw new IllegalArgumentException("tableName cannot be null or empty");
}
if (createTableCQL == null || createTableCQL.isBlank()) {
throw new IllegalArgumentException("createTableCQL cannot be null or empty");
}
// when no-index option is specified by the user, indexCQLs will be null or empty
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.stargate.sgv2.jsonapi.service.cqldriver.sstablewriter;

import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.assertj.core.api.AssertionsForClassTypes.fail;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

Expand All @@ -20,11 +21,14 @@
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.Triple;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

public class OfflineCommandsProcessorIT {

Expand Down Expand Up @@ -61,19 +65,17 @@ private static void deleteRecursively(File sstablesTestDirectory) {
}
}

@Test
public void testOfflineCommandsProcessor()
throws ExecutionException, InterruptedException, IOException {
testOfflineCommandsProcessor(false);
}

@Test
public void testOfflineCommandsProcessorVector()
throws ExecutionException, InterruptedException, IOException {
testOfflineCommandsProcessor(true);
public static Stream<Arguments> testScenarios() {
return Stream.of(
Arguments.of(false, false),
Arguments.of(false, true),
Arguments.of(true, false),
Arguments.of(true, true));
}

private void testOfflineCommandsProcessor(boolean isVectorSearch)
@ParameterizedTest
@MethodSource("testScenarios")
public void testOfflineCommandsProcessor(boolean isVectorTable, boolean includeVectorData)
throws ExecutionException, InterruptedException, IOException {
String testId = UUID.randomUUID().toString();
String namespace = "test_namespace";
Expand All @@ -85,22 +87,23 @@ private void testOfflineCommandsProcessor(boolean isVectorSearch)
}
OfflineCommandsProcessor offlineCommandsProcessor = OfflineCommandsProcessor.getInstance();
// begin session
Triple<BeginOfflineSessionResponse, CommandContext, String> beginSessionResponse =
Triple<BeginOfflineSessionResponse, CommandContext, SchemaInfo> beginSessionResponse =
beginSession(
offlineCommandsProcessor,
namespace,
sstablesOutputDirectory,
fileWriterBufferSizeInMB,
embeddingProvider,
isVectorSearch);
isVectorTable);
BeginOfflineSessionResponse beginOfflineSessionResponse = beginSessionResponse.getLeft();
if (beginOfflineSessionResponse.errors() != null
&& !beginOfflineSessionResponse.errors().isEmpty()) {
throw new RuntimeException(
"Error while beginning session : " + beginOfflineSessionResponse.errors());
}
CommandContext commandContext = beginSessionResponse.getMiddle();
String createTableCQL = beginSessionResponse.getRight();
SchemaInfo schemaInfo = beginSessionResponse.getRight();
String createTableCQL = schemaInfo.createTableCQL();
String expectedCreateCQL =
"""
CREATE TABLE IF NOT EXISTS "test_namespace"."test_collection_false"(
Expand All @@ -121,20 +124,63 @@ PRIMARY KEY (key))
""";
// assertThat(createTableCQL).isEqualTo(expectedCreateCQL);//TODO-SL fix assertion
assertThat(createTableCQL).isNotNull();
String tableName = "test_collection" + (isVectorTable ? "_true" : "_false");
assertThat(createTableCQL)
.startsWith(
"CREATE TABLE IF NOT EXISTS \"test_namespace\".\"test_collection"
+ (isVectorSearch ? "_true" : "_false"));
.startsWith("CREATE TABLE IF NOT EXISTS \"test_namespace\".\"" + tableName);
assertThat(schemaInfo.keyspaceName()).isEqualTo("test_namespace");
assertThat(schemaInfo.tableName())
.isEqualTo("test_collection" + (isVectorTable ? "_true" : "_false"));
List<String> indexCQLs = new ArrayList<>(schemaInfo.indexCQLs());
assertThat(indexCQLs.size()).isEqualTo(isVectorTable ? 9 : 8);
indexCQLs.remove(
"CREATE CUSTOM INDEX IF NOT EXISTS %s_exists_keys ON \"test_namespace\".\"%s\" (exist_keys) USING 'StorageAttachedIndex'"
.formatted(tableName, tableName));
indexCQLs.remove(
"CREATE CUSTOM INDEX IF NOT EXISTS %s_array_size ON \"test_namespace\".\"%s\" (entries(array_size)) USING 'StorageAttachedIndex'"
.formatted(tableName, tableName));
indexCQLs.remove(
"CREATE CUSTOM INDEX IF NOT EXISTS %s_array_contains ON \"test_namespace\".\"%s\" (array_contains) USING 'StorageAttachedIndex'"
.formatted(tableName, tableName));
indexCQLs.remove(
"CREATE CUSTOM INDEX IF NOT EXISTS %s_query_bool_values ON \"test_namespace\".\"%s\" (entries(query_bool_values)) USING 'StorageAttachedIndex'"
.formatted(tableName, tableName));
indexCQLs.remove(
"CREATE CUSTOM INDEX IF NOT EXISTS %s_query_dbl_values ON \"test_namespace\".\"%s\" (entries(query_dbl_values)) USING 'StorageAttachedIndex'"
.formatted(tableName, tableName));
indexCQLs.remove(
"CREATE CUSTOM INDEX IF NOT EXISTS %s_query_text_values ON \"test_namespace\".\"%s\" (entries(query_text_values)) USING 'StorageAttachedIndex'"
.formatted(tableName, tableName));
indexCQLs.remove(
"CREATE CUSTOM INDEX IF NOT EXISTS %s_query_timestamp_values ON \"test_namespace\".\"%s\" (entries(query_timestamp_values)) USING 'StorageAttachedIndex'"
.formatted(tableName, tableName));
indexCQLs.remove(
"CREATE CUSTOM INDEX IF NOT EXISTS %s_query_null_values ON \"test_namespace\".\"%s\" (query_null_values) USING 'StorageAttachedIndex'"
.formatted(tableName, tableName));
if (isVectorTable) {
indexCQLs.remove(
"CREATE CUSTOM INDEX IF NOT EXISTS test_collection_true_query_vector_value ON \"test_namespace\".\"test_collection_true\" (query_vector_value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'similarity_function': 'COSINE'}");
}
assertThat(indexCQLs.size()).isEqualTo(0);
String sessionId = beginOfflineSessionResponse.sessionId();
// load data
List<JsonNode> jsonNodes = getRecords(isVectorSearch);
List<JsonNode> jsonNodes = getRecords(includeVectorData);
OfflineInsertManyResponse offlineInsertManyResponse =
loadTestData(offlineCommandsProcessor, commandContext, sessionId, jsonNodes);
boolean verifyVectorDataForNonVectorTable = !isVectorTable && includeVectorData;
if (offlineInsertManyResponse.errors() != null
&& !offlineInsertManyResponse.errors().isEmpty()) {
if (verifyVectorDataForNonVectorTable) {
assertThat(offlineInsertManyResponse.errors().size()).isEqualTo(1);
assertThat(offlineInsertManyResponse.errors().get(0).message())
.contains("Vector search is not enabled for the collection %s".formatted(tableName));
return;
}
throw new RuntimeException(
"Error while inserting data : " + offlineInsertManyResponse.errors());
}
if (verifyVectorDataForNonVectorTable) {
fail("Should have failed for vector data in non vector table");
}
// get statsus
OfflineGetStatusResponse offlineGetStatusResponse =
getStatus(offlineCommandsProcessor, commandContext, sessionId);
Expand All @@ -145,7 +191,7 @@ PRIMARY KEY (key))
assertEquals(sessionId, offlineGetStatusResponse.offlineWriterSessionStatus().sessionId());
assertEquals(namespace, offlineGetStatusResponse.offlineWriterSessionStatus().keyspace());
assertEquals(
"test_collection_" + isVectorSearch,
"test_collection_" + isVectorTable,
offlineGetStatusResponse.offlineWriterSessionStatus().tableName());
assertEquals(
sstablesOutputDirectory,
Expand All @@ -169,7 +215,7 @@ PRIMARY KEY (key))
assertEquals(sessionId, endOfflineSessionResponse.offlineWriterSessionStatus().sessionId());
assertEquals(namespace, endOfflineSessionResponse.offlineWriterSessionStatus().keyspace());
assertEquals(
"test_collection_" + isVectorSearch,
"test_collection_" + isVectorTable,
endOfflineSessionResponse.offlineWriterSessionStatus().tableName());
assertEquals(
sstablesOutputDirectory,
Expand Down Expand Up @@ -268,7 +314,7 @@ private OfflineInsertManyResponse loadTestData(
return offlineCommandsProcessor.loadData(sessionId, commandContext, jsonNodes);
}

private Triple<BeginOfflineSessionResponse, CommandContext, String> beginSession(
private Triple<BeginOfflineSessionResponse, CommandContext, SchemaInfo> beginSession(
OfflineCommandsProcessor offlineCommandsProcessor,
String namespace,
String ssTablesOutputDirectory,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package io.stargate.sgv2.jsonapi.service.cqldriver.sstablewriter;

import static org.junit.Assert.assertThrows;

import java.util.List;
import java.util.stream.Stream;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

public class SchemaInfoTest {
public static Stream<Arguments> provideInvalidSchemaInfo() {
return Stream.of(
Arguments.of(
null, "table1", "CREATE TABLE table1 (id UUID PRIMARY KEY, name TEXT)", List.of()),
Arguments.of(
"", "table1", "CREATE TABLE table1 (id UUID PRIMARY KEY, name TEXT)", List.of()),
Arguments.of(
" ", "table1", "CREATE TABLE table1 (id UUID PRIMARY KEY, name TEXT)", List.of()),
Arguments.of(
"keyspace1", null, "CREATE TABLE table1 (id UUID PRIMARY KEY, name TEXT)", List.of()),
Arguments.of(
"keyspace1", "", "CREATE TABLE table1 (id UUID PRIMARY KEY, name TEXT)", List.of()),
Arguments.of(
"keyspace1", " ", "CREATE TABLE table1 (id UUID PRIMARY KEY, name TEXT)", List.of()),
Arguments.of("keyspace1", "table1", null, List.of()),
Arguments.of("keyspace1", "table1", "", List.of()),
Arguments.of("keyspace1", "table1", " ", List.of()));
}

@ParameterizedTest
@MethodSource("provideInvalidSchemaInfo")
public void testInvalidSchemaInfo(
String keyspaceName, String tableName, String createTableCQL, List<String> indexCQLs) {
assertThrows(
IllegalArgumentException.class,
() -> new SchemaInfo(keyspaceName, tableName, createTableCQL, indexCQLs));
}
}