From a83e9ca0d39c595b8dff8944f6748177d5a20b3b Mon Sep 17 00:00:00 2001 From: congqixia Date: Mon, 30 Sep 2024 14:27:16 +0800 Subject: [PATCH] feat: Support unmarshal resultset into orm receiver (#827) Related to #800 --------- Signed-off-by: Congqi Xia --- client/data.go | 1 + client/results.go | 151 ++++++++++++++++++++++++++ client/results_test.go | 109 +++++++++++++++++++ test/testcases/groupby_search_test.go | 86 ++++++++++----- 4 files changed, 317 insertions(+), 30 deletions(-) create mode 100644 client/results_test.go diff --git a/client/data.go b/client/data.go index 0e61a71d..79e5a435 100644 --- a/client/data.go +++ b/client/data.go @@ -134,6 +134,7 @@ func (c *GrpcClient) handleSearchResult(schema *entity.Schema, outputFields []st for i := 0; i < int(results.GetNumQueries()); i++ { rc := int(results.GetTopks()[i]) // result entry count for current query entry := SearchResult{ + sch: schema, ResultCount: rc, Scores: results.GetScores()[offset : offset+rc], } diff --git a/client/results.go b/client/results.go index 418abb57..8aba95ac 100644 --- a/client/results.go +++ b/client/results.go @@ -1,6 +1,10 @@ package client import ( + "go/ast" + "reflect" + + "github.com/cockroachdb/errors" "github.com/milvus-io/milvus-sdk-go/v2/entity" ) @@ -9,6 +13,9 @@ import ( // Fields contains the data of `outputFieleds` specified or all columns if non // Scores is actually the distance between the vector current record contains and the search target vector type SearchResult struct { + // internal schema for unmarshaling + sch *entity.Schema + ResultCount int // the returning entry count GroupByValue entity.Column IDs entity.Column // auto generated id, can be mapped to the columns from `Insert` API @@ -44,6 +51,66 @@ func (sr *SearchResult) Slice(start, end int) *SearchResult { return result } +func (sr *SearchResult) Unmarshal(receiver interface{}) (err error) { + err = sr.Fields.Unmarshal(receiver) + if err != nil { + return err + } + return sr.fillPKEntry(receiver) +} + +func (sr *SearchResult) fillPKEntry(receiver interface{}) (err error) { + defer func() { + if x := recover(); x != nil { + err = errors.Newf("failed to unmarshal result set: %v", x) + } + }() + rr := reflect.ValueOf(receiver) + + if rr.Kind() == reflect.Ptr { + if rr.IsNil() && rr.CanAddr() { + rr.Set(reflect.New(rr.Type().Elem())) + } + rr = rr.Elem() + } + + rt := rr.Type() + rv := rr + + switch rt.Kind() { + case reflect.Slice: + pkField := sr.sch.PKField() + + et := rt.Elem() + for et.Kind() == reflect.Ptr { + et = et.Elem() + } + + candidates := parseCandidates(et) + candi, ok := candidates[pkField.Name] + if !ok { + // pk field not found in struct, skip + return nil + } + for i := 0; i < sr.IDs.Len(); i++ { + row := rv.Index(i) + for row.Kind() == reflect.Ptr { + row = row.Elem() + } + + val, err := sr.IDs.Get(i) + if err != nil { + return err + } + row.Field(candi).Set(reflect.ValueOf(val)) + } + rr.Set(rv) + default: + return errors.Newf("receiver need to be slice or array but get %v", rt.Kind()) + } + return nil +} + // ResultSet is an alias type for column slice. type ResultSet []entity.Column @@ -71,3 +138,87 @@ func (rs ResultSet) GetColumn(fieldName string) entity.Column { } return nil } + +func (rs ResultSet) Unmarshal(receiver interface{}) (err error) { + defer func() { + if x := recover(); x != nil { + err = errors.Newf("failed to unmarshal result set: %v", x) + } + }() + rr := reflect.ValueOf(receiver) + + if rr.Kind() == reflect.Ptr { + if rr.IsNil() && rr.CanAddr() { + rr.Set(reflect.New(rr.Type().Elem())) + } + rr = rr.Elem() + } + + rt := rr.Type() + rv := rr + + switch rt.Kind() { + // TODO maybe support Array and just fill data + // case reflect.Array: + case reflect.Slice: + et := rt.Elem() + if et.Kind() != reflect.Ptr { + return errors.Newf("receiver must be slice of pointers but get: %v", et.Kind()) + } + for et.Kind() == reflect.Ptr { + et = et.Elem() + } + for i := 0; i < rs.Len(); i++ { + data := reflect.New(et) + err := rs.fillData(data.Elem(), et, i) + if err != nil { + return err + } + rv = reflect.Append(rv, data) + } + rr.Set(rv) + default: + return errors.Newf("receiver need to be slice or array but get %v", rt.Kind()) + } + return nil +} + +func parseCandidates(dataType reflect.Type) map[string]int { + result := make(map[string]int) + for i := 0; i < dataType.NumField(); i++ { + f := dataType.Field(i) + // ignore anonymous field for now + if f.Anonymous || !ast.IsExported(f.Name) { + continue + } + + name := f.Name + tag := f.Tag.Get(entity.MilvusTag) + tagSettings := entity.ParseTagSetting(tag, entity.MilvusTagSep) + if tagName, has := tagSettings[entity.MilvusTagName]; has { + name = tagName + } + + result[name] = i + } + return result +} + +func (rs ResultSet) fillData(data reflect.Value, dataType reflect.Type, idx int) error { + m := parseCandidates(dataType) + for i := 0; i < len(rs); i++ { + name := rs[i].Name() + fidx, ok := m[name] + if !ok { + // maybe return error + continue + } + val, err := rs[i].Get(idx) + if err != nil { + return err + } + // TODO check datatype + data.Field(fidx).Set(reflect.ValueOf(val)) + } + return nil +} diff --git a/client/results_test.go b/client/results_test.go new file mode 100644 index 00000000..9181b70d --- /dev/null +++ b/client/results_test.go @@ -0,0 +1,109 @@ +package client + +import ( + "testing" + + "github.com/milvus-io/milvus-sdk-go/v2/entity" + "github.com/stretchr/testify/suite" +) + +type ResultSetSuite struct { + suite.Suite +} + +func (s *ResultSetSuite) TestResultsetUnmarshal() { + type MyData struct { + A int64 `milvus:"name:id"` + V []float32 `milvus:"name:vector"` + } + type OtherData struct { + A string `milvus:"name:id"` + V []float32 `milvus:"name:vector"` + } + + var ( + idData = []int64{1, 2, 3} + vectorData = [][]float32{ + {0.1, 0.2}, + {0.1, 0.2}, + {0.1, 0.2}, + } + ) + + rs := ResultSet([]entity.Column{ + entity.NewColumnInt64("id", idData), + entity.NewColumnFloatVector("vector", 2, vectorData), + }) + err := rs.Unmarshal([]MyData{}) + s.Error(err) + + receiver := []MyData{} + err = rs.Unmarshal(&receiver) + s.Error(err) + + var ptrReceiver []*MyData + err = rs.Unmarshal(&ptrReceiver) + s.NoError(err) + + for idx, row := range ptrReceiver { + s.Equal(row.A, idData[idx]) + s.Equal(row.V, vectorData[idx]) + } + + var otherReceiver []*OtherData + err = rs.Unmarshal(&otherReceiver) + s.Error(err) +} + +func (s *ResultSetSuite) TestSearchResultUnmarshal() { + type MyData struct { + A int64 `milvus:"name:id"` + V []float32 `milvus:"name:vector"` + } + type OtherData struct { + A string `milvus:"name:id"` + V []float32 `milvus:"name:vector"` + } + + var ( + idData = []int64{1, 2, 3} + vectorData = [][]float32{ + {0.1, 0.2}, + {0.1, 0.2}, + {0.1, 0.2}, + } + ) + + sr := SearchResult{ + sch: entity.NewSchema(). + WithField(entity.NewField().WithName("id").WithIsPrimaryKey(true).WithDataType(entity.FieldTypeInt64)). + WithField(entity.NewField().WithName("vector").WithDim(2).WithDataType(entity.FieldTypeFloatVector)), + IDs: entity.NewColumnInt64("id", idData), + Fields: ResultSet([]entity.Column{ + entity.NewColumnFloatVector("vector", 2, vectorData), + }), + } + err := sr.Unmarshal([]MyData{}) + s.Error(err) + + receiver := []MyData{} + err = sr.Unmarshal(&receiver) + s.Error(err) + + var ptrReceiver []*MyData + err = sr.Unmarshal(&ptrReceiver) + s.NoError(err) + + for idx, row := range ptrReceiver { + s.Equal(row.A, idData[idx]) + s.Equal(row.V, vectorData[idx]) + } + + var otherReceiver []*OtherData + err = sr.Unmarshal(&otherReceiver) + s.Error(err) +} + +func TestResults(t *testing.T) { + suite.Run(t, new(ResultSetSuite)) +} diff --git a/test/testcases/groupby_search_test.go b/test/testcases/groupby_search_test.go index c2d4b665..8bc1285d 100644 --- a/test/testcases/groupby_search_test.go +++ b/test/testcases/groupby_search_test.go @@ -64,13 +64,17 @@ func prepareDataForGroupBySearch(t *testing.T, loopInsert int, insertNi int, idx mc := createMilvusClient(ctx, t) // create collection with all datatype - cp := CollectionParams{CollectionFieldsType: AllFields, AutoID: false, EnableDynamicField: true, - ShardsNum: common.DefaultShards, Dim: common.DefaultDim} + cp := CollectionParams{ + CollectionFieldsType: AllFields, AutoID: false, EnableDynamicField: true, + ShardsNum: common.DefaultShards, Dim: common.DefaultDim, + } collName := createCollection(ctx, t, mc, cp) // insert - dp := DataParams{CollectionName: collName, PartitionName: "", CollectionFieldsType: AllFields, - start: 0, nb: insertNi, dim: common.DefaultDim, EnableDynamicField: true, WithRows: false} + dp := DataParams{ + CollectionName: collName, PartitionName: "", CollectionFieldsType: AllFields, + start: 0, nb: insertNi, dim: common.DefaultDim, EnableDynamicField: true, WithRows: false, + } for i := 0; i < loopInsert; i++ { _, _ = insertData(ctx, t, mc, dp) } @@ -79,9 +83,11 @@ func prepareDataForGroupBySearch(t *testing.T, loopInsert int, insertNi int, idx mc.Flush(ctx, collName, false) } - //create scalar index - supportedGroupByFields := []string{common.DefaultIntFieldName, common.DefaultInt8FieldName, common.DefaultInt16FieldName, - common.DefaultInt32FieldName, common.DefaultVarcharFieldName, common.DefaultBoolFieldName} + // create scalar index + supportedGroupByFields := []string{ + common.DefaultIntFieldName, common.DefaultInt8FieldName, common.DefaultInt16FieldName, + common.DefaultInt32FieldName, common.DefaultVarcharFieldName, common.DefaultBoolFieldName, + } for _, groupByField := range supportedGroupByFields { err := mc.CreateIndex(ctx, collName, groupByField, entity.NewScalarIndex(), false) common.CheckErr(t, err, true) @@ -148,10 +154,12 @@ func TestSearchGroupByFloatDefault(t *testing.T) { expr = fmt.Sprintf("%s == %v", groupByField, groupByValue) } // search filter with groupByValue is the top1 - resFilter, _ := mc.Search(ctx, collName, []string{}, expr, []string{common.DefaultIntFieldName, - groupByField}, []entity.Vector{queryVec[i]}, common.DefaultFloatVecFieldName, metricType, 1, sp) + resFilter, _ := mc.Search(ctx, collName, []string{}, expr, []string{ + common.DefaultIntFieldName, + groupByField, + }, []entity.Vector{queryVec[i]}, common.DefaultFloatVecFieldName, metricType, 1, sp) filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0) - //log.Printf("Search top1 with %s: groupByValue: %v, pkValue: %d. The returned pk by filter search is: %d", + // log.Printf("Search top1 with %s: groupByValue: %v, pkValue: %d. The returned pk by filter search is: %d", // groupByField, groupByValue, pkValue, filterTop1Pk) if filterTop1Pk == pkValue { hitsNum += 1 @@ -185,13 +193,17 @@ func TestGroupBySearchSparseVector(t *testing.T) { mc := createMilvusClient(ctx, t) // create -> insert [0, 3000) -> flush -> index -> load - cp := CollectionParams{CollectionFieldsType: Int64VarcharSparseVec, AutoID: false, EnableDynamicField: true, - ShardsNum: common.DefaultShards, Dim: common.DefaultDim, MaxLength: common.TestMaxLen} + cp := CollectionParams{ + CollectionFieldsType: Int64VarcharSparseVec, AutoID: false, EnableDynamicField: true, + ShardsNum: common.DefaultShards, Dim: common.DefaultDim, MaxLength: common.TestMaxLen, + } collName := createCollection(ctx, t, mc, cp, client.WithConsistencyLevel(entity.ClStrong)) // insert data - dp := DataParams{DoInsert: true, CollectionName: collName, CollectionFieldsType: Int64VarcharSparseVec, start: 0, - nb: 200, dim: common.DefaultDim, EnableDynamicField: true} + dp := DataParams{ + DoInsert: true, CollectionName: collName, CollectionFieldsType: Int64VarcharSparseVec, start: 0, + nb: 200, dim: common.DefaultDim, EnableDynamicField: true, + } for i := 0; i < 100; i++ { _, _ = insertData(ctx, t, mc, dp) } @@ -219,8 +231,10 @@ func TestGroupBySearchSparseVector(t *testing.T) { pkValue, _ := resGroupBy[i].IDs.GetAsInt64(j) expr := fmt.Sprintf("%s == '%v' ", common.DefaultVarcharFieldName, groupByValue) // search filter with groupByValue is the top1 - resFilter, _ := mc.Search(ctx, collName, []string{}, expr, []string{common.DefaultIntFieldName, - common.DefaultVarcharFieldName}, []entity.Vector{queryVec[i]}, common.DefaultSparseVecFieldName, entity.IP, 1, sp) + resFilter, _ := mc.Search(ctx, collName, []string{}, expr, []string{ + common.DefaultIntFieldName, + common.DefaultVarcharFieldName, + }, []entity.Vector{queryVec[i]}, common.DefaultSparseVecFieldName, entity.IP, 1, sp) filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0) log.Printf("Search top1 with %s: groupByValue: %v, pkValue: %d. The returned pk by filter search is: %d", common.DefaultVarcharFieldName, groupByValue, pkValue, filterTop1Pk) @@ -251,13 +265,17 @@ func TestSearchGroupByBinaryDefault(t *testing.T) { mc := createMilvusClient(ctx, t) // create collection with all datatype - cp := CollectionParams{CollectionFieldsType: VarcharBinaryVec, AutoID: false, EnableDynamicField: true, - ShardsNum: common.DefaultShards, Dim: common.DefaultDim} + cp := CollectionParams{ + CollectionFieldsType: VarcharBinaryVec, AutoID: false, EnableDynamicField: true, + ShardsNum: common.DefaultShards, Dim: common.DefaultDim, + } collName := createCollection(ctx, t, mc, cp) // insert - dp := DataParams{CollectionName: collName, PartitionName: "", CollectionFieldsType: VarcharBinaryVec, - start: 0, nb: 1000, dim: common.DefaultDim, EnableDynamicField: true, WithRows: false} + dp := DataParams{ + CollectionName: collName, PartitionName: "", CollectionFieldsType: VarcharBinaryVec, + start: 0, nb: 1000, dim: common.DefaultDim, EnableDynamicField: true, WithRows: false, + } for i := 0; i < 2; i++ { _, _ = insertData(ctx, t, mc, dp) } @@ -296,8 +314,10 @@ func TestSearchGroupByBinaryGrowing(t *testing.T) { mc := createMilvusClient(ctx, t) // create collection with all datatype - cp := CollectionParams{CollectionFieldsType: VarcharBinaryVec, AutoID: false, EnableDynamicField: true, - ShardsNum: common.DefaultShards, Dim: common.DefaultDim} + cp := CollectionParams{ + CollectionFieldsType: VarcharBinaryVec, AutoID: false, EnableDynamicField: true, + ShardsNum: common.DefaultShards, Dim: common.DefaultDim, + } collName := createCollection(ctx, t, mc, cp) // create index and load @@ -307,8 +327,10 @@ func TestSearchGroupByBinaryGrowing(t *testing.T) { common.CheckErr(t, err, true) // insert - dp := DataParams{CollectionName: collName, PartitionName: "", CollectionFieldsType: VarcharBinaryVec, - start: 0, nb: 1000, dim: common.DefaultDim, EnableDynamicField: true, WithRows: false} + dp := DataParams{ + CollectionName: collName, PartitionName: "", CollectionFieldsType: VarcharBinaryVec, + start: 0, nb: 1000, dim: common.DefaultDim, EnableDynamicField: true, WithRows: false, + } _, _ = insertData(ctx, t, mc, dp) // search params @@ -318,8 +340,10 @@ func TestSearchGroupByBinaryGrowing(t *testing.T) { // search with groupBy field for _, groupByField := range supportedGroupByFields { - _, err := mc.Search(ctx, collName, []string{}, "", []string{common.DefaultVarcharFieldName, - groupByField}, queryVec, common.DefaultBinaryVecFieldName, metricType, common.DefaultTopK, sp, + _, err := mc.Search(ctx, collName, []string{}, "", []string{ + common.DefaultVarcharFieldName, + groupByField, + }, queryVec, common.DefaultBinaryVecFieldName, metricType, common.DefaultTopK, sp, client.WithGroupByField(groupByField), client.WithSearchQueryConsistencyLevel(entity.ClStrong)) common.CheckErr(t, err, false, "not support search_group_by operation based on binary vector column") } @@ -356,12 +380,14 @@ func TestSearchGroupByFloatGrowing(t *testing.T) { } else { expr = fmt.Sprintf("%s == %v", groupByField, groupByValue) } - resFilter, _ := mc.Search(ctx, collName, []string{}, expr, []string{common.DefaultIntFieldName, - groupByField}, []entity.Vector{queryVec[i]}, common.DefaultFloatVecFieldName, metricType, 1, sp, client.WithSearchQueryConsistencyLevel(entity.ClStrong)) + resFilter, _ := mc.Search(ctx, collName, []string{}, expr, []string{ + common.DefaultIntFieldName, + groupByField, + }, []entity.Vector{queryVec[i]}, common.DefaultFloatVecFieldName, metricType, 1, sp, client.WithSearchQueryConsistencyLevel(entity.ClStrong)) // search filter with groupByValue is the top1 filterTop1Pk, _ := resFilter[0].IDs.GetAsInt64(0) - //log.Printf("Search top1 with %s: groupByValue: %v, pkValue: %d. The returned pk by filter search is: %d", + // log.Printf("Search top1 with %s: groupByValue: %v, pkValue: %d. The returned pk by filter search is: %d", // groupByField, groupByValue, pkValue, filterTop1Pk) if filterTop1Pk == pkValue { hitsNum += 1 @@ -390,7 +416,7 @@ func TestSearchGroupByPagination(t *testing.T) { // search params queryVec := common.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector) sp, _ := entity.NewIndexIvfFlatSearchParam(32) - var offset = int64(10) + offset := int64(10) // search pagination & groupBy resGroupByPagination, _ := mc.Search(ctx, collName, []string{}, "", []string{common.DefaultIntFieldName, common.DefaultVarcharFieldName},