Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose source_model parameter for vector-enabled collections #1606

Merged
merged 24 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
bc05add
remove hardcoded value
Hazel-Datastax Oct 24, 2024
cfefcf3
Add source model support in collection
Hazel-Datastax Oct 24, 2024
dea9077
Merge branch 'refs/heads/main' into hazel/collection_source_model
Hazel-Datastax Oct 24, 2024
489606d
Remove unused method (was moved to VectorConfig)
Hazel-Datastax Oct 25, 2024
9f612f6
Add support for older collections
Hazel-Datastax Oct 25, 2024
fc2342a
Fix typo
Hazel-Datastax Oct 25, 2024
931dd7a
rename
Hazel-Datastax Oct 25, 2024
da96598
Add validation
Hazel-Datastax Oct 28, 2024
af1d2e6
Add ITs
Hazel-Datastax Oct 28, 2024
538bfcc
Merge branch 'main' into hazel/collection_source_model
Hazel-Datastax Oct 28, 2024
c19a8cc
Fix unit tests
Hazel-Datastax Oct 28, 2024
6d1d5a0
Fix ITs
Hazel-Datastax Oct 28, 2024
2311ee6
Merge remote-tracking branch 'refs/remotes/origin/main' into hazel/co…
Hazel-Datastax Oct 30, 2024
11d068b
resolve conflicts
Hazel-Datastax Oct 30, 2024
59367d1
Merge remote-tracking branch 'refs/remotes/origin/main' into hazel/co…
Hazel-Datastax Oct 31, 2024
983a754
update the naming
Hazel-Datastax Oct 31, 2024
5f36442
Merge branch 'main' into hazel/collection_source_model
Hazel-Datastax Oct 31, 2024
04b2327
Add IT
Hazel-Datastax Oct 31, 2024
51e7fc4
Add default
Hazel-Datastax Oct 31, 2024
402a047
move `SUPPORTED_SOURCE_MODELS` map from `VectorConstants` to `SourceM…
Hazel-Datastax Oct 31, 2024
ee7ce7b
Merge branch 'main' into hazel/collection_source_model
Hazel-Datastax Oct 31, 2024
ec06bd7
create a method to get all the source model names
Hazel-Datastax Oct 31, 2024
5f6f21c
fix
Hazel-Datastax Oct 31, 2024
834ec17
Merge remote-tracking branch 'origin/hazel/collection_source_model' i…
Hazel-Datastax Oct 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,20 @@ public record VectorSearchConfig(
@JsonProperty("metric")
@JsonAlias("function") // old name
String metric,
@Nullable
@Pattern(
regexp =
"(openai-v3-large|openai-v3-small|ada002|gecko|nv-qa-4|cohere-v3|bert|other)",
Hazel-Datastax marked this conversation as resolved.
Show resolved Hide resolved
message =
"sourceModel options are 'openai-v3-large', 'openai-v3-small', 'ada002', 'gecko', 'nv-qa-4', 'cohere-v3', 'bert', and 'other'")
@Schema(
description =
"The 'sourceModel' option configures the index with the fastest settings for a given source of embeddings vectors",
defaultValue = "other",
type = SchemaType.STRING,
implementation = String.class)
@JsonProperty("sourceModel")
String sourceModel,
@Valid
@Nullable
@JsonInclude(JsonInclude.Include.NON_NULL)
Expand All @@ -105,9 +119,11 @@ public record VectorSearchConfig(
@JsonProperty("service")
VectorizeConfig vectorizeConfig) {

public VectorSearchConfig(Integer dimension, String metric, VectorizeConfig vectorizeConfig) {
public VectorSearchConfig(
Integer dimension, String metric, String sourceModel, VectorizeConfig vectorizeConfig) {
this.dimension = dimension;
this.metric = metric == null ? "cosine" : metric;
tatu-at-datastax marked this conversation as resolved.
Show resolved Hide resolved
this.sourceModel = sourceModel == null ? "other" : sourceModel;
this.vectorizeConfig = vectorizeConfig;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ interface Fields {
/** Key for vector function name definition in cql index. */
String VECTOR_INDEX_FUNCTION_NAME = "similarity_function";

/** Key for vector source model name definition in cql index. */
String VECTOR_INDEX_SOURCE_MODEL_NAME = "source_model";

/** Field name used in projection clause to get similarity score in response. */
String VECTOR_FUNCTION_SIMILARITY_FIELD = "$similarity";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ public enum ErrorCodeV1 {

VECTOR_SEARCH_INVALID_FUNCTION_NAME("Invalid vector search function name"),

VECTOR_SEARCH_INVALID_SOURCE_MODEL_NAME("Invalid vector search source model name"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure but should we use "Unrecognized" instead of "Invalid" for this (I know we use "invalid" above so maybe it's more consistent)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes...I used "invalid" because the above used it. Changed to "Unrecognized", should we change the above to "Unrecognized" as well?


VECTOR_SEARCH_TOO_BIG_VALUE("Vector embedding property '$vector' length too big"),
VECTOR_SIZE_MISMATCH("Length of vector parameter different from declared '$vector' dimension"),

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import com.datastax.oss.driver.api.core.metadata.schema.TableMetadata;
import com.datastax.oss.driver.api.core.type.VectorType;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants;
import io.stargate.sgv2.jsonapi.config.constants.VectorConstant;
import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction;
import io.stargate.sgv2.jsonapi.service.schema.SourceModel;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -50,15 +52,19 @@ public static TableSchemaObject from(TableMetadata tableMetadata, ObjectMapper o
indexMetadata -> indexMetadata.getTarget().equals(column.getKey().asCql(true)))
.findFirst();
SimilarityFunction similarityFunction = SimilarityFunction.COSINE;
SourceModel sourceModel = SourceModel.OTHER;
if (index.isPresent()) {
final IndexMetadata indexMetadata = index.get();
final Map<String, String> indexOptions = indexMetadata.getOptions();
final String sourceModel = indexOptions.get("source_model");
final String similarityFunctionValue = indexOptions.get("similarity_function");
final String sourceModelValue =
indexOptions.get(DocumentConstants.Fields.VECTOR_INDEX_SOURCE_MODEL_NAME);
final String similarityFunctionValue =
indexOptions.get(DocumentConstants.Fields.VECTOR_INDEX_FUNCTION_NAME);
if (similarityFunctionValue != null) {
similarityFunction = SimilarityFunction.fromString(similarityFunctionValue);
} else if (sourceModel != null) {
similarityFunction = VectorConstant.SUPPORTED_SOURCES.get(sourceModel);
} else if (sourceModelValue != null) {
similarityFunction = VectorConstant.SUPPORTED_SOURCES.get(sourceModelValue);
sourceModel = SourceModel.fromString(sourceModelValue);
}
}
int dimension = vectorType.getDimensions();
Expand All @@ -67,6 +73,7 @@ public static TableSchemaObject from(TableMetadata tableMetadata, ObjectMapper o
column.getKey().asInternal(),
dimension,
similarityFunction,
sourceModel,
vectorizeConfigMap.get(column.getKey().asInternal()));
columnVectorDefinitions.add(columnVectorDefinition);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants;
import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction;
import io.stargate.sgv2.jsonapi.service.schema.SourceModel;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -72,6 +73,7 @@ public record ColumnVectorDefinition(
String fieldName,
int vectorSize,
SimilarityFunction similarityFunction,
SourceModel sourceModel,
VectorizeConfig vectorizeConfig) {

// convert a vector jsonNode from comment option to vectorConfig, used for collection
Expand All @@ -80,11 +82,15 @@ public static ColumnVectorDefinition fromJson(JsonNode jsonNode, ObjectMapper ob
int dimension = jsonNode.get("dimension").asInt();
SimilarityFunction similarityFunction =
SimilarityFunction.fromString(jsonNode.get("metric").asText());
// sourceModel doesn't exist if the collection was created before supporting sourceModel; if
// missing, it will be an empty string and sourceModel becomes OTHER.
SourceModel sourceModel = SourceModel.fromString(jsonNode.path("sourceModel").asText());

return fromJson(
DocumentConstants.Fields.VECTOR_EMBEDDING_TEXT_FIELD,
dimension,
similarityFunction,
sourceModel,
jsonNode,
objectMapper);
}
Expand All @@ -94,6 +100,7 @@ public static ColumnVectorDefinition fromJson(
String fieldName,
int dimension,
SimilarityFunction similarityFunction,
SourceModel sourceModel,
JsonNode jsonNode,
ObjectMapper objectMapper) {
VectorizeConfig vectorizeConfig = null;
Expand All @@ -102,7 +109,8 @@ public static ColumnVectorDefinition fromJson(
if (vectorizeServiceNode != null) {
vectorizeConfig = VectorizeConfig.fromJson(vectorizeServiceNode, objectMapper);
}
return new ColumnVectorDefinition(fieldName, dimension, similarityFunction, vectorizeConfig);
return new ColumnVectorDefinition(
fieldName, dimension, similarityFunction, sourceModel, vectorizeConfig);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.QueryExecutor;
import io.stargate.sgv2.jsonapi.service.operation.Operation;
import io.stargate.sgv2.jsonapi.service.schema.SimilarityFunction;
import io.stargate.sgv2.jsonapi.service.schema.SourceModel;
import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionSchemaObject;
import io.stargate.sgv2.jsonapi.service.schema.collections.CollectionTableMatcher;
import java.time.Duration;
Expand All @@ -41,6 +42,7 @@ public record CreateCollectionOperation(
boolean vectorSearch,
int vectorSize,
String vectorFunction,
String sourceModel,
String comment,
int ddlDelayMillis,
boolean tooManyIndexesRollbackEnabled,
Expand All @@ -60,6 +62,7 @@ public static CreateCollectionOperation withVectorSearch(
String name,
int vectorSize,
String vectorFunction,
String sourceModel,
String comment,
int ddlDelayMillis,
boolean tooManyIndexesRollbackEnabled,
Expand All @@ -73,6 +76,7 @@ public static CreateCollectionOperation withVectorSearch(
true,
vectorSize,
vectorFunction,
sourceModel,
comment,
ddlDelayMillis,
tooManyIndexesRollbackEnabled,
Expand All @@ -98,6 +102,7 @@ public static CreateCollectionOperation withoutVectorSearch(
false,
0,
null,
null,
comment,
ddlDelayMillis,
tooManyIndexesRollbackEnabled,
Expand Down Expand Up @@ -142,6 +147,7 @@ public Uni<Supplier<CommandResult>> execute(
vectorSearch,
vectorSize,
SimilarityFunction.fromString(vectorFunction),
SourceModel.fromString(sourceModel),
comment,
objectMapper);
// if table exists we have to choices:
Expand Down Expand Up @@ -492,6 +498,8 @@ public List<SimpleStatement> getIndexStatements(
appender
+ " \"%s_query_vector_value\" ON \"%s\".\"%s\" (query_vector_value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'similarity_function': '"
+ vectorFunction()
+ "', 'source_model': '"
+ sourceModel()
+ "'}";
statements.add(
SimpleStatement.newInstance(String.format(vectorSearch, table, keyspace, table)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import com.datastax.oss.driver.api.querybuilder.schema.CreateIndexOnTable;
import com.datastax.oss.driver.api.querybuilder.schema.CreateIndexStart;
import com.datastax.oss.driver.internal.querybuilder.schema.DefaultCreateIndex;
import io.stargate.sgv2.jsonapi.config.constants.DocumentConstants;
import io.stargate.sgv2.jsonapi.service.cqldriver.executor.TableSchemaObject;
import io.stargate.sgv2.jsonapi.service.cqldriver.override.ExtendedCreateIndex;
import io.stargate.sgv2.jsonapi.service.operation.SchemaAttempt;
Expand Down Expand Up @@ -90,10 +91,11 @@ public record VectorIndexOptions(SimilarityFunction similarityFunction, String s
public Map<String, Object> getOptions() {
Map<String, Object> options = new HashMap<>();
if (similarityFunction != null) {
options.put("similarity_function", similarityFunction.getMetric());
options.put(
DocumentConstants.Fields.VECTOR_INDEX_FUNCTION_NAME, similarityFunction.getMetric());
}
if (sourceModel != null) {
options.put("source_model", sourceModel);
options.put(DocumentConstants.Fields.VECTOR_INDEX_SOURCE_MODEL_NAME, sourceModel);
}
return options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ public Operation resolveKeyspaceCommand(
command.name(),
vector.dimension(),
vector.metric(),
vector.sourceModel(),
comment,
operationsConfig.databaseConfig().ddlDelayMillis(),
operationsConfig.tooManyIndexesRollbackEnabled(),
Expand Down Expand Up @@ -198,7 +199,7 @@ private CreateCollectionCommand.Options.VectorSearchConfig validateVectorOptions
vectorDimension = validateVectorize.validateService(service, vectorDimension);
vector =
new CreateCollectionCommand.Options.VectorSearchConfig(
vectorDimension, vector.metric(), vector.vectorizeConfig());
vectorDimension, vector.metric(), vector.sourceModel(), vector.vectorizeConfig());
} else {
// Ensure vector dimension is provided when service configuration is absent.
if (vectorDimension == null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package io.stargate.sgv2.jsonapi.service.schema;

import io.stargate.sgv2.jsonapi.exception.ErrorCodeV1;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* The source model used for the vector index. This is only applicable if the vector index is
* enabled.
*/
public enum SourceModel {
ADA002("ada002"),
BERT("bert"),
COHERE_V3("cohere-v3"),
GECKO("gecko"),
NV_QA_4("nv-qa-4"),
OPENAI_V3_LARGE("openai-v3-large"),
OPENAI_V3_SMALL("openai-v3-small"),
OTHER("other"),
UNDEFINED("undefined");

private final String sourceModel;
Hazel-Datastax marked this conversation as resolved.
Show resolved Hide resolved
private static final java.util.Map<String, SourceModel> FUNCTIONS_MAP =
Stream.of(SourceModel.values())
.collect(Collectors.toMap(SourceModel::getSourceModel, sourceModel -> sourceModel));

SourceModel(String sourceModel) {
this.sourceModel = sourceModel;
}

public String getSourceModel() {
Hazel-Datastax marked this conversation as resolved.
Show resolved Hide resolved
return sourceModel;
}

public static SourceModel fromString(String sourceModel) {
Hazel-Datastax marked this conversation as resolved.
Show resolved Hide resolved
if (sourceModel == null) return UNDEFINED;
// The string may be empty if the collection was created before supporting source models
if (sourceModel.isEmpty()) return OTHER;
SourceModel model = FUNCTIONS_MAP.get(sourceModel);
if (model == null) {
throw ErrorCodeV1.VECTOR_SEARCH_INVALID_SOURCE_MODEL_NAME.toApiException("'%s'", sourceModel);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should then also list valid/known source model names, not just invalid/unrecognized value

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. I followed the pattern in SimilarityFunction. Do we want to change it as well?

}
return model;
}
}
Loading
Loading