Skip to content

Commit

Permalink
SchemaObject changes to support multiple vector configs (#1499)
Browse files Browse the repository at this point in the history
  • Loading branch information
maheshrajamani authored Oct 7, 2024
1 parent 1d4f2bd commit 647012c
Show file tree
Hide file tree
Showing 24 changed files with 375 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -197,16 +197,24 @@ public Uni<RestResponse<CommandResult>> postCommand(
}
// TODO: refactor this code to be cleaner so it assigns on one line
EmbeddingProvider embeddingProvider = null;
final VectorConfig.VectorizeConfig vectorizeConfig =
schemaObject.vectorConfig().vectorizeConfig();
VectorConfig vectorConfig = schemaObject.vectorConfig();
final VectorConfig.ColumnVectorDefinition columnVectorDefinition =
vectorConfig.columnVectorDefinitions() == null
|| vectorConfig.columnVectorDefinitions().isEmpty()
? null
: vectorConfig.columnVectorDefinitions().get(0);
final VectorConfig.ColumnVectorDefinition.VectorizeConfig vectorizeConfig =
columnVectorDefinition != null
? columnVectorDefinition.vectorizeConfig()
: null;
if (vectorizeConfig != null) {
embeddingProvider =
embeddingProviderFactory.getConfiguration(
dataApiRequestInfo.getTenantId(),
dataApiRequestInfo.getCassandraToken(),
vectorizeConfig.provider(),
vectorizeConfig.modelName(),
schemaObject.vectorConfig().vectorSize(),
columnVectorDefinition.vectorSize(),
vectorizeConfig.parameters(),
vectorizeConfig.authentication(),
command.getClass().getSimpleName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ public enum Code implements ErrorCode<SchemaException> {
COLUMN_DEFINITION_MISSING,
COLUMN_TYPE_INCORRECT,
COLUMN_TYPE_UNSUPPORTED,
INVALID_CONFIGURATION,
INVALID_VECTORIZE_CONFIGURATION,
LIST_TYPE_INCORRECT_DEFINITION,
MAP_TYPE_INCORRECT_DEFINITION,
MISSING_PRIMARY_KEYS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ private Uni<SchemaObject> loadSchemaObject(
optionalTable.get(), objectMapper);
}

// 04-Sep-2024, tatu: Used to check that API Tables enabled; no longer checked here
return new TableSchemaObject(table);
return TableSchemaObject.from(table, objectMapper);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ public SchemaObjectName name() {
}

/**
* Subclasses must always return an instance of VectorConfig, if there is no vector config they
* should return VectorConfig.notEnabledVectorConfig()
* Subclasses must always return VectorConfig, if there is no vector config they should return
* VectorConfig.notEnabledVectorConfig().
*
* @return
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,118 @@
package io.stargate.sgv2.jsonapi.service.cqldriver.executor;

import com.datastax.oss.driver.api.core.CqlIdentifier;
import com.datastax.oss.driver.api.core.data.ByteUtils;
import com.datastax.oss.driver.api.core.metadata.schema.ColumnMetadata;
import com.datastax.oss.driver.api.core.metadata.schema.IndexMetadata;
import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata;
import com.datastax.oss.driver.api.core.type.VectorType;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.stargate.sgv2.jsonapi.exception.SchemaException;
import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class TableSchemaObject extends TableBasedSchemaObject {

public static final SchemaObjectType TYPE = SchemaObjectType.TABLE;

public TableSchemaObject(TableMetadata tableMetadata) {
private final VectorConfig vectorConfig;

private TableSchemaObject(TableMetadata tableMetadata, VectorConfig vectorConfig) {
super(TYPE, tableMetadata);
this.vectorConfig = vectorConfig;
}

@Override
public VectorConfig vectorConfig() {
return VectorConfig.notEnabledVectorConfig();
return vectorConfig;
}

@Override
public IndexUsage newIndexUsage() {
return IndexUsage.NO_OP;
}

/**
* Get table schema object from table metadata
*
* @param tableMetadata
* @param objectMapper
* @return
*/
public static TableSchemaObject from(TableMetadata tableMetadata, ObjectMapper objectMapper) {
Map<String, ByteBuffer> extensions =
(Map<String, ByteBuffer>)
tableMetadata.getOptions().get(CqlIdentifier.fromInternal("extensions"));
String vectorizeJson = null;
if (extensions != null) {
ByteBuffer vectorizeBuffer =
(ByteBuffer) extensions.get("com.datastax.data-api.vectorize-config");
vectorizeJson =
vectorizeBuffer != null
? new String(ByteUtils.getArray(vectorizeBuffer.duplicate()), StandardCharsets.UTF_8)
: null;
}
Map<String, VectorConfig.ColumnVectorDefinition.VectorizeConfig> vectorizeConfigMap =
new HashMap<>();
if (vectorizeJson != null) {
try {
JsonNode vectorizeByColumns = objectMapper.readTree(vectorizeJson);
Iterator<Map.Entry<String, JsonNode>> it = vectorizeByColumns.fields();
while (it.hasNext()) {
Map.Entry<String, JsonNode> entry = it.next();
try {
VectorConfig.ColumnVectorDefinition.VectorizeConfig vectorizeConfig =
objectMapper.treeToValue(
entry.getValue(), VectorConfig.ColumnVectorDefinition.VectorizeConfig.class);
vectorizeConfigMap.put(entry.getKey(), vectorizeConfig);
} catch (JsonProcessingException | IllegalArgumentException e) {
throw SchemaException.Code.INVALID_VECTORIZE_CONFIGURATION.get(
Map.of("field", entry.getKey()));
}
}
} catch (JsonProcessingException e) {
throw SchemaException.Code.INVALID_CONFIGURATION.get();
}
}
VectorConfig vectorConfig;
List<VectorConfig.ColumnVectorDefinition> columnVectorDefinitions = new ArrayList<>();
for (Map.Entry<CqlIdentifier, ColumnMetadata> column : tableMetadata.getColumns().entrySet()) {
if (column.getValue().getType() instanceof VectorType vectorType) {
final Optional<IndexMetadata> index = tableMetadata.getIndex(column.getKey());
SimilarityFunction similarityFunction = SimilarityFunction.COSINE;
if (index.isPresent()) {
final IndexMetadata indexMetadata = index.get();
final Map<String, String> indexOptions = indexMetadata.getOptions();
final String similarityFunctionValue = indexOptions.get("similarity_function");
if (similarityFunctionValue != null) {
similarityFunction = SimilarityFunction.fromString(similarityFunctionValue);
}
}
int dimension = vectorType.getDimensions();
VectorConfig.ColumnVectorDefinition columnVectorDefinition =
new VectorConfig.ColumnVectorDefinition(
column.getKey().asInternal(),
dimension,
similarityFunction,
vectorizeConfigMap.get(column.getKey().asInternal()));
columnVectorDefinitions.add(columnVectorDefinition);
}
}
if (columnVectorDefinitions.isEmpty()) {
vectorConfig = VectorConfig.notEnabledVectorConfig();
} else {
vectorConfig = new VectorConfig(true, Collections.unmodifiableList(columnVectorDefinitions));
}
return new TableSchemaObject(tableMetadata, vectorConfig);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,66 +2,106 @@

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants;
import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction;
import java.util.List;
import java.util.Map;

/**
* incorporates vectorizeConfig into vectorConfig
*
* @param vectorEnabled
* @param vectorSize
* @param similarityFunction
* @param vectorizeConfig
* @param vectorEnabled - If vector field is available for the table
* @param columnVectorDefinitions - List of columnVectorDefinitions each with respect to a
* column/field
*/
public record VectorConfig(
boolean vectorEnabled,
int vectorSize,
SimilarityFunction similarityFunction,
VectorizeConfig vectorizeConfig) {
boolean vectorEnabled, List<ColumnVectorDefinition> columnVectorDefinitions) {

// TODO: this is an immutable record, this can be singleton
// TODO: Remove the use of NULL for the objects like vectorizeConfig
public static VectorConfig notEnabledVectorConfig() {
return new VectorConfig(false, -1, null, null);
return new VectorConfig(false, null);
}

// convert a vector jsonNode from table comment to vectorConfig
public static VectorConfig fromJson(JsonNode jsonNode, ObjectMapper objectMapper) {
// dimension, similarityFunction, must exist
int dimension = jsonNode.get("dimension").asInt();
SimilarityFunction similarityFunction =
SimilarityFunction.fromString(jsonNode.get("metric").asText());
/**
* Configuration for a column, In case of collection this will be of size one
*
* @param fieldName
* @param vectorSize
* @param similarityFunction
* @param vectorizeConfig
*/
public record ColumnVectorDefinition(
String fieldName,
int vectorSize,
SimilarityFunction similarityFunction,
VectorizeConfig vectorizeConfig) {

VectorizeConfig vectorizeConfig = null;
// construct vectorizeConfig
JsonNode vectorizeServiceNode = jsonNode.get("service");
if (vectorizeServiceNode != null) {
// provider, modelName, must exist
String provider = vectorizeServiceNode.get("provider").asText();
String modelName = vectorizeServiceNode.get("modelName").asText();
// construct VectorizeConfig.authentication, can be null
JsonNode vectorizeServiceAuthenticationNode = vectorizeServiceNode.get("authentication");
Map<String, String> vectorizeServiceAuthentication =
vectorizeServiceAuthenticationNode == null
? null
: objectMapper.convertValue(vectorizeServiceAuthenticationNode, Map.class);
// construct VectorizeConfig.parameters, can be null
JsonNode vectorizeServiceParameterNode = vectorizeServiceNode.get("parameters");
Map<String, Object> vectorizeServiceParameter =
vectorizeServiceParameterNode == null
? null
: objectMapper.convertValue(vectorizeServiceParameterNode, Map.class);
vectorizeConfig =
new VectorizeConfig(
provider, modelName, vectorizeServiceAuthentication, vectorizeServiceParameter);
// convert a vector jsonNode from comment option to vectorConfig, used for collection
public static ColumnVectorDefinition fromJson(JsonNode jsonNode, ObjectMapper objectMapper) {
// dimension, similarityFunction, must exist
int dimension = jsonNode.get("dimension").asInt();
SimilarityFunction similarityFunction =
SimilarityFunction.fromString(jsonNode.get("metric").asText());

return fromJson(
DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD,
dimension,
similarityFunction,
jsonNode,
objectMapper);
}

return new VectorConfig(true, dimension, similarityFunction, vectorizeConfig);
}
// convert a vector jsonNode from table extension to vectorConfig, used for tables
public static ColumnVectorDefinition fromJson(
String fieldName,
int dimension,
SimilarityFunction similarityFunction,
JsonNode jsonNode,
ObjectMapper objectMapper) {
VectorizeConfig vectorizeConfig = null;
// construct vectorizeConfig
JsonNode vectorizeServiceNode = jsonNode.get("service");
if (vectorizeServiceNode != null) {
vectorizeConfig = VectorizeConfig.fromJson(vectorizeServiceNode, objectMapper);
}
return new ColumnVectorDefinition(fieldName, dimension, similarityFunction, vectorizeConfig);
}

public record VectorizeConfig(
String provider,
String modelName,
Map<String, String> authentication,
Map<String, Object> parameters) {}
/**
* Represent the vectorize configuration defined for a column
*
* @param provider
* @param modelName
* @param authentication
* @param parameters
*/
public record VectorizeConfig(
String provider,
String modelName,
Map<String, String> authentication,
Map<String, Object> parameters) {

protected static VectorizeConfig fromJson(
JsonNode vectorizeServiceNode, ObjectMapper objectMapper) {
// provider, modelName, must exist
String provider = vectorizeServiceNode.get("provider").asText();
String modelName = vectorizeServiceNode.get("modelName").asText();
// construct VectorizeConfig.authentication, can be null
JsonNode vectorizeServiceAuthenticationNode = vectorizeServiceNode.get("authentication");
Map<String, String> vectorizeServiceAuthentication =
vectorizeServiceAuthenticationNode == null
? null
: objectMapper.convertValue(vectorizeServiceAuthenticationNode, Map.class);
// construct VectorizeConfig.parameters, can be null
JsonNode vectorizeServiceParameterNode = vectorizeServiceNode.get("parameters");
Map<String, Object> vectorizeServiceParameter =
vectorizeServiceParameterNode == null
? null
: objectMapper.convertValue(vectorizeServiceParameterNode, Map.class);
return new VectorizeConfig(
provider, modelName, vectorizeServiceAuthentication, vectorizeServiceParameter);
}
}
}
}
Loading

0 comments on commit 647012c

Please sign in to comment.