Skip to content

Commit

Permalink
*: support fixed dimension vector (#55002)
Browse files Browse the repository at this point in the history
ref #54245
  • Loading branch information
EricZequan authored Aug 6, 2024
1 parent 5389de9 commit 7fff125
Show file tree
Hide file tree
Showing 9 changed files with 182 additions and 20 deletions.
2 changes: 2 additions & 0 deletions pkg/ddl/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,8 @@ func needChangeColumnData(oldCol, newCol *model.ColumnInfo) bool {
if types.IsBinaryStr(&oldCol.FieldType) {
return newCol.GetFlen() != oldCol.GetFlen()
}
case mysql.TypeTiDBVectorFloat32:
return newCol.GetFlen() != types.UnspecifiedLength && oldCol.GetFlen() != newCol.GetFlen()
}

return needTruncationOrToggleSign()
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func newReturnFieldTypeForBaseBuiltinFunc(funcName string, retType types.EvalTyp
case types.ETJson:
fieldType = types.NewFieldTypeBuilder().SetType(mysql.TypeJSON).SetFlag(mysql.BinaryFlag).SetFlen(mysql.MaxBlobWidth).SetCharset(mysql.DefaultCharset).SetCollate(mysql.DefaultCollationName).BuildP()
case types.ETVectorFloat32:
fieldType = types.NewFieldTypeBuilder().SetType(mysql.TypeTiDBVectorFloat32).SetFlag(mysql.BinaryFlag).SetFlen(mysql.MaxBlobWidth).BuildP()
fieldType = types.NewFieldTypeBuilder().SetType(mysql.TypeTiDBVectorFloat32).SetFlag(mysql.BinaryFlag).SetFlen(types.UnspecifiedLength).BuildP()
}
if mysql.HasBinaryFlag(fieldType.GetFlag()) && fieldType.GetType() != mysql.TypeJSON {
fieldType.SetCharset(charset.CharsetBin)
Expand Down
19 changes: 16 additions & 3 deletions pkg/expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -781,8 +781,14 @@ func (b *builtinCastStringAsVectorFloat32Sig) evalVectorFloat32(ctx EvalContext,
if isNull || err != nil {
return types.ZeroVectorFloat32, isNull, err
}
res, err := types.ParseVectorFloat32(val)
return res, false, err
vec, err := types.ParseVectorFloat32(val)
if err != nil {
return types.ZeroVectorFloat32, false, err
}
if err = vec.CheckDimsFitColumn(b.tp.GetFlen()); err != nil {
return types.ZeroVectorFloat32, isNull, err
}
return vec, false, nil
}

type builtinCastVectorFloat32AsVectorFloat32Sig struct {
Expand All @@ -796,7 +802,14 @@ func (b *builtinCastVectorFloat32AsVectorFloat32Sig) Clone() builtinFunc {
}

func (b *builtinCastVectorFloat32AsVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) {
return b.args[0].EvalVectorFloat32(ctx, row)
val, isNull, err := b.args[0].EvalVectorFloat32(ctx, row)
if isNull || err != nil {
return types.ZeroVectorFloat32, isNull, err
}
if err = val.CheckDimsFitColumn(b.tp.GetFlen()); err != nil {
return types.ZeroVectorFloat32, isNull, err
}
return val, false, nil
}

type builtinCastIntAsIntSig struct {
Expand Down
5 changes: 4 additions & 1 deletion pkg/expression/builtin_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ func (b *builtinVecFromTextSig) evalVectorFloat32(ctx EvalContext, row chunk.Row

vec, err := types.ParseVectorFloat32(v)
if err != nil {
return res, false, err
return types.ZeroVectorFloat32, false, err
}
if err = vec.CheckDimsFitColumn(b.tp.GetFlen()); err != nil {
return types.ZeroVectorFloat32, isNull, err
}

return vec, false, nil
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/integration_test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ go_test(
"main_test.go",
],
flaky = True,
shard_count = 33,
shard_count = 35,
deps = [
"//pkg/config",
"//pkg/domain",
Expand Down
116 changes: 111 additions & 5 deletions pkg/expression/integration_test/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,90 @@ import (
"github.com/tikv/client-go/v2/oracle"
)

func TestVectorColumnInfo(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")

// Create vector type column without specified dimension.
tk.MustExec("create table t(embedding VECTOR)")
tk.MustExec("drop table if exists t;")
tk.MustExec("create table t(embedding VECTOR<FLOAT>)")

// SHOW CREATE TABLE
tk.MustQuery("show create table t").Check(testkit.Rows(
"t CREATE TABLE `t` (\n" +
" `embedding` vector<float> DEFAULT NULL\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin",
))

// SHOW COLUMNS
tk.MustQuery("show columns from t").Check(testkit.Rows(
"embedding vector<float> YES <nil> ",
))

// Create vector type column with specified dimension.
tk.MustExec("drop table if exists t;")
tk.MustExec("create table t(embedding VECTOR(3))")
tk.MustExec("drop table if exists t;")
tk.MustExec("create table t(embedding VECTOR<FLOAT>(3))")
tk.MustExec("drop table if exists t;")
tk.MustExec("create table t(embedding VECTOR<FLOAT>(0))")

// SHOW CREATE TABLE
tk.MustExec("drop table if exists t;")
tk.MustExec("create table t(embedding VECTOR(3))")
tk.MustQuery("show create table t").Check(testkit.Rows(
"t CREATE TABLE `t` (\n" +
" `embedding` vector<float>(3) DEFAULT NULL\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin",
))

// SHOW COLUMNS
tk.MustQuery("show columns from t").Check(testkit.Rows(
"embedding vector<float>(3) YES <nil> ",
))

// INFORMATION_SCHEMA.COLUMNS
tk.MustQuery("SELECT data_type, column_type FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = 't'").Check(testkit.Rows(
"vector<float> vector<float>(3)",
))

// Vector dimension MUST be equal or less than 16383.
tk.MustExec("drop table if exists t;")
tk.MustGetErrMsg("create table t(embedding VECTOR<FLOAT>(16384))", "vector cannot have more than 16383 dimensions")
}

func TestFixedVector(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")

tk.MustExec("create table t(embedding VECTOR)")
tk.MustExec("insert into t values ('[1,2,3]')")
tk.MustExec("insert into t values ('[1,2,3,4]')")

// Failed to modify column type cause vectors with different dimension.
tk.MustContainErrMsg("alter table t modify column embedding VECTOR(3)", "vector has 4 dimensions, does not fit VECTOR(3)")

// Mixed dimension to fixed dimension.
tk.MustExec("delete from t where vec_dims(embedding) != 3")
tk.MustExec("alter table t modify column embedding VECTOR(3)")
tk.MustGetErrMsg("insert into t values ('[]')", "vector has 0 dimensions, does not fit VECTOR(3)")
tk.MustGetErrMsg("insert into t values ('[1,2,3,4]')", "vector has 4 dimensions, does not fit VECTOR(3)")
tk.MustGetErrMsg("insert into t values (VEC_FROM_TEXT('[]'))", "vector has 0 dimensions, does not fit VECTOR(3)")
tk.MustGetErrMsg("insert into t values (VEC_FROM_TEXT('[1,2,3,4]'))", "vector has 4 dimensions, does not fit VECTOR(3)")
tk.MustGetErrMsg("update t set embedding = '[1,2,3,4]' where embedding = '[1,2,3]'", "vector has 4 dimensions, does not fit VECTOR(3)")
tk.MustGetErrMsg("update t set embedding = '[]' where embedding = '[1,2,3]'", "vector has 0 dimensions, does not fit VECTOR(3)")

// Fixed dimension to mixed dimension.
tk.MustExec("alter table t modify column embedding VECTOR")
tk.MustExec("insert into t values ('[1,2,3,4]')")

// Vector dimension MUST be equal or less than 16383.
tk.MustGetErrMsg("alter table t modify column embedding VECTOR(16384)", "vector cannot have more than 16383 dimensions")
}

func TestVector(t *testing.T) {
store := testkit.CreateMockStore(t)

Expand Down Expand Up @@ -106,6 +190,7 @@ func TestVectorOperators(t *testing.T) {

tk := testkit.NewTestKit(t, store)
tk.MustExec("USE test;")

tk.MustExec(`CREATE TABLE t(embedding VECTOR);`)
tk.MustExec(`INSERT INTO t VALUES
('[1, 2, 3]'),
Expand All @@ -119,7 +204,7 @@ func TestVectorOperators(t *testing.T) {
tk.MustQuery(`SELECT VEC_FROM_TEXT('[]') IS NOT NULL`).Check(testkit.Rows("1"))
tk.MustQuery(`SELECT VEC_FROM_TEXT('[]') IS NULL`).Check(testkit.Rows("0"))
tk.MustQuery(`SELECT * FROM t WHERE embedding = VEC_FROM_TEXT('[1,2,3]');`).Check(testkit.Rows("[1,2,3]"))
tk.MustQuery(`SELECT * FROM t WHERE embedding BETWEEN '[1, 2, 3]' AND '[4, 5, 6]'`).Check(testkit.Rows("[1,2,3]", "[4,5,6]"))
tk.MustQuery(`SELECT * FROM t WHERE embedding BETWEEN '[1,2,3]' AND '[4,5,6]'`).Check(testkit.Rows("[1,2,3]", "[4,5,6]"))
tk.MustExecToErr(`SELECT * FROM t WHERE embedding IN ('[1, 2, 3]', '[4, 5, 6]')`)
tk.MustExecToErr(`SELECT * FROM t WHERE embedding NOT IN ('[1, 2, 3]', '[4, 5, 6]')`)
}
Expand Down Expand Up @@ -173,10 +258,18 @@ func TestVectorConversion(t *testing.T) {
tk.MustQueryToErr("SELECT CAST(VEC_FROM_TEXT('[1,2,3]') AS DATE);")
tk.MustQueryToErr("SELECT CAST(VEC_FROM_TEXT('[1,2,3]') AS TIME);")

// expect error result
tk.MustExecToErr("SELECT CAST('[1,2,3]' AS VECTOR);")
tk.MustExecToErr("SELECT CAST('[1,2,3]' AS VECTOR<FLOAT>);")
tk.MustExecToErr("SELECT CAST('[1,2,3]' AS VECTOR<DOUBLE>);")
tk.MustQuery("SELECT CAST('[1,2,3]' AS VECTOR);").Check(testkit.Rows("[1,2,3]"))
tk.MustQuery("SELECT CAST('[]' AS VECTOR);").Check(testkit.Rows("[]"))
tk.MustQuery("SELECT CAST('[1,2,3]' AS VECTOR<FLOAT>);").Check(testkit.Rows("[1,2,3]"))
tk.MustContainErrMsg("SELECT CAST('[1,2,3]' AS VECTOR<DOUBLE>);", "Only VECTOR is supported for now")

tk.MustQuery("SELECT CAST('[1,2,3]' AS VECTOR<FLOAT>(3));").Check(testkit.Rows("[1,2,3]"))
err := tk.QueryToErr("SELECT CAST('[1,2,3]' AS VECTOR<FLOAT>(2));")
require.EqualError(t, err, "vector has 3 dimensions, does not fit VECTOR(2)")

tk.MustQuery("SELECT CAST(VEC_FROM_TEXT('[1,2,3]') AS VECTOR<FLOAT>(3));").Check(testkit.Rows("[1,2,3]"))
err = tk.QueryToErr("SELECT CAST(VEC_FROM_TEXT('[1,2,3]') AS VECTOR<FLOAT>(2));")
require.EqualError(t, err, "vector has 3 dimensions, does not fit VECTOR(2)")

// CONVERT
tk.MustQuery("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), BINARY);").Check(testkit.Rows("[1,2,3]"))
Expand All @@ -192,6 +285,19 @@ func TestVectorConversion(t *testing.T) {
tk.MustQueryToErr("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), DATETIME);")
tk.MustQueryToErr("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), DATE);")
tk.MustQueryToErr("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), TIME);")

tk.MustQuery("SELECT CONVERT('[1,2,3]', VECTOR);").Check(testkit.Rows("[1,2,3]"))
tk.MustQuery("SELECT CONVERT('[]', VECTOR);").Check(testkit.Rows("[]"))
tk.MustQuery("SELECT CONVERT('[1,2,3]', VECTOR<FLOAT>);").Check(testkit.Rows("[1,2,3]"))
tk.MustContainErrMsg("SELECT CONVERT('[1,2,3]', VECTOR<DOUBLE>);", "Only VECTOR is supported for now")

tk.MustQuery("SELECT CONVERT('[1,2,3]', VECTOR<FLOAT>(3));").Check(testkit.Rows("[1,2,3]"))
err = tk.QueryToErr("SELECT CONVERT('[1,2,3]', VECTOR<FLOAT>(2));")
require.EqualError(t, err, "vector has 3 dimensions, does not fit VECTOR(2)")

tk.MustQuery("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), VECTOR<FLOAT>(3));").Check(testkit.Rows("[1,2,3]"))
err = tk.QueryToErr("SELECT CONVERT(VEC_FROM_TEXT('[1,2,3]'), VECTOR<FLOAT>(2));")
require.EqualError(t, err, "vector has 3 dimensions, does not fit VECTOR(2)")
}

func TestVectorAggregations(t *testing.T) {
Expand Down
10 changes: 6 additions & 4 deletions pkg/planner/core/preprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -1457,6 +1457,12 @@ func checkColumn(colDef *ast.ColumnDef) error {
if tp.GetFlen() > mysql.MaxBitDisplayWidth {
return types.ErrTooBigDisplayWidth.GenWithStackByArgs(colDef.Name.Name.O, mysql.MaxBitDisplayWidth)
}
case mysql.TypeTiDBVectorFloat32:
if tp.GetFlen() != types.UnspecifiedLength {
if err := types.CheckVectorDimValid(tp.GetFlen()); err != nil {
return err
}
}
default:
// TODO: Add more types.
}
Expand Down Expand Up @@ -1742,10 +1748,6 @@ func (p *preprocessor) checkFuncCastExpr(node *ast.FuncCastExpr) {
return
}
}
if node.Tp.GetType() == mysql.TypeTiDBVectorFloat32 {
p.err = errors.Errorf("vector type is not supported")
return
}
}

func (p *preprocessor) updateStateFromStaleReadProcessor() error {
Expand Down
16 changes: 12 additions & 4 deletions pkg/types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ func (d *Datum) compareMysqlTime(ctx Context, time Time) (int, error) {
}
}

func (d *Datum) compareVectorFloat32(sc Context, vec VectorFloat32) (int, error) {
func (d *Datum) compareVectorFloat32(ctx Context, vec VectorFloat32) (int, error) {
switch d.k {
case KindNull, KindMinNotNull:
return -1, nil
Expand Down Expand Up @@ -1796,15 +1796,23 @@ func (d *Datum) convertToMysqlJSON(_ *FieldType) (ret Datum, err error) {
return ret, errors.Trace(err)
}

func (d *Datum) convertToVectorFloat32(_ Context, _ *FieldType) (ret Datum, err error) {
func (d *Datum) convertToVectorFloat32(_ Context, target *FieldType) (ret Datum, err error) {
switch d.k {
case KindVectorFloat32:
v := d.GetVectorFloat32()
if err = v.CheckDimsFitColumn(target.GetFlen()); err != nil {
return ret, errors.Trace(err)
}
ret = *d
case KindString, KindBytes:
var v VectorFloat32
if v, err = ParseVectorFloat32(d.GetString()); err == nil {
ret.SetVectorFloat32(v)
if v, err = ParseVectorFloat32(d.GetString()); err != nil {
return ret, errors.Trace(err)
}
if err = v.CheckDimsFitColumn(target.GetFlen()); err != nil {
return ret, errors.Trace(err)
}
ret.SetVectorFloat32(v)
default:
return invalidConv(d, mysql.TypeTiDBVectorFloat32)
}
Expand Down
30 changes: 29 additions & 1 deletion pkg/types/vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

jsoniter "github.com/json-iterator/go"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/parser/types"
)

func init() {
Expand Down Expand Up @@ -55,6 +56,28 @@ func InitVectorFloat32(dims int) VectorFloat32 {
return VectorFloat32{data: data}
}

// CheckVectorDimValid checks if the vector's dimension is valid.
func CheckVectorDimValid(dim int) error {
const (
maxVectorDimension = 16383
)
if dim < 0 {
return errors.Errorf("dimensions for type vector must be at least 0")
}
if dim > maxVectorDimension {
return errors.Errorf("vector cannot have more than %d dimensions", maxVectorDimension)
}
return nil
}

// CheckDimsFitColumn checks if the vector has the expected dimension, which is defined by the column type or cast type.
func (v VectorFloat32) CheckDimsFitColumn(expectedFlen int) error {
if expectedFlen != types.UnspecifiedLength && v.Len() != expectedFlen {
return errors.Errorf("vector has %d dimensions, does not fit VECTOR(%d)", v.Len(), expectedFlen)
}
return nil
}

// Len returns the length (dimension) of the vector.
func (v VectorFloat32) Len() int {
return int(binary.LittleEndian.Uint32(v.data))
Expand Down Expand Up @@ -139,7 +162,12 @@ func ParseVectorFloat32(s string) (VectorFloat32, error) {
return ZeroVectorFloat32, errors.Errorf("Invalid vector text: %s", s)
}

vec := InitVectorFloat32(len(values))
dim := len(values)
if err := CheckVectorDimValid(dim); err != nil {
return ZeroVectorFloat32, err
}

vec := InitVectorFloat32(dim)
copy(vec.Elements(), values)
return vec, nil
}
Expand Down

0 comments on commit 7fff125

Please sign in to comment.