Skip to content

Commit

Permalink
Merge pull request #32 from basenana/feature/search_knn
Browse files Browse the repository at this point in the history
add parentId in search vector
  • Loading branch information
zwwhdls authored Dec 6, 2023
2 parents d93081b + c9351b1 commit cfd37d8
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 68 deletions.
2 changes: 1 addition & 1 deletion cmd/apps/question.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ var QuestionCmd = &cobra.Command{
}

func run(question string) error {
a, err := friday.Fri.Question(context.TODO(), question)
a, err := friday.Fri.Question(context.TODO(), 0, question)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/friday/ingest.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (f *Friday) IngestFromFile(ctx context.Context, file models.File) error {
func (f *Friday) Ingest(ctx context.Context, elements []models.Element) error {
f.Log.Debugf("Ingesting %d ...", len(elements))
for i, element := range elements {
exist, err := f.Vector.Get(ctx, element.Name, element.Group)
exist, err := f.Vector.Get(ctx, element.OID, element.Name, element.Group)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/friday/ingest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ func (f FakeStore) Store(ctx context.Context, element *models.Element, extra map
return nil
}

func (f FakeStore) Search(ctx context.Context, vectors []float32, k int) ([]*models.Doc, error) {
func (f FakeStore) Search(ctx context.Context, parentId int64, vectors []float32, k int) ([]*models.Doc, error) {
return []*models.Doc{}, nil
}

func (f FakeStore) Get(ctx context.Context, name string, group int) (*models.Element, error) {
func (f FakeStore) Get(ctx context.Context, oid int64, name string, group int) (*models.Element, error) {
return &models.Element{}, nil
}

Expand Down
8 changes: 4 additions & 4 deletions pkg/friday/question.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ import (
"github.com/basenana/friday/pkg/llm/prompts"
)

func (f *Friday) Question(ctx context.Context, q string) (string, error) {
func (f *Friday) Question(ctx context.Context, parentId int64, q string) (string, error) {
prompt := prompts.NewQuestionPrompt(f.Prompts[questionPromptKey])
c, err := f.searchDocs(ctx, q)
c, err := f.searchDocs(ctx, parentId, q)
if err != nil {
return "", err
}
Expand All @@ -44,13 +44,13 @@ func (f *Friday) Question(ctx context.Context, q string) (string, error) {
return c, nil
}

func (f *Friday) searchDocs(ctx context.Context, q string) (string, error) {
func (f *Friday) searchDocs(ctx context.Context, parentId int64, q string) (string, error) {
f.Log.Debugf("vector query for %s ...", q)
qv, _, err := f.Embedding.VectorQuery(ctx, q)
if err != nil {
return "", fmt.Errorf("vector embedding error: %w", err)
}
docs, err := f.Vector.Search(ctx, qv, *f.VectorTopK)
docs, err := f.Vector.Search(ctx, parentId, qv, *f.VectorTopK)
if err != nil {
return "", fmt.Errorf("vector search error: %w", err)
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/friday/question_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ var _ = Describe("TestQuestion", func() {

Context("question", func() {
It("question should be succeed", func() {
ans, err := loFriday.Question(context.TODO(), "I am a question")
ans, err := loFriday.Question(context.TODO(), 0, "I am a question")
Expect(err).Should(BeNil())
Expect(ans).Should(Equal("I am an answer"))
})
It("searchDocs should be succeed", func() {
ans, err := loFriday.searchDocs(context.TODO(), "I am a question")
ans, err := loFriday.searchDocs(context.TODO(), 0, "I am a question")
Expect(err).Should(BeNil())
Expect(ans).Should(Equal("There are logs of questions"))
})
Expand All @@ -69,14 +69,14 @@ func (f FakeQuestionStore) Store(ctx context.Context, element *models.Element, e
return nil
}

func (f FakeQuestionStore) Search(ctx context.Context, vectors []float32, k int) ([]*models.Doc, error) {
func (f FakeQuestionStore) Search(ctx context.Context, parentId int64, vectors []float32, k int) ([]*models.Doc, error) {
return []*models.Doc{{
Id: "abc",
Content: "There are logs of questions",
}}, nil
}

func (f FakeQuestionStore) Get(ctx context.Context, name string, group int) (*models.Element, error) {
func (f FakeQuestionStore) Get(ctx context.Context, oid int64, name string, group int) (*models.Element, error) {
return &models.Element{}, nil
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/vectorstore/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ import (

type VectorStore interface {
Store(ctx context.Context, element *models.Element, extra map[string]any) error
Search(ctx context.Context, vectors []float32, k int) ([]*models.Doc, error)
Get(ctx context.Context, name string, group int) (*models.Element, error)
Search(ctx context.Context, parentId int64, vectors []float32, k int) ([]*models.Doc, error)
Get(ctx context.Context, oid int64, name string, group int) (*models.Element, error)
}
23 changes: 10 additions & 13 deletions pkg/vectorstore/pgvector/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ type Index struct {
ID string `gorm:"column:id;primaryKey"`
Name string `gorm:"column:name;type:varchar(256);index:index_name"`
OID int64 `gorm:"column:oid;index:index_oid"`
Group int `gorm:"column:group;index:index_group"`
ParentID *int64 `gorm:"column:parent_entry_id;index:index_parent_id"`
Group int `gorm:"column:idx_group;index:index_group"`
ParentID int64 `gorm:"column:parent_entry_id;index:index_parent_id"`
Content string `gorm:"column:content"`
Vector string `gorm:"column:vector;type:vector(1536)"`
Extra string `gorm:"column:metadata"`
Expand All @@ -52,16 +52,13 @@ func (v *Index) Update(vector *Index) {
}

func (v *Index) From(element *models.Element) *Index {
parentId := element.ParentId
i := &Index{
ID: element.ID,
Name: element.Name,
OID: element.OID,
Group: element.Group,
Content: element.Content,
}
if parentId != 0 {
i.ParentID = &parentId
ID: element.ID,
Name: element.Name,
OID: element.OID,
Group: element.Group,
Content: element.Content,
ParentID: element.ParentId,
}
return i
}
Expand All @@ -71,7 +68,7 @@ func (v *Index) To() *models.Doc {
OID: v.OID,
Name: v.Name,
Group: v.Group,
ParentId: *v.ParentID,
ParentId: v.ParentID,
Content: v.Content,
}
}
Expand All @@ -81,7 +78,7 @@ func (v *Index) ToElement() *models.Element {
OID: v.OID,
Name: v.Name,
Group: v.Group,
ParentId: *v.ParentID,
ParentId: v.ParentID,
Content: v.Content,
}
}
18 changes: 14 additions & 4 deletions pkg/vectorstore/pgvector/pgvector.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,20 @@ func (p *PgVectorClient) Store(ctx context.Context, element *models.Element, ext
})
}

func (p *PgVectorClient) Search(ctx context.Context, vectors []float32, k int) ([]*models.Doc, error) {
func (p *PgVectorClient) Search(ctx context.Context, parentId int64, vectors []float32, k int) ([]*models.Doc, error) {
var (
vectorModels = make([]Index, 0)
result = make([]*models.Doc, 0)
)
if err := p.dEntity.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
query := p.dEntity.DB.WithContext(ctx)
vectorJson, _ := json.Marshal(vectors)
res := query.Order(fmt.Sprintf("vector <-> '%s'", string(vectorJson))).Limit(k).Find(&vectorModels)
var res *gorm.DB
if parentId == 0 {
res = query.Order(fmt.Sprintf("vector <-> '%s'", string(vectorJson))).Limit(k).Find(&vectorModels)
} else {
res = query.Where("parent_entry_id = ?", parentId).Order(fmt.Sprintf("vector <-> '%s'", string(vectorJson))).Limit(k).Find(&vectorModels)
}
if res.Error != nil {
return res.Error
}
Expand All @@ -139,10 +144,15 @@ func (p *PgVectorClient) Search(ctx context.Context, vectors []float32, k int) (
return result, nil
}

func (p *PgVectorClient) Get(ctx context.Context, name string, group int) (*models.Element, error) {
func (p *PgVectorClient) Get(ctx context.Context, oid int64, name string, group int) (*models.Element, error) {
vModel := &Index{}
err := p.dEntity.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
res := tx.Where("name = ? AND group = ?", name, group).First(vModel)
var res *gorm.DB
if oid == 0 {
res = tx.Where("name = ? AND group = ?", name, group).First(vModel)
} else {
res = tx.Where("name = ? AND oid = ? AND group = ?", name, oid, group).First(vModel)
}
if res.Error != nil && res.Error != gorm.ErrRecordNotFound {
return res.Error
}
Expand Down
47 changes: 19 additions & 28 deletions pkg/vectorstore/postgres/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ type Index struct {
Name string `gorm:"column:name;index:index_name"`
OID int64 `gorm:"column:oid;index:index_oid"`
Group int `gorm:"column:idx_group;index:index_group"`
ParentID *int64 `gorm:"column:parent_entry_id;index:index_parent_id"`
ParentID int64 `gorm:"column:parent_entry_id;index:index_parent_id"`
Content string `gorm:"column:content"`
Vector string `gorm:"column:vector;type:json"`
Extra string `gorm:"column:extra"`
Expand All @@ -53,37 +53,31 @@ func (v *Index) Update(vector *Index) {
}

func (v *Index) From(element *models.Element) (*Index, error) {
parentId := element.ParentId
i := &Index{
ID: element.ID,
Name: element.Name,
OID: element.OID,
Group: element.Group,
Content: element.Content,
ID: element.ID,
Name: element.Name,
OID: element.OID,
Group: element.Group,
ParentID: element.ParentId,
Content: element.Content,
}
vector, err := json.Marshal(element.Vector)
if err != nil {
return nil, err
}
i.Vector = string(vector)

if parentId != 0 {
i.ParentID = &parentId
}
return i, nil
}

func (v *Index) To() (*models.Element, error) {
parentId := v.ParentID
res := &models.Element{
ID: v.ID,
Name: v.Name,
Group: v.Group,
OID: v.OID,
Content: v.Content,
}
if parentId != nil {
res.ParentId = *parentId
ID: v.ID,
Name: v.Name,
Group: v.Group,
OID: v.OID,
ParentId: v.ParentID,
Content: v.Content,
}
var vector []float32
err := json.Unmarshal([]byte(v.Vector), &vector)
Expand All @@ -96,16 +90,13 @@ func (v *Index) To() (*models.Element, error) {
}

func (v *Index) ToDoc() *models.Doc {
parentId := v.ParentID
res := &models.Doc{
Id: v.ID,
OID: v.OID,
Name: v.Name,
Group: v.Group,
Content: v.Content,
}
if parentId != nil {
res.ParentId = *parentId
Id: v.ID,
OID: v.OID,
Name: v.Name,
Group: v.Group,
Content: v.Content,
ParentId: v.ParentID,
}

return res
Expand Down
18 changes: 14 additions & 4 deletions pkg/vectorstore/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,19 @@ func (p *PostgresClient) Store(ctx context.Context, element *models.Element, ext
})
}

func (p *PostgresClient) Search(ctx context.Context, vectors []float32, k int) ([]*models.Doc, error) {
func (p *PostgresClient) Search(ctx context.Context, parentId int64, vectors []float32, k int) ([]*models.Doc, error) {
vectors64 := make([]float64, 0)
for _, v := range vectors {
vectors64 = append(vectors64, float64(v))
}
// query from db
existIndexes := make([]Index, 0)
res := p.dEntity.WithContext(ctx).Find(&existIndexes)
var res *gorm.DB
if parentId == 0 {
res = p.dEntity.WithContext(ctx).Find(&existIndexes)
} else {
res = p.dEntity.WithContext(ctx).Where("parent_entry_id = ?", parentId).Find(&existIndexes)
}
if res.Error != nil {
return nil, res.Error
}
Expand Down Expand Up @@ -155,9 +160,14 @@ func (p *PostgresClient) Search(ctx context.Context, vectors []float32, k int) (
return results, nil
}

func (p *PostgresClient) Get(ctx context.Context, name string, group int) (*models.Element, error) {
func (p *PostgresClient) Get(ctx context.Context, oid int64, name string, group int) (*models.Element, error) {
vModel := &Index{}
res := p.dEntity.WithContext(ctx).Where("name = ? AND idx_group = ?", name, group).First(vModel)
var res *gorm.DB
if oid == 0 {
res = p.dEntity.WithContext(ctx).Where("name = ? AND idx_group = ?", name, group).First(vModel)
} else {
res = p.dEntity.WithContext(ctx).Where("name = ? AND oid = ? AND idx_group = ?", name, oid, group).First(vModel)
}
if res.Error != nil {
if res.Error == gorm.ErrRecordNotFound {
return nil, nil
Expand Down
13 changes: 8 additions & 5 deletions pkg/vectorstore/redis/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (r RedisClient) Store(ctx context.Context, element *models.Element, extra m
FieldValue("vector", rueidis.VectorString32(element.Vector)).Build()).Error()
}

func (r RedisClient) Get(ctx context.Context, name string, group int) (*models.Element, error) {
func (r RedisClient) Get(ctx context.Context, oid int64, name string, group int) (*models.Element, error) {
resp, err := r.client.Do(ctx, r.client.B().Get().Key(fmt.Sprintf("%s:%s-%d", r.prefix, name, group)).Build()).ToMessage()
if err != nil {
return nil, err
Expand All @@ -126,10 +126,13 @@ func (r RedisClient) Get(ctx context.Context, name string, group int) (*models.E
return nil, err
}

oid, err := files.StrToInt64(res["oid"])
if err != nil {
return nil, err
if oid == 0 {
oid, err = files.StrToInt64(res["oid"])
if err != nil {
return nil, err
}
}

parentId, err := files.StrToInt64(res["parentid"])
if err != nil {
return nil, err
Expand All @@ -144,7 +147,7 @@ func (r RedisClient) Get(ctx context.Context, name string, group int) (*models.E
}, nil
}

func (r RedisClient) Search(ctx context.Context, vectors []float32, k int) ([]*models.Doc, error) {
func (r RedisClient) Search(ctx context.Context, parentId int64, vectors []float32, k int) ([]*models.Doc, error) {
resp, err := r.client.Do(ctx, r.client.B().FtSearch().Index(r.index).
Query("*=>[KNN 10 @vector $B AS vector_score]").
Return("4").Identifier("id").Identifier("content").
Expand Down

0 comments on commit cfd37d8

Please sign in to comment.