diff --git a/cmd/apps/question.go b/cmd/apps/question.go index 3186b5e..f2e22e2 100644 --- a/cmd/apps/question.go +++ b/cmd/apps/question.go @@ -39,7 +39,7 @@ var QuestionCmd = &cobra.Command{ } func run(question string) error { - a, err := friday.Fri.Question(context.TODO(), question) + a, err := friday.Fri.Question(context.TODO(), 0, question) if err != nil { return err } diff --git a/pkg/friday/ingest.go b/pkg/friday/ingest.go index b1afe39..b857fe5 100644 --- a/pkg/friday/ingest.go +++ b/pkg/friday/ingest.go @@ -50,7 +50,7 @@ func (f *Friday) IngestFromFile(ctx context.Context, file models.File) error { func (f *Friday) Ingest(ctx context.Context, elements []models.Element) error { f.Log.Debugf("Ingesting %d ...", len(elements)) for i, element := range elements { - exist, err := f.Vector.Get(ctx, element.Name, element.Group) + exist, err := f.Vector.Get(ctx, element.OID, element.Name, element.Group) if err != nil { return err } diff --git a/pkg/friday/ingest_test.go b/pkg/friday/ingest_test.go index 5009e1a..637092f 100644 --- a/pkg/friday/ingest_test.go +++ b/pkg/friday/ingest_test.go @@ -64,11 +64,11 @@ func (f FakeStore) Store(ctx context.Context, element *models.Element, extra map return nil } -func (f FakeStore) Search(ctx context.Context, vectors []float32, k int) ([]*models.Doc, error) { +func (f FakeStore) Search(ctx context.Context, parentId int64, vectors []float32, k int) ([]*models.Doc, error) { return []*models.Doc{}, nil } -func (f FakeStore) Get(ctx context.Context, name string, group int) (*models.Element, error) { +func (f FakeStore) Get(ctx context.Context, oid int64, name string, group int) (*models.Element, error) { return &models.Element{}, nil } diff --git a/pkg/friday/question.go b/pkg/friday/question.go index 326d3c2..e732a70 100644 --- a/pkg/friday/question.go +++ b/pkg/friday/question.go @@ -24,9 +24,9 @@ import ( "github.com/basenana/friday/pkg/llm/prompts" ) -func (f *Friday) Question(ctx context.Context, q string) (string, error) { +func (f *Friday) Question(ctx context.Context, parentId int64, q string) (string, error) { prompt := prompts.NewQuestionPrompt(f.Prompts[questionPromptKey]) - c, err := f.searchDocs(ctx, q) + c, err := f.searchDocs(ctx, parentId, q) if err != nil { return "", err } @@ -44,13 +44,13 @@ func (f *Friday) Question(ctx context.Context, q string) (string, error) { return c, nil } -func (f *Friday) searchDocs(ctx context.Context, q string) (string, error) { +func (f *Friday) searchDocs(ctx context.Context, parentId int64, q string) (string, error) { f.Log.Debugf("vector query for %s ...", q) qv, _, err := f.Embedding.VectorQuery(ctx, q) if err != nil { return "", fmt.Errorf("vector embedding error: %w", err) } - docs, err := f.Vector.Search(ctx, qv, *f.VectorTopK) + docs, err := f.Vector.Search(ctx, parentId, qv, *f.VectorTopK) if err != nil { return "", fmt.Errorf("vector search error: %w", err) } diff --git a/pkg/friday/question_test.go b/pkg/friday/question_test.go index 850cddb..db5dac8 100644 --- a/pkg/friday/question_test.go +++ b/pkg/friday/question_test.go @@ -49,12 +49,12 @@ var _ = Describe("TestQuestion", func() { Context("question", func() { It("question should be succeed", func() { - ans, err := loFriday.Question(context.TODO(), "I am a question") + ans, err := loFriday.Question(context.TODO(), 0, "I am a question") Expect(err).Should(BeNil()) Expect(ans).Should(Equal("I am an answer")) }) It("searchDocs should be succeed", func() { - ans, err := loFriday.searchDocs(context.TODO(), "I am a question") + ans, err := loFriday.searchDocs(context.TODO(), 0, "I am a question") Expect(err).Should(BeNil()) Expect(ans).Should(Equal("There are logs of questions")) }) @@ -69,14 +69,14 @@ func (f FakeQuestionStore) Store(ctx context.Context, element *models.Element, e return nil } -func (f FakeQuestionStore) Search(ctx context.Context, vectors []float32, k int) ([]*models.Doc, error) { +func (f FakeQuestionStore) Search(ctx context.Context, parentId int64, vectors []float32, k int) ([]*models.Doc, error) { return []*models.Doc{{ Id: "abc", Content: "There are logs of questions", }}, nil } -func (f FakeQuestionStore) Get(ctx context.Context, name string, group int) (*models.Element, error) { +func (f FakeQuestionStore) Get(ctx context.Context, oid int64, name string, group int) (*models.Element, error) { return &models.Element{}, nil } diff --git a/pkg/vectorstore/interface.go b/pkg/vectorstore/interface.go index 03b35a4..e351c01 100644 --- a/pkg/vectorstore/interface.go +++ b/pkg/vectorstore/interface.go @@ -24,6 +24,6 @@ import ( type VectorStore interface { Store(ctx context.Context, element *models.Element, extra map[string]any) error - Search(ctx context.Context, vectors []float32, k int) ([]*models.Doc, error) - Get(ctx context.Context, name string, group int) (*models.Element, error) + Search(ctx context.Context, parentId int64, vectors []float32, k int) ([]*models.Doc, error) + Get(ctx context.Context, oid int64, name string, group int) (*models.Element, error) } diff --git a/pkg/vectorstore/pgvector/model.go b/pkg/vectorstore/pgvector/model.go index 25ea359..3150bc5 100644 --- a/pkg/vectorstore/pgvector/model.go +++ b/pkg/vectorstore/pgvector/model.go @@ -26,8 +26,8 @@ type Index struct { ID string `gorm:"column:id;primaryKey"` Name string `gorm:"column:name;type:varchar(256);index:index_name"` OID int64 `gorm:"column:oid;index:index_oid"` - Group int `gorm:"column:group;index:index_group"` - ParentID *int64 `gorm:"column:parent_entry_id;index:index_parent_id"` + Group int `gorm:"column:idx_group;index:index_group"` + ParentID int64 `gorm:"column:parent_entry_id;index:index_parent_id"` Content string `gorm:"column:content"` Vector string `gorm:"column:vector;type:vector(1536)"` Extra string `gorm:"column:metadata"` @@ -52,16 +52,13 @@ func (v *Index) Update(vector *Index) { } func (v *Index) From(element *models.Element) *Index { - parentId := element.ParentId i := &Index{ - ID: element.ID, - Name: element.Name, - OID: element.OID, - Group: element.Group, - Content: element.Content, - } - if parentId != 0 { - i.ParentID = &parentId + ID: element.ID, + Name: element.Name, + OID: element.OID, + Group: element.Group, + Content: element.Content, + ParentID: element.ParentId, } return i } @@ -71,7 +68,7 @@ func (v *Index) To() *models.Doc { OID: v.OID, Name: v.Name, Group: v.Group, - ParentId: *v.ParentID, + ParentId: v.ParentID, Content: v.Content, } } @@ -81,7 +78,7 @@ func (v *Index) ToElement() *models.Element { OID: v.OID, Name: v.Name, Group: v.Group, - ParentId: *v.ParentID, + ParentId: v.ParentID, Content: v.Content, } } diff --git a/pkg/vectorstore/pgvector/pgvector.go b/pkg/vectorstore/pgvector/pgvector.go index 6e6290d..11d8569 100644 --- a/pkg/vectorstore/pgvector/pgvector.go +++ b/pkg/vectorstore/pgvector/pgvector.go @@ -116,7 +116,7 @@ func (p *PgVectorClient) Store(ctx context.Context, element *models.Element, ext }) } -func (p *PgVectorClient) Search(ctx context.Context, vectors []float32, k int) ([]*models.Doc, error) { +func (p *PgVectorClient) Search(ctx context.Context, parentId int64, vectors []float32, k int) ([]*models.Doc, error) { var ( vectorModels = make([]Index, 0) result = make([]*models.Doc, 0) @@ -124,7 +124,12 @@ func (p *PgVectorClient) Search(ctx context.Context, vectors []float32, k int) ( if err := p.dEntity.WithContext(ctx).Transaction(func(tx *gorm.DB) error { query := p.dEntity.DB.WithContext(ctx) vectorJson, _ := json.Marshal(vectors) - res := query.Order(fmt.Sprintf("vector <-> '%s'", string(vectorJson))).Limit(k).Find(&vectorModels) + var res *gorm.DB + if parentId == 0 { + res = query.Order(fmt.Sprintf("vector <-> '%s'", string(vectorJson))).Limit(k).Find(&vectorModels) + } else { + res = query.Where("parent_entry_id = ?", parentId).Order(fmt.Sprintf("vector <-> '%s'", string(vectorJson))).Limit(k).Find(&vectorModels) + } if res.Error != nil { return res.Error } @@ -139,10 +144,15 @@ func (p *PgVectorClient) Search(ctx context.Context, vectors []float32, k int) ( return result, nil } -func (p *PgVectorClient) Get(ctx context.Context, name string, group int) (*models.Element, error) { +func (p *PgVectorClient) Get(ctx context.Context, oid int64, name string, group int) (*models.Element, error) { vModel := &Index{} err := p.dEntity.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - res := tx.Where("name = ? AND group = ?", name, group).First(vModel) + var res *gorm.DB + if oid == 0 { + res = tx.Where("name = ? AND group = ?", name, group).First(vModel) + } else { + res = tx.Where("name = ? AND oid = ? AND group = ?", name, oid, group).First(vModel) + } if res.Error != nil && res.Error != gorm.ErrRecordNotFound { return res.Error } diff --git a/pkg/vectorstore/postgres/model.go b/pkg/vectorstore/postgres/model.go index 76ce900..1d1bbaa 100644 --- a/pkg/vectorstore/postgres/model.go +++ b/pkg/vectorstore/postgres/model.go @@ -28,7 +28,7 @@ type Index struct { Name string `gorm:"column:name;index:index_name"` OID int64 `gorm:"column:oid;index:index_oid"` Group int `gorm:"column:idx_group;index:index_group"` - ParentID *int64 `gorm:"column:parent_entry_id;index:index_parent_id"` + ParentID int64 `gorm:"column:parent_entry_id;index:index_parent_id"` Content string `gorm:"column:content"` Vector string `gorm:"column:vector;type:json"` Extra string `gorm:"column:extra"` @@ -53,13 +53,13 @@ func (v *Index) Update(vector *Index) { } func (v *Index) From(element *models.Element) (*Index, error) { - parentId := element.ParentId i := &Index{ - ID: element.ID, - Name: element.Name, - OID: element.OID, - Group: element.Group, - Content: element.Content, + ID: element.ID, + Name: element.Name, + OID: element.OID, + Group: element.Group, + ParentID: element.ParentId, + Content: element.Content, } vector, err := json.Marshal(element.Vector) if err != nil { @@ -67,23 +67,17 @@ func (v *Index) From(element *models.Element) (*Index, error) { } i.Vector = string(vector) - if parentId != 0 { - i.ParentID = &parentId - } return i, nil } func (v *Index) To() (*models.Element, error) { - parentId := v.ParentID res := &models.Element{ - ID: v.ID, - Name: v.Name, - Group: v.Group, - OID: v.OID, - Content: v.Content, - } - if parentId != nil { - res.ParentId = *parentId + ID: v.ID, + Name: v.Name, + Group: v.Group, + OID: v.OID, + ParentId: v.ParentID, + Content: v.Content, } var vector []float32 err := json.Unmarshal([]byte(v.Vector), &vector) @@ -96,16 +90,13 @@ func (v *Index) To() (*models.Element, error) { } func (v *Index) ToDoc() *models.Doc { - parentId := v.ParentID res := &models.Doc{ - Id: v.ID, - OID: v.OID, - Name: v.Name, - Group: v.Group, - Content: v.Content, - } - if parentId != nil { - res.ParentId = *parentId + Id: v.ID, + OID: v.OID, + Name: v.Name, + Group: v.Group, + Content: v.Content, + ParentId: v.ParentID, } return res diff --git a/pkg/vectorstore/postgres/postgres.go b/pkg/vectorstore/postgres/postgres.go index 02f4d65..74b794f 100644 --- a/pkg/vectorstore/postgres/postgres.go +++ b/pkg/vectorstore/postgres/postgres.go @@ -117,14 +117,19 @@ func (p *PostgresClient) Store(ctx context.Context, element *models.Element, ext }) } -func (p *PostgresClient) Search(ctx context.Context, vectors []float32, k int) ([]*models.Doc, error) { +func (p *PostgresClient) Search(ctx context.Context, parentId int64, vectors []float32, k int) ([]*models.Doc, error) { vectors64 := make([]float64, 0) for _, v := range vectors { vectors64 = append(vectors64, float64(v)) } // query from db existIndexes := make([]Index, 0) - res := p.dEntity.WithContext(ctx).Find(&existIndexes) + var res *gorm.DB + if parentId == 0 { + res = p.dEntity.WithContext(ctx).Find(&existIndexes) + } else { + res = p.dEntity.WithContext(ctx).Where("parent_entry_id = ?", parentId).Find(&existIndexes) + } if res.Error != nil { return nil, res.Error } @@ -155,9 +160,14 @@ func (p *PostgresClient) Search(ctx context.Context, vectors []float32, k int) ( return results, nil } -func (p *PostgresClient) Get(ctx context.Context, name string, group int) (*models.Element, error) { +func (p *PostgresClient) Get(ctx context.Context, oid int64, name string, group int) (*models.Element, error) { vModel := &Index{} - res := p.dEntity.WithContext(ctx).Where("name = ? AND idx_group = ?", name, group).First(vModel) + var res *gorm.DB + if oid == 0 { + res = p.dEntity.WithContext(ctx).Where("name = ? AND idx_group = ?", name, group).First(vModel) + } else { + res = p.dEntity.WithContext(ctx).Where("name = ? AND oid = ? AND idx_group = ?", name, oid, group).First(vModel) + } if res.Error != nil { if res.Error == gorm.ErrRecordNotFound { return nil, nil diff --git a/pkg/vectorstore/redis/redis.go b/pkg/vectorstore/redis/redis.go index 45c4f55..e289994 100644 --- a/pkg/vectorstore/redis/redis.go +++ b/pkg/vectorstore/redis/redis.go @@ -116,7 +116,7 @@ func (r RedisClient) Store(ctx context.Context, element *models.Element, extra m FieldValue("vector", rueidis.VectorString32(element.Vector)).Build()).Error() } -func (r RedisClient) Get(ctx context.Context, name string, group int) (*models.Element, error) { +func (r RedisClient) Get(ctx context.Context, oid int64, name string, group int) (*models.Element, error) { resp, err := r.client.Do(ctx, r.client.B().Get().Key(fmt.Sprintf("%s:%s-%d", r.prefix, name, group)).Build()).ToMessage() if err != nil { return nil, err @@ -126,10 +126,13 @@ func (r RedisClient) Get(ctx context.Context, name string, group int) (*models.E return nil, err } - oid, err := files.StrToInt64(res["oid"]) - if err != nil { - return nil, err + if oid == 0 { + oid, err = files.StrToInt64(res["oid"]) + if err != nil { + return nil, err + } } + parentId, err := files.StrToInt64(res["parentid"]) if err != nil { return nil, err @@ -144,7 +147,7 @@ func (r RedisClient) Get(ctx context.Context, name string, group int) (*models.E }, nil } -func (r RedisClient) Search(ctx context.Context, vectors []float32, k int) ([]*models.Doc, error) { +func (r RedisClient) Search(ctx context.Context, parentId int64, vectors []float32, k int) ([]*models.Doc, error) { resp, err := r.client.Do(ctx, r.client.B().FtSearch().Index(r.index). Query("*=>[KNN 10 @vector $B AS vector_score]"). Return("4").Identifier("id").Identifier("content").