From 510bbc3f720b7c99bdc82c76509a2dfb86afb7ba Mon Sep 17 00:00:00 2001 From: Jordan Lewis Date: Sun, 31 Mar 2024 20:53:40 -0400 Subject: [PATCH] index creation --- pkg/sql/BUILD.bazel | 1 + pkg/sql/backfill.go | 14 ++- pkg/sql/catalog/catpb/enum.proto | 3 + pkg/sql/catalog/colinfo/col_type_info.go | 2 +- pkg/sql/catalog/descs/hydrate.go | 19 +++- pkg/sql/colenc/inverted.go | 5 ++ pkg/sql/create_index.go | 27 ++++++ pkg/sql/parser/sql.y | 2 +- pkg/sql/rowenc/BUILD.bazel | 2 + pkg/sql/rowenc/index_encoding.go | 35 +++++++- .../scbuild/internal/scbuildstmt/BUILD.bazel | 1 + .../internal/scbuildstmt/create_index.go | 38 +++++++- .../scexec/scmutationexec/index.go | 3 + pkg/sql/schemachanger/scpb/BUILD.bazel | 2 + pkg/sql/schemachanger/scpb/elements.proto | 3 + pkg/sql/sem/builtins/builtins.go | 16 ++++ pkg/sql/sem/builtins/fixed_oids.go | 1 + .../indexstorageparam/index_storage_param.go | 16 ++++ pkg/util/encoding/encoding.go | 17 ++++ pkg/util/encoding/float.go | 86 +++++++++++++++++++ pkg/util/vector/BUILD.bazel | 2 + pkg/util/vector/vector.go | 43 ++++++++++ pkg/util/vector/vectorpb/config.go | 2 +- pkg/util/vector/vectorpb/config.proto | 23 +++-- 24 files changed, 347 insertions(+), 16 deletions(-) diff --git a/pkg/sql/BUILD.bazel b/pkg/sql/BUILD.bazel index 94eb8bee673c..e7ffcfb9913d 100644 --- a/pkg/sql/BUILD.bazel +++ b/pkg/sql/BUILD.bazel @@ -590,6 +590,7 @@ go_library( "//pkg/util/uint128", "//pkg/util/ulid", "//pkg/util/uuid", + "//pkg/util/vector/vectorpb", "@com_github_cockroachdb_apd_v3//:apd", "@com_github_cockroachdb_errors//:errors", "@com_github_cockroachdb_errors//hintdetail", diff --git a/pkg/sql/backfill.go b/pkg/sql/backfill.go index 20f7e14947d4..7e309d68983d 100644 --- a/pkg/sql/backfill.go +++ b/pkg/sql/backfill.go @@ -1806,12 +1806,17 @@ func countExpectedRowsForInvertedIndex( ctx context.Context, txn descs.Txn, ) error { var stmt string - geoConfig := idx.GetGeoConfig() - if geoConfig.IsEmpty() { + if geoConfig := idx.GetGeoConfig(); geoConfig.IsEmpty() { stmt = fmt.Sprintf( `SELECT coalesce(sum_int(crdb_internal.num_inverted_index_entries(%s, %d)), 0) FROM [%d AS t]`, colNameOrExpr, idx.GetVersion(), desc.GetID(), ) + } else if vectorConfig := idx.GetVectorConfig(); vectorConfig.IsEmpty() { + nLists := vectorConfig.GetIvfFlat().NLists + stmt = fmt.Sprintf( + `SELECT coalesce(sum_int(crdb_internal.num_inverted_index_entries(%s, %d)), 0) + least(%d, count(colNameOrExpr)) FROM [%d AS t]`, + colNameOrExpr, idx.GetVersion(), nLists, desc.GetID(), + ) } else { stmt = fmt.Sprintf( `SELECT coalesce(sum_int(crdb_internal.num_geo_inverted_index_entries(%d, %d, %s)), 0) FROM [%d AS t]`, @@ -1832,6 +1837,11 @@ func countExpectedRowsForInvertedIndex( return errors.New("failed to verify inverted index count") } expectedCount = int64(tree.MustBeDInt(row[0])) + // For ivf indexes, the expected count is the sum of the number of + // entries in the inverted index and the number of rows in the table. + if len(row) > 1 { + expectedCount += int64(tree.MustBeDInt(row[1])) + } return nil }) }); err != nil { diff --git a/pkg/sql/catalog/catpb/enum.proto b/pkg/sql/catalog/catpb/enum.proto index f3e61fc9b0da..cb588f63d64f 100644 --- a/pkg/sql/catalog/catpb/enum.proto +++ b/pkg/sql/catalog/catpb/enum.proto @@ -48,4 +48,7 @@ enum InvertedIndexColumnKind { // TRIGRAM is the trigram kind of inverted index column. It's only valid on // text columns. TRIGRAM = 1; + // IVFFLAT is the IVFFLAT kind of inverted index column. It's only valid on + // vector columns. + IVFFLAT = 2; } diff --git a/pkg/sql/catalog/colinfo/col_type_info.go b/pkg/sql/catalog/colinfo/col_type_info.go index 4476cd7f4ac0..d9b298233b63 100644 --- a/pkg/sql/catalog/colinfo/col_type_info.go +++ b/pkg/sql/catalog/colinfo/col_type_info.go @@ -181,7 +181,7 @@ func ColumnTypeIsInvertedIndexable(t *types.T) bool { switch t.Family() { case types.ArrayFamily: return t.ArrayContents().Family() != types.RefCursorFamily - case types.JsonFamily, types.StringFamily: + case types.JsonFamily, types.StringFamily, types.PGVectorFamily: return true } return ColumnTypeIsOnlyInvertedIndexable(t) diff --git a/pkg/sql/catalog/descs/hydrate.go b/pkg/sql/catalog/descs/hydrate.go index a9ab1d8e3716..addf62a831f9 100644 --- a/pkg/sql/catalog/descs/hydrate.go +++ b/pkg/sql/catalog/descs/hydrate.go @@ -243,7 +243,24 @@ func hydrate( if !isHydratable(desc) { return nil } - return typedesc.HydrateTypesInDescriptor(ctx, desc, typeLookupFunc) + err := typedesc.HydrateTypesInDescriptor(ctx, desc, typeLookupFunc) + if err != nil { + return err + } + if tableDesc, ok := desc.(catalog.TableDescriptor); ok { + for _, idx := range tableDesc.NonDropIndexes() { + vectorConfig := idx.GetVectorConfig() + if vectorConfig.IsEmpty() { + continue + } + ivfFlat := vectorConfig.GetIvfFlat() + if ivfFlat == nil { + continue + } + ivfFlat.Centroids = nil // scan centroids + } + } + return nil } // makeTypeLookupFuncForHydration builds a typedesc.TypeLookupFunc for the diff --git a/pkg/sql/colenc/inverted.go b/pkg/sql/colenc/inverted.go index 75553a38b1d3..1781dea79e7a 100644 --- a/pkg/sql/colenc/inverted.go +++ b/pkg/sql/colenc/inverted.go @@ -57,6 +57,7 @@ func (b *BatchEncoder) encodeInvertedSecondaryIndex( vec = b.b.ColVecs()[i] } indexGeoConfig := index.GetGeoConfig() + indexVectorConfig := index.GetVectorConfig() for row := 0; row < b.count; row++ { if kys[row] == nil { continue @@ -67,6 +68,10 @@ func (b *BatchEncoder) encodeInvertedSecondaryIndex( if keys, err = rowenc.EncodeGeoInvertedIndexTableKeys(ctx, val, kys[row], indexGeoConfig); err != nil { return err } + } else if !indexVectorConfig.IsEmpty() { + if keys, err = rowenc.EncodeVectorInvertedIndexTableKeys(ctx, val, kys[row], indexVectorConfig); err != nil { + return err + } } else { if keys, err = rowenc.EncodeInvertedIndexTableKeys(val, kys[row], index.GetVersion()); err != nil { return err diff --git a/pkg/sql/create_index.go b/pkg/sql/create_index.go index 08d08873f7e4..6b9584988f3f 100644 --- a/pkg/sql/create_index.go +++ b/pkg/sql/create_index.go @@ -37,6 +37,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/cockroachdb/cockroach/pkg/util/errorutil/unimplemented" "github.com/cockroachdb/cockroach/pkg/util/log/eventpb" + "github.com/cockroachdb/cockroach/pkg/util/vector/vectorpb" "github.com/cockroachdb/errors" ) @@ -437,6 +438,32 @@ func populateInvertedIndexDescriptor( default: return newUndefinedOpclassError(invCol.OpClass) } + case types.PGVectorFamily: + width := column.GetType().Width() + if width == 0 { + return pgerror.New(pgcode.InvalidObjectDefinition, "column does not have dimensions") + } + indexDesc.VectorConfig = vectorpb.Config{ + IndexType: &vectorpb.Config_IvfFlat{ + IvfFlat: &vectorpb.IVFFlatConfig{ + // lists defaults to 100. + NLists: 100, + }, + }, + Dimensions: width, + } + indexDesc.InvertedColumnKinds[0] = catpb.InvertedIndexColumnKind_IVFFLAT + switch invCol.OpClass { + // The default operator class is "vector_l2_ops". + case "vector_l2_ops", "": + indexDesc.VectorConfig.DistanceFunction = vectorpb.DistanceFunction_L2 + case "vector_ip_ops": + indexDesc.VectorConfig.DistanceFunction = vectorpb.DistanceFunction_IP + case "vector_cosine_ops": + indexDesc.VectorConfig.DistanceFunction = vectorpb.DistanceFunction_COSINE + default: + return newUndefinedOpclassError(invCol.OpClass) + } default: return tabledesc.NewInvalidInvertedColumnError(column.GetName(), column.GetType().Name()) } diff --git a/pkg/sql/parser/sql.y b/pkg/sql/parser/sql.y index a34202f2ae8d..c9709edc9cb7 100644 --- a/pkg/sql/parser/sql.y +++ b/pkg/sql/parser/sql.y @@ -11387,7 +11387,7 @@ opt_index_access_method: { /* FORCE DOC */ switch $2 { - case "gin", "gist": + case "gin", "gist", "ivfflat": $$.val = true case "btree": $$.val = false diff --git a/pkg/sql/rowenc/BUILD.bazel b/pkg/sql/rowenc/BUILD.bazel index ecf48682638f..02ccc00d3d22 100644 --- a/pkg/sql/rowenc/BUILD.bazel +++ b/pkg/sql/rowenc/BUILD.bazel @@ -41,6 +41,8 @@ go_library( "//pkg/util/trigram", "//pkg/util/tsearch", "//pkg/util/unique", + "//pkg/util/vector", + "//pkg/util/vector/vectorpb", "@com_github_cockroachdb_errors//:errors", "@com_github_cockroachdb_redact//:redact", ], diff --git a/pkg/sql/rowenc/index_encoding.go b/pkg/sql/rowenc/index_encoding.go index 62b297763eee..9de620f2afd3 100644 --- a/pkg/sql/rowenc/index_encoding.go +++ b/pkg/sql/rowenc/index_encoding.go @@ -42,6 +42,8 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/trigram" "github.com/cockroachdb/cockroach/pkg/util/tsearch" "github.com/cockroachdb/cockroach/pkg/util/unique" + "github.com/cockroachdb/cockroach/pkg/util/vector" + "github.com/cockroachdb/cockroach/pkg/util/vector/vectorpb" "github.com/cockroachdb/errors" ) @@ -624,10 +626,13 @@ func EncodeInvertedIndexKeys( } else { val = tree.DNull } - indexGeoConfig := index.GetGeoConfig() - if !indexGeoConfig.IsEmpty() { + config := index.GetVectorConfig() + if indexGeoConfig := index.GetGeoConfig(); !indexGeoConfig.IsEmpty() { return EncodeGeoInvertedIndexTableKeys(ctx, val, keyPrefix, indexGeoConfig) + } else if vectorConfig := config; !vectorConfig.IsEmpty() { + return EncodeVectorInvertedIndexTableKeys(ctx, val, keyPrefix, vectorConfig) } + return EncodeInvertedIndexTableKeys(val, keyPrefix, index.GetVersion()) } @@ -1087,6 +1092,32 @@ func EncodeGeoInvertedIndexTableKeys( } } +// EncodeVectorInvertedIndexTableKeys is the equivalent of EncodeInvertedIndexTableKeys +// for vectors. +func EncodeVectorInvertedIndexTableKeys(_ context.Context, val tree.Datum, keyPrefix []byte, + vectorConfig vectorpb.Config) ([][]byte, error) { + if val == tree.DNull { + return nil, nil + } + vec := tree.MustBeDPGVector(val).T + centroid, err := vector.GetClosestCentroid(vec, vectorConfig) + if err != nil { + return nil, err + } + if vectorConfig.Dimensions != int32(len(centroid)) { + return nil, errors.Errorf("centroid has %d dimensions, expected %d", len(centroid), vectorConfig.Dimensions) + } + if vectorConfig.Dimensions != int32(len(vec)) { + return nil, errors.Errorf("vector has %d dimensions, expected %d", len(vec), vectorConfig.Dimensions) + } + // The buffer will be used to encode the centroid and the input vector. 4 bytes per + // float32 times 2 vectors. + b := make([]byte, 0, len(keyPrefix)+int(1+4*2*vectorConfig.Dimensions)) + b = append(b, keyPrefix...) + encoding.EncodeIvfCentroidVector(b, centroid, vec) + return [][]byte{b}, nil +} + func encodeGeoKeys( inKey []byte, geoKeys []geoindex.Key, bbox geopb.BoundingBox, ) (keys [][]byte, err error) { diff --git a/pkg/sql/schemachanger/scbuild/internal/scbuildstmt/BUILD.bazel b/pkg/sql/schemachanger/scbuild/internal/scbuildstmt/BUILD.bazel index b86ef09fc882..95c60b46baf5 100644 --- a/pkg/sql/schemachanger/scbuild/internal/scbuildstmt/BUILD.bazel +++ b/pkg/sql/schemachanger/scbuild/internal/scbuildstmt/BUILD.bazel @@ -81,6 +81,7 @@ go_library( "//pkg/sql/types", "//pkg/util/errorutil/unimplemented", "//pkg/util/protoutil", + "//pkg/util/vector/vectorpb", "@com_github_cockroachdb_errors//:errors", "@com_github_cockroachdb_redact//:redact", "@com_github_lib_pq//oid", diff --git a/pkg/sql/schemachanger/scbuild/internal/scbuildstmt/create_index.go b/pkg/sql/schemachanger/scbuild/internal/scbuildstmt/create_index.go index 9dfed832cfc2..d64816c2dad1 100644 --- a/pkg/sql/schemachanger/scbuild/internal/scbuildstmt/create_index.go +++ b/pkg/sql/schemachanger/scbuild/internal/scbuildstmt/create_index.go @@ -43,6 +43,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/cockroachdb/cockroach/pkg/util/errorutil/unimplemented" "github.com/cockroachdb/cockroach/pkg/util/protoutil" + "github.com/cockroachdb/cockroach/pkg/util/vector/vectorpb" "github.com/cockroachdb/errors" ) @@ -348,6 +349,33 @@ func processColNodeType( } indexSpec.secondary.GeoConfig = geoindex.DefaultGeographyIndexConfig() b.IncrementSchemaChangeIndexCounter("geography_inverted") + case types.PGVectorFamily: + width := columnType.Type.Width() + if width == 0 { + panic(pgerror.New(pgcode.InvalidObjectDefinition, "column does not have dimensions")) + } + indexSpec.secondary.VectorConfig = &vectorpb.Config{ + IndexType: &vectorpb.Config_IvfFlat{ + IvfFlat: &vectorpb.IVFFlatConfig{ + // lists defaults to 100. + NLists: 100, + }, + }, + Dimensions: width, + } + invertedKind = catpb.InvertedIndexColumnKind_IVFFLAT + switch columnNode.OpClass { + // The default operator class is "vector_l2_ops". + case "vector_l2_ops", "": + indexSpec.secondary.VectorConfig.DistanceFunction = vectorpb.DistanceFunction_L2 + case "vector_ip_ops": + indexSpec.secondary.VectorConfig.DistanceFunction = vectorpb.DistanceFunction_IP + case "vector_cosine_ops": + indexSpec.secondary.VectorConfig.DistanceFunction = vectorpb.DistanceFunction_COSINE + default: + panic(newUndefinedOpclassError(columnNode.OpClass)) + } + b.IncrementSchemaChangeIndexCounter("ivfflat_inverted") case types.StringFamily: // Check the opclass of the last column in the list, which is the column // we're going to inverted index. @@ -930,7 +958,7 @@ func maybeAddIndexPredicate(b BuildCtx, n *tree.CreateIndex, idxSpec *indexSpec) } // maybeApplyStorageParameters apply any storage parameters into the index spec, -// this is only used for GeoConfig today. +// this is only used for GeoConfig and VectorConfig today. func maybeApplyStorageParameters(b BuildCtx, n *tree.CreateIndex, idxSpec *indexSpec) { if len(n.StorageParams) == 0 { return @@ -939,6 +967,9 @@ func maybeApplyStorageParameters(b BuildCtx, n *tree.CreateIndex, idxSpec *index if idxSpec.secondary.GeoConfig != nil { dummyIndexDesc.GeoConfig = *idxSpec.secondary.GeoConfig } + if idxSpec.secondary.VectorConfig != nil { + dummyIndexDesc.VectorConfig = *idxSpec.secondary.VectorConfig + } storageParamSetter := &indexstorageparam.Setter{ IndexDesc: dummyIndexDesc, } @@ -951,6 +982,11 @@ func maybeApplyStorageParameters(b BuildCtx, n *tree.CreateIndex, idxSpec *index } else { idxSpec.secondary.GeoConfig = nil } + if !dummyIndexDesc.VectorConfig.IsEmpty() { + idxSpec.secondary.VectorConfig = &dummyIndexDesc.VectorConfig + } else { + idxSpec.secondary.VectorConfig = nil + } } // fallbackIfRelationIsNotTable falls back if a relation element is diff --git a/pkg/sql/schemachanger/scexec/scmutationexec/index.go b/pkg/sql/schemachanger/scexec/scmutationexec/index.go index 0cb7a2ae6547..2e357b094c57 100644 --- a/pkg/sql/schemachanger/scexec/scmutationexec/index.go +++ b/pkg/sql/schemachanger/scexec/scmutationexec/index.go @@ -99,6 +99,9 @@ func addNewIndexMutation( if opIndex.GeoConfig != nil { idx.GeoConfig = *opIndex.GeoConfig } + if opIndex.VectorConfig != nil { + idx.VectorConfig = *opIndex.VectorConfig + } return enqueueIndexMutation(tbl, idx, state, descpb.DescriptorMutation_ADD) } diff --git a/pkg/sql/schemachanger/scpb/BUILD.bazel b/pkg/sql/schemachanger/scpb/BUILD.bazel index 6677b36a4319..4e431f3cb0c2 100644 --- a/pkg/sql/schemachanger/scpb/BUILD.bazel +++ b/pkg/sql/schemachanger/scpb/BUILD.bazel @@ -42,6 +42,7 @@ go_proto_library( "//pkg/sql/sem/catid", # keep "//pkg/sql/sem/semenumpb", "//pkg/sql/types", + "//pkg/util/vector/vectorpb", "@com_github_gogo_protobuf//gogoproto", ], ) @@ -60,6 +61,7 @@ proto_library( "//pkg/sql/catalog/catpb:catpb_proto", "//pkg/sql/sem/semenumpb:semenumpb_proto", "//pkg/sql/types:types_proto", + "//pkg/util/vector/vectorpb:vectorpb_proto", "@com_github_gogo_protobuf//gogoproto:gogo_proto", ], ) diff --git a/pkg/sql/schemachanger/scpb/elements.proto b/pkg/sql/schemachanger/scpb/elements.proto index e99c16a2c6d6..84a8f3d831dd 100644 --- a/pkg/sql/schemachanger/scpb/elements.proto +++ b/pkg/sql/schemachanger/scpb/elements.proto @@ -19,6 +19,7 @@ import "sql/catalog/catpb/function.proto"; import "sql/types/types.proto"; import "gogoproto/gogo.proto"; import "geo/geopb/config.proto"; +import "util/vector/vectorpb/config.proto"; option (gogoproto.equal_all) = true; @@ -262,6 +263,8 @@ message Index { // Invisibility specifies index invisibility to the optimizer. double invisibility = 25; + cockroach.vector.vectorindex.Config vector_config = 26 [(gogoproto.nullable) = true]; + reserved 3, 4, 5, 6, 7; } diff --git a/pkg/sql/sem/builtins/builtins.go b/pkg/sql/sem/builtins/builtins.go index 29e42f20b123..405c225169ae 100644 --- a/pkg/sql/sem/builtins/builtins.go +++ b/pkg/sql/sem/builtins/builtins.go @@ -6414,6 +6414,22 @@ SELECT Volatility: volatility.Stable, CalledOnNullInput: true, }, + tree.Overload{ + Types: tree.ParamTypes{ + {Name: "val", Typ: types.PGVector}, + {Name: "version", Typ: types.Int}, + }, + ReturnType: tree.FixedReturnType(types.Int), + Fn: func(ctx context.Context, evalCtx *eval.Context, args tree.Datums) (tree.Datum, error) { + if args[0] == tree.DNull { + return tree.DZero, nil + } + return tree.NewDInt(tree.DInt(1)), nil + }, + Info: "This function is used only by CockroachDB's developers for testing purposes.", + Volatility: volatility.Stable, + CalledOnNullInput: true, + }, ), "crdb_internal.assignment_cast": makeBuiltin( diff --git a/pkg/sql/sem/builtins/fixed_oids.go b/pkg/sql/sem/builtins/fixed_oids.go index b3de3a16b491..e46137efbccc 100644 --- a/pkg/sql/sem/builtins/fixed_oids.go +++ b/pkg/sql/sem/builtins/fixed_oids.go @@ -2592,6 +2592,7 @@ var builtinOidsArray = []string{ 2626: `inner_product(v1: vector, v2: vector) -> float`, 2627: `vector_dims(vector: vector) -> int`, 2628: `vector_norm(vector: vector) -> float`, + 2629: `crdb_internal.num_inverted_index_entries(val: vector, version: int) -> int`, } var builtinOidsBySignature map[string]oid.Oid diff --git a/pkg/sql/storageparam/indexstorageparam/index_storage_param.go b/pkg/sql/storageparam/indexstorageparam/index_storage_param.go index 8c6b3d317c25..5861516e7731 100644 --- a/pkg/sql/storageparam/indexstorageparam/index_storage_param.go +++ b/pkg/sql/storageparam/indexstorageparam/index_storage_param.go @@ -126,6 +126,8 @@ func (po *Setter) Set( return po.applyS2ConfigSetting(ctx, evalCtx, key, expr, 1, 32) case `geometry_min_x`, `geometry_max_x`, `geometry_min_y`, `geometry_max_y`: return po.applyGeometryIndexSetting(ctx, evalCtx, key, expr) + case `lists`: + return po.applyIvfflatConfigSetting(ctx, evalCtx, key, expr) // `bucket_count` is handled in schema changer when creating hash sharded // indexes. case `bucket_count`: @@ -180,3 +182,17 @@ func (po *Setter) RunPostChecks() error { } return nil } + +func (po *Setter) applyIvfflatConfigSetting(ctx context.Context, evalCtx *eval.Context, key string, + expr tree.Datum) error { + cfg := po.IndexDesc.VectorConfig.GetIvfFlat() + if cfg == nil { + return pgerror.Newf(pgcode.InvalidParameterValue, "%q can only be applied to ivfflat indexes", key) + } + val, err := paramparse.DatumAsInt(ctx, evalCtx, key, expr) + if err != nil { + return errors.Wrapf(err, "error decoding %q", key) + } + cfg.NLists = int32(val) + return nil +} diff --git a/pkg/util/encoding/encoding.go b/pkg/util/encoding/encoding.go index c3f62460a452..2cdde57332c8 100644 --- a/pkg/util/encoding/encoding.go +++ b/pkg/util/encoding/encoding.go @@ -87,6 +87,11 @@ const ( geoMarker = timeTZMarker + 1 geoDescMarker = geoMarker + 1 + // Markers for the 2 sections of a vector ivf index. The centroids are stored + // separately from the inverted index itself. + vectorIvfCentroidsMarker = 0x00 + vectorIvfIndexMarker = 0x01 + // Markers and terminators for key encoding Datum arrays in sorted order. // For the arrayKeyMarker and other types like bytes and bit arrays, it // might be unclear why we have a separate marker for the ascending and @@ -1126,6 +1131,18 @@ func DecodeGeoInvertedKey(b []byte) (loX, loY, hiX, hiY float64, remaining []byt return loX, loY, hiX, hiY, b, nil } +// EncodeIvfCentroidVector encodes the centroid and vector for an ivfflat index. +func EncodeIvfCentroidVector(b []byte, centroid []float32, vec []float32) []byte { + b = append(b, vectorIvfIndexMarker) + for _, f := range centroid { + b = EncodeFloat32Ascending(b, f) + } + for _, f := range vec { + b = EncodeFloat32Ascending(b, f) + } + return b +} + // EncodeNullDescending is the descending equivalent of EncodeNullAscending. func EncodeNullDescending(b []byte) []byte { return append(b, encodedNullDesc) diff --git a/pkg/util/encoding/float.go b/pkg/util/encoding/float.go index 4a3f4091f255..ed8c71c20cf2 100644 --- a/pkg/util/encoding/float.go +++ b/pkg/util/encoding/float.go @@ -101,3 +101,89 @@ func DecodeFloatDescending(buf []byte) ([]byte, float64, error) { } return b, r, err } + +// EncodeFloat32Ascending returns the resulting byte slice with the encoded float32 +// appended to b. The encoded format for a float32 value f is, for positive f, the +// encoding of the 32 bits (in IEEE 754 format) re-interpreted as an int64 and +// encoded using EncodeUint32Ascending. For negative f, we keep the sign bit and +// invert all other bits, encoding this value using EncodeUint32Descending. This +// approach was inspired by in github.com/google/orderedcode/orderedcode.go. +// +// One of five single-byte prefix tags are appended to the front of the encoding. +// These tags enforce logical ordering of keys for both ascending and descending +// encoding directions. The tags split the encoded floats into five categories: +// - NaN for an ascending encoding direction +// - Negative valued floats +// - Zero (positive and negative) +// - Positive valued floats +// - NaN for a descending encoding direction +// This ordering ensures that NaNs are always sorted first in either encoding +// direction, and that after them a logical ordering is followed. +func EncodeFloat32Ascending(b []byte, f float32) []byte { + // Handle the simplistic cases first. + switch { + case math.IsNaN(float64(f)): + return append(b, floatNaN) + case f == 0: + // This encodes both positive and negative zero the same. Negative zero uses + // composite indexes to decode itself correctly. + return append(b, floatZero) + } + u := math.Float32bits(f) + if u&(1<<31) != 0 { + u = ^u + b = append(b, floatNeg) + } else { + b = append(b, floatPos) + } + return EncodeUint32Ascending(b, u) +} + +// EncodeFloat32Descending is the descending version of EncodeFloat32Ascending. +func EncodeFloat32Descending(b []byte, f float32) []byte { + if math.IsNaN(float64(f)) { + return append(b, floatNaNDesc) + } + return EncodeFloat32Ascending(b, -f) +} + +// DecodeFloat32Ascending returns the remaining byte slice after decoding and the decoded +// float32 from buf. +func DecodeFloat32Ascending(buf []byte) ([]byte, float32, error) { + if PeekType(buf) != Float { + return buf, 0, errors.Errorf("did not find marker") + } + switch buf[0] { + case floatNaN, floatNaNDesc: + return buf[1:], float32(math.NaN()), nil + case floatNeg: + b, u, err := DecodeUint32Ascending(buf[1:]) + if err != nil { + return b, 0, err + } + u = ^u + return b, math.Float32frombits(u), nil + case floatZero: + return buf[1:], 0, nil + case floatPos: + b, u, err := DecodeUint32Ascending(buf[1:]) + if err != nil { + return b, 0, err + } + return b, math.Float32frombits(u), nil + default: + return nil, 0, errors.Errorf("unknown prefix of the encoded byte slice: %q", buf) + } +} + +// DecodeFloat32Descending decodes floats encoded with EncodeFloat32Descending. +func DecodeFloat32Descending(buf []byte) ([]byte, float32, error) { + b, r, err := DecodeFloat32Ascending(buf) + if r != 0 && !math.IsNaN(float64(r)) { + // All values except for 0 and NaN were negated in EncodeFloat32Descending, so + // we have to negate them back. Negative zero uses composite indexes to + // decode itself correctly. + r = -r + } + return b, r, err +} diff --git a/pkg/util/vector/BUILD.bazel b/pkg/util/vector/BUILD.bazel index 1cb837ae0a1f..f5aff8a91213 100644 --- a/pkg/util/vector/BUILD.bazel +++ b/pkg/util/vector/BUILD.bazel @@ -12,6 +12,8 @@ go_library( "//pkg/sql/pgwire/pgcode", "//pkg/sql/pgwire/pgerror", "//pkg/util/encoding", + "//pkg/util/vector/vectorpb", + "@com_github_cockroachdb_errors//:errors", ], ) diff --git a/pkg/util/vector/vector.go b/pkg/util/vector/vector.go index eab12a827d62..13ad56cbc5a8 100644 --- a/pkg/util/vector/vector.go +++ b/pkg/util/vector/vector.go @@ -10,6 +10,8 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/cockroach/pkg/util/encoding" + "github.com/cockroachdb/cockroach/pkg/util/vector/vectorpb" + "github.com/cockroachdb/errors" ) // MaxDim is the maximum number of dimensions a vector can have. @@ -244,6 +246,47 @@ func Mult(t T, t2 T) (T, error) { return ret, nil } +func distanceFuncFromProto(distanceMethod vectorpb.DistanceFunction) func(T, T) (float64, error) { + switch distanceMethod { + case vectorpb.DistanceFunction_L2: + return L2Distance + case vectorpb.DistanceFunction_IP: + return InnerProduct + case vectorpb.DistanceFunction_COSINE: + return CosDistance + } + panic(fmt.Sprintf("unsupported distance function %s", distanceMethod)) +} + +// GetClosestCentroid returns the centroid from the index configuration that is +// closest to the given vector t. +func GetClosestCentroid(t T, cfg vectorpb.Config) (T, error) { + switch cfg.IndexType.(type) { + case *vectorpb.Config_IvfFlat: + default: + return nil, errors.AssertionFailedf("unsupported index type %T", cfg.IndexType) + } + distanceFunc := distanceFuncFromProto(cfg.DistanceFunction) + + ivf := cfg.GetIvfFlat() + if len(ivf.Centroids) == 0 { + return nil, errors.AssertionFailedf("no centroids found in index configuration") + } + var closest T + closestDistance := math.MaxFloat64 + for _, centroid := range ivf.Centroids { + dist, err := distanceFunc(t, centroid.Centroid) + if err != nil { + return nil, err + } + if dist < closestDistance { + closest = centroid.Centroid + closestDistance = dist + } + } + return closest, nil +} + // Random returns a random vector. func Random(rng *rand.Rand) T { n := 1 + rng.Intn(1000) diff --git a/pkg/util/vector/vectorpb/config.go b/pkg/util/vector/vectorpb/config.go index 45462719a5d2..dcc57856a4a7 100644 --- a/pkg/util/vector/vectorpb/config.go +++ b/pkg/util/vector/vectorpb/config.go @@ -2,5 +2,5 @@ package vectorpb // IsEmpty returns whether the config contains an index configuration. func (cfg Config) IsEmpty() bool { - return cfg.IvfFlat == nil + return cfg.IndexType == nil } diff --git a/pkg/util/vector/vectorpb/config.proto b/pkg/util/vector/vectorpb/config.proto index fc6d5eeae14a..6885c7600742 100644 --- a/pkg/util/vector/vectorpb/config.proto +++ b/pkg/util/vector/vectorpb/config.proto @@ -10,19 +10,28 @@ import "gogoproto/gogo.proto"; // At the moment, only one major indexing strategy is implemented (ivfflat). message Config { option (gogoproto.equal) = true; - option (gogoproto.onlyone) = true; - IVFFlatConfig ivf_flat = 1; + oneof index_type { IVFFlatConfig ivf_flat = 1; } + DistanceFunction distance_function = 2; + int32 dimensions = 3; } -message IVFFlatConfig { - option (gogoproto.equal) = true; - Centroids centroids = 1; +enum DistanceFunction { + // INVALID is an invalid distance function. + INVALID = 0; + // L2 is the Euclidean distance. + L2 = 1; + // IP is the inner product. + IP = 2; + // Cosine is the cosine similarity. + COSINE = 3; } -message Centroids { +message IVFFlatConfig { option (gogoproto.equal) = true; - repeated Centroid centroids = 1; + repeated Centroid centroids = 1 [ (gogoproto.nullable) = false ]; + // NLists is the number of lists in the index. + int32 n_lists = 2; } message Centroid {