Skip to content

Commit

Permalink
Add vectorize config to tables (#1489)
Browse files Browse the repository at this point in the history
Co-authored-by: Tatu Saloranta <tatu.saloranta@datastax.com>
  • Loading branch information
maheshrajamani and tatu-at-datastax authored Oct 4, 2024
1 parent 9eaa25a commit 3f70a10
Show file tree
Hide file tree
Showing 13 changed files with 1,166 additions and 920 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;
import io.stargate.sgv2.jsonapi.api.model.command.impl.VectorizeConfig;
import io.stargate.sgv2.jsonapi.api.model.command.table.definition.datatype.ColumnType;
import io.stargate.sgv2.jsonapi.exception.SchemaException;
import java.io.IOException;
Expand All @@ -25,13 +26,14 @@ public ColumnType deserialize(
throws IOException, JacksonException {
JsonNode definition = deserializationContext.readTree(jsonParser);
if (definition.isTextual()) {
return ColumnType.fromString(definition.asText(), null, null, -1);
return ColumnType.fromString(definition.asText(), null, null, -1, null);
}
if (definition.isObject() && definition.has("type")) {
String type = definition.path("type").asText();
String keyType = null;
String valueType = null;
int dimension = -1;
VectorizeConfig vectorConfig = null;
if (definition.has("keyType")) {
keyType = definition.path("keyType").asText();
}
Expand All @@ -41,7 +43,15 @@ public ColumnType deserialize(
if (definition.has("dimension")) {
dimension = definition.path("dimension").asInt();
}
return ColumnType.fromString(type, keyType, valueType, dimension);
if (definition.has("service")) {
JsonNode service = definition.path("service");
try {
vectorConfig = deserializationContext.readTreeAsValue(service, VectorizeConfig.class);
} catch (JacksonException je) {
throw SchemaException.Code.VECTOR_TYPE_INCORRECT_DEFINITION.get();
}
}
return ColumnType.fromString(type, keyType, valueType, dimension, vectorConfig);
}
throw SchemaException.Code.COLUMN_TYPE_INCORRECT.get();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import io.stargate.sgv2.jsonapi.api.model.command.CollectionOnlyCommand;
import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants;
import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1;
import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants;
import jakarta.validation.Valid;
import jakarta.validation.constraints.*;
import java.util.*;
Expand Down Expand Up @@ -111,84 +110,6 @@ public VectorSearchConfig(Integer dimension, String metric, VectorizeConfig vect
this.metric = metric == null ? "cosine" : metric;
this.vectorizeConfig = vectorizeConfig;
}

public record VectorizeConfig(
@NotNull
@Schema(
description = "Registered Embedding service provider",
type = SchemaType.STRING,
implementation = String.class)
@JsonProperty("provider")
String provider,
@Schema(
description = "Registered Embedding service model",
type = SchemaType.STRING,
implementation = String.class)
@JsonProperty("modelName")
String modelName,
@Valid
@Nullable
@Schema(
description = "Authentication config for chosen embedding service",
type = SchemaType.OBJECT)
@JsonProperty("authentication")
@JsonInclude(JsonInclude.Include.NON_NULL)
Map<String, String> authentication,
@Nullable
@Schema(
description =
"Optional parameters that match the messageTemplate provided for the provider",
type = SchemaType.OBJECT)
@JsonProperty("parameters")
@JsonInclude(JsonInclude.Include.NON_NULL)
Map<String, Object> parameters) {

public VectorizeConfig(
String provider,
String modelName,
Map<String, String> authentication,
Map<String, Object> parameters) {
this.provider = provider;
// HuggingfaceDedicated does not need user to specify model explicitly
// If user specifies modelName other than endpoint-defined-model, will error out
// By default, huggingfaceDedicated provider use endpoint-defined-model as placeholder
if (provider.equals(ProviderConstants.HUGGINGFACE_DEDICATED)) {
if (modelName != null
&& !modelName.equals(ProviderConstants.HUGGINGFACE_DEDICATED_DEFINED_MODEL)) {
throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException(
"'modelName' is not needed for provider %s explicitly, only '%s' is accepted",
ProviderConstants.HUGGINGFACE_DEDICATED,
ProviderConstants.HUGGINGFACE_DEDICATED_DEFINED_MODEL);
}
this.modelName = ProviderConstants.HUGGINGFACE_DEDICATED_DEFINED_MODEL;
} else {
this.modelName = modelName;
}
if (authentication != null && !authentication.isEmpty()) {
Map<String, String> updatedAuth = new HashMap<>();
for (Map.Entry<String, String> userAuth : authentication.entrySet()) {
// Determine the full credential name based on the sharedKeyValue pair
// If the sharedKeyValue does not contain a dot (e.g. myKey) or the part after the dot
// does not match the key (e.g. myKey.test), append the key to the sharedKeyValue with
// a dot (e.g. myKey.providerKey or myKey.test.providerKey). Otherwise, use the
// sharedKeyValue (e.g. myKey.providerKey) as is.
String sharedKeyValue = userAuth.getValue();
String credentialName =
sharedKeyValue.lastIndexOf('.') <= 0
|| !sharedKeyValue
.substring(sharedKeyValue.lastIndexOf('.') + 1)
.equals(userAuth.getKey())
? sharedKeyValue + "." + userAuth.getKey()
: sharedKeyValue;
updatedAuth.put(userAuth.getKey(), credentialName);
}
this.authentication = updatedAuth;
} else {
this.authentication = authentication;
}
this.parameters = parameters;
}
}
}

public record IndexingConfig(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package io.stargate.sgv2.jsonapi.api.model.command.impl;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1;
import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants;
import jakarta.validation.Valid;
import jakarta.validation.constraints.*;
import java.util.*;
import javax.annotation.Nullable;
import org.eclipse.microprofile.openapi.annotations.enums.SchemaType;
import org.eclipse.microprofile.openapi.annotations.media.Schema;

public record VectorizeConfig(
@NotNull
@Schema(
description = "Registered Embedding service provider",
type = SchemaType.STRING,
implementation = String.class)
@JsonProperty("provider")
String provider,
@Schema(
description = "Registered Embedding service model",
type = SchemaType.STRING,
implementation = String.class)
@JsonProperty("modelName")
String modelName,
@Valid
@Nullable
@Schema(
description = "Authentication config for chosen embedding service",
type = SchemaType.OBJECT)
@JsonProperty("authentication")
@JsonInclude(JsonInclude.Include.NON_NULL)
Map<String, String> authentication,
@Nullable
@Schema(
description =
"Optional parameters that match the messageTemplate provided for the provider",
type = SchemaType.OBJECT)
@JsonProperty("parameters")
@JsonInclude(JsonInclude.Include.NON_NULL)
Map<String, Object> parameters) {

public VectorizeConfig(
String provider,
String modelName,
Map<String, String> authentication,
Map<String, Object> parameters) {
this.provider = provider;
// HuggingfaceDedicated does not need user to specify model explicitly
// If user specifies modelName other than endpoint-defined-model, will error out
// By default, huggingfaceDedicated provider use endpoint-defined-model as placeholder
if (provider.equals(ProviderConstants.HUGGINGFACE_DEDICATED)) {
if (modelName != null
&& !modelName.equals(ProviderConstants.HUGGINGFACE_DEDICATED_DEFINED_MODEL)) {
throw ErrorCodeV1.INVALID_CREATE_COLLECTION_OPTIONS.toApiException(
"'modelName' is not needed for provider %s explicitly, only '%s' is accepted",
ProviderConstants.HUGGINGFACE_DEDICATED,
ProviderConstants.HUGGINGFACE_DEDICATED_DEFINED_MODEL);
}
this.modelName = ProviderConstants.HUGGINGFACE_DEDICATED_DEFINED_MODEL;
} else {
this.modelName = modelName;
}
if (authentication != null && !authentication.isEmpty()) {
Map<String, String> updatedAuth = new HashMap<>();
for (Map.Entry<String, String> userAuth : authentication.entrySet()) {
// Determine the full credential name based on the sharedKeyValue pair
// If the sharedKeyValue does not contain a dot (e.g. myKey) or the part after the dot
// does not match the key (e.g. myKey.test), append the key to the sharedKeyValue with
// a dot (e.g. myKey.providerKey or myKey.test.providerKey). Otherwise, use the
// sharedKeyValue (e.g. myKey.providerKey) as is.
String sharedKeyValue = userAuth.getValue();
String credentialName =
sharedKeyValue.lastIndexOf('.') <= 0
|| !sharedKeyValue
.substring(sharedKeyValue.lastIndexOf('.') + 1)
.equals(userAuth.getKey())
? sharedKeyValue + "." + userAuth.getKey()
: sharedKeyValue;
updatedAuth.put(userAuth.getKey(), credentialName);
}
this.authentication = updatedAuth;
} else {
this.authentication = authentication;
}
this.parameters = parameters;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import io.stargate.sgv2.jsonapi.api.model.command.deserializers.ColumnDefinitionDeserializer;
import io.stargate.sgv2.jsonapi.api.model.command.impl.VectorizeConfig;
import io.stargate.sgv2.jsonapi.exception.SchemaException;
import io.stargate.sgv2.jsonapi.service.schema.tables.ApiDataType;
import java.util.List;
Expand Down Expand Up @@ -40,7 +41,8 @@ static List<String> getSupportedTypes() {
}

// Returns the column type from the string.
static ColumnType fromString(String type, String keyType, String valueType, int dimension) {
static ColumnType fromString(
String type, String keyType, String valueType, int dimension, VectorizeConfig vectorConfig) {
// TODO: the name of the type should be a part of the ColumnType interface, and use a map for
// the lookup
switch (type) {
Expand Down Expand Up @@ -87,8 +89,8 @@ static ColumnType fromString(String type, String keyType, String valueType, int
}
try {
return new ComplexTypes.MapType(
fromString(keyType, null, null, dimension),
fromString(valueType, null, null, dimension));
fromString(keyType, null, null, dimension, vectorConfig),
fromString(valueType, null, null, dimension, vectorConfig));
} catch (SchemaException se) {
throw SchemaException.Code.MAP_TYPE_INCORRECT_DEFINITION.get();
}
Expand All @@ -99,7 +101,8 @@ static ColumnType fromString(String type, String keyType, String valueType, int
throw SchemaException.Code.LIST_TYPE_INCORRECT_DEFINITION.get();
}
try {
return new ComplexTypes.ListType(fromString(valueType, null, null, dimension));
return new ComplexTypes.ListType(
fromString(valueType, null, null, dimension, vectorConfig));
} catch (SchemaException se) {
throw SchemaException.Code.LIST_TYPE_INCORRECT_DEFINITION.get();
}
Expand All @@ -111,7 +114,8 @@ static ColumnType fromString(String type, String keyType, String valueType, int
throw SchemaException.Code.SET_TYPE_INCORRECT_DEFINITION.get();
}
try {
return new ComplexTypes.SetType(fromString(valueType, null, null, dimension));
return new ComplexTypes.SetType(
fromString(valueType, null, null, dimension, vectorConfig));
} catch (SchemaException se) {
throw SchemaException.Code.SET_TYPE_INCORRECT_DEFINITION.get();
}
Expand All @@ -123,7 +127,7 @@ static ColumnType fromString(String type, String keyType, String valueType, int
throw SchemaException.Code.VECTOR_TYPE_INCORRECT_DEFINITION.get();
}
try {
return new ComplexTypes.VectorType(PrimitiveTypes.FLOAT, dimension);
return new ComplexTypes.VectorType(PrimitiveTypes.FLOAT, dimension, vectorConfig);
} catch (SchemaException se) {
throw SchemaException.Code.VECTOR_TYPE_INCORRECT_DEFINITION.get();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.stargate.sgv2.jsonapi.api.model.command.table.definition.datatype;

import io.stargate.sgv2.jsonapi.api.model.command.impl.VectorizeConfig;
import io.stargate.sgv2.jsonapi.service.schema.tables.ApiDataType;
import io.stargate.sgv2.jsonapi.service.schema.tables.ComplexApiDataType;
import io.stargate.sgv2.jsonapi.service.schema.tables.PrimitiveApiDataType;
Expand Down Expand Up @@ -58,16 +59,26 @@ public static class VectorType implements ColumnType {
// Float will be default type for vector
private final ColumnType valueType;
private final int vectorSize;
private final VectorizeConfig vectorConfig;

public VectorType(ColumnType valueType, int vectorSize) {
public VectorType(ColumnType valueType, int vectorSize, VectorizeConfig vectorConfig) {
this.valueType = valueType;
this.vectorSize = vectorSize;
this.vectorConfig = vectorConfig;
}

@Override
public ApiDataType getApiDataType() {
return new ComplexApiDataType.VectorType(
(PrimitiveApiDataType) valueType.getApiDataType(), vectorSize);
}

public VectorizeConfig getVectorConfig() {
return vectorConfig;
}

public int getDimension() {
return vectorSize;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import com.datastax.oss.driver.api.core.CqlIdentifier;
import com.datastax.oss.driver.api.core.cql.SimpleStatement;
import com.datastax.oss.driver.api.core.data.ByteUtils;
import com.datastax.oss.driver.api.core.metadata.schema.ClusteringOrder;
import com.datastax.oss.driver.api.core.type.DataType;
import com.datastax.oss.driver.api.querybuilder.schema.CreateTable;
Expand All @@ -23,14 +24,15 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

public class CreateTableAttempt extends SchemaAttempt<KeyspaceSchemaObject> {

private final String tableName;
private final Map<String, ApiDataType> columnTypes;
private final List<String> partitionKeys;
private final List<PrimaryKey.OrderingKey> clusteringKeys;
private final String comment;
private final Map<String, String> customProperties;
private final boolean ifNotExists;

protected CreateTableAttempt(
Expand All @@ -43,7 +45,7 @@ protected CreateTableAttempt(
List<String> partitionKeys,
List<PrimaryKey.OrderingKey> clusteringKeys,
boolean ifNotExists,
String comment) {
Map<String, String> customProperties) {
super(
position,
schemaObject,
Expand All @@ -54,7 +56,7 @@ protected CreateTableAttempt(
this.partitionKeys = partitionKeys;
this.clusteringKeys = clusteringKeys;
this.ifNotExists = ifNotExists;
this.comment = comment;
this.customProperties = customProperties;

setStatus(OperationStatus.READY);
}
Expand All @@ -73,8 +75,17 @@ protected SimpleStatement buildStatement() {
// Add all primary keys and colunms
CreateTable createTable = addColumnsAndKeys(create);

// Add comment which has table properties for vectorize
CreateTableWithOptions createWithOptions = createTable.withComment(comment);
// Add customProperties which has table properties for vectorize
// Convert value to hex string using the ByteUtils.toHexString
// This needs to use `createTable.withExtensions()` method in driver when PR
// (https://github.com/apache/cassandra-java-driver/pull/1964) is released
final Map<String, String> extensions =
customProperties.entrySet().stream()
.collect(
Collectors.toMap(
e -> e.getKey(), e -> ByteUtils.toHexString(e.getValue().getBytes())));

CreateTableWithOptions createWithOptions = createTable.withOption("extensions", extensions);

// Add the clustering key order
createWithOptions = addClusteringOrder(createWithOptions);
Expand Down
Loading

0 comments on commit 3f70a10

Please sign in to comment.