From 074c71ea084c0316ddb33aa973293f12919fb559 Mon Sep 17 00:00:00 2001 From: zwwhdls Date: Tue, 24 Oct 2023 20:31:27 +0800 Subject: [PATCH 1/3] init pgclient & knn Signed-off-by: zwwhdls --- go.mod | 1 + go.sum | 2 + pkg/build/common/init.go | 4 +- pkg/vectorstore/db/entity.go | 4 +- pkg/vectorstore/db/migrate.go | 44 ------ pkg/vectorstore/pgvector/migrate.go | 44 ++++++ pkg/vectorstore/{db => pgvector}/model.go | 32 ++-- pkg/vectorstore/pgvector/pgvector.go | 177 ++++++++++++++++++++++ pkg/vectorstore/postgres/migrate.go | 44 ++++++ pkg/vectorstore/postgres/model.go | 44 ++++++ pkg/vectorstore/postgres/postgres.go | 72 +++------ 11 files changed, 354 insertions(+), 114 deletions(-) delete mode 100644 pkg/vectorstore/db/migrate.go create mode 100644 pkg/vectorstore/pgvector/migrate.go rename pkg/vectorstore/{db => pgvector}/model.go (56%) create mode 100644 pkg/vectorstore/pgvector/pgvector.go create mode 100644 pkg/vectorstore/postgres/migrate.go create mode 100644 pkg/vectorstore/postgres/model.go diff --git a/go.mod b/go.mod index 7240c39..86e8e7c 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( ) require ( + github.com/cdipaolo/goml v0.0.0-20220715001353-00e0c845ae1c // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect diff --git a/go.sum b/go.sum index cb212bc..709e90d 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/basenana/go-flow v0.0.0-20230801131009-d05f1f41b706 h1:FxXoMwMZsufBjSZg8yWpqfV6FNs5F0JuLnO+iJxojnw= github.com/basenana/go-flow v0.0.0-20230801131009-d05f1f41b706/go.mod h1:Rs13PWsg/ITdXRiVJcI+yS0iqCfNHxCbIFEt5DCt/RQ= +github.com/cdipaolo/goml v0.0.0-20220715001353-00e0c845ae1c h1:uqJXOhayPfl/QruVBP6VF0KUWNDzO/F14X8CPEkkFD8= +github.com/cdipaolo/goml v0.0.0-20220715001353-00e0c845ae1c/go.mod h1:Ue8jgVLdBDCtsh1laikvraXqXzKCyKiruCcCcaeNDFE= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= diff --git a/pkg/build/common/init.go b/pkg/build/common/init.go index 513c549..690d494 100644 --- a/pkg/build/common/init.go +++ b/pkg/build/common/init.go @@ -21,7 +21,7 @@ import ( "github.com/basenana/friday/pkg/build/withvector" "github.com/basenana/friday/pkg/friday" "github.com/basenana/friday/pkg/vectorstore" - "github.com/basenana/friday/pkg/vectorstore/postgres" + "github.com/basenana/friday/pkg/vectorstore/pgvector" "github.com/basenana/friday/pkg/vectorstore/redis" ) @@ -41,7 +41,7 @@ func NewFriday(conf *config.Config) (f *friday.Friday, err error) { } } } else if conf.VectorStoreType == config.VectorStorePostgres { - vectorStore, err = postgres.NewPostgresClient(conf.VectorUrl) + vectorStore, err = pgvector.NewPostgresClient(conf.VectorUrl) if err != nil { return nil, err } diff --git a/pkg/vectorstore/db/entity.go b/pkg/vectorstore/db/entity.go index c503472..cecddf5 100644 --- a/pkg/vectorstore/db/entity.go +++ b/pkg/vectorstore/db/entity.go @@ -24,9 +24,9 @@ type Entity struct { *gorm.DB } -func NewDbEntity(db *gorm.DB) (*Entity, error) { +func NewDbEntity(db *gorm.DB, migrate func(db *gorm.DB) error) (*Entity, error) { ent := &Entity{DB: db} - if err := Migrate(db); err != nil { + if err := migrate(db); err != nil { return nil, err } return ent, nil diff --git a/pkg/vectorstore/db/migrate.go b/pkg/vectorstore/db/migrate.go deleted file mode 100644 index db342e5..0000000 --- a/pkg/vectorstore/db/migrate.go +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright 2023 Friday Author. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package db - -import ( - "github.com/go-gormigrate/gormigrate/v2" - "gorm.io/gorm" -) - -func buildMigrations() []*gormigrate.Migration { - return []*gormigrate.Migration{ - { - ID: "2023100700", - Migrate: func(db *gorm.DB) error { - return db.AutoMigrate( - &Index{}, - ) - }, - Rollback: func(db *gorm.DB) error { - return nil - }, - }, - } -} - -func Migrate(db *gorm.DB) error { - m := gormigrate.New(db, gormigrate.DefaultOptions, buildMigrations()) - err := m.Migrate() - return err -} diff --git a/pkg/vectorstore/pgvector/migrate.go b/pkg/vectorstore/pgvector/migrate.go new file mode 100644 index 0000000..43a34ae --- /dev/null +++ b/pkg/vectorstore/pgvector/migrate.go @@ -0,0 +1,44 @@ +/* + Copyright 2023 Friday Author. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package pgvector + +import ( + "github.com/go-gormigrate/gormigrate/v2" + "gorm.io/gorm" +) + +func buildMigrations() []*gormigrate.Migration { + return []*gormigrate.Migration{ + { + ID: "2023100700", + Migrate: func(db *gorm.DB) error { + return db.AutoMigrate( + &Index{}, + ) + }, + Rollback: func(db *gorm.DB) error { + return nil + }, + }, + } +} + +func Migrate(db *gorm.DB) error { + m := gormigrate.New(db, gormigrate.DefaultOptions, buildMigrations()) + err := m.Migrate() + return err +} diff --git a/pkg/vectorstore/db/model.go b/pkg/vectorstore/pgvector/model.go similarity index 56% rename from pkg/vectorstore/db/model.go rename to pkg/vectorstore/pgvector/model.go index d073005..5b0b41b 100644 --- a/pkg/vectorstore/db/model.go +++ b/pkg/vectorstore/pgvector/model.go @@ -1,20 +1,20 @@ /* - * Copyright 2023 Friday Author. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package db + Copyright 2023 Friday Author. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package pgvector import "time" diff --git a/pkg/vectorstore/pgvector/pgvector.go b/pkg/vectorstore/pgvector/pgvector.go new file mode 100644 index 0000000..539e899 --- /dev/null +++ b/pkg/vectorstore/pgvector/pgvector.go @@ -0,0 +1,177 @@ +/* + Copyright 2023 Friday Author. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package pgvector + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "gorm.io/driver/postgres" + "gorm.io/gorm" + + "github.com/basenana/friday/pkg/models" + "github.com/basenana/friday/pkg/utils/logger" + "github.com/basenana/friday/pkg/vectorstore" + "github.com/basenana/friday/pkg/vectorstore/db" +) + +type PgVectorClient struct { + log logger.Logger + dEntity *db.Entity +} + +var _ vectorstore.VectorStore = &PgVectorClient{} + +func NewPgVectorClient(postgresUrl string) (*PgVectorClient, error) { + dbObj, err := gorm.Open(postgres.Open(postgresUrl), &gorm.Config{Logger: logger.NewDbLogger()}) + if err != nil { + panic(err) + } + + dbConn, err := dbObj.DB() + if err != nil { + return nil, err + } + + dbConn.SetMaxIdleConns(5) + dbConn.SetMaxOpenConns(50) + dbConn.SetConnMaxLifetime(time.Hour) + + if err = dbConn.Ping(); err != nil { + return nil, err + } + + dbEnt, err := db.NewDbEntity(dbObj, Migrate) + if err != nil { + return nil, err + } + + return &PgVectorClient{ + log: logger.NewLogger("postgres"), + dEntity: dbEnt, + }, nil +} + +func (p *PgVectorClient) Store(id, content string, metadata models.Metadata, extra map[string]interface{}, vectors []float32) error { + ctx := context.Background() + + if extra == nil { + extra = make(map[string]interface{}) + } + extra["category"] = metadata.Category + extra["group"] = metadata.Group + + var m string + b, err := json.Marshal(metadata) + if err != nil { + return err + } + m = string(b) + + vectorJson, _ := json.Marshal(vectors) + v := &Index{ + ID: id, + Name: metadata.Source, + ParentDir: metadata.ParentDir, + Context: content, + Metadata: m, + Vector: string(vectorJson), + CreatedAt: time.Now().UnixNano(), + ChangedAt: time.Now().UnixNano(), + } + return p.dEntity.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + vModel := Index{ID: id} + res := tx.First(vModel) + if res.Error != nil && res.Error != gorm.ErrRecordNotFound { + return res.Error + } + + if res.Error == gorm.ErrRecordNotFound { + res = tx.Create(v) + if res.Error != nil { + return res.Error + } + return nil + } + + vModel.Update(v) + res = tx.Where("id = ?", id).Updates(vModel) + if res.Error != nil || res.RowsAffected == 0 { + if res.RowsAffected == 0 { + return errors.New("operation conflict") + } + return res.Error + } + return nil + }) +} + +func (p *PgVectorClient) Search(vectors []float32, k int) ([]models.Doc, error) { + ctx := context.Background() + var ( + vectorModels = make([]Index, 0) + result = make([]models.Doc, 0) + ) + if err := p.dEntity.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + query := p.dEntity.DB.WithContext(ctx) + vectorJson, _ := json.Marshal(vectors) + res := query.Order(fmt.Sprintf("vector <-> '%s'", string(vectorJson))).Limit(k).Find(&vectorModels) + if res.Error != nil { + return res.Error + } + return nil + }); err != nil { + return nil, err + } + + for _, v := range vectorModels { + metadata := make(map[string]interface{}) + if err := json.Unmarshal([]byte(v.Metadata), &metadata); err != nil { + return nil, err + } + result = append(result, models.Doc{ + Id: v.ID, + Metadata: metadata, + Content: v.Context, + }) + } + return result, nil +} + +func (p *PgVectorClient) Exist(id string) (bool, error) { + ctx := context.Background() + var exist = false + err := p.dEntity.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + vModel := Index{ID: id} + res := tx.First(vModel) + if res.Error != nil && res.Error != gorm.ErrRecordNotFound { + return res.Error + } + + if res.Error == gorm.ErrRecordNotFound { + exist = false + return nil + } + exist = true + return nil + }) + + return exist, err +} diff --git a/pkg/vectorstore/postgres/migrate.go b/pkg/vectorstore/postgres/migrate.go new file mode 100644 index 0000000..8cdcc19 --- /dev/null +++ b/pkg/vectorstore/postgres/migrate.go @@ -0,0 +1,44 @@ +/* + Copyright 2023 Friday Author. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package postgres + +import ( + "github.com/go-gormigrate/gormigrate/v2" + "gorm.io/gorm" +) + +func buildMigrations() []*gormigrate.Migration { + return []*gormigrate.Migration{ + { + ID: "2023101100", + Migrate: func(db *gorm.DB) error { + return db.AutoMigrate( + &Index{}, + ) + }, + Rollback: func(db *gorm.DB) error { + return nil + }, + }, + } +} + +func Migrate(db *gorm.DB) error { + m := gormigrate.New(db, gormigrate.DefaultOptions, buildMigrations()) + err := m.Migrate() + return err +} diff --git a/pkg/vectorstore/postgres/model.go b/pkg/vectorstore/postgres/model.go new file mode 100644 index 0000000..575acd3 --- /dev/null +++ b/pkg/vectorstore/postgres/model.go @@ -0,0 +1,44 @@ +/* + Copyright 2023 Friday Author. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package postgres + +import "time" + +type Index struct { + ID string `gorm:"column:id;type:varchar(256);primaryKey"` + Name string `gorm:"column:name;type:varchar(256);index:source"` + ParentDir string `gorm:"column:parent_dir;type:varchar(256);index:parent_dir"` + Context string `gorm:"column:context"` + Metadata string `gorm:"column:metadata"` + Vector string `gorm:"column:vector;type:json"` + CreatedAt int64 `gorm:"column:created_at"` + ChangedAt int64 `gorm:"column:changed_at"` +} + +func (v *Index) TableName() string { + return "friday_idx" +} + +func (v *Index) Update(vector *Index) { + v.ID = vector.ID + v.Name = vector.Name + v.ParentDir = vector.ParentDir + v.Context = vector.Context + v.Metadata = vector.Metadata + v.Vector = vector.Vector + v.ChangedAt = time.Now().UnixNano() +} diff --git a/pkg/vectorstore/postgres/postgres.go b/pkg/vectorstore/postgres/postgres.go index 4bef6fa..3ba5f4a 100644 --- a/pkg/vectorstore/postgres/postgres.go +++ b/pkg/vectorstore/postgres/postgres.go @@ -20,9 +20,10 @@ import ( "context" "encoding/json" "errors" - "fmt" "time" + "github.com/cdipaolo/goml/base" + "github.com/cdipaolo/goml/cluster" "gorm.io/driver/postgres" "gorm.io/gorm" @@ -37,8 +38,6 @@ type PostgresClient struct { dEntity *db.Entity } -var _ vectorstore.VectorStore = &PostgresClient{} - func NewPostgresClient(postgresUrl string) (*PostgresClient, error) { dbObj, err := gorm.Open(postgres.Open(postgresUrl), &gorm.Config{Logger: logger.NewDbLogger()}) if err != nil { @@ -58,7 +57,7 @@ func NewPostgresClient(postgresUrl string) (*PostgresClient, error) { return nil, err } - dbEnt, err := db.NewDbEntity(dbObj) + dbEnt, err := db.NewDbEntity(dbObj, Migrate) if err != nil { return nil, err } @@ -86,7 +85,7 @@ func (p *PostgresClient) Store(id, content string, metadata models.Metadata, ext m = string(b) vectorJson, _ := json.Marshal(vectors) - v := &db.Index{ + v := &Index{ ID: id, Name: metadata.Source, ParentDir: metadata.ParentDir, @@ -97,7 +96,7 @@ func (p *PostgresClient) Store(id, content string, metadata models.Metadata, ext ChangedAt: time.Now().UnixNano(), } return p.dEntity.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - vModel := db.Index{ID: id} + vModel := Index{ID: id} res := tx.First(vModel) if res.Error != nil && res.Error != gorm.ErrRecordNotFound { return res.Error @@ -124,54 +123,27 @@ func (p *PostgresClient) Store(id, content string, metadata models.Metadata, ext } func (p *PostgresClient) Search(vectors []float32, k int) ([]models.Doc, error) { - ctx := context.Background() - var ( - vectorModels = make([]db.Index, 0) - result = make([]models.Doc, 0) - ) - if err := p.dEntity.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - query := p.dEntity.DB.WithContext(ctx) - vectorJson, _ := json.Marshal(vectors) - res := query.Order(fmt.Sprintf("vector <-> '%s'", string(vectorJson))).Limit(k).Find(&vectorModels) - if res.Error != nil { - return res.Error - } - return nil - }); err != nil { - return nil, err + vectors64 := make([]float64, 0) + for _, v := range vectors { + vectors64 = append(vectors64, float64(v)) } + // query from db + existVectors := [][]float64{} - for _, v := range vectorModels { - metadata := make(map[string]interface{}) - if err := json.Unmarshal([]byte(v.Metadata), &metadata); err != nil { - return nil, err - } - result = append(result, models.Doc{ - Id: v.ID, - Metadata: metadata, - Content: v.Context, - }) + model := cluster.NewKNN(k, existVectors, vectors64, base.EuclideanDistance) + + // make predictions like usual + _, err := model.Predict([]float64{-10, 1}) + if err != nil { + return nil, err } - return result, nil + // todo + return nil, nil } func (p *PostgresClient) Exist(id string) (bool, error) { - ctx := context.Background() - var exist = false - err := p.dEntity.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - vModel := db.Index{ID: id} - res := tx.First(vModel) - if res.Error != nil && res.Error != gorm.ErrRecordNotFound { - return res.Error - } - - if res.Error == gorm.ErrRecordNotFound { - exist = false - return nil - } - exist = true - return nil - }) - - return exist, err + //TODO implement me + panic("implement me") } + +var _ vectorstore.VectorStore = &PostgresClient{} From b42e80803a3917be077f5216b28645ec8be41752 Mon Sep 17 00:00:00 2001 From: zwwhdls Date: Thu, 30 Nov 2023 21:47:02 +0800 Subject: [PATCH 2/3] add knn Signed-off-by: zwwhdls --- flow/operator/ingest.go | 2 +- pkg/build/common/init.go | 2 +- pkg/friday/ingest.go | 55 ++++++------ pkg/friday/ingest_test.go | 17 ++-- pkg/friday/question.go | 6 +- pkg/friday/question_test.go | 15 ++-- pkg/friday/summary.go | 8 +- pkg/friday/summary_test.go | 32 +++---- pkg/friday/wechat.go | 7 +- pkg/models/doc.go | 5 +- pkg/models/element.go | 24 +++--- pkg/spliter/text.go | 10 +-- pkg/spliter/text_test.go | 66 +++++--------- pkg/utils/files/doc.go | 10 +++ pkg/vectorstore/interface.go | 12 ++- pkg/vectorstore/pgvector/migrate.go | 4 +- pkg/vectorstore/pgvector/model.go | 61 +++++++++++-- pkg/vectorstore/pgvector/pgvector.go | 84 +++++++----------- pkg/vectorstore/postgres/migrate.go | 4 +- pkg/vectorstore/postgres/model.go | 86 +++++++++++++++++-- pkg/vectorstore/postgres/postgres.go | 124 ++++++++++++++++++--------- pkg/vectorstore/redis/redis.go | 88 +++++++++++++------ scripts/element.py | 15 ++-- 23 files changed, 434 insertions(+), 303 deletions(-) diff --git a/flow/operator/ingest.go b/flow/operator/ingest.go index 3792fc0..bb90a35 100644 --- a/flow/operator/ingest.go +++ b/flow/operator/ingest.go @@ -38,7 +38,7 @@ func (i *ingestOperator) Do(ctx context.Context, param *flow.Parameter) error { source := i.spec.Parameters["source"] knowledge := i.spec.Parameters["knowledge"] doc := models.File{ - Source: source, + Name: source, Content: knowledge, } return friday.Fri.IngestFromFile(context.TODO(), doc) diff --git a/pkg/build/common/init.go b/pkg/build/common/init.go index 690d494..9d0e045 100644 --- a/pkg/build/common/init.go +++ b/pkg/build/common/init.go @@ -41,7 +41,7 @@ func NewFriday(conf *config.Config) (f *friday.Friday, err error) { } } } else if conf.VectorStoreType == config.VectorStorePostgres { - vectorStore, err = pgvector.NewPostgresClient(conf.VectorUrl) + vectorStore, err = pgvector.NewPgVectorClient(conf.VectorUrl) if err != nil { return nil, err } diff --git a/pkg/friday/ingest.go b/pkg/friday/ingest.go index 80f135f..b1afe39 100644 --- a/pkg/friday/ingest.go +++ b/pkg/friday/ingest.go @@ -18,14 +18,10 @@ package friday import ( "context" - "crypto/sha256" - "encoding/hex" "encoding/json" - "fmt" "os" - "path/filepath" - "strconv" - "strings" + + "github.com/google/uuid" "github.com/basenana/friday/pkg/models" "github.com/basenana/friday/pkg/utils/files" @@ -34,17 +30,15 @@ import ( // IngestFromFile ingest a whole file providing models.File func (f *Friday) IngestFromFile(ctx context.Context, file models.File) error { elements := []models.Element{} - parentDir := filepath.Dir(file.Source) // split doc subDocs := f.Spliter.Split(file.Content) for i, subDoc := range subDocs { e := models.Element{ - Content: subDoc, - Metadata: models.Metadata{ - Source: file.Source, - Group: strconv.Itoa(i), - ParentDir: parentDir, - }, + Name: file.Name, + Group: i, + OID: file.OID, + ParentId: file.ParentId, + Content: subDoc, } elements = append(elements, e) } @@ -56,15 +50,12 @@ func (f *Friday) IngestFromFile(ctx context.Context, file models.File) error { func (f *Friday) Ingest(ctx context.Context, elements []models.Element) error { f.Log.Debugf("Ingesting %d ...", len(elements)) for i, element := range elements { - // id: sha256(source)-group - h := sha256.New() - h.Write([]byte(element.Metadata.Source)) - val := hex.EncodeToString(h.Sum(nil))[:64] - id := fmt.Sprintf("%s-%s", val, element.Metadata.Group) - if exist, err := f.Vector.Exist(id); err != nil { + exist, err := f.Vector.Get(ctx, element.Name, element.Group) + if err != nil { return err - } else if exist { - f.Log.Debugf("vector %d(th) id(%s) source(%s) exist, skip ...", i, id, element.Metadata.Source) + } + if exist != nil && exist.Content == element.Content { + f.Log.Debugf("vector %d(th) name(%s) group(%d) exist, skip ...", i, element.Name, element.Group) continue } @@ -73,10 +64,17 @@ func (f *Friday) Ingest(ctx context.Context, elements []models.Element) error { return err } - t := strings.TrimSpace(element.Content) + if exist != nil { + element.ID = exist.ID + element.OID = exist.OID + element.ParentId = exist.ParentId + } else { + element.ID = uuid.New().String() + } + element.Vector = vectors - f.Log.Debugf("store %d(th) vector id (%s) source(%s) ...", i, id, element.Metadata.Source) - if err := f.Vector.Store(id, t, element.Metadata, m, vectors); err != nil { + f.Log.Debugf("store %d(th) vector name(%s) group(%d) ...", i, element.Name, element.Group) + if err := f.Vector.Store(ctx, &element, m); err != nil { return err } } @@ -90,6 +88,7 @@ func (f *Friday) IngestFromElementFile(ctx context.Context, ps string) error { return err } elements := []models.Element{} + if err := json.Unmarshal(doc, &elements); err != nil { return err } @@ -106,16 +105,12 @@ func (f *Friday) IngestFromOriginFile(ctx context.Context, ps string) error { elements := []models.Element{} for n, file := range fs { - parentDir := filepath.Dir(n) subDocs := f.Spliter.Split(file) for i, subDoc := range subDocs { e := models.Element{ Content: subDoc, - Metadata: models.Metadata{ - Source: n, - Group: strconv.Itoa(i), - ParentDir: parentDir, - }, + Name: n, + Group: i, } elements = append(elements, e) } diff --git a/pkg/friday/ingest_test.go b/pkg/friday/ingest_test.go index 491e591..52b7bb7 100644 --- a/pkg/friday/ingest_test.go +++ b/pkg/friday/ingest_test.go @@ -46,11 +46,8 @@ var _ = Describe("TestIngest", func() { elements := []models.Element{ { Content: "test-content", - Metadata: models.Metadata{ - Source: "test-source", - Title: "test-title", - ParentDir: "/", - }, + Name: "test-title", + Group: 0, }, } err := loFriday.Ingest(context.TODO(), elements) @@ -63,16 +60,16 @@ type FakeStore struct{} var _ vectorstore.VectorStore = &FakeStore{} -func (f FakeStore) Store(id, content string, metadata models.Metadata, extra map[string]interface{}, vectors []float32) error { +func (f FakeStore) Store(ctx context.Context, element *models.Element, extra map[string]any) error { return nil } -func (f FakeStore) Search(vectors []float32, k int) ([]models.Doc, error) { - return []models.Doc{}, nil +func (f FakeStore) Search(ctx context.Context, vectors []float32, k int) ([]*models.Doc, error) { + return []*models.Doc{}, nil } -func (f FakeStore) Exist(id string) (bool, error) { - return false, nil +func (f FakeStore) Get(ctx context.Context, name string, group int) (*models.Element, error) { + return &models.Element{}, nil } type FakeEmbedding struct{} diff --git a/pkg/friday/question.go b/pkg/friday/question.go index c822cdc..4d2cac4 100644 --- a/pkg/friday/question.go +++ b/pkg/friday/question.go @@ -50,14 +50,14 @@ func (f *Friday) searchDocs(ctx context.Context, q string) (string, error) { if err != nil { return "", fmt.Errorf("vector embedding error: %w", err) } - contexts, err := f.Vector.Search(qv, defaultTopK) + docs, err := f.Vector.Search(ctx, qv, defaultTopK) if err != nil { return "", fmt.Errorf("vector search error: %w", err) } cs := []string{} - for _, c := range contexts { - f.Log.Debugf("searched from [%s] for %s", c.Metadata["source"], c.Content) + for _, c := range docs { + f.Log.Debugf("searched from [%s] for %s", c.Name, c.Content) cs = append(cs, c.Content) } return strings.Join(cs, "\n"), nil diff --git a/pkg/friday/question_test.go b/pkg/friday/question_test.go index 17f7eca..16e7073 100644 --- a/pkg/friday/question_test.go +++ b/pkg/friday/question_test.go @@ -63,20 +63,19 @@ type FakeQuestionStore struct{} var _ vectorstore.VectorStore = &FakeQuestionStore{} -func (f FakeQuestionStore) Store(id, content string, metadata models.Metadata, extra map[string]interface{}, vectors []float32) error { +func (f FakeQuestionStore) Store(ctx context.Context, element *models.Element, extra map[string]any) error { return nil } -func (f FakeQuestionStore) Search(vectors []float32, k int) ([]models.Doc, error) { - return []models.Doc{{ - Id: "abc", - Metadata: map[string]interface{}{}, - Content: "There are logs of questions", +func (f FakeQuestionStore) Search(ctx context.Context, vectors []float32, k int) ([]*models.Doc, error) { + return []*models.Doc{{ + Id: "abc", + Content: "There are logs of questions", }}, nil } -func (f FakeQuestionStore) Exist(id string) (bool, error) { - return false, nil +func (f FakeQuestionStore) Get(ctx context.Context, name string, group int) (*models.Element, error) { + return &models.Element{}, nil } type FakeQuestionEmbedding struct{} diff --git a/pkg/friday/summary.go b/pkg/friday/summary.go index a50f5a9..98c0f83 100644 --- a/pkg/friday/summary.go +++ b/pkg/friday/summary.go @@ -30,10 +30,10 @@ func (f *Friday) Summary(ctx context.Context, elements []models.Element, summary docs := make(map[string][]string) for _, element := range elements { - if _, ok := docs[element.Metadata.Source]; !ok { - docs[element.Metadata.Source] = []string{element.Content} + if _, ok := docs[element.Name]; !ok { + docs[element.Name] = []string{element.Content} } else { - docs[element.Metadata.Source] = append(docs[element.Metadata.Source], element.Content) + docs[element.Name] = append(docs[element.Name], element.Content) } } for source, doc := range docs { @@ -57,7 +57,7 @@ func (f *Friday) SummaryFromFile(ctx context.Context, file models.File, summaryT return nil, err } return map[string]string{ - file.Source: summaryOfFile, + file.Name: summaryOfFile, }, err } diff --git a/pkg/friday/summary_test.go b/pkg/friday/summary_test.go index 9a979b5..5cdc86f 100644 --- a/pkg/friday/summary_test.go +++ b/pkg/friday/summary_test.go @@ -44,15 +44,11 @@ var _ = Describe("TestStuffSummary", func() { loFriday.Spliter = spliter.NewTextSpliter(spliter.DefaultChunkSize, spliter.DefaultChunkOverlap, "\n") elements = []models.Element{{ Content: "test-content", - Metadata: models.Metadata{ - Source: "test-source", - Title: "test-title", - ParentDir: "/", - }, + Name: "test-title", + Group: 0, }} file = models.File{ Name: "test-file", - Source: "test-file-source", Content: "test-file-content", } }) @@ -62,14 +58,14 @@ var _ = Describe("TestStuffSummary", func() { summary, err := loFriday.Summary(context.TODO(), elements, summaryType) Expect(err).Should(BeNil()) Expect(summary).Should(Equal(map[string]string{ - "test-source": "a b c", + "test-title": "a b c", })) }) It("SummaryFromFile should be succeed", func() { summary, err := loFriday.SummaryFromFile(context.TODO(), file, summaryType) Expect(err).Should(BeNil()) Expect(summary).Should(Equal(map[string]string{ - "test-file-source": "a b c", + "test-file": "a b c", })) }) }) @@ -90,15 +86,11 @@ var _ = Describe("TestMapReduceSummary", func() { loFriday.Spliter = spliter.NewTextSpliter(8, 2, "\n") elements = []models.Element{{ Content: "test-content", - Metadata: models.Metadata{ - Source: "test-source", - Title: "test-title", - ParentDir: "/", - }, + Name: "test-title", + Group: 0, }} file = models.File{ Name: "test-file", - Source: "test-file-source", Content: "test file content", } }) @@ -108,14 +100,14 @@ var _ = Describe("TestMapReduceSummary", func() { summary, err := loFriday.Summary(context.TODO(), elements, summaryType) Expect(err).Should(BeNil()) Expect(summary).Should(Equal(map[string]string{ - "test-source": "a b c", + "test-title": "a b c", })) }) It("SummaryFromFile should be succeed", func() { summary, err := loFriday.SummaryFromFile(context.TODO(), file, summaryType) Expect(err).Should(BeNil()) Expect(summary).Should(Equal(map[string]string{ - "test-file-source": "a b c", + "test-file": "a b c", })) }) }) @@ -135,15 +127,11 @@ var _ = Describe("TestRefineSummary", func() { loFriday.Spliter = spliter.NewTextSpliter(spliter.DefaultChunkSize, spliter.DefaultChunkOverlap, "\n") elements = []models.Element{{ Content: "test-content", - Metadata: models.Metadata{ - Source: "test-source", - Title: "test-title", - ParentDir: "/", - }, + Name: "test-title", + Group: 0, }} file = models.File{ Name: "test-file", - Source: "test-file-source", Content: "test-file-content", } }) diff --git a/pkg/friday/wechat.go b/pkg/friday/wechat.go index 7961c8f..d8b11a2 100644 --- a/pkg/friday/wechat.go +++ b/pkg/friday/wechat.go @@ -21,7 +21,6 @@ import ( "encoding/json" "fmt" "os" - "strconv" "strings" "github.com/basenana/friday/pkg/llm/prompts" @@ -76,10 +75,8 @@ func (f *Friday) ChatConclusionFromFile(ctx context.Context, chatFile string) (s for i, subDoc := range subDocs { e := models.Element{ Content: subDoc, - Metadata: models.Metadata{ - Source: n, - Group: strconv.Itoa(i), - }, + Name: n, + Group: i, } elements = append(elements, e) } diff --git a/pkg/models/doc.go b/pkg/models/doc.go index 1ca0071..13d5877 100644 --- a/pkg/models/doc.go +++ b/pkg/models/doc.go @@ -30,6 +30,9 @@ type Document struct { type Doc struct { Id string - Metadata map[string]interface{} + OID int64 + Name string + Group int + ParentId int64 Content string } diff --git a/pkg/models/element.go b/pkg/models/element.go index d0a7e05..82c9cfb 100644 --- a/pkg/models/element.go +++ b/pkg/models/element.go @@ -17,20 +17,18 @@ package models type File struct { - Name string `json:"name"` - Source string `json:"source"` - Content string `json:"content"` + Name string `json:"name"` + OID int64 `json:"oid"` + ParentId int64 `json:"parent_id"` + Content string `json:"content"` } type Element struct { - Content string `json:"content"` - Metadata Metadata `json:"metadata"` -} - -type Metadata struct { - Source string `json:"source"` - Title string `json:"title"` - ParentDir string `json:"parent_dir"` - Group string `json:"group"` - Category string `json:"category"` + ID string `json:"id"` + Name string `json:"name"` + Group int `json:"group"` + OID int64 `json:"oid"` + ParentId int64 `json:"parent_id"` + Content string `json:"content"` + Vector []float32 `json:"vector"` } diff --git a/pkg/spliter/text.go b/pkg/spliter/text.go index 16c9314..e04e63a 100644 --- a/pkg/spliter/text.go +++ b/pkg/spliter/text.go @@ -17,7 +17,6 @@ package spliter import ( - "strconv" "strings" "github.com/basenana/friday/pkg/models" @@ -64,7 +63,7 @@ func (t *TextSpliter) Split(text string) []string { func (t *TextSpliter) Merge(elements []models.Element) []models.Element { elementGroups := map[string][]models.Element{} for _, element := range elements { - source := element.Metadata.Source + source := element.Name if _, ok := elementGroups[source]; !ok { elementGroups[source] = []models.Element{element} continue @@ -81,12 +80,9 @@ func (t *TextSpliter) Merge(elements []models.Element) []models.Element { merged := t.merge(splits) for i, content := range merged { mergedElements = append(mergedElements, models.Element{ + Name: source, + Group: i, Content: content, - Metadata: models.Metadata{ - Source: source, - Title: subElements[0].Metadata.Title, - Group: strconv.Itoa(i), - }, }) } } diff --git a/pkg/spliter/text_test.go b/pkg/spliter/text_test.go index 6589001..106bef7 100644 --- a/pkg/spliter/text_test.go +++ b/pkg/spliter/text_test.go @@ -49,33 +49,25 @@ func TestTextSpliter_Merge(t1 *testing.T) { args: args{ elements: []models.Element{ { - Content: "this is a test", - Metadata: models.Metadata{ - Source: "test", - Title: "test", - Group: "0", - Category: "context", - }, + ID: "123", + Name: "test", + Group: 0, + OID: 0, + ParentId: 0, + Content: "this is a test", }, { Content: "hello world", - Metadata: models.Metadata{ - Source: "test", - Title: "test", - Group: "1", - Category: "context", - }, + Name: "test", + Group: 1, }, }, }, want: []models.Element{ { Content: "this is a test\nhello world", - Metadata: models.Metadata{ - Source: "test", - Title: "test", - Group: "0", - }, + Name: "test", + Group: 0, }, }, }, @@ -90,49 +82,31 @@ func TestTextSpliter_Merge(t1 *testing.T) { elements: []models.Element{ { Content: "this is a test", - Metadata: models.Metadata{ - Source: "test", - Title: "test", - Group: "0", - Category: "context", - }, + Name: "test", + Group: 0, }, { Content: "hello world", - Metadata: models.Metadata{ - Source: "test", - Title: "test", - Group: "1", - Category: "context", - }, + Name: "test", + Group: 1, }, { Content: "你好", - Metadata: models.Metadata{ - Source: "hello", - Title: "hello", - Group: "0", - Category: "context", - }, + Name: "hello", + Group: 0, }, }, }, want: []models.Element{ { Content: "this is a test\nhello world", - Metadata: models.Metadata{ - Source: "test", - Title: "test", - Group: "0", - }, + Name: "test", + Group: 0, }, { Content: "你好", - Metadata: models.Metadata{ - Source: "hello", - Title: "hello", - Group: "0", - }, + Name: "hello", + Group: 0, }, }, }, diff --git a/pkg/utils/files/doc.go b/pkg/utils/files/doc.go index ef45fc6..27a6c8e 100644 --- a/pkg/utils/files/doc.go +++ b/pkg/utils/files/doc.go @@ -17,7 +17,9 @@ package files import ( + "fmt" "regexp" + "strconv" ) func Length(doc string) int { @@ -38,3 +40,11 @@ func Length(doc string) int { return wordCount + punctuationCount } + +func Int64ToStr(s int64) string { + return fmt.Sprintf("doc_%d", s) +} + +func StrToInt64(s string) (int64, error) { + return strconv.ParseInt(s[4:], 10, 64) +} diff --git a/pkg/vectorstore/interface.go b/pkg/vectorstore/interface.go index f47d37c..03b35a4 100644 --- a/pkg/vectorstore/interface.go +++ b/pkg/vectorstore/interface.go @@ -16,10 +16,14 @@ package vectorstore -import "github.com/basenana/friday/pkg/models" +import ( + "context" + + "github.com/basenana/friday/pkg/models" +) type VectorStore interface { - Store(id, content string, metadata models.Metadata, extra map[string]interface{}, vectors []float32) error - Search(vectors []float32, k int) ([]models.Doc, error) - Exist(id string) (bool, error) + Store(ctx context.Context, element *models.Element, extra map[string]any) error + Search(ctx context.Context, vectors []float32, k int) ([]*models.Doc, error) + Get(ctx context.Context, name string, group int) (*models.Element, error) } diff --git a/pkg/vectorstore/pgvector/migrate.go b/pkg/vectorstore/pgvector/migrate.go index 43a34ae..e1ad131 100644 --- a/pkg/vectorstore/pgvector/migrate.go +++ b/pkg/vectorstore/pgvector/migrate.go @@ -38,7 +38,9 @@ func buildMigrations() []*gormigrate.Migration { } func Migrate(db *gorm.DB) error { - m := gormigrate.New(db, gormigrate.DefaultOptions, buildMigrations()) + options := gormigrate.DefaultOptions + options.TableName = "friday_migrations" + m := gormigrate.New(db, options, buildMigrations()) err := m.Migrate() return err } diff --git a/pkg/vectorstore/pgvector/model.go b/pkg/vectorstore/pgvector/model.go index 5b0b41b..25ea359 100644 --- a/pkg/vectorstore/pgvector/model.go +++ b/pkg/vectorstore/pgvector/model.go @@ -16,15 +16,21 @@ package pgvector -import "time" +import ( + "time" + + "github.com/basenana/friday/pkg/models" +) type Index struct { - ID string `gorm:"column:id;type:varchar(256);primaryKey"` - Name string `gorm:"column:name;type:varchar(256);index:source"` - ParentDir string `gorm:"column:parent_dir;type:varchar(256);index:parent_dir"` - Context string `gorm:"column:context"` - Metadata string `gorm:"column:metadata"` + ID string `gorm:"column:id;primaryKey"` + Name string `gorm:"column:name;type:varchar(256);index:index_name"` + OID int64 `gorm:"column:oid;index:index_oid"` + Group int `gorm:"column:group;index:index_group"` + ParentID *int64 `gorm:"column:parent_entry_id;index:index_parent_id"` + Content string `gorm:"column:content"` Vector string `gorm:"column:vector;type:vector(1536)"` + Extra string `gorm:"column:metadata"` CreatedAt int64 `gorm:"column:created_at"` ChangedAt int64 `gorm:"column:changed_at"` } @@ -36,9 +42,46 @@ func (v *Index) TableName() string { func (v *Index) Update(vector *Index) { v.ID = vector.ID v.Name = vector.Name - v.ParentDir = vector.ParentDir - v.Context = vector.Context - v.Metadata = vector.Metadata + v.OID = vector.OID + v.Group = vector.Group + v.ParentID = vector.ParentID + v.Content = vector.Content + v.Extra = vector.Extra v.Vector = vector.Vector v.ChangedAt = time.Now().UnixNano() } + +func (v *Index) From(element *models.Element) *Index { + parentId := element.ParentId + i := &Index{ + ID: element.ID, + Name: element.Name, + OID: element.OID, + Group: element.Group, + Content: element.Content, + } + if parentId != 0 { + i.ParentID = &parentId + } + return i +} +func (v *Index) To() *models.Doc { + return &models.Doc{ + Id: v.ID, + OID: v.OID, + Name: v.Name, + Group: v.Group, + ParentId: *v.ParentID, + Content: v.Content, + } +} +func (v *Index) ToElement() *models.Element { + return &models.Element{ + ID: v.ID, + OID: v.OID, + Name: v.Name, + Group: v.Group, + ParentId: *v.ParentID, + Content: v.Content, + } +} diff --git a/pkg/vectorstore/pgvector/pgvector.go b/pkg/vectorstore/pgvector/pgvector.go index 539e899..9b10e13 100644 --- a/pkg/vectorstore/pgvector/pgvector.go +++ b/pkg/vectorstore/pgvector/pgvector.go @@ -69,36 +69,29 @@ func NewPgVectorClient(postgresUrl string) (*PgVectorClient, error) { }, nil } -func (p *PgVectorClient) Store(id, content string, metadata models.Metadata, extra map[string]interface{}, vectors []float32) error { - ctx := context.Background() +func (p *PgVectorClient) Store(ctx context.Context, element *models.Element, extra map[string]any) error { + return p.dEntity.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if extra == nil { + extra = make(map[string]interface{}) + } + extra["name"] = element.Name + extra["group"] = element.Group - if extra == nil { - extra = make(map[string]interface{}) - } - extra["category"] = metadata.Category - extra["group"] = metadata.Group + var m string + b, err := json.Marshal(extra) + if err != nil { + return err + } + m = string(b) + vectorJson, _ := json.Marshal(element.Vector) - var m string - b, err := json.Marshal(metadata) - if err != nil { - return err - } - m = string(b) - - vectorJson, _ := json.Marshal(vectors) - v := &Index{ - ID: id, - Name: metadata.Source, - ParentDir: metadata.ParentDir, - Context: content, - Metadata: m, - Vector: string(vectorJson), - CreatedAt: time.Now().UnixNano(), - ChangedAt: time.Now().UnixNano(), - } - return p.dEntity.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - vModel := Index{ID: id} - res := tx.First(vModel) + var v *Index + v = v.From(element) + v.Extra = m + v.Vector = string(vectorJson) + + vModel := &Index{} + res := tx.Where("name = ? AND group = ?", element.Name, element.Group).First(vModel) if res.Error != nil && res.Error != gorm.ErrRecordNotFound { return res.Error } @@ -112,7 +105,7 @@ func (p *PgVectorClient) Store(id, content string, metadata models.Metadata, ext } vModel.Update(v) - res = tx.Where("id = ?", id).Updates(vModel) + res = tx.Where("name = ? AND group = ?", element.Name, element.Group).Updates(vModel) if res.Error != nil || res.RowsAffected == 0 { if res.RowsAffected == 0 { return errors.New("operation conflict") @@ -123,11 +116,10 @@ func (p *PgVectorClient) Store(id, content string, metadata models.Metadata, ext }) } -func (p *PgVectorClient) Search(vectors []float32, k int) ([]models.Doc, error) { - ctx := context.Background() +func (p *PgVectorClient) Search(ctx context.Context, vectors []float32, k int) ([]*models.Doc, error) { var ( vectorModels = make([]Index, 0) - result = make([]models.Doc, 0) + result = make([]*models.Doc, 0) ) if err := p.dEntity.WithContext(ctx).Transaction(func(tx *gorm.DB) error { query := p.dEntity.DB.WithContext(ctx) @@ -142,36 +134,24 @@ func (p *PgVectorClient) Search(vectors []float32, k int) ([]models.Doc, error) } for _, v := range vectorModels { - metadata := make(map[string]interface{}) - if err := json.Unmarshal([]byte(v.Metadata), &metadata); err != nil { - return nil, err - } - result = append(result, models.Doc{ - Id: v.ID, - Metadata: metadata, - Content: v.Context, - }) + result = append(result, v.To()) } return result, nil } -func (p *PgVectorClient) Exist(id string) (bool, error) { - ctx := context.Background() - var exist = false +func (p *PgVectorClient) Get(ctx context.Context, name string, group int) (*models.Element, error) { + vModel := &Index{} err := p.dEntity.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - vModel := Index{ID: id} - res := tx.First(vModel) + res := tx.Where("name = ? AND group = ?", name, group).First(vModel) if res.Error != nil && res.Error != gorm.ErrRecordNotFound { return res.Error } - if res.Error == gorm.ErrRecordNotFound { - exist = false - return nil - } - exist = true return nil }) - return exist, err + if err != nil { + return nil, err + } + return vModel.ToElement(), err } diff --git a/pkg/vectorstore/postgres/migrate.go b/pkg/vectorstore/postgres/migrate.go index 15af737..e58e928 100644 --- a/pkg/vectorstore/postgres/migrate.go +++ b/pkg/vectorstore/postgres/migrate.go @@ -49,7 +49,9 @@ func buildMigrations() []*gormigrate.Migration { } func Migrate(db *gorm.DB) error { - m := gormigrate.New(db, gormigrate.DefaultOptions, buildMigrations()) + options := gormigrate.DefaultOptions + options.TableName = "friday_migrations" + m := gormigrate.New(db, options, buildMigrations()) err := m.Migrate() return err } diff --git a/pkg/vectorstore/postgres/model.go b/pkg/vectorstore/postgres/model.go index 303e136..a949f1f 100644 --- a/pkg/vectorstore/postgres/model.go +++ b/pkg/vectorstore/postgres/model.go @@ -16,15 +16,22 @@ package postgres -import "time" +import ( + "encoding/json" + "time" + + "github.com/basenana/friday/pkg/models" +) type Index struct { - ID string `gorm:"column:id;type:varchar(256);primaryKey"` - Name string `gorm:"column:name;type:varchar(256);index:source"` - ParentDir string `gorm:"column:parent_dir;type:varchar(256);index:parent_dir"` - Context string `gorm:"column:context"` - Metadata string `gorm:"column:metadata"` + ID string `gorm:"column:id;primaryKey"` + Name string `gorm:"column:name;type:varchar(256);index:index_name"` + OID int64 `gorm:"column:oid;index:index_oid"` + Group int `gorm:"column:group;index:index_group"` + ParentID *int64 `gorm:"column:parent_entry_id;index:index_parent_id"` + Content string `gorm:"column:content"` Vector string `gorm:"column:vector;type:json"` + Extra string `gorm:"column:extra"` CreatedAt int64 `gorm:"column:created_at"` ChangedAt int64 `gorm:"column:changed_at"` } @@ -36,13 +43,74 @@ func (v *Index) TableName() string { func (v *Index) Update(vector *Index) { v.ID = vector.ID v.Name = vector.Name - v.ParentDir = vector.ParentDir - v.Context = vector.Context - v.Metadata = vector.Metadata + v.OID = vector.OID + v.Group = vector.Group + v.ParentID = vector.ParentID + v.Content = vector.Content + v.Extra = vector.Extra v.Vector = vector.Vector v.ChangedAt = time.Now().UnixNano() } +func (v *Index) From(element *models.Element) (*Index, error) { + parentId := element.ParentId + i := &Index{ + ID: element.ID, + Name: element.Name, + OID: element.OID, + Group: element.Group, + Content: element.Content, + } + vector, err := json.Marshal(element.Vector) + if err != nil { + return nil, err + } + i.Vector = string(vector) + + if parentId != 0 { + i.ParentID = &parentId + } + return i, nil +} + +func (v *Index) To() (*models.Element, error) { + parentId := v.ParentID + res := &models.Element{ + ID: v.ID, + Name: v.Name, + Group: v.Group, + OID: v.OID, + Content: v.Content, + } + if parentId != nil { + res.ParentId = *parentId + } + var vector []float32 + err := json.Unmarshal([]byte(v.Vector), &vector) + if err != nil { + return nil, err + } + res.Vector = vector + + return res, nil +} + +func (v *Index) ToDoc() *models.Doc { + parentId := v.ParentID + res := &models.Doc{ + Id: v.ID, + OID: v.OID, + Name: v.Name, + Group: v.Group, + Content: v.Content, + } + if parentId != nil { + res.ParentId = *parentId + } + + return res +} + type BleveKV struct { ID string `gorm:"column:id;primaryKey"` Key []byte `gorm:"column:key"` diff --git a/pkg/vectorstore/postgres/postgres.go b/pkg/vectorstore/postgres/postgres.go index 7bd7029..be89059 100644 --- a/pkg/vectorstore/postgres/postgres.go +++ b/pkg/vectorstore/postgres/postgres.go @@ -20,10 +20,10 @@ import ( "context" "encoding/json" "errors" + "sort" "time" "github.com/cdipaolo/goml/base" - "github.com/cdipaolo/goml/cluster" "gorm.io/driver/postgres" "gorm.io/gorm" @@ -68,41 +68,36 @@ func NewPostgresClient(postgresUrl string) (*PostgresClient, error) { }, nil } -func (p *PostgresClient) Store(id, content string, metadata models.Metadata, extra map[string]interface{}, vectors []float32) error { - ctx := context.Background() +func (p *PostgresClient) Store(ctx context.Context, element *models.Element, extra map[string]any) error { + return p.dEntity.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if extra == nil { + extra = make(map[string]interface{}) + } + extra["name"] = element.Name + extra["group"] = element.Group - if extra == nil { - extra = make(map[string]interface{}) - } - extra["category"] = metadata.Category - extra["group"] = metadata.Group + b, err := json.Marshal(extra) + if err != nil { + return err + } - var m string - b, err := json.Marshal(metadata) - if err != nil { - return err - } - m = string(b) - - vectorJson, _ := json.Marshal(vectors) - v := &Index{ - ID: id, - Name: metadata.Source, - ParentDir: metadata.ParentDir, - Context: content, - Metadata: m, - Vector: string(vectorJson), - CreatedAt: time.Now().UnixNano(), - ChangedAt: time.Now().UnixNano(), - } - return p.dEntity.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - vModel := Index{ID: id} - res := tx.First(vModel) + var v *Index + v, err = v.From(element) + if err != nil { + return err + } + + v.Extra = string(b) + + vModel := &Index{} + res := tx.Where("name = ? AND group = ?", element.Name, element.Group).First(vModel) if res.Error != nil && res.Error != gorm.ErrRecordNotFound { return res.Error } if res.Error == gorm.ErrRecordNotFound { + v.CreatedAt = time.Now().UnixNano() + v.ChangedAt = time.Now().UnixNano() res = tx.Create(v) if res.Error != nil { return res.Error @@ -111,7 +106,7 @@ func (p *PostgresClient) Store(id, content string, metadata models.Metadata, ext } vModel.Update(v) - res = tx.Where("id = ?", id).Updates(vModel) + res = tx.Where("name = ? AND group = ?", element.Name, element.Group).Updates(vModel) if res.Error != nil || res.RowsAffected == 0 { if res.RowsAffected == 0 { return errors.New("operation conflict") @@ -122,28 +117,54 @@ func (p *PostgresClient) Store(id, content string, metadata models.Metadata, ext }) } -func (p *PostgresClient) Search(vectors []float32, k int) ([]models.Doc, error) { +func (p *PostgresClient) Search(ctx context.Context, vectors []float32, k int) ([]*models.Doc, error) { vectors64 := make([]float64, 0) for _, v := range vectors { vectors64 = append(vectors64, float64(v)) } // query from db - existVectors := [][]float64{} + existIndexes := make([]Index, 0) + res := p.dEntity.WithContext(ctx).Find(&existIndexes) + if res.Error != nil { + return nil, res.Error + } - model := cluster.NewKNN(k, existVectors, vectors64, base.EuclideanDistance) + // knn search + dists := distances{} + for _, index := range existIndexes { + var vector []float64 + err := json.Unmarshal([]byte(index.Vector), &vector) + if err != nil { + return nil, err + } - // make predictions like usual - _, err := model.Predict([]float64{-10, 1}) - if err != nil { - return nil, err + dists = append(dists, distance{ + Index: index, + dist: base.EuclideanDistance(vector, vectors64), + }) } - // todo - return nil, nil + + sort.Sort(dists) + + minKIndexes := dists[0:k] + results := make([]*models.Doc, k) + for _, index := range minKIndexes { + results = append(results, index.ToDoc()) + } + + return results, nil } -func (p *PostgresClient) Exist(id string) (bool, error) { - //TODO implement me - panic("implement me") +func (p *PostgresClient) Get(ctx context.Context, name string, group int) (*models.Element, error) { + vModel := &Index{} + res := p.dEntity.WithContext(ctx).Where("name = ? AND group = ?", name, group).First(vModel) + if res.Error != nil { + if res.Error == gorm.ErrRecordNotFound { + return nil, nil + } + return nil, res.Error + } + return vModel.To() } var _ vectorstore.VectorStore = &PostgresClient{} @@ -157,3 +178,22 @@ func (p *PostgresClient) Inited(ctx context.Context) (bool, error) { return count > 0, nil } + +type distance struct { + Index + dist float64 +} + +type distances []distance + +func (d distances) Len() int { + return len(d) +} + +func (d distances) Less(i, j int) bool { + return d[i].dist < d[j].dist +} + +func (d distances) Swap(i, j int) { + d[i], d[j] = d[j], d[i] +} diff --git a/pkg/vectorstore/redis/redis.go b/pkg/vectorstore/redis/redis.go index a5e6315..45c4f55 100644 --- a/pkg/vectorstore/redis/redis.go +++ b/pkg/vectorstore/redis/redis.go @@ -25,6 +25,7 @@ import ( "github.com/redis/rueidis" "github.com/basenana/friday/pkg/models" + "github.com/basenana/friday/pkg/utils/files" "github.com/basenana/friday/pkg/utils/logger" "github.com/basenana/friday/pkg/vectorstore" ) @@ -79,7 +80,11 @@ func (r RedisClient) initIndex() error { context.Background(), r.client.B().Arbitrary("FT.CREATE", r.index, "ON", "HASH", "PREFIX", "1", r.prefix, "SCHEMA"). Args("id", "TEXT"). - Args("metadata", "TEXT"). + Args("name", "TEXT"). + Args("group", "TEXT"). + Args("extra", "TEXT"). + Args("oid", "TEXT"). + Args("parentid", "TEXT"). Args("content", "TEXT"). Args("vector", "VECTOR", "FLAT", "6", "TYPE", "FLOAT32", "DIM", strconv.Itoa(r.dim), "DISTANCE_METRIC", "L2"). Build()).Error(); err != nil { @@ -88,42 +93,58 @@ func (r RedisClient) initIndex() error { return nil } -func (r RedisClient) Store(id, content string, metadata models.Metadata, extra map[string]interface{}, vectors []float32) error { - ctx := context.Background() - +func (r RedisClient) Store(ctx context.Context, element *models.Element, extra map[string]any) error { if extra == nil { extra = make(map[string]interface{}) } - extra["category"] = metadata.Category - extra["group"] = metadata.Group + extra["group"] = element.Group var m string - b, err := json.Marshal(metadata) + b, err := json.Marshal(extra) if err != nil { return err } m = string(b) - return r.client.Do(ctx, r.client.B().Hset().Key(fmt.Sprintf("%s:%s", r.prefix, id)).FieldValue(). - FieldValue("id", id). - FieldValue("metadata", m). - FieldValue("content", content). - FieldValue("vector", rueidis.VectorString32(vectors)).Build()).Error() + return r.client.Do(ctx, r.client.B().Hset().Key(fmt.Sprintf("%s:%s-%d", r.prefix, element.Name, element.Group)).FieldValue(). + FieldValue("id", element.ID). + FieldValue("name", element.Name). + FieldValue("group", strconv.Itoa(element.Group)). + FieldValue("extra", m). + FieldValue("oid", files.Int64ToStr(element.OID)). + FieldValue("parentid", files.Int64ToStr(element.ParentId)). + FieldValue("content", element.Content). + FieldValue("vector", rueidis.VectorString32(element.Vector)).Build()).Error() } -func (r RedisClient) Exist(id string) (exist bool, err error) { - ctx := context.Background() - resp := r.client.Do(ctx, r.client.B().Get().Key(fmt.Sprintf("%s:%s", r.prefix, id)).Build()) - if resp.RedisError() != nil && resp.RedisError().IsNil() { - exist = false - return +func (r RedisClient) Get(ctx context.Context, name string, group int) (*models.Element, error) { + resp, err := r.client.Do(ctx, r.client.B().Get().Key(fmt.Sprintf("%s:%s-%d", r.prefix, name, group)).Build()).ToMessage() + if err != nil { + return nil, err + } + res, err := resp.AsStrMap() + if err != nil { + return nil, err } - exist = true - return -} -func (r RedisClient) Search(vectors []float32, k int) ([]models.Doc, error) { - ctx := context.Background() + oid, err := files.StrToInt64(res["oid"]) + if err != nil { + return nil, err + } + parentId, err := files.StrToInt64(res["parentid"]) + if err != nil { + return nil, err + } + return &models.Element{ + ID: res["id"], + Name: res["name"], + Group: group, + OID: oid, + ParentId: parentId, + Content: res["content"], + }, nil +} +func (r RedisClient) Search(ctx context.Context, vectors []float32, k int) ([]*models.Doc, error) { resp, err := r.client.Do(ctx, r.client.B().FtSearch().Index(r.index). Query("*=>[KNN 10 @vector $B AS vector_score]"). Return("4").Identifier("id").Identifier("content"). @@ -135,7 +156,7 @@ func (r RedisClient) Search(vectors []float32, k int) ([]models.Doc, error) { if err != nil { return nil, err } - results := make([]models.Doc, 0) + results := make([]*models.Doc, 0) for i := 1; i < len(resp[1:]); i += 2 { res, err := resp[i+1].AsStrMap() @@ -146,9 +167,24 @@ func (r RedisClient) Search(vectors []float32, k int) ([]models.Doc, error) { if err := json.Unmarshal([]byte(res["metadata"]), &metadata); err != nil { return nil, err } - results = append(results, models.Doc{ + oid, err := files.StrToInt64(res["oid"]) + if err != nil { + return nil, err + } + parentId, err := files.StrToInt64(res["parentid"]) + if err != nil { + return nil, err + } + group, err := strconv.Atoi(res["group"]) + if err != nil { + return nil, err + } + results = append(results, &models.Doc{ Id: res["id"], - Metadata: metadata, + OID: oid, + Name: res["name"], + Group: group, + ParentId: parentId, Content: res["content"], }) r.log.Debugf("id: %s, content: %s, score: %s\n", res["id"], res["content"], res["vector_score"]) diff --git a/scripts/element.py b/scripts/element.py index f03472a..815d655 100644 --- a/scripts/element.py +++ b/scripts/element.py @@ -1,16 +1,15 @@ import json import os import sys -import uuid from typing import List import nltk +from unstructured.documents.elements import Text, ElementMetadata from unstructured.file_utils.filetype import FileType, detect_filetype from unstructured.partition.auto import partition from unstructured.partition.doc import partition_doc from unstructured.partition.html import partition_html from unstructured.partition.md import partition_md -from unstructured.documents.elements import Text, ElementMetadata script_path = os.path.dirname(__file__) nltk.data.path = [os.getenv("NLTK_DATA", os.path.join(os.path.dirname(__file__), "nltk_data"))] + nltk.data.path @@ -56,15 +55,15 @@ def _element(f, elements) -> List: for element in elements: if hasattr(element, "tag") and element.tag == "h2": group += 1 - metadata = { - "source": f, - "title": os.path.basename(f), + doc = { + "name": f, } if hasattr(element, "metadata"): - metadata.update(element.metadata.to_dict()) + doc.update(element.metadata.to_dict()) if hasattr(element, "category"): - metadata["category"] = element.category - res.append({"content": str(element), "metadata": metadata}) + doc["category"] = element.category + doc["content"] = str(element) + res.append(doc) return res From 92f85240ab12bc65ccfd78324bd98addfce9a3b9 Mon Sep 17 00:00:00 2001 From: zwwhdls Date: Wed, 6 Dec 2023 21:37:32 +0800 Subject: [PATCH 3/3] fix knn in pg client search Signed-off-by: zwwhdls --- config/config.go | 2 ++ pkg/build/common/init.go | 16 +++++++++++++++- pkg/build/withvector/init.go | 6 ++++++ pkg/friday/friday.go | 14 +++++++++----- pkg/friday/question.go | 4 ++-- pkg/friday/question_test.go | 2 ++ pkg/search/search.go | 9 ++++++--- pkg/utils/files/file.go | 2 +- pkg/utils/logger/dblogger.go | 4 ++-- pkg/vectorstore/pgvector/pgvector.go | 6 +++--- pkg/vectorstore/postgres/bleve.go | 4 +++- pkg/vectorstore/postgres/model.go | 6 +++--- pkg/vectorstore/postgres/postgres.go | 14 +++++++------- 13 files changed, 61 insertions(+), 28 deletions(-) diff --git a/config/config.go b/config/config.go index 8db040c..ee98986 100644 --- a/config/config.go +++ b/config/config.go @@ -74,6 +74,7 @@ type EmbeddingConfig struct { type VectorStoreConfig struct { VectorStoreType VectorStoreType `json:"vector_store_type"` VectorUrl string `json:"vector_url"` + TopK *int `json:"top_k,omitempty"` // topk of knn, default is 6 EmbeddingDim int `json:"embedding_dim,omitempty"` // embedding dimension, default is 1536 } @@ -102,4 +103,5 @@ type VectorStoreType string const ( VectorStoreRedis VectorStoreType = "redis" VectorStorePostgres VectorStoreType = "postgres" + VectorStorePGVector VectorStoreType = "pgvector" ) diff --git a/pkg/build/common/init.go b/pkg/build/common/init.go index 80bf4f0..9a6665b 100644 --- a/pkg/build/common/init.go +++ b/pkg/build/common/init.go @@ -20,12 +20,21 @@ import ( "github.com/basenana/friday/config" "github.com/basenana/friday/pkg/build/withvector" "github.com/basenana/friday/pkg/friday" + "github.com/basenana/friday/pkg/utils/logger" "github.com/basenana/friday/pkg/vectorstore" "github.com/basenana/friday/pkg/vectorstore/pgvector" + "github.com/basenana/friday/pkg/vectorstore/postgres" "github.com/basenana/friday/pkg/vectorstore/redis" ) func NewFriday(conf *config.Config) (f *friday.Friday, err error) { + log := conf.Logger + if conf.Logger == nil { + log = logger.NewLogger("friday") + } + log.SetDebug(conf.Debug) + conf.Logger = log + var vectorStore vectorstore.VectorStore // init vector store if conf.VectorStoreConfig.VectorStoreType == config.VectorStoreRedis { @@ -40,8 +49,13 @@ func NewFriday(conf *config.Config) (f *friday.Friday, err error) { return nil, err } } + } else if conf.VectorStoreConfig.VectorStoreType == config.VectorStorePGVector { + vectorStore, err = pgvector.NewPgVectorClient(conf.Logger, conf.VectorStoreConfig.VectorUrl) + if err != nil { + return nil, err + } } else if conf.VectorStoreConfig.VectorStoreType == config.VectorStorePostgres { - vectorStore, err = pgvector.NewPgVectorClient(conf.VectorStoreConfig.VectorUrl) + vectorStore, err = postgres.NewPostgresClient(conf.Logger, conf.VectorStoreConfig.VectorUrl) if err != nil { return nil, err } diff --git a/pkg/build/withvector/init.go b/pkg/build/withvector/init.go index bc24187..2db4385 100644 --- a/pkg/build/withvector/init.go +++ b/pkg/build/withvector/init.go @@ -75,6 +75,11 @@ func NewFridayWithVector(conf *config.Config, vectorClient vectorstore.VectorSto conf.VectorStoreConfig.EmbeddingDim = len(testEmbed) } + defaultVectorTopK := friday.DefaultTopK + if conf.VectorStoreConfig.TopK == nil { + conf.VectorStoreConfig.TopK = &defaultVectorTopK + } + // init text spliter chunkSize := spliter.DefaultChunkSize overlapSize := spliter.DefaultChunkOverlap @@ -97,6 +102,7 @@ func NewFridayWithVector(conf *config.Config, vectorClient vectorstore.VectorSto Prompts: prompts, Embedding: embeddingModel, Vector: vectorClient, + VectorTopK: conf.VectorStoreConfig.TopK, Spliter: textSpliter, } return diff --git a/pkg/friday/friday.go b/pkg/friday/friday.go index 03aa283..f4617a5 100644 --- a/pkg/friday/friday.go +++ b/pkg/friday/friday.go @@ -25,7 +25,7 @@ import ( ) const ( - defaultTopK = 6 + DefaultTopK = 6 questionPromptKey = "question" keywordsPromptKey = "keywords" wechatPromptKey = "wechat" @@ -41,9 +41,13 @@ type Friday struct { LimitToken int - LLM llm.LLM - Prompts map[string]string + LLM llm.LLM + Prompts map[string]string + Embedding embedding.Embedding - Vector vectorstore.VectorStore - Spliter spliter.Spliter + + Vector vectorstore.VectorStore + VectorTopK *int + + Spliter spliter.Spliter } diff --git a/pkg/friday/question.go b/pkg/friday/question.go index ced8a41..326d3c2 100644 --- a/pkg/friday/question.go +++ b/pkg/friday/question.go @@ -25,7 +25,7 @@ import ( ) func (f *Friday) Question(ctx context.Context, q string) (string, error) { - prompt := prompts.NewQuestionPrompt(questionPromptKey) + prompt := prompts.NewQuestionPrompt(f.Prompts[questionPromptKey]) c, err := f.searchDocs(ctx, q) if err != nil { return "", err @@ -50,7 +50,7 @@ func (f *Friday) searchDocs(ctx context.Context, q string) (string, error) { if err != nil { return "", fmt.Errorf("vector embedding error: %w", err) } - docs, err := f.Vector.Search(ctx, qv, defaultTopK) + docs, err := f.Vector.Search(ctx, qv, *f.VectorTopK) if err != nil { return "", fmt.Errorf("vector search error: %w", err) } diff --git a/pkg/friday/question_test.go b/pkg/friday/question_test.go index 0791e59..850cddb 100644 --- a/pkg/friday/question_test.go +++ b/pkg/friday/question_test.go @@ -37,12 +37,14 @@ var _ = Describe("TestQuestion", func() { ) BeforeEach(func() { + topk := 6 loFriday.Vector = FakeStore{} loFriday.Log = logger.NewLogger("test-question") loFriday.Spliter = spliter.NewTextSpliter(loFriday.Log, spliter.DefaultChunkSize, spliter.DefaultChunkOverlap, "\n") loFriday.Embedding = FakeQuestionEmbedding{} loFriday.LLM = FakeQuestionLLM{} loFriday.Vector = FakeQuestionStore{} + loFriday.VectorTopK = &topk }) Context("question", func() { diff --git a/pkg/search/search.go b/pkg/search/search.go index 0116a5c..d260feb 100644 --- a/pkg/search/search.go +++ b/pkg/search/search.go @@ -2,10 +2,13 @@ package search import ( "context" - "github.com/basenana/friday/pkg/vectorstore/postgres" + "os" + "github.com/blevesearch/bleve/v2" "github.com/blevesearch/bleve/v2/index/upsidedown" - "os" + + "github.com/basenana/friday/pkg/utils/logger" + "github.com/basenana/friday/pkg/vectorstore/postgres" ) var singleIndex bleve.Index @@ -17,7 +20,7 @@ func InitSearchEngine() error { return err } - pgCli, err := postgres.NewPostgresClient(dsn) + pgCli, err := postgres.NewPostgresClient(logger.NewLogger("database"), dsn) if err != nil { return err } diff --git a/pkg/utils/files/file.go b/pkg/utils/files/file.go index b721047..4188396 100644 --- a/pkg/utils/files/file.go +++ b/pkg/utils/files/file.go @@ -48,7 +48,7 @@ func ReadFiles(ps string) (docs map[string]string, err error) { } return } - if !strings.HasSuffix(p.Name(), ".md") && !strings.HasSuffix(p.Name(), ".txt") { + if !strings.HasSuffix(p.Name(), ".md") && !strings.HasSuffix(p.Name(), ".txt") && !strings.HasSuffix(p.Name(), ".html") { return } doc, err := os.ReadFile(ps) diff --git a/pkg/utils/logger/dblogger.go b/pkg/utils/logger/dblogger.go index 30bd97b..0afe201 100644 --- a/pkg/utils/logger/dblogger.go +++ b/pkg/utils/logger/dblogger.go @@ -55,6 +55,6 @@ func (l *DBLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql st } } -func NewDbLogger() *DBLogger { - return &DBLogger{NewLogger("database")} +func NewDbLogger(log Logger) *DBLogger { + return &DBLogger{log} } diff --git a/pkg/vectorstore/pgvector/pgvector.go b/pkg/vectorstore/pgvector/pgvector.go index 9b10e13..6e6290d 100644 --- a/pkg/vectorstore/pgvector/pgvector.go +++ b/pkg/vectorstore/pgvector/pgvector.go @@ -39,8 +39,8 @@ type PgVectorClient struct { var _ vectorstore.VectorStore = &PgVectorClient{} -func NewPgVectorClient(postgresUrl string) (*PgVectorClient, error) { - dbObj, err := gorm.Open(postgres.Open(postgresUrl), &gorm.Config{Logger: logger.NewDbLogger()}) +func NewPgVectorClient(log logger.Logger, postgresUrl string) (*PgVectorClient, error) { + dbObj, err := gorm.Open(postgres.Open(postgresUrl), &gorm.Config{Logger: logger.NewDbLogger(log)}) if err != nil { panic(err) } @@ -64,7 +64,7 @@ func NewPgVectorClient(postgresUrl string) (*PgVectorClient, error) { } return &PgVectorClient{ - log: logger.NewLogger("postgres"), + log: log, dEntity: dbEnt, }, nil } diff --git a/pkg/vectorstore/postgres/bleve.go b/pkg/vectorstore/postgres/bleve.go index 2e6b576..cf4f059 100644 --- a/pkg/vectorstore/postgres/bleve.go +++ b/pkg/vectorstore/postgres/bleve.go @@ -9,6 +9,8 @@ import ( "github.com/blevesearch/bleve/v2/registry" "github.com/blevesearch/upsidedown_store_api" "gorm.io/gorm" + + "github.com/basenana/friday/pkg/utils/logger" ) const ( @@ -24,7 +26,7 @@ func pgKVStoreConstructor(mo store.MergeOperator, config map[string]interface{}) if !ok { return nil, fmt.Errorf("dsn not found") } - pgCli, err := NewPostgresClient(dsnStr.(string)) + pgCli, err := NewPostgresClient(logger.NewLogger("bleve"), dsnStr.(string)) if err != nil { return nil, err } diff --git a/pkg/vectorstore/postgres/model.go b/pkg/vectorstore/postgres/model.go index a949f1f..76ce900 100644 --- a/pkg/vectorstore/postgres/model.go +++ b/pkg/vectorstore/postgres/model.go @@ -24,10 +24,10 @@ import ( ) type Index struct { - ID string `gorm:"column:id;primaryKey"` - Name string `gorm:"column:name;type:varchar(256);index:index_name"` + ID string `gorm:"column:id;type:varchar(256);primaryKey"` + Name string `gorm:"column:name;index:index_name"` OID int64 `gorm:"column:oid;index:index_oid"` - Group int `gorm:"column:group;index:index_group"` + Group int `gorm:"column:idx_group;index:index_group"` ParentID *int64 `gorm:"column:parent_entry_id;index:index_parent_id"` Content string `gorm:"column:content"` Vector string `gorm:"column:vector;type:json"` diff --git a/pkg/vectorstore/postgres/postgres.go b/pkg/vectorstore/postgres/postgres.go index be89059..02f4d65 100644 --- a/pkg/vectorstore/postgres/postgres.go +++ b/pkg/vectorstore/postgres/postgres.go @@ -38,8 +38,8 @@ type PostgresClient struct { dEntity *db.Entity } -func NewPostgresClient(postgresUrl string) (*PostgresClient, error) { - dbObj, err := gorm.Open(postgres.Open(postgresUrl), &gorm.Config{Logger: logger.NewDbLogger()}) +func NewPostgresClient(log logger.Logger, postgresUrl string) (*PostgresClient, error) { + dbObj, err := gorm.Open(postgres.Open(postgresUrl), &gorm.Config{Logger: logger.NewDbLogger(log)}) if err != nil { panic(err) } @@ -63,7 +63,7 @@ func NewPostgresClient(postgresUrl string) (*PostgresClient, error) { } return &PostgresClient{ - log: logger.NewLogger("postgres"), + log: log, dEntity: dbEnt, }, nil } @@ -90,7 +90,7 @@ func (p *PostgresClient) Store(ctx context.Context, element *models.Element, ext v.Extra = string(b) vModel := &Index{} - res := tx.Where("name = ? AND group = ?", element.Name, element.Group).First(vModel) + res := tx.Where("name = ? AND idx_group = ?", element.Name, element.Group).First(vModel) if res.Error != nil && res.Error != gorm.ErrRecordNotFound { return res.Error } @@ -106,7 +106,7 @@ func (p *PostgresClient) Store(ctx context.Context, element *models.Element, ext } vModel.Update(v) - res = tx.Where("name = ? AND group = ?", element.Name, element.Group).Updates(vModel) + res = tx.Where("name = ? AND idx_group = ?", element.Name, element.Group).Updates(vModel) if res.Error != nil || res.RowsAffected == 0 { if res.RowsAffected == 0 { return errors.New("operation conflict") @@ -147,7 +147,7 @@ func (p *PostgresClient) Search(ctx context.Context, vectors []float32, k int) ( sort.Sort(dists) minKIndexes := dists[0:k] - results := make([]*models.Doc, k) + results := make([]*models.Doc, 0) for _, index := range minKIndexes { results = append(results, index.ToDoc()) } @@ -157,7 +157,7 @@ func (p *PostgresClient) Search(ctx context.Context, vectors []float32, k int) ( func (p *PostgresClient) Get(ctx context.Context, name string, group int) (*models.Element, error) { vModel := &Index{} - res := p.dEntity.WithContext(ctx).Where("name = ? AND group = ?", name, group).First(vModel) + res := p.dEntity.WithContext(ctx).Where("name = ? AND idx_group = ?", name, group).First(vModel) if res.Error != nil { if res.Error == gorm.ErrRecordNotFound { return nil, nil