Skip to content

Commit

Permalink
Merge pull request #31 from basenana/feature/knn
Browse files Browse the repository at this point in the history
support knn in pg client
  • Loading branch information
zwwhdls authored Dec 6, 2023
2 parents 564ca73 + 92f8524 commit d93081b
Show file tree
Hide file tree
Showing 35 changed files with 803 additions and 399 deletions.
2 changes: 2 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ type EmbeddingConfig struct {
type VectorStoreConfig struct {
VectorStoreType VectorStoreType `json:"vector_store_type"`
VectorUrl string `json:"vector_url"`
TopK *int `json:"top_k,omitempty"` // topk of knn, default is 6
EmbeddingDim int `json:"embedding_dim,omitempty"` // embedding dimension, default is 1536
}

Expand Down Expand Up @@ -102,4 +103,5 @@ type VectorStoreType string
const (
VectorStoreRedis VectorStoreType = "redis"
VectorStorePostgres VectorStoreType = "postgres"
VectorStorePGVector VectorStoreType = "pgvector"
)
2 changes: 1 addition & 1 deletion flow/operator/ingest.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func (i *ingestOperator) Do(ctx context.Context, param *flow.Parameter) error {
source := i.spec.Parameters["source"]
knowledge := i.spec.Parameters["knowledge"]
doc := models.File{
Source: source,
Name: source,
Content: knowledge,
}
return friday.Fri.IngestFromFile(context.TODO(), doc)
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ require (
)

require (
github.com/cdipaolo/goml v0.0.0-20220715001353-00e0c845ae1c // indirect
github.com/RoaringBitmap/roaring v1.2.3 // indirect
github.com/bits-and-blooms/bitset v1.2.0 // indirect
github.com/blevesearch/bleve_index_api v1.0.6 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ github.com/RoaringBitmap/roaring v1.2.3 h1:yqreLINqIrX22ErkKI0vY47/ivtJr6n+kMhVO
github.com/RoaringBitmap/roaring v1.2.3/go.mod h1:plvDsJQpxOC5bw8LRteu/MLWHsHez/3y6cubLI4/1yE=
github.com/basenana/go-flow v0.0.0-20230801131009-d05f1f41b706 h1:FxXoMwMZsufBjSZg8yWpqfV6FNs5F0JuLnO+iJxojnw=
github.com/basenana/go-flow v0.0.0-20230801131009-d05f1f41b706/go.mod h1:Rs13PWsg/ITdXRiVJcI+yS0iqCfNHxCbIFEt5DCt/RQ=
github.com/cdipaolo/goml v0.0.0-20220715001353-00e0c845ae1c h1:uqJXOhayPfl/QruVBP6VF0KUWNDzO/F14X8CPEkkFD8=
github.com/cdipaolo/goml v0.0.0-20220715001353-00e0c845ae1c/go.mod h1:Ue8jgVLdBDCtsh1laikvraXqXzKCyKiruCcCcaeNDFE=
github.com/bits-and-blooms/bitset v1.2.0 h1:Kn4yilvwNtMACtf1eYDlG8H77R07mZSPbMjLyS07ChA=
github.com/bits-and-blooms/bitset v1.2.0/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA=
github.com/blevesearch/bleve/v2 v2.3.10 h1:z8V0wwGoL4rp7nG/O3qVVLYxUqCbEwskMt4iRJsPLgg=
Expand Down
16 changes: 15 additions & 1 deletion pkg/build/common/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,21 @@ import (
"github.com/basenana/friday/config"
"github.com/basenana/friday/pkg/build/withvector"
"github.com/basenana/friday/pkg/friday"
"github.com/basenana/friday/pkg/utils/logger"
"github.com/basenana/friday/pkg/vectorstore"
"github.com/basenana/friday/pkg/vectorstore/pgvector"
"github.com/basenana/friday/pkg/vectorstore/postgres"
"github.com/basenana/friday/pkg/vectorstore/redis"
)

func NewFriday(conf *config.Config) (f *friday.Friday, err error) {
log := conf.Logger
if conf.Logger == nil {
log = logger.NewLogger("friday")
}
log.SetDebug(conf.Debug)
conf.Logger = log

var vectorStore vectorstore.VectorStore
// init vector store
if conf.VectorStoreConfig.VectorStoreType == config.VectorStoreRedis {
Expand All @@ -40,8 +49,13 @@ func NewFriday(conf *config.Config) (f *friday.Friday, err error) {
return nil, err
}
}
} else if conf.VectorStoreConfig.VectorStoreType == config.VectorStorePGVector {
vectorStore, err = pgvector.NewPgVectorClient(conf.Logger, conf.VectorStoreConfig.VectorUrl)
if err != nil {
return nil, err
}
} else if conf.VectorStoreConfig.VectorStoreType == config.VectorStorePostgres {
vectorStore, err = postgres.NewPostgresClient(conf.VectorStoreConfig.VectorUrl)
vectorStore, err = postgres.NewPostgresClient(conf.Logger, conf.VectorStoreConfig.VectorUrl)
if err != nil {
return nil, err
}
Expand Down
6 changes: 6 additions & 0 deletions pkg/build/withvector/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ func NewFridayWithVector(conf *config.Config, vectorClient vectorstore.VectorSto
conf.VectorStoreConfig.EmbeddingDim = len(testEmbed)
}

defaultVectorTopK := friday.DefaultTopK
if conf.VectorStoreConfig.TopK == nil {
conf.VectorStoreConfig.TopK = &defaultVectorTopK
}

// init text spliter
chunkSize := spliter.DefaultChunkSize
overlapSize := spliter.DefaultChunkOverlap
Expand All @@ -97,6 +102,7 @@ func NewFridayWithVector(conf *config.Config, vectorClient vectorstore.VectorSto
Prompts: prompts,
Embedding: embeddingModel,
Vector: vectorClient,
VectorTopK: conf.VectorStoreConfig.TopK,
Spliter: textSpliter,
}
return
Expand Down
14 changes: 9 additions & 5 deletions pkg/friday/friday.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
)

const (
defaultTopK = 6
DefaultTopK = 6
questionPromptKey = "question"
keywordsPromptKey = "keywords"
wechatPromptKey = "wechat"
Expand All @@ -41,9 +41,13 @@ type Friday struct {

LimitToken int

LLM llm.LLM
Prompts map[string]string
LLM llm.LLM
Prompts map[string]string

Embedding embedding.Embedding
Vector vectorstore.VectorStore
Spliter spliter.Spliter

Vector vectorstore.VectorStore
VectorTopK *int

Spliter spliter.Spliter
}
55 changes: 25 additions & 30 deletions pkg/friday/ingest.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,10 @@ package friday

import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"

"github.com/google/uuid"

"github.com/basenana/friday/pkg/models"
"github.com/basenana/friday/pkg/utils/files"
Expand All @@ -34,17 +30,15 @@ import (
// IngestFromFile ingest a whole file providing models.File
func (f *Friday) IngestFromFile(ctx context.Context, file models.File) error {
elements := []models.Element{}
parentDir := filepath.Dir(file.Source)
// split doc
subDocs := f.Spliter.Split(file.Content)
for i, subDoc := range subDocs {
e := models.Element{
Content: subDoc,
Metadata: models.Metadata{
Source: file.Source,
Group: strconv.Itoa(i),
ParentDir: parentDir,
},
Name: file.Name,
Group: i,
OID: file.OID,
ParentId: file.ParentId,
Content: subDoc,
}
elements = append(elements, e)
}
Expand All @@ -56,15 +50,12 @@ func (f *Friday) IngestFromFile(ctx context.Context, file models.File) error {
func (f *Friday) Ingest(ctx context.Context, elements []models.Element) error {
f.Log.Debugf("Ingesting %d ...", len(elements))
for i, element := range elements {
// id: sha256(source)-group
h := sha256.New()
h.Write([]byte(element.Metadata.Source))
val := hex.EncodeToString(h.Sum(nil))[:64]
id := fmt.Sprintf("%s-%s", val, element.Metadata.Group)
if exist, err := f.Vector.Exist(id); err != nil {
exist, err := f.Vector.Get(ctx, element.Name, element.Group)
if err != nil {
return err
} else if exist {
f.Log.Debugf("vector %d(th) id(%s) source(%s) exist, skip ...", i, id, element.Metadata.Source)
}
if exist != nil && exist.Content == element.Content {
f.Log.Debugf("vector %d(th) name(%s) group(%d) exist, skip ...", i, element.Name, element.Group)
continue
}

Expand All @@ -73,10 +64,17 @@ func (f *Friday) Ingest(ctx context.Context, elements []models.Element) error {
return err
}

t := strings.TrimSpace(element.Content)
if exist != nil {
element.ID = exist.ID
element.OID = exist.OID
element.ParentId = exist.ParentId
} else {
element.ID = uuid.New().String()
}
element.Vector = vectors

f.Log.Debugf("store %d(th) vector id (%s) source(%s) ...", i, id, element.Metadata.Source)
if err := f.Vector.Store(id, t, element.Metadata, m, vectors); err != nil {
f.Log.Debugf("store %d(th) vector name(%s) group(%d) ...", i, element.Name, element.Group)
if err := f.Vector.Store(ctx, &element, m); err != nil {
return err
}
}
Expand All @@ -90,6 +88,7 @@ func (f *Friday) IngestFromElementFile(ctx context.Context, ps string) error {
return err
}
elements := []models.Element{}

if err := json.Unmarshal(doc, &elements); err != nil {
return err
}
Expand All @@ -106,16 +105,12 @@ func (f *Friday) IngestFromOriginFile(ctx context.Context, ps string) error {

elements := []models.Element{}
for n, file := range fs {
parentDir := filepath.Dir(n)
subDocs := f.Spliter.Split(file)
for i, subDoc := range subDocs {
e := models.Element{
Content: subDoc,
Metadata: models.Metadata{
Source: n,
Group: strconv.Itoa(i),
ParentDir: parentDir,
},
Name: n,
Group: i,
}
elements = append(elements, e)
}
Expand Down
17 changes: 7 additions & 10 deletions pkg/friday/ingest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,8 @@ var _ = Describe("TestIngest", func() {
elements := []models.Element{
{
Content: "test-content",
Metadata: models.Metadata{
Source: "test-source",
Title: "test-title",
ParentDir: "/",
},
Name: "test-title",
Group: 0,
},
}
err := loFriday.Ingest(context.TODO(), elements)
Expand All @@ -63,16 +60,16 @@ type FakeStore struct{}

var _ vectorstore.VectorStore = &FakeStore{}

func (f FakeStore) Store(id, content string, metadata models.Metadata, extra map[string]interface{}, vectors []float32) error {
func (f FakeStore) Store(ctx context.Context, element *models.Element, extra map[string]any) error {
return nil
}

func (f FakeStore) Search(vectors []float32, k int) ([]models.Doc, error) {
return []models.Doc{}, nil
func (f FakeStore) Search(ctx context.Context, vectors []float32, k int) ([]*models.Doc, error) {
return []*models.Doc{}, nil
}

func (f FakeStore) Exist(id string) (bool, error) {
return false, nil
func (f FakeStore) Get(ctx context.Context, name string, group int) (*models.Element, error) {
return &models.Element{}, nil
}

type FakeEmbedding struct{}
Expand Down
8 changes: 4 additions & 4 deletions pkg/friday/question.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
)

func (f *Friday) Question(ctx context.Context, q string) (string, error) {
prompt := prompts.NewQuestionPrompt(questionPromptKey)
prompt := prompts.NewQuestionPrompt(f.Prompts[questionPromptKey])
c, err := f.searchDocs(ctx, q)
if err != nil {
return "", err
Expand All @@ -50,14 +50,14 @@ func (f *Friday) searchDocs(ctx context.Context, q string) (string, error) {
if err != nil {
return "", fmt.Errorf("vector embedding error: %w", err)
}
contexts, err := f.Vector.Search(qv, defaultTopK)
docs, err := f.Vector.Search(ctx, qv, *f.VectorTopK)
if err != nil {
return "", fmt.Errorf("vector search error: %w", err)
}

cs := []string{}
for _, c := range contexts {
f.Log.Debugf("searched from [%s] for %s", c.Metadata["source"], c.Content)
for _, c := range docs {
f.Log.Debugf("searched from [%s] for %s", c.Name, c.Content)
cs = append(cs, c.Content)
}
return strings.Join(cs, "\n"), nil
Expand Down
17 changes: 9 additions & 8 deletions pkg/friday/question_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@ var _ = Describe("TestQuestion", func() {
)

BeforeEach(func() {
topk := 6
loFriday.Vector = FakeStore{}
loFriday.Log = logger.NewLogger("test-question")
loFriday.Spliter = spliter.NewTextSpliter(loFriday.Log, spliter.DefaultChunkSize, spliter.DefaultChunkOverlap, "\n")
loFriday.Embedding = FakeQuestionEmbedding{}
loFriday.LLM = FakeQuestionLLM{}
loFriday.Vector = FakeQuestionStore{}
loFriday.VectorTopK = &topk
})

Context("question", func() {
Expand All @@ -63,20 +65,19 @@ type FakeQuestionStore struct{}

var _ vectorstore.VectorStore = &FakeQuestionStore{}

func (f FakeQuestionStore) Store(id, content string, metadata models.Metadata, extra map[string]interface{}, vectors []float32) error {
func (f FakeQuestionStore) Store(ctx context.Context, element *models.Element, extra map[string]any) error {
return nil
}

func (f FakeQuestionStore) Search(vectors []float32, k int) ([]models.Doc, error) {
return []models.Doc{{
Id: "abc",
Metadata: map[string]interface{}{},
Content: "There are logs of questions",
func (f FakeQuestionStore) Search(ctx context.Context, vectors []float32, k int) ([]*models.Doc, error) {
return []*models.Doc{{
Id: "abc",
Content: "There are logs of questions",
}}, nil
}

func (f FakeQuestionStore) Exist(id string) (bool, error) {
return false, nil
func (f FakeQuestionStore) Get(ctx context.Context, name string, group int) (*models.Element, error) {
return &models.Element{}, nil
}

type FakeQuestionEmbedding struct{}
Expand Down
8 changes: 4 additions & 4 deletions pkg/friday/summary.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ func (f *Friday) Summary(ctx context.Context, elements []models.Element, summary

docs := make(map[string][]string)
for _, element := range elements {
if _, ok := docs[element.Metadata.Source]; !ok {
docs[element.Metadata.Source] = []string{element.Content}
if _, ok := docs[element.Name]; !ok {
docs[element.Name] = []string{element.Content}
} else {
docs[element.Metadata.Source] = append(docs[element.Metadata.Source], element.Content)
docs[element.Name] = append(docs[element.Name], element.Content)
}
}
for source, doc := range docs {
Expand All @@ -57,7 +57,7 @@ func (f *Friday) SummaryFromFile(ctx context.Context, file models.File, summaryT
return nil, err
}
return map[string]string{
file.Source: summaryOfFile,
file.Name: summaryOfFile,
}, err
}

Expand Down
Loading

0 comments on commit d93081b

Please sign in to comment.