From 65c3d3952624bed51439e069b5571f1b8a7d2c31 Mon Sep 17 00:00:00 2001 From: zwwhdls Date: Sun, 15 Dec 2024 23:18:43 +0800 Subject: [PATCH] feat: support pg as doc store Signed-off-by: zwwhdls --- api/doc.go | 186 ++---- api/request.go | 305 ++-------- config/config.go | 14 +- pkg/build/common/init.go | 12 +- pkg/build/withvector/init.go | 4 +- pkg/dispatch/plugin/header.go | 2 +- pkg/friday/friday.go | 4 +- pkg/friday/ingest_test.go | 4 +- pkg/friday/question_test.go | 4 +- pkg/models/doc/attr.go | 44 -- pkg/models/doc/document.go | 134 +++-- pkg/models/doc/query.go | 100 ---- pkg/models/pagination.go | 67 +++ pkg/models/types.go | 33 ++ pkg/search/config.go | 2 +- pkg/search/search.go | 7 +- pkg/service/chain.go | 247 ++------ pkg/service/chain_test.go | 194 ++---- pkg/store/{vectorstore => }/db/entity.go | 0 pkg/store/docstore/meili.go | 257 -------- pkg/store/docstore/mock.go | 232 -------- pkg/store/interface.go | 38 ++ pkg/store/meili/meili.go | 354 +++++++++++ pkg/store/meili/mock.go | 236 ++++++++ pkg/store/meili/model.go | 552 ++++++++++++++++++ .../{vectorstore => }/pgvector/migrate.go | 0 pkg/store/{vectorstore => }/pgvector/model.go | 0 .../{vectorstore => }/pgvector/pgvector.go | 9 +- pkg/store/{vectorstore => }/postgres/bleve.go | 4 +- pkg/store/postgres/document.go | 179 ++++++ .../{vectorstore => }/postgres/migrate.go | 17 + pkg/store/{vectorstore => }/postgres/model.go | 77 +++ pkg/store/postgres/postgres.go | 67 +++ .../postgres.go => postgres/vector.go} | 89 +-- pkg/store/{vectorstore => }/redis/redis.go | 10 +- pkg/store/utils/utils.go | 74 +++ pkg/store/vectorstore/interface.go | 29 - .../docstore/interface.go => utils/knn.go} | 26 +- pkg/utils/logger/dblogger.go | 60 -- 39 files changed, 2085 insertions(+), 1588 deletions(-) create mode 100644 pkg/models/pagination.go create mode 100644 pkg/models/types.go rename pkg/store/{vectorstore => }/db/entity.go (100%) delete mode 100644 pkg/store/docstore/meili.go delete mode 100644 pkg/store/docstore/mock.go create mode 100644 pkg/store/interface.go create mode 100644 pkg/store/meili/meili.go create mode 100644 pkg/store/meili/mock.go create mode 100644 pkg/store/meili/model.go rename pkg/store/{vectorstore => }/pgvector/migrate.go (100%) rename pkg/store/{vectorstore => }/pgvector/model.go (100%) rename pkg/store/{vectorstore => }/pgvector/pgvector.go (95%) rename pkg/store/{vectorstore => }/postgres/bleve.go (97%) create mode 100644 pkg/store/postgres/document.go rename pkg/store/{vectorstore => }/postgres/migrate.go (78%) rename pkg/store/{vectorstore => }/postgres/model.go (54%) create mode 100644 pkg/store/postgres/postgres.go rename pkg/store/{vectorstore/postgres/postgres.go => postgres/vector.go} (68%) rename pkg/store/{vectorstore => }/redis/redis.go (94%) create mode 100644 pkg/store/utils/utils.go delete mode 100644 pkg/store/vectorstore/interface.go rename pkg/{store/docstore/interface.go => utils/knn.go} (57%) delete mode 100644 pkg/utils/logger/dblogger.go diff --git a/api/doc.go b/api/doc.go index a63c9fa..e545b80 100644 --- a/api/doc.go +++ b/api/doc.go @@ -20,9 +20,11 @@ import ( "fmt" "strconv" "strings" + "time" "github.com/gin-gonic/gin" + "github.com/basenana/friday/pkg/models" "github.com/basenana/friday/pkg/models/doc" "github.com/basenana/friday/pkg/utils" ) @@ -36,29 +38,22 @@ func (s *HttpServer) store() gin.HandlerFunc { c.JSON(400, gin.H{"error": err.Error()}) return } + body.Namespace = namespace + enId, _ := strconv.Atoi(entryId) + body.EntryId = int64(enId) if err := body.Valid(); err != nil { c.JSON(400, gin.H{"error": err.Error()}) return } - body.Namespace = namespace - body.EntryId = entryId // store the document - doc := body.ToDocument() - if err := s.chain.Store(c, doc); err != nil { + if err := s.chain.CreateDocument(c, &body.Document); err != nil { if strings.Contains(err.Error(), "already exists") { return } c.String(500, fmt.Sprintf("store document error: %s", err)) return } - attrs := body.ToAttr() - for _, attr := range attrs { - if err := s.chain.StoreAttr(c, attr); err != nil { - c.String(500, fmt.Sprintf("create document attr error: %s", err)) - return - } - } - c.JSON(200, doc) + c.JSON(200, body.Document) } } @@ -66,20 +61,18 @@ func (s *HttpServer) update() gin.HandlerFunc { return func(c *gin.Context) { entryId := c.Param("entryId") namespace := c.Param("namespace") - body := &DocAttrRequest{} + body := &DocUpdateRequest{} if err := c.ShouldBindJSON(&body); err != nil { c.JSON(400, gin.H{"error": err.Error()}) return } body.Namespace = namespace - body.EntryId = entryId + enId, _ := strconv.Atoi(entryId) + body.EntryId = int64(enId) // update the document - attrs := body.ToDocAttr() - for _, attr := range attrs { - if err := s.chain.StoreAttr(c, attr); err != nil { - c.String(500, fmt.Sprintf("update document error: %s", err)) - return - } + if err := s.chain.UpdateDocument(c, body.ToModel()); err != nil { + c.String(500, fmt.Sprintf("update document error: %s", err)) + return } c.JSON(200, body) } @@ -89,8 +82,13 @@ func (s *HttpServer) get() gin.HandlerFunc { return func(c *gin.Context) { namespace := c.Param("namespace") entryId := c.Param("entryId") - document, err := s.chain.GetDocument(c, namespace, entryId) + enId, _ := strconv.Atoi(entryId) + document, err := s.chain.GetDocument(c, namespace, int64(enId)) if err != nil { + if err == models.ErrNotFound { + c.String(404, fmt.Sprintf("document not found: %s", entryId)) + return + } c.String(500, fmt.Sprintf("get document error: %s", err)) return } @@ -98,30 +96,7 @@ func (s *HttpServer) get() gin.HandlerFunc { c.String(404, fmt.Sprintf("document not found: %s", entryId)) return } - docWithAttr := &DocumentWithAttr{ - Document: document, - } - - attrs, err := s.chain.GetDocumentAttrs(c, namespace, entryId) - if err != nil { - c.String(500, fmt.Sprintf("get document attrs error: %s", err)) - return - } - for _, attr := range attrs { - if attr.Key == "parentId" { - docWithAttr.ParentID = attr.Value.(string) - } - if attr.Key == "mark" { - marked := attr.Value.(bool) - docWithAttr.Mark = &marked - } - if attr.Key == "unRead" { - unRead := attr.Value.(bool) - docWithAttr.UnRead = &unRead - } - } - - c.JSON(200, docWithAttr) + c.JSON(200, document) } } @@ -131,55 +106,22 @@ func (s *HttpServer) filter() gin.HandlerFunc { if docQuery == nil { return } - docs, err := s.chain.Search(c, docQuery.ToQuery(), docQuery.GetAttrQueries()) - if err != nil { - c.String(500, fmt.Sprintf("search document error: %s", err)) - return - } - var docWithAttrs []DocumentWithAttr - ids := []string{} - for _, doc := range docs { - ids = append(ids, doc.EntryId) - } - allAttrs, err := s.chain.ListDocumentAttrs(c, docQuery.Namespace, ids) + ctx := models.WithPagination(c, models.NewPagination(docQuery.Page, docQuery.PageSize)) + + docs, err := s.chain.Search(ctx, docQuery) if err != nil { - c.String(500, fmt.Sprintf("list document attrs error: %s", err)) + c.String(500, fmt.Sprintf("search document error: %s", err)) return } - attrsMap := map[string][]*doc.DocumentAttr{} - for _, attr := range allAttrs { - if attrsMap[attr.EntryId] == nil { - attrsMap[attr.EntryId] = []*doc.DocumentAttr{} - } - attrsMap[attr.EntryId] = append(attrsMap[attr.EntryId], attr) - } - for _, document := range docs { - docWithAttr := DocumentWithAttr{Document: document} - attrs := attrsMap[document.EntryId] - for _, attr := range attrs { - if attr.Key == "parentId" { - docWithAttr.ParentID = attr.Value.(string) - } - if attr.Key == "mark" { - marked := attr.Value.(bool) - docWithAttr.Mark = &marked - } - if attr.Key == "unRead" { - unRead := attr.Value.(bool) - docWithAttr.UnRead = &unRead - } - } - docWithAttrs = append(docWithAttrs, docWithAttr) - } - c.JSON(200, docWithAttrs) + c.JSON(200, docs) } } -func getFilterQuery(c *gin.Context) *DocQuery { +func getFilterQuery(c *gin.Context) *doc.DocumentFilter { namespace := c.Param("namespace") - page, err := strconv.Atoi(c.DefaultQuery("page", "0")) + page, err := strconv.Atoi(c.DefaultQuery("page", "1")) if err != nil { c.String(400, fmt.Sprintf("invalid page number: %s", c.Query("page"))) return nil @@ -190,18 +132,35 @@ func getFilterQuery(c *gin.Context) *DocQuery { return nil } - docQuery := DocQuery{ - Namespace: namespace, - Source: c.Query("source"), - WebUrl: c.Query("webUrl"), - ParentID: c.Query("parentID"), - Search: c.Query("search"), - HitsPerPage: int64(pageSize), - Page: int64(page), - Sort: c.DefaultQuery("sort", "createdAt"), - Desc: c.DefaultQuery("desc", "false") == "true", + sort, err := strconv.Atoi(c.DefaultQuery("sort", "4")) + if err != nil { + c.String(400, fmt.Sprintf("invalid sort: %s", c.Query("page"))) + return nil + } + docQuery := &doc.DocumentFilter{ + Namespace: namespace, + Search: c.Query("search"), + FuzzyName: c.Query("fuzzyName"), + Source: c.Query("source"), + Marked: utils.ToPtr(c.Query("mark") == "true"), + Unread: utils.ToPtr(c.Query("unRead") == "true"), + Page: int64(page), + PageSize: int64(pageSize), + Order: doc.DocumentOrder{ + Order: doc.DocOrder(sort), + Desc: c.Query("desc") == "true", + }, } + parentId := c.Query("parentId") + if parentId != "" { + pId, err := strconv.Atoi(c.Query("parentId")) + if err != nil { + c.String(400, fmt.Sprintf("invalid parentId: %s", c.Query("page"))) + return nil + } + docQuery.ParentID = utils.ToPtr(int64(pId)) + } createAtStart := c.Query("createAtStart") if createAtStart != "" { createAtStartTimestamp, err := strconv.Atoi(createAtStart) @@ -209,7 +168,7 @@ func getFilterQuery(c *gin.Context) *DocQuery { c.String(400, fmt.Sprintf("invalid createAtStart: %s", c.Query("page"))) return nil } - docQuery.CreatedAtStart = utils.ToPtr(int64(createAtStartTimestamp)) + docQuery.CreatedAtStart = utils.ToPtr(time.Unix(int64(createAtStartTimestamp), 0)) } createAtEnd := c.Query("createAtEnd") if createAtEnd != "" { @@ -218,7 +177,7 @@ func getFilterQuery(c *gin.Context) *DocQuery { c.String(400, fmt.Sprintf("invalid createAtEnd: %s", c.Query("page"))) return nil } - docQuery.ChangedAtEnd = utils.ToPtr(int64(createAtEndTimestamp)) + docQuery.ChangedAtEnd = utils.ToPtr(time.Unix(int64(createAtEndTimestamp), 0)) } updatedAtStart := c.Query("updatedAtStart") if updatedAtStart != "" { @@ -227,7 +186,7 @@ func getFilterQuery(c *gin.Context) *DocQuery { c.String(400, fmt.Sprintf("invalid updatedAtStart: %s", c.Query("page"))) return nil } - docQuery.ChangedAtStart = utils.ToPtr(int64(updatedAtStartTimestamp)) + docQuery.ChangedAtStart = utils.ToPtr(time.Unix(int64(updatedAtStartTimestamp), 0)) } updatedAtEnd := c.Query("updatedAtEnd") if updatedAtEnd != "" { @@ -236,39 +195,18 @@ func getFilterQuery(c *gin.Context) *DocQuery { c.String(400, fmt.Sprintf("invalid updatedAtEnd: %s", c.Query("page"))) return nil } - docQuery.ChangedAtEnd = utils.ToPtr(int64(updatedAtEndTimestamp)) - } - fuzzyName := c.Query("fuzzyName") - if fuzzyName != "" { - docQuery.FuzzyName = &fuzzyName - } - if c.Query("unRead") != "" { - docQuery.UnRead = utils.ToPtr(c.Query("unRead") == "true") + docQuery.ChangedAtEnd = utils.ToPtr(time.Unix(int64(updatedAtEndTimestamp), 0)) } - if c.Query("mark") != "" { - docQuery.Mark = utils.ToPtr(c.Query("mark") == "true") - } - return &docQuery + return docQuery } func (s *HttpServer) delete() gin.HandlerFunc { return func(c *gin.Context) { namespace := c.Param("namespace") - queries := []*doc.AttrQuery{} entryId := c.Param("entryId") - queries = append(queries, - &doc.AttrQuery{ - Attr: "entryId", - Option: "=", - Value: entryId, - }, - &doc.AttrQuery{ - Attr: "namespace", - Option: "=", - Value: namespace, - }, - ) - if err := s.chain.DeleteByFilter(c, doc.DocumentAttrQuery{AttrQueries: queries}); err != nil { + + enId, _ := strconv.Atoi(entryId) + if err := s.chain.Delete(c, namespace, int64(enId)); err != nil { c.String(500, fmt.Sprintf("delete document error: %s", err)) return } diff --git a/api/request.go b/api/request.go index e4f755a..c9fdc85 100644 --- a/api/request.go +++ b/api/request.go @@ -18,75 +18,17 @@ package api import ( "fmt" - - "github.com/google/uuid" + "time" "github.com/basenana/friday/pkg/models/doc" ) type DocRequest struct { - EntryId string `json:"entryId,omitempty"` - Name string `json:"name"` - Namespace string `json:"namespace"` - Source string `json:"source,omitempty"` - WebUrl string `json:"webUrl,omitempty"` - Content string `json:"content"` - UnRead *bool `json:"unRead,omitempty"` - Mark *bool `json:"mark,omitempty"` - ParentID string `json:"parentId,omitempty"` - CreatedAt int64 `json:"createdAt,omitempty"` - ChangedAt int64 `json:"changedAt,omitempty"` -} - -func (r *DocRequest) ToDocument() *doc.Document { - return &doc.Document{ - Id: uuid.New().String(), - EntryId: r.EntryId, - Name: r.Name, - Kind: "document", - Namespace: r.Namespace, - Source: r.Source, - WebUrl: r.WebUrl, - Content: r.Content, - CreatedAt: r.CreatedAt, - UpdatedAt: r.ChangedAt, - } -} - -func (r *DocRequest) ToAttr() doc.DocumentAttrList { - attrs := doc.DocumentAttrList{} - if r.ParentID != "" { - attrs = append(attrs, &doc.DocumentAttr{ - Id: uuid.New().String(), - Namespace: r.Namespace, - EntryId: r.EntryId, - Key: "parentId", - Value: r.ParentID, - }) - } - if r.Mark != nil { - attrs = append(attrs, &doc.DocumentAttr{ - Id: uuid.New().String(), - Namespace: r.Namespace, - EntryId: r.EntryId, - Key: "mark", - Value: *r.Mark, - }) - } - if r.UnRead != nil { - attrs = append(attrs, &doc.DocumentAttr{ - Id: uuid.New().String(), - Namespace: r.Namespace, - EntryId: r.EntryId, - Key: "unRead", - Value: *r.UnRead, - }) - } - return attrs + doc.Document } func (r *DocRequest) Valid() error { - if r.EntryId == "" || r.EntryId == "0" { + if r.EntryId == 0 { return fmt.Errorf("entryId is required") } if r.Namespace == "" { @@ -95,211 +37,74 @@ func (r *DocRequest) Valid() error { return nil } -type DocAttrRequest struct { +type DocUpdateRequest struct { Namespace string `json:"namespace"` - EntryId string `json:"entryId,omitempty"` - ParentID string `json:"parentId,omitempty"` + EntryId int64 `json:"entryId,omitempty"` + ParentID *int64 `json:"parentId,omitempty"` UnRead *bool `json:"unRead,omitempty"` Mark *bool `json:"mark,omitempty"` } -func (r *DocAttrRequest) ToDocAttr() []*doc.DocumentAttr { - attrs := []*doc.DocumentAttr{} - if r.ParentID != "" { - attrs = append(attrs, &doc.DocumentAttr{ - Id: uuid.New().String(), - Namespace: r.Namespace, - EntryId: r.EntryId, - Key: "parentId", - Value: r.ParentID, - }) +func (r *DocUpdateRequest) Valid() error { + if r.EntryId == 0 { + return fmt.Errorf("entryId is required") } - if r.Mark != nil { - attrs = append(attrs, &doc.DocumentAttr{ - Id: uuid.New().String(), - Namespace: r.Namespace, - EntryId: r.EntryId, - Key: "mark", - Value: *r.Mark, - }) + if r.Namespace == "" { + return fmt.Errorf("namespace is required") } - if r.UnRead != nil { - attrs = append(attrs, &doc.DocumentAttr{ - Id: uuid.New().String(), - Namespace: r.Namespace, - EntryId: r.EntryId, - Key: "unRead", - Value: *r.UnRead, - }) + return nil +} +func (r *DocUpdateRequest) ToModel() *doc.Document { + return &doc.Document{ + EntryId: r.EntryId, + Namespace: r.Namespace, + ParentEntryID: r.ParentID, + Marked: r.Mark, + Unread: r.UnRead, } - return attrs } type DocQuery struct { - IDs []string `json:"ids"` - Namespace string `json:"namespace"` - Source string `json:"source,omitempty"` - WebUrl string `json:"webUrl,omitempty"` - ParentID string `json:"parentId,omitempty"` - UnRead *bool `json:"unRead,omitempty"` - Mark *bool `json:"mark,omitempty"` - CreatedAtStart *int64 `json:"createdAtStart,omitempty"` - CreatedAtEnd *int64 `json:"createdAtEnd,omitempty"` - ChangedAtStart *int64 `json:"changedAtStart,omitempty"` - ChangedAtEnd *int64 `json:"changedAtEnd,omitempty"` - FuzzyName *string `json:"fuzzyName,omitempty"` + EntryIds []int64 `json:"entryIds,omitempty"` + Namespace string `json:"namespace"` + Source string `json:"source,omitempty"` + WebUrl string `json:"webUrl,omitempty"` + ParentID *int64 `json:"parentId,omitempty"` + UnRead *bool `json:"unRead,omitempty"` + Mark *bool `json:"mark,omitempty"` + CreatedAtStart *time.Time `json:"createdAtStart,omitempty"` + CreatedAtEnd *time.Time `json:"createdAtEnd,omitempty"` + ChangedAtStart *time.Time `json:"changedAtStart,omitempty"` + ChangedAtEnd *time.Time `json:"changedAtEnd,omitempty"` + FuzzyName string `json:"fuzzyName,omitempty"` Search string `json:"search"` - HitsPerPage int64 `json:"hitsPerPage,omitempty"` - Page int64 `json:"page,omitempty"` - Limit int64 `json:"limit,omitempty"` - Sort string `json:"sort,omitempty"` - Desc bool `json:"desc,omitempty"` + PageSize int64 `json:"PageSize,omitempty"` + Page int64 `json:"page,omitempty"` + Sort int `json:"sort,omitempty"` + Desc bool `json:"desc,omitempty"` } -func (q *DocQuery) ToQuery() *doc.DocumentQuery { - query := &doc.DocumentQuery{ - Search: q.Search, - HitsPerPage: q.HitsPerPage, - Page: q.Page, - Sort: []doc.Sort{{ - Attr: q.Sort, - Asc: !q.Desc, - }}, - } - attrQueries := []*doc.AttrQuery{{ - Attr: "namespace", - Option: "=", - Value: q.Namespace, - }} - if q.Source != "" { - attrQueries = append(attrQueries, &doc.AttrQuery{ - Attr: "source", - Option: "=", - Value: q.Source, - }) - } - if q.WebUrl != "" { - attrQueries = append(attrQueries, &doc.AttrQuery{ - Attr: "webUrl", - Option: "=", - Value: q.WebUrl, - }) - } - if q.CreatedAtStart != nil { - attrQueries = append(attrQueries, &doc.AttrQuery{ - Attr: "createdAt", - Option: ">=", - Value: *q.CreatedAtStart, - }) - } - if q.ChangedAtStart != nil { - attrQueries = append(attrQueries, &doc.AttrQuery{ - Attr: "updatedAt", - Option: ">=", - Value: *q.ChangedAtStart, - }) +func (q *DocQuery) ToModel() *doc.DocumentFilter { + return &doc.DocumentFilter{ + Namespace: q.Namespace, + Search: q.Search, + FuzzyName: q.FuzzyName, + ParentID: q.ParentID, + Source: q.Source, + Marked: q.Mark, + Unread: q.UnRead, + CreatedAtStart: q.CreatedAtStart, + CreatedAtEnd: q.CreatedAtEnd, + ChangedAtStart: q.ChangedAtStart, + ChangedAtEnd: q.ChangedAtEnd, + Page: q.Page, + PageSize: q.PageSize, + Order: doc.DocumentOrder{ + Order: doc.DocOrder(q.Sort), + Desc: q.Desc, + }, } - if q.CreatedAtEnd != nil { - attrQueries = append(attrQueries, &doc.AttrQuery{ - Attr: "createdAt", - Option: "<=", - Value: *q.CreatedAtEnd, - }) - } - if q.ChangedAtEnd != nil { - attrQueries = append(attrQueries, &doc.AttrQuery{ - Attr: "updatedAt", - Option: "<=", - Value: *q.ChangedAtEnd, - }) - } - if q.FuzzyName != nil { - attrQueries = append(attrQueries, &doc.AttrQuery{ - Attr: "name", - Option: "CONTAINS", - Value: *q.FuzzyName, - }) - } - - query.AttrQueries = attrQueries - return query -} - -func (q *DocQuery) GetAttrQueries() []*doc.DocumentAttrQuery { - attrQueries := []*doc.DocumentAttrQuery{} - if q.UnRead != nil { - attrQueries = append(attrQueries, &doc.DocumentAttrQuery{ - AttrQueries: []*doc.AttrQuery{ - { - Attr: "namespace", - Option: "=", - Value: q.Namespace, - }, - { - Attr: "key", - Option: "=", - Value: "unRead", - }, - { - Attr: "value", - Option: "=", - Value: *q.UnRead, - }, - }, - }) - } - if q.Mark != nil { - attrQueries = append(attrQueries, &doc.DocumentAttrQuery{ - AttrQueries: []*doc.AttrQuery{ - { - Attr: "namespace", - Option: "=", - Value: q.Namespace, - }, - { - Attr: "key", - Option: "=", - Value: "mark", - }, - { - Attr: "value", - Option: "=", - Value: *q.Mark, - }, - }, - }) - } - if q.ParentID != "" { - attrQueries = append(attrQueries, &doc.DocumentAttrQuery{ - AttrQueries: []*doc.AttrQuery{ - { - Attr: "namespace", - Option: "=", - Value: q.Namespace, - }, - { - Attr: "key", - Option: "=", - Value: "parentId", - }, - { - Attr: "value", - Option: "=", - Value: q.ParentID, - }, - }, - }) - } - return attrQueries -} - -type DocumentWithAttr struct { - *doc.Document - - ParentID string `json:"parentId,omitempty"` - UnRead *bool `json:"unRead,omitempty"` - Mark *bool `json:"mark,omitempty"` } diff --git a/config/config.go b/config/config.go index fd59589..e6f7473 100644 --- a/config/config.go +++ b/config/config.go @@ -31,8 +31,8 @@ type Config struct { // plugins Plugins []string `json:"plugins,omitempty"` - // meilisearch - MeiliConfig MeiliConfig `json:"meiliConfig,omitempty"` + // docStore + DocStore DocStoreConfig `json:"docStore,omitempty"` // llm limit token LimitToken int `json:"limitToken,omitempty"` // used by summary, split input into mutil sub-docs summaried by llm separately. @@ -67,6 +67,16 @@ type MeiliConfig struct { AttrIndex string `json:"attrIndex,omitempty"` } +type DocStoreConfig struct { + Type string `json:"type"` + MeiliConfig MeiliConfig `json:"meiliConfig,omitempty"` + PostgresConfig PostgresConfig `json:"postgresConfig,omitempty"` +} + +type PostgresConfig struct { + DSN string `json:"dsn,omitempty"` +} + type LLMConfig struct { LLMType LLMType `json:"llmType"` Prompts map[string]string `json:"prompts,omitempty"` diff --git a/pkg/build/common/init.go b/pkg/build/common/init.go index 44ee9d8..24cd9a0 100644 --- a/pkg/build/common/init.go +++ b/pkg/build/common/init.go @@ -20,10 +20,10 @@ import ( "github.com/basenana/friday/config" "github.com/basenana/friday/pkg/build/withvector" "github.com/basenana/friday/pkg/friday" - "github.com/basenana/friday/pkg/store/vectorstore" - "github.com/basenana/friday/pkg/store/vectorstore/pgvector" - "github.com/basenana/friday/pkg/store/vectorstore/postgres" - "github.com/basenana/friday/pkg/store/vectorstore/redis" + "github.com/basenana/friday/pkg/store" + "github.com/basenana/friday/pkg/store/pgvector" + "github.com/basenana/friday/pkg/store/postgres" + "github.com/basenana/friday/pkg/store/redis" "github.com/basenana/friday/pkg/utils/logger" ) @@ -35,7 +35,7 @@ func NewFriday(conf *config.Config) (f *friday.Friday, err error) { log.SetDebug(conf.Debug) conf.Logger = log - var vectorStore vectorstore.VectorStore + var vectorStore store.VectorStore // init vector store if conf.VectorStoreConfig.VectorStoreType == config.VectorStoreRedis { if conf.VectorStoreConfig.EmbeddingDim == 0 { @@ -55,7 +55,7 @@ func NewFriday(conf *config.Config) (f *friday.Friday, err error) { return nil, err } } else if conf.VectorStoreConfig.VectorStoreType == config.VectorStorePostgres { - vectorStore, err = postgres.NewPostgresClient(conf.Logger, conf.VectorStoreConfig.VectorUrl) + vectorStore, err = postgres.NewPostgresClient(conf.VectorStoreConfig.VectorUrl) if err != nil { return nil, err } diff --git a/pkg/build/withvector/init.go b/pkg/build/withvector/init.go index 0e82edd..3bce078 100644 --- a/pkg/build/withvector/init.go +++ b/pkg/build/withvector/init.go @@ -30,11 +30,11 @@ import ( "github.com/basenana/friday/pkg/llm/client/glm-6b" openaiv1 "github.com/basenana/friday/pkg/llm/client/openai/v1" "github.com/basenana/friday/pkg/spliter" - "github.com/basenana/friday/pkg/store/vectorstore" + "github.com/basenana/friday/pkg/store" "github.com/basenana/friday/pkg/utils/logger" ) -func NewFridayWithVector(conf *config.Config, vectorClient vectorstore.VectorStore) (f *friday.Friday, err error) { +func NewFridayWithVector(conf *config.Config, vectorClient store.VectorStore) (f *friday.Friday, err error) { log := conf.Logger if conf.Logger == nil { log = logger.NewLogger("friday") diff --git a/pkg/dispatch/plugin/header.go b/pkg/dispatch/plugin/header.go index 4ffff22..76bf408 100644 --- a/pkg/dispatch/plugin/header.go +++ b/pkg/dispatch/plugin/header.go @@ -39,7 +39,7 @@ func (h *HeaderImgPlugin) Run(ctx context.Context, doc *doc.Document) error { var headerImgUrl string query, err := goquery.NewDocumentFromReader(bytes.NewReader([]byte(doc.Content))) if err != nil { - return fmt.Errorf("build doc query with id %s error: %s", doc.Id, err) + return fmt.Errorf("build doc query with id %d error: %s", doc.EntryId, err) } query.Find("img").EachWithBreak(func(i int, selection *goquery.Selection) bool { diff --git a/pkg/friday/friday.go b/pkg/friday/friday.go index e151359..5427a05 100644 --- a/pkg/friday/friday.go +++ b/pkg/friday/friday.go @@ -25,7 +25,7 @@ import ( "github.com/basenana/friday/pkg/models" "github.com/basenana/friday/pkg/models/vector" "github.com/basenana/friday/pkg/spliter" - "github.com/basenana/friday/pkg/store/vectorstore" + "github.com/basenana/friday/pkg/store" "github.com/basenana/friday/pkg/utils/logger" ) @@ -53,7 +53,7 @@ type Friday struct { Embedding embedding.Embedding - Vector vectorstore.VectorStore + Vector store.VectorStore VectorTopK *int Spliter spliter.Spliter diff --git a/pkg/friday/ingest_test.go b/pkg/friday/ingest_test.go index 295c25c..3e63259 100644 --- a/pkg/friday/ingest_test.go +++ b/pkg/friday/ingest_test.go @@ -25,7 +25,7 @@ import ( "github.com/basenana/friday/pkg/embedding" "github.com/basenana/friday/pkg/models/vector" "github.com/basenana/friday/pkg/spliter" - "github.com/basenana/friday/pkg/store/vectorstore" + "github.com/basenana/friday/pkg/store" "github.com/basenana/friday/pkg/utils/logger" ) @@ -59,7 +59,7 @@ var _ = Describe("TestIngest", func() { type FakeStore struct{} -var _ vectorstore.VectorStore = &FakeStore{} +var _ store.VectorStore = &FakeStore{} func (f FakeStore) Store(ctx context.Context, element *vector.Element, extra map[string]any) error { return nil diff --git a/pkg/friday/question_test.go b/pkg/friday/question_test.go index 35f2346..1b593a9 100644 --- a/pkg/friday/question_test.go +++ b/pkg/friday/question_test.go @@ -27,7 +27,7 @@ import ( "github.com/basenana/friday/pkg/llm/prompts" "github.com/basenana/friday/pkg/models/vector" "github.com/basenana/friday/pkg/spliter" - "github.com/basenana/friday/pkg/store/vectorstore" + "github.com/basenana/friday/pkg/store" "github.com/basenana/friday/pkg/utils/logger" ) @@ -203,7 +203,7 @@ var _ = Describe("TestQuestion", func() { type FakeQuestionStore struct{} -var _ vectorstore.VectorStore = &FakeQuestionStore{} +var _ store.VectorStore = &FakeQuestionStore{} func (f FakeQuestionStore) Store(ctx context.Context, element *vector.Element, extra map[string]any) error { return nil diff --git a/pkg/models/doc/attr.go b/pkg/models/doc/attr.go index 307e30d..2b2caef 100644 --- a/pkg/models/doc/attr.go +++ b/pkg/models/doc/attr.go @@ -15,47 +15,3 @@ */ package doc - -import "fmt" - -type DocumentAttr struct { - Id string `json:"id"` - Kind string `json:"kind"` - Namespace string `json:"namespace"` - EntryId string `json:"entryId"` - Key string `json:"key"` - Value interface{} `json:"value"` -} - -var ( - DocAttrFilterableAttrs = []string{"namespace", "entryId", "key", "id", "kind", "value"} - DocAttrSortAttrs = []string{"createdAt", "updatedAt"} -) - -var _ DocPtrInterface = &DocumentAttr{} - -func (d *DocumentAttr) ID() string { - return d.Id -} - -func (d *DocumentAttr) EntryID() string { - return d.EntryId -} - -func (d *DocumentAttr) Type() string { - return "attr" -} - -func (d *DocumentAttr) String() string { - return fmt.Sprintf("EntryId(%s) %s: %v", d.EntryId, d.Key, d.Value) -} - -type DocumentAttrList []*DocumentAttr - -func (d DocumentAttrList) String() string { - result := "" - for _, attr := range d { - result += fmt.Sprintf("EntryId(%s) %s: %v\n", attr.EntryId, attr.Key, attr.Value) - } - return result -} diff --git a/pkg/models/doc/document.go b/pkg/models/doc/document.go index 9f98bca..eb02471 100644 --- a/pkg/models/doc/document.go +++ b/pkg/models/doc/document.go @@ -18,62 +18,110 @@ package doc import ( "fmt" + "time" ) -var ( - DocFilterableAttrs = []string{"namespace", "id", "entryId", "kind", "name", "source", "webUrl", "createdAt", "updatedAt"} - DocSortAttrs = []string{"createdAt", "updatedAt", "name"} -) - -type DocPtrInterface interface { - ID() string - EntryID() string - Type() string - String() string -} - type Document struct { - Id string `json:"id"` - Kind string `json:"kind"` - Namespace string `json:"namespace"` - EntryId string `json:"entryId"` - Name string `json:"name"` - Source string `json:"source,omitempty"` - WebUrl string `json:"webUrl,omitempty"` - - Content string `json:"content"` - Summary string `json:"summary,omitempty"` - HeaderImage string `json:"headerImage,omitempty"` - SubContent string `json:"subContent,omitempty"` - - CreatedAt int64 `json:"createdAt,omitempty"` - UpdatedAt int64 `json:"updatedAt,omitempty"` + EntryId int64 `json:"entry_id"` + Name string `json:"name"` + Namespace string `json:"namespace"` + ParentEntryID *int64 `json:"parent_entry_id"` + Source string `json:"source"` + Content string `json:"content,omitempty"` + Summary string `json:"summary,omitempty"` + WebUrl string `json:"web_url,omitempty"` + HeaderImage string `json:"header_image,omitempty"` + SubContent string `json:"sub_content,omitempty"` + Marked *bool `json:"marked,omitempty"` + Unread *bool `json:"unread,omitempty"` + CreatedAt time.Time `json:"created_at"` + ChangedAt time.Time `json:"changed_at"` } -func (d *Document) ID() string { - return d.Id +type DocumentFilter struct { + Namespace string + Search string + FuzzyName string + ParentID *int64 + Source string + Marked *bool + Unread *bool + CreatedAtStart *time.Time + CreatedAtEnd *time.Time + ChangedAtStart *time.Time + ChangedAtEnd *time.Time + + // Pagination + Page int64 + PageSize int64 + Order DocumentOrder } -func (d *Document) EntryID() string { - return d.EntryId +func (f *DocumentFilter) String() string { + s := fmt.Sprintf("namespace: %s", f.Namespace) + if f.Search != "" { + s += fmt.Sprintf(", search: %s", f.Search) + } + if f.FuzzyName != "" { + s += fmt.Sprintf(", fuzzyName: %s", f.FuzzyName) + } + if f.ParentID != nil { + s += fmt.Sprintf(", parentID: %d", *f.ParentID) + } + if f.Source != "" { + s += fmt.Sprintf(", source: %s", f.Source) + } + if f.Marked != nil { + s += fmt.Sprintf(", marked: %v", *f.Marked) + } + if f.Unread != nil { + s += fmt.Sprintf(", unread: %v", *f.Unread) + } + if f.CreatedAtStart != nil { + s += fmt.Sprintf(", createdAtStart: %s", f.CreatedAtStart) + } + if f.CreatedAtEnd != nil { + s += fmt.Sprintf(", createdAtEnd: %s", f.CreatedAtEnd) + } + if f.ChangedAtStart != nil { + s += fmt.Sprintf(", changedAtStart: %s", f.ChangedAtStart) + } + if f.ChangedAtEnd != nil { + s += fmt.Sprintf(", changedAtEnd: %s", f.ChangedAtEnd) + } + s += fmt.Sprintf(", page: %d, pageSize: %d, sort: %s", f.Page, f.PageSize, f.Order.String()) + return s } -func (d *Document) Type() string { - return "document" +type DocumentOrder struct { + Order DocOrder + Desc bool } -func (d *Document) String() string { - return fmt.Sprintf("EntryId(%s) %s", d.EntryId, d.Name) +func (o DocumentOrder) String() string { + return fmt.Sprintf("order: %s, desc: %v", o.Order.String(), o.Desc) } -type DocumentList []*Document +type DocOrder int + +const ( + Name DocOrder = iota + Source + Marked + Unread + CreatedAt +) -func (d DocumentList) String() string { - result := "" - for _, doc := range d { - result += fmt.Sprintf("EntryId(%s) %s\n", doc.EntryId, doc.Name) +func (d DocOrder) String() string { + names := []string{ + "name", + "source", + "marked", + "unread", + "created_at", } - return result + if d < Name || d > CreatedAt { + return "" + } + return names[d] } - -var _ DocPtrInterface = &Document{} diff --git a/pkg/models/doc/query.go b/pkg/models/doc/query.go index 7c64404..2b2caef 100644 --- a/pkg/models/doc/query.go +++ b/pkg/models/doc/query.go @@ -15,103 +15,3 @@ */ package doc - -import ( - "encoding/json" - "fmt" - - "github.com/meilisearch/meilisearch-go" -) - -type DocumentQuery struct { - AttrQueries []*AttrQuery - - Search string - HitsPerPage int64 - Page int64 - Offset int64 - Limit int64 - Sort []Sort -} - -type Sort struct { - Attr string - Asc bool -} - -func (s *Sort) String() string { - if s.Asc { - return fmt.Sprintf("%s:asc", s.Attr) - } - return fmt.Sprintf("%s:desc", s.Attr) -} - -type DocumentAttrQuery struct { - AttrQueries []*AttrQuery -} - -func (q *DocumentAttrQuery) String() string { - result := "" - for _, aq := range q.AttrQueries { - result += aq.String() + " " - } - return result -} - -type AttrQuery struct { - Attr string - Option string - Value interface{} -} - -func (aq *AttrQuery) ToFilter() interface{} { - vs, _ := json.Marshal(aq.Value) - return fmt.Sprintf("%s %s %s", aq.Attr, aq.Option, vs) -} - -func (aq *AttrQuery) String() string { - return aq.ToFilter().(string) -} - -func (q *DocumentQuery) String() string { - filters := "" - for _, aq := range q.AttrQueries { - filters += aq.String() + " " - } - return fmt.Sprintf("search: [%s], attr query: [%s]", q.Search, filters) -} - -func (q *DocumentQuery) ToRequest() *meilisearch.SearchRequest { - // build filter - filter := []interface{}{} - for _, aq := range q.AttrQueries { - filter = append(filter, aq.ToFilter()) - } - sorts := []string{} - for _, s := range q.Sort { - sorts = append(sorts, s.String()) - } - - return &meilisearch.SearchRequest{ - Offset: q.Offset, - Limit: q.Limit, - Sort: sorts, - HitsPerPage: q.HitsPerPage, - Page: q.Page, - Query: q.Search, - Filter: filter, - } -} - -func (q *DocumentAttrQuery) ToRequest() *meilisearch.SearchRequest { - filter := []interface{}{} - for _, aq := range q.AttrQueries { - filter = append(filter, aq.ToFilter()) - } - return &meilisearch.SearchRequest{ - Filter: filter, - Limit: 10000, - HitsPerPage: 10000, - Query: "", - } -} diff --git a/pkg/models/pagination.go b/pkg/models/pagination.go new file mode 100644 index 0000000..70d52cf --- /dev/null +++ b/pkg/models/pagination.go @@ -0,0 +1,67 @@ +/* + Copyright 2024 NanaFS Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package models + +import ( + "context" +) + +const ( + PageKey = "page" + PageSizeKey = "pageSize" +) + +type Pagination struct { + Page int64 + PageSize int64 +} + +func (p *Pagination) Limit() int { + return int(p.PageSize) +} + +func (p *Pagination) Offset() int { + return int((p.Page - 1) * p.PageSize) +} + +func NewPagination(page, pageSize int64) *Pagination { + if page > 0 && pageSize > 0 { + return &Pagination{ + Page: page, + PageSize: pageSize, + } + } + return nil +} + +func GetPagination(ctx context.Context) *Pagination { + if ctx.Value(PageKey) != nil && ctx.Value(PageSizeKey) != nil { + return &Pagination{ + Page: ctx.Value(PageKey).(int64), + PageSize: ctx.Value(PageSizeKey).(int64), + } + } + return nil +} + +func WithPagination(ctx context.Context, page *Pagination) context.Context { + if page != nil { + ctx = context.WithValue(ctx, PageKey, page.Page) + ctx = context.WithValue(ctx, PageSizeKey, page.PageSize) + } + return ctx +} diff --git a/pkg/models/types.go b/pkg/models/types.go new file mode 100644 index 0000000..786de17 --- /dev/null +++ b/pkg/models/types.go @@ -0,0 +1,33 @@ +/* + Copyright 2024 Friday Author. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package models + +import "errors" + +var ( + ErrNotFound = errors.New("no record") + ErrNameTooLong = errors.New("name too long") + ErrIsExist = errors.New("record existed") + ErrNotEmpty = errors.New("group not empty") + ErrNoGroup = errors.New("not group") + ErrIsGroup = errors.New("this object is a group") + ErrNoAccess = errors.New("no access") + ErrNoPerm = errors.New("no permission") + ErrConflict = errors.New("operation conflict") + ErrUnsupported = errors.New("unsupported operation") + ErrNotEnable = errors.New("not enable") +) diff --git a/pkg/search/config.go b/pkg/search/config.go index 4d983bb..0d83c27 100644 --- a/pkg/search/config.go +++ b/pkg/search/config.go @@ -7,7 +7,7 @@ import ( "github.com/blevesearch/bleve/v2/index/upsidedown" - "github.com/basenana/friday/pkg/store/vectorstore/postgres" + "github.com/basenana/friday/pkg/store/postgres" ) func initConfigFile(dsn string) (string, error) { diff --git a/pkg/search/search.go b/pkg/search/search.go index 77fcca9..32afced 100644 --- a/pkg/search/search.go +++ b/pkg/search/search.go @@ -7,8 +7,7 @@ import ( "github.com/blevesearch/bleve/v2" "github.com/blevesearch/bleve/v2/index/upsidedown" - postgres2 "github.com/basenana/friday/pkg/store/vectorstore/postgres" - "github.com/basenana/friday/pkg/utils/logger" + "github.com/basenana/friday/pkg/store/postgres" ) var singleIndex bleve.Index @@ -20,7 +19,7 @@ func InitSearchEngine() error { return err } - pgCli, err := postgres2.NewPostgresClient(logger.NewLogger("database"), dsn) + pgCli, err := postgres.NewPostgresClient(dsn) if err != nil { return err } @@ -53,7 +52,7 @@ func InitSearchEngine() error { mapping.AddDocumentMapping("document", documentMapping) pgConfig := map[string]interface{}{"dsn": dsn} - index, err := bleve.NewUsing(fpath, mapping, upsidedown.Name, postgres2.PgKVStoreName, pgConfig) + index, err := bleve.NewUsing(fpath, mapping, upsidedown.Name, postgres.PgKVStoreName, pgConfig) if err != nil { return err } diff --git a/pkg/service/chain.go b/pkg/service/chain.go index 1eea671..64d9d9d 100644 --- a/pkg/service/chain.go +++ b/pkg/service/chain.go @@ -25,15 +25,18 @@ import ( "github.com/basenana/friday/config" "github.com/basenana/friday/pkg/dispatch" "github.com/basenana/friday/pkg/dispatch/plugin" + "github.com/basenana/friday/pkg/models" "github.com/basenana/friday/pkg/models/doc" - "github.com/basenana/friday/pkg/store/docstore" + "github.com/basenana/friday/pkg/store" + "github.com/basenana/friday/pkg/store/meili" + "github.com/basenana/friday/pkg/store/postgres" "github.com/basenana/friday/pkg/utils/logger" ) type Chain struct { - MeiliClient docstore.DocStoreInterface - Plugins []plugin.ChainPlugin - Log *zap.SugaredLogger + DocClient store.DocStoreInterface + Plugins []plugin.ChainPlugin + Log *zap.SugaredLogger } var ChainPool *dispatch.Pool @@ -44,28 +47,43 @@ func NewChain(conf config.Config) (*Chain, error) { plugins = append(plugins, plugin.DefaultRegisterer.Get(p)) } log := logger.NewLog("chain") - client, err := docstore.NewMeiliClient(conf) - if err != nil { - log.Errorf("new meili client error: %s", err) - return nil, err + var ( + client store.DocStoreInterface + err error + ) + switch conf.DocStore.Type { + case "meili": + client, err = meili.NewMeiliClient(conf) + if err != nil { + log.Errorf("new meili client error: %s", err) + return nil, err + } + case "postgres": + client, err = postgres.NewPostgresClient(conf.DocStore.PostgresConfig.DSN) + if err != nil { + log.Errorf("new postgres client error: %s", err) + return nil, err + } + default: + return nil, fmt.Errorf("unsupported docstore type: %s", conf.DocStore.Type) } return &Chain{ - MeiliClient: client, - Plugins: plugins, - Log: log, + DocClient: client, + Plugins: plugins, + Log: log, }, nil } -func (c *Chain) Store(ctx context.Context, document *doc.Document) error { - document.Kind = "document" +func (c *Chain) CreateDocument(ctx context.Context, document *doc.Document) error { return ChainPool.Run(ctx, func(ctx context.Context) error { - c.Log.Debugf("store document: %+v", document.String()) - if d, err := c.GetDocument(ctx, document.Namespace, document.EntryId); err != nil { + ctx = c.WithNamespace(ctx, document.Namespace) + c.Log.Debugf("create document od entryId: %d", document.EntryId) + if d, err := c.GetDocument(ctx, document.Namespace, document.EntryId); err != nil && err != models.ErrNotFound { c.Log.Errorf("get document error: %s", err) return err } else if d != nil { - c.Log.Debugf("document already exists: %+v", d.String()) - return fmt.Errorf("document already exists: %+v", d.String()) + c.Log.Debugf("document already exists: %s", d.Name) + return fmt.Errorf("document already exists: %s", d.Name) } for _, plugin := range c.Plugins { err := plugin.Run(ctx, document) @@ -74,195 +92,50 @@ func (c *Chain) Store(ctx context.Context, document *doc.Document) error { return err } } - c.Log.Debugf("store document: %+v", document.String()) - return c.MeiliClient.Store(ctx, document) + c.Log.Debugf("create document: %+v", document.Name) + return c.DocClient.CreateDocument(ctx, document) }) } -func (c *Chain) StoreAttr(ctx context.Context, docAttr *doc.DocumentAttr) error { +func (c *Chain) UpdateDocument(ctx context.Context, document *doc.Document) error { return ChainPool.Run(ctx, func(ctx context.Context) error { - if err := c.MeiliClient.DeleteByFilter(ctx, doc.DocumentAttrQuery{ - AttrQueries: []*doc.AttrQuery{ - { - Attr: "namespace", - Option: "=", - Value: docAttr.Namespace, - }, - { - Attr: "key", - Option: "=", - Value: docAttr.Key, - }, - { - Attr: "entryId", - Option: "=", - Value: docAttr.EntryId, - }, - { - Attr: "kind", - Option: "=", - Value: "attr", - }}, - }); err != nil { - c.Log.Errorf("delete document attr error: %s", err) - return err - } - docAttr.Kind = "attr" - c.Log.Debugf("store attr: %+v", docAttr.String()) - return c.MeiliClient.Store(ctx, docAttr) + ctx = c.WithNamespace(ctx, document.Namespace) + c.Log.Debugf("update document of entryId: %d", document.EntryId) + return c.DocClient.UpdateDocument(ctx, document) }) } -func (c *Chain) GetDocument(ctx context.Context, namespace, entryId string) (*doc.Document, error) { - c.Log.Debugf("get document: namespace=%s, entryId=%s", namespace, entryId) - docs, err := c.MeiliClient.Search(ctx, &doc.DocumentQuery{ - AttrQueries: []*doc.AttrQuery{ - { - Attr: "namespace", - Option: "=", - Value: namespace, - }, - { - Attr: "entryId", - Option: "=", - Value: entryId, - }, - { - Attr: "kind", - Option: "=", - Value: "document", - }, - }, - Search: "", - HitsPerPage: 1, - Page: 1, - }) +func (c *Chain) GetDocument(ctx context.Context, namespace string, entryId int64) (*doc.Document, error) { + c.Log.Debugf("get document: namespace=%s, entryId=%d", namespace, entryId) + ctx = c.WithNamespace(ctx, namespace) + doc, err := c.DocClient.GetDocument(ctx, entryId) if err != nil { c.Log.Errorf("get document error: %s", err) return nil, err } - if len(docs) == 0 { - c.Log.Debugf("document not found: namespace=%s, entryId=%s", namespace, entryId) - return nil, nil - } - c.Log.Debugf("get document: %+v", docs[0].String()) - return docs[0], nil -} - -func (c *Chain) ListDocumentAttrs(ctx context.Context, namespace string, entryIds []string) (doc.DocumentAttrList, error) { - docAttrQuery := &doc.DocumentAttrQuery{ - AttrQueries: []*doc.AttrQuery{ - { - Attr: "namespace", - Option: "=", - Value: namespace, - }, - { - Attr: "entryId", - Option: "IN", - Value: entryIds, - }, - { - Attr: "kind", - Option: "=", - Value: "attr", - }, - }, - } - c.Log.Debugf("list document attrs: %+v", docAttrQuery.String()) - attrs, err := c.MeiliClient.FilterAttr(ctx, docAttrQuery) - if err != nil { - c.Log.Errorf("list document attrs error: %s", err) - return nil, err - } - c.Log.Debugf("list %d document attrs: %s", len(attrs), attrs.String()) - return attrs, nil -} - -func (c *Chain) GetDocumentAttrs(ctx context.Context, namespace, entryId string) ([]*doc.DocumentAttr, error) { - docAttrQuery := &doc.DocumentAttrQuery{ - AttrQueries: []*doc.AttrQuery{ - { - Attr: "namespace", - Option: "=", - Value: namespace, - }, - { - Attr: "entryId", - Option: "=", - Value: entryId, - }, - { - Attr: "kind", - Option: "=", - Value: "attr", - }, - }, - } - c.Log.Debugf("get document attrs: %+v", docAttrQuery.String()) - attrs, err := c.MeiliClient.FilterAttr(ctx, docAttrQuery) - if err != nil { - c.Log.Errorf("get document attrs error: %s", err) - return nil, err - } - c.Log.Debugf("get %d document attrs: %s", len(attrs), attrs.String()) - return attrs, nil + c.Log.Debugf("get document of entryId: %d", entryId) + return doc, nil } -func (c *Chain) Search(ctx context.Context, query *doc.DocumentQuery, attrQueries []*doc.DocumentAttrQuery) ([]*doc.Document, error) { - attrs := doc.DocumentAttrList{} - for _, attrQuery := range attrQueries { - attrQuery.AttrQueries = append(attrQuery.AttrQueries, &doc.AttrQuery{ - Attr: "kind", - Option: "=", - Value: "attr", - }) - c.Log.Debugf("filter attr query: %+v", attrQuery.String()) - attr, err := c.MeiliClient.FilterAttr(ctx, attrQuery) - if err != nil { - return nil, err - } - attrs = append(attrs, attr...) - } - c.Log.Debugf("filter %d attrs: %s", len(attrs), attrs.String()) - ids := []string{} - for _, attr := range attrs { - ids = append(ids, attr.EntryId) - } - if len(ids) == 0 && len(attrQueries) != 0 { - return nil, nil - } - - query.AttrQueries = append(query.AttrQueries, &doc.AttrQuery{ - Attr: "kind", - Option: "=", - Value: "document", - }) - if len(ids) != 0 { - query.AttrQueries = append(query.AttrQueries, &doc.AttrQuery{ - Attr: "entryId", - Option: "IN", - Value: ids, - }) - } - c.Log.Debugf("search document query: %+v", query.String()) - docs, err := c.MeiliClient.Search(ctx, query) - if err != nil { - c.Log.Errorf("search document error: %s", err) - return nil, err - } - c.Log.Debugf("search %d documents: %s", len(docs), docs.String()) - return docs, nil +func (c *Chain) Search(ctx context.Context, filter *doc.DocumentFilter) ([]*doc.Document, error) { + ctx = c.WithNamespace(ctx, filter.Namespace) + c.Log.Debugf("search document: %+v", filter.String()) + return c.DocClient.FilterDocuments(ctx, filter) } -func (c *Chain) DeleteByFilter(ctx context.Context, queries doc.DocumentAttrQuery) error { +func (c *Chain) Delete(ctx context.Context, namespace string, entryId int64) error { + ctx = c.WithNamespace(ctx, namespace) return ChainPool.Run(ctx, func(ctx context.Context) error { - c.Log.Debugf("delete by filter: %+v", queries.String()) - err := c.MeiliClient.DeleteByFilter(ctx, queries) + c.Log.Debugf("delete document of entryId: %d", entryId) + err := c.DocClient.DeleteDocument(ctx, entryId) if err != nil { - c.Log.Errorf("delete by filter error: %s", err) + c.Log.Errorf("delete document of entryId %d error: %s", entryId, err) return err } return nil }) } + +func (c *Chain) WithNamespace(ctx context.Context, namespace string) context.Context { + return models.WithNamespace(ctx, models.NewNamespace(namespace)) +} diff --git a/pkg/service/chain_test.go b/pkg/service/chain_test.go index 7c9c0c0..cce87b5 100644 --- a/pkg/service/chain_test.go +++ b/pkg/service/chain_test.go @@ -18,6 +18,7 @@ package service_test import ( "context" + "time" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -27,152 +28,95 @@ import ( _ "github.com/basenana/friday/pkg/dispatch/plugin" "github.com/basenana/friday/pkg/models/doc" "github.com/basenana/friday/pkg/service" - "github.com/basenana/friday/pkg/store/docstore" + "github.com/basenana/friday/pkg/store/meili" "github.com/basenana/friday/pkg/utils/logger" ) var _ = Describe("Chain", func() { var ( Chain *service.Chain - parentId1 = "1" - parentId2 = "2" - entryId11 = "11" - entryId12 = "12" - entryId21 = "21" + parentId1 = int64(1) + parentId2 = int64(2) + entryId11 = int64(11) + entryId12 = int64(12) + entryId21 = int64(21) doc11 *doc.Document doc12 *doc.Document doc21 *doc.Document - attr11 *doc.DocumentAttr - attr12 *doc.DocumentAttr - attr21 *doc.DocumentAttr + t = time.Date(2024, time.January, 1, 0, 0, 0, 0, time.UTC) ) BeforeEach(func() { service.ChainPool = dispatch.NewPool(10) logger.InitLog() Chain = &service.Chain{ - MeiliClient: &docstore.MockClient{}, - Plugins: []plugin.ChainPlugin{}, - Log: logger.NewLog("test"), + DocClient: &meili.MockClient{}, + Plugins: []plugin.ChainPlugin{}, + Log: logger.NewLog("test"), } for _, p := range plugin.DefaultRegisterer.Chains { Chain.Plugins = append(Chain.Plugins, p) } doc11 = &doc.Document{ - Id: "1", - Namespace: "test-ns", - EntryId: entryId11, - Name: "test-name-11", - Kind: "document", - Content: "

test

", - CreatedAt: 1733584543, - UpdatedAt: 1733584543, + Namespace: "test-ns", + EntryId: entryId11, + ParentEntryID: &parentId1, + Name: "test-name-11", + Content: "

test

", + CreatedAt: t, + ChangedAt: t, } doc12 = &doc.Document{ - Id: "2", - Namespace: "test-ns", - EntryId: entryId12, - Name: "test-name-12", - Kind: "document", - Content: "

test

", - CreatedAt: 1733584543, - UpdatedAt: 1733584543, + Namespace: "test-ns", + EntryId: entryId12, + ParentEntryID: &parentId1, + Name: "test-name-12", + Content: "

test

", + CreatedAt: t, + ChangedAt: t, } doc21 = &doc.Document{ - Id: "3", - Namespace: "test-ns", - EntryId: entryId21, - Name: "test-name-21", - Kind: "document", - Content: "

test

", - CreatedAt: 1733584543, - UpdatedAt: 1733584543, + Namespace: "test-ns", + EntryId: entryId21, + ParentEntryID: &parentId2, + Name: "test-name-21", + Content: "

test

", + CreatedAt: t, + ChangedAt: t, } - attr11 = &doc.DocumentAttr{ - Id: "4", - Namespace: "test-ns", - Kind: "attr", - EntryId: entryId11, - Key: "parentId", - Value: parentId1, - } - attr12 = &doc.DocumentAttr{ - Id: "5", - Namespace: "test-ns", - Kind: "attr", - EntryId: entryId12, - Key: "parentId", - Value: parentId1, - } - attr21 = &doc.DocumentAttr{ - Id: "6", - Namespace: "test-ns", - Kind: "attr", - EntryId: entryId21, - Key: "parentId", - Value: parentId2, - } - err := Chain.Store(context.TODO(), doc11) - Expect(err).Should(BeNil()) - err = Chain.Store(context.TODO(), doc12) - Expect(err).Should(BeNil()) - err = Chain.Store(context.TODO(), doc21) - Expect(err).Should(BeNil()) - err = Chain.StoreAttr(context.TODO(), attr11) + err := Chain.CreateDocument(context.TODO(), doc11) Expect(err).Should(BeNil()) - err = Chain.StoreAttr(context.TODO(), attr12) + err = Chain.CreateDocument(context.TODO(), doc12) Expect(err).Should(BeNil()) - err = Chain.StoreAttr(context.TODO(), attr21) + err = Chain.CreateDocument(context.TODO(), doc21) Expect(err).Should(BeNil()) }) Describe("documents", func() { Context("store document ", func() { It("store document should be successful", func() { - err := Chain.Store(context.TODO(), &doc.Document{Id: "10"}) + err := Chain.CreateDocument(context.TODO(), &doc.Document{EntryId: int64(30)}) Expect(err).Should(BeNil()) }) It("store document attr should be successful", func() { - err := Chain.StoreAttr(context.TODO(), &doc.DocumentAttr{Id: "11"}) + err := Chain.CreateDocument(context.TODO(), &doc.Document{EntryId: int64(31)}) Expect(err).Should(BeNil()) }) }) Context("search document", func() { It("search document should be successful", func() { - docs, err := Chain.Search(context.TODO(), &doc.DocumentQuery{ - AttrQueries: []*doc.AttrQuery{{ - Attr: "namespace", - Option: "=", - Value: "test-ns", - }}, - Search: "test", - }, []*doc.DocumentAttrQuery{}) + docs, err := Chain.Search(context.TODO(), &doc.DocumentFilter{ + Search: "test", + Namespace: "test-ns", + }) Expect(err).Should(BeNil()) Expect(docs).Should(HaveLen(3)) }) It("search document with attr should be successful", func() { - docs, err := Chain.Search(context.TODO(), &doc.DocumentQuery{ - AttrQueries: []*doc.AttrQuery{{ - Attr: "namespace", - Option: "=", - Value: "test-ns", - }}, - Search: "test", - }, []*doc.DocumentAttrQuery{ - { - AttrQueries: []*doc.AttrQuery{ - { - Attr: "parentId", - Option: "=", - Value: parentId1, - }, - { - Attr: "namespace", - Option: "=", - Value: "test-ns", - }, - }, - }, + docs, err := Chain.Search(context.TODO(), &doc.DocumentFilter{ + Search: "test", + Namespace: "test-ns", + ParentID: &parentId1, }) Expect(err).Should(BeNil()) Expect(docs).Should(HaveLen(2)) @@ -181,59 +125,31 @@ var _ = Describe("Chain", func() { Context("plugin should work", func() { It("header plugin should work", func() { doc3 := &doc.Document{ - Id: "100", Namespace: "test-ns", - EntryId: "100", + EntryId: int64(100), Name: "test-name-100", Content: "

test

", - CreatedAt: 1733584543, - UpdatedAt: 1733584543, + CreatedAt: t, + ChangedAt: t, } - err := Chain.Store(context.TODO(), doc3) + err := Chain.CreateDocument(context.TODO(), doc3) Expect(err).Should(BeNil()) Expect(doc3.HeaderImage).Should(Equal("http://abc")) }) }) Context("delete document", func() { It("delete document by filter should be successful", func() { - err := Chain.Store(context.TODO(), &doc.Document{ - Id: "12", - Namespace: "test-ns", - EntryId: "10", - Name: "test-name-10", - }) - Expect(err).Should(BeNil()) - err = Chain.StoreAttr(context.TODO(), &doc.DocumentAttr{ - Id: "13", + err := Chain.CreateDocument(context.TODO(), &doc.Document{ Namespace: "test-ns", - EntryId: "10", + EntryId: int64(40), + Name: "test-name-40", }) Expect(err).Should(BeNil()) - err = Chain.DeleteByFilter(context.TODO(), doc.DocumentAttrQuery{ - AttrQueries: []*doc.AttrQuery{ - { - Attr: "namespace", - Option: "=", - Value: "test-ns", - }, - { - Attr: "entryId", - Option: "=", - Value: "10", - }, - }}, - ) + err = Chain.Delete(context.TODO(), "test-ns", int64(40)) Expect(err).Should(BeNil()) - docs, err := Chain.Search(context.TODO(), &doc.DocumentQuery{ - AttrQueries: []*doc.AttrQuery{{ - Attr: "entryId", - Option: "=", - Value: "10", - }}, - Search: "test", - }, []*doc.DocumentAttrQuery{}) + docs, err := Chain.GetDocument(context.TODO(), "test-ns", int64(40)) Expect(err).Should(BeNil()) - Expect(docs).Should(HaveLen(0)) + Expect(docs).Should(BeNil()) }) }) }) diff --git a/pkg/store/vectorstore/db/entity.go b/pkg/store/db/entity.go similarity index 100% rename from pkg/store/vectorstore/db/entity.go rename to pkg/store/db/entity.go diff --git a/pkg/store/docstore/meili.go b/pkg/store/docstore/meili.go deleted file mode 100644 index c3df4fa..0000000 --- a/pkg/store/docstore/meili.go +++ /dev/null @@ -1,257 +0,0 @@ -/* - Copyright 2024 Friday Author. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package docstore - -import ( - "context" - "encoding/json" - "fmt" - "time" - - "github.com/meilisearch/meilisearch-go" - "go.uber.org/zap" - - "github.com/basenana/friday/config" - "github.com/basenana/friday/pkg/models/doc" - "github.com/basenana/friday/pkg/utils" - "github.com/basenana/friday/pkg/utils/logger" -) - -type MeiliClient struct { - log *zap.SugaredLogger - meiliUrl string - masterKey string - adminApiKey string - searchApiKey string - docIndex meilisearch.IndexManager - attrIndex meilisearch.IndexManager - client meilisearch.ServiceManager -} - -var _ DocStoreInterface = &MeiliClient{} - -func NewMeiliClient(conf config.Config) (DocStoreInterface, error) { - client := meilisearch.New(conf.MeiliConfig.MeiliUrl, meilisearch.WithAPIKey(conf.MeiliConfig.MasterKey)) - docIndex := client.Index(conf.MeiliConfig.DocIndex) - attrIndex := client.Index(conf.MeiliConfig.AttrIndex) - - log := logger.NewLog("meilisearch") - meiliClient := &MeiliClient{ - log: log, - meiliUrl: conf.MeiliConfig.MeiliUrl, - masterKey: conf.MeiliConfig.MasterKey, - adminApiKey: conf.MeiliConfig.AdminApiKey, - searchApiKey: conf.MeiliConfig.SearchApiKey, - docIndex: docIndex, - attrIndex: attrIndex, - client: client, - } - return meiliClient, meiliClient.init() -} - -func (c *MeiliClient) init() error { - attrs, err := c.docIndex.GetFilterableAttributes() - if err != nil { - return err - } - if !utils.Equal(doc.DocFilterableAttrs, attrs) { - t, err := c.docIndex.UpdateFilterableAttributes(&doc.DocFilterableAttrs) - if err != nil { - return err - } - if err = c.wait(context.TODO(), "document", t.TaskUID); err != nil { - return err - } - } - - sortAttrs := doc.DocSortAttrs - crtSortAttrs, err := c.docIndex.GetSortableAttributes() - if err != nil { - return err - } - if !utils.Equal(sortAttrs, crtSortAttrs) { - t, err := c.docIndex.UpdateSortableAttributes(&sortAttrs) - if err != nil { - return err - } - if err = c.wait(context.TODO(), "document", t.TaskUID); err != nil { - return err - } - } - - // attr index - attrAttrs, err := c.attrIndex.GetFilterableAttributes() - if err != nil { - return err - } - if !utils.Equal(doc.DocAttrFilterableAttrs, attrAttrs) { - t, err := c.docIndex.UpdateFilterableAttributes(&doc.DocAttrFilterableAttrs) - if err != nil { - return err - } - if err = c.wait(context.TODO(), "attr", t.TaskUID); err != nil { - return err - } - } - attrSortAttrs := doc.DocAttrSortAttrs - crtAttrSortAttrs, err := c.docIndex.GetSortableAttributes() - if err != nil { - return err - } - if !utils.Equal(attrSortAttrs, crtAttrSortAttrs) { - t, err := c.docIndex.UpdateSortableAttributes(&attrSortAttrs) - if err != nil { - return err - } - if err = c.wait(context.TODO(), "attr", t.TaskUID); err != nil { - return err - } - } - return nil -} - -func (c *MeiliClient) index(kind string) meilisearch.IndexManager { - if kind == "attr" { - return c.attrIndex - } - return c.docIndex -} - -func (c *MeiliClient) Store(ctx context.Context, docPtr doc.DocPtrInterface) error { - c.log.Debugf("store entryId %s %s: %s", docPtr.EntryID(), docPtr.Type(), docPtr.String()) - task, err := c.index(docPtr.Type()).AddDocuments(docPtr, "id") - if err != nil { - c.log.Error(err) - return err - } - if err := c.wait(ctx, docPtr.Type(), task.TaskUID); err != nil { - c.log.Errorf("store document with entryId %s error: %s", docPtr.EntryID(), err) - } - return nil -} - -func (c *MeiliClient) FilterAttr(ctx context.Context, query *doc.DocumentAttrQuery) (doc.DocumentAttrList, error) { - c.log.Debugf("query document attr : [%s]", query.String()) - rep, err := c.index("attr").Search("", query.ToRequest()) - if err != nil { - return nil, err - } - attrs := doc.DocumentAttrList{} - for _, hit := range rep.Hits { - b, _ := json.Marshal(hit) - attr := &doc.DocumentAttr{} - err = json.Unmarshal(b, &attr) - if err != nil { - c.log.Errorf("unmarshal document attr error: %s", err) - continue - } - attrs = append(attrs, attr) - } - return attrs, nil -} - -func (c *MeiliClient) Search(ctx context.Context, query *doc.DocumentQuery) (doc.DocumentList, error) { - c.log.Debugf("search document: [%s] query: [%s]", query.Search, query.String()) - rep, err := c.index("document").Search(query.Search, query.ToRequest()) - if err != nil { - return nil, err - } - documents := doc.DocumentList{} - for _, hit := range rep.Hits { - b, _ := json.Marshal(hit) - document := &doc.Document{} - err = json.Unmarshal(b, &document) - if err != nil { - c.log.Errorf("unmarshal document error: %s", err) - continue - } - documents = append(documents, document) - } - return documents, nil -} - -func (c *MeiliClient) Update(ctx context.Context, document *doc.Document) error { - c.log.Debugf("update document: %s", document.ID()) - t, err := c.index(document.Type()).UpdateDocuments(document) - if err != nil { - c.log.Error(err) - return err - } - if err := c.wait(ctx, document.Type(), t.TaskUID); err != nil { - c.log.Errorf("update document %s error: %s", document.ID, err) - } - return nil -} - -func (c *MeiliClient) Delete(ctx context.Context, docId string) error { - c.log.Debugf("delete document: %s", docId) - t, err := c.index("document").DeleteDocument(docId) - if err != nil { - c.log.Error(err) - return err - } - if err := c.wait(ctx, "document", t.TaskUID); err != nil { - c.log.Errorf("delete document %s error: %s", docId, err) - } - return nil -} - -func (c *MeiliClient) DeleteByFilter(ctx context.Context, aqs doc.DocumentAttrQuery) error { - c.log.Debugf("delete by filter: %+v", aqs.String()) - filter := []interface{}{} - for _, aq := range aqs.AttrQueries { - filter = append(filter, aq.ToFilter()) - } - - t, err := c.index("attr").DeleteDocumentsByFilter(filter) - if err != nil { - c.log.Error(err) - return err - } - if err := c.wait(ctx, "attr", t.TaskUID); err != nil { - c.log.Errorf("delete document by filter error: %s", err) - } - return nil -} - -func (c *MeiliClient) wait(ctx context.Context, kind string, taskUID int64) error { - t := time.NewTicker(100 * time.Millisecond) - defer t.Stop() - for { - select { - case <-ctx.Done(): - return fmt.Errorf("context timeout") - case <-t.C: - t, err := c.index(kind).GetTask(taskUID) - if err != nil { - c.log.Error(err) - return err - } - if t.Status == meilisearch.TaskStatusFailed { - err := fmt.Errorf("task %d failed: %s", taskUID, t.Error) - return err - } - if t.Status == meilisearch.TaskStatusCanceled { - err := fmt.Errorf("task %d canceled: %s", taskUID, t.Error) - return err - } - if t.Status == meilisearch.TaskStatusSucceeded { - return nil - } - } - } -} diff --git a/pkg/store/docstore/mock.go b/pkg/store/docstore/mock.go deleted file mode 100644 index a588a0e..0000000 --- a/pkg/store/docstore/mock.go +++ /dev/null @@ -1,232 +0,0 @@ -/* - Copyright 2024 Friday Author. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package docstore - -import ( - "context" - "reflect" - "strings" - - "github.com/basenana/friday/pkg/models/doc" -) - -type MockClient struct { - docs []*doc.Document - attrs []*doc.DocumentAttr -} - -var _ DocStoreInterface = &MockClient{} - -func (m *MockClient) Store(ctx context.Context, docPtr doc.DocPtrInterface) error { - if docPtr.Type() == "document" { - d := docPtr.(*doc.Document) - if m.docs == nil { - m.docs = []*doc.Document{} - } - m.docs = append(m.docs, d) - } - if docPtr.Type() == "attr" { - d := docPtr.(*doc.DocumentAttr) - if m.attrs == nil { - m.attrs = []*doc.DocumentAttr{} - } - m.attrs = append(m.attrs, d) - } - return nil -} - -func (m *MockClient) FilterAttr(ctx context.Context, query *doc.DocumentAttrQuery) (doc.DocumentAttrList, error) { - aq := query.AttrQueries - result := []*doc.DocumentAttr{} - for _, attr := range m.attrs { - matched := true - all := len(aq) - for _, q := range aq { - if q.Attr == "namespace" { - all -= 1 - if !match(q, attr.Namespace) { - matched = false - continue - } - } - if q.Attr == "entryId" { - all -= 1 - if !match(q, attr.EntryId) { - matched = false - continue - } - } - if attr.Key == q.Attr { - all -= 1 - if !match(q, attr.Value) { - matched = false - } - } - if q.Attr == "kind" { - all -= 1 - if !match(q, attr.Kind) { - matched = false - continue - } - } - } - if matched && all == 0 { - result = append(result, attr) - } - } - return result, nil -} - -func (m *MockClient) Search(ctx context.Context, query *doc.DocumentQuery) (doc.DocumentList, error) { - aq := query.AttrQueries - result := []*doc.Document{} - for _, d := range m.docs { - matched := true - all := len(aq) - for _, q := range aq { - if q.Attr == "entryId" { - all -= 1 - if !match(q, d.EntryId) { - matched = false - continue - } - } - if q.Attr == "namespace" { - all -= 1 - if !match(q, d.Namespace) { - matched = false - continue - } - } - if q.Attr == "id" { - all -= 1 - if !match(q, d.Id) { - matched = false - continue - } - } - if q.Attr == "kind" { - all -= 1 - if !match(q, d.Kind) { - matched = false - continue - } - } - } - if matched && all == 0 && strings.Contains(d.Content, query.Search) { - result = append(result, d) - } - } - return result, nil -} - -func (m *MockClient) DeleteByFilter(ctx context.Context, aqs doc.DocumentAttrQuery) error { - attrs := make(map[string]*doc.DocumentAttr) - for _, attr := range m.attrs { - attrs[attr.Id] = attr - } - docs := make(map[string]*doc.Document) - for _, d := range m.docs { - docs[d.Id] = d - } - for _, d := range m.docs { - matched := true - all := len(aqs.AttrQueries) - for _, q := range aqs.AttrQueries { - if q.Attr == "entryId" { - all -= 1 - if !match(q, d.EntryId) { - matched = false - continue - } - } - if q.Attr == "namespace" { - all -= 1 - if !match(q, d.Namespace) { - matched = false - continue - } - } - if q.Attr == "id" { - all -= 1 - if !match(q, d.Id) { - matched = false - continue - } - } - if q.Attr == "kind" { - all -= 1 - if !match(q, d.Kind) { - matched = false - continue - } - } - } - if matched && all == 0 { - delete(docs, d.Id) - } - } - for _, attr := range m.attrs { - matched := true - all := len(aqs.AttrQueries) - for _, aq := range aqs.AttrQueries { - if attr.Key == aq.Attr { - all -= 1 - if !match(aq, attr.Value) { - matched = false - } - } - if aq.Attr == "kind" { - all -= 1 - if !match(aq, attr.Kind) { - matched = false - continue - } - } - - } - if matched && all == 0 { - delete(attrs, attr.Id) - } - } - var attrsSlice []*doc.DocumentAttr - for _, v := range attrs { - attrsSlice = append(attrsSlice, v) - } - m.attrs = attrsSlice - var docsSlice []*doc.Document - for _, v := range docs { - docsSlice = append(docsSlice, v) - } - m.docs = docsSlice - return nil -} - -func match[T string | interface{}](aq *doc.AttrQuery, t T) bool { - if aq.Option == "=" && reflect.DeepEqual(t, aq.Value.(T)) { - return true - } - if aq.Option == "IN" { - value := aq.Value.([]T) - for _, v := range value { - if reflect.DeepEqual(t, v) { - return true - } - } - } - return false -} diff --git a/pkg/store/interface.go b/pkg/store/interface.go new file mode 100644 index 0000000..138bf7b --- /dev/null +++ b/pkg/store/interface.go @@ -0,0 +1,38 @@ +/* + Copyright 2024 Friday Author. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package store + +import ( + "context" + + "github.com/basenana/friday/pkg/models/doc" + "github.com/basenana/friday/pkg/models/vector" +) + +type VectorStore interface { + Store(ctx context.Context, element *vector.Element, extra map[string]any) error + Search(ctx context.Context, query vector.VectorDocQuery, vectors []float32, k int) ([]*vector.Doc, error) + Get(ctx context.Context, oid int64, name string, group int) (*vector.Element, error) +} + +type DocStoreInterface interface { + CreateDocument(ctx context.Context, doc *doc.Document) error + UpdateDocument(ctx context.Context, doc *doc.Document) error + GetDocument(ctx context.Context, entryId int64) (*doc.Document, error) + FilterDocuments(ctx context.Context, filter *doc.DocumentFilter) ([]*doc.Document, error) + DeleteDocument(ctx context.Context, docId int64) error +} diff --git a/pkg/store/meili/meili.go b/pkg/store/meili/meili.go new file mode 100644 index 0000000..4b05e46 --- /dev/null +++ b/pkg/store/meili/meili.go @@ -0,0 +1,354 @@ +/* + Copyright 2024 Friday Author. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package meili + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/meilisearch/meilisearch-go" + "go.uber.org/zap" + + "github.com/basenana/friday/config" + "github.com/basenana/friday/pkg/models" + "github.com/basenana/friday/pkg/models/doc" + "github.com/basenana/friday/pkg/store" + "github.com/basenana/friday/pkg/utils" + "github.com/basenana/friday/pkg/utils/logger" +) + +type Client struct { + log *zap.SugaredLogger + meiliUrl string + masterKey string + adminApiKey string + searchApiKey string + docIndex meilisearch.IndexManager + attrIndex meilisearch.IndexManager + client meilisearch.ServiceManager +} + +var _ store.DocStoreInterface = &Client{} + +func NewMeiliClient(conf config.Config) (store.DocStoreInterface, error) { + client := meilisearch.New(conf.DocStore.MeiliConfig.MeiliUrl, meilisearch.WithAPIKey(conf.DocStore.MeiliConfig.MasterKey)) + docIndex := client.Index(conf.DocStore.MeiliConfig.DocIndex) + attrIndex := client.Index(conf.DocStore.MeiliConfig.AttrIndex) + + log := logger.NewLog("meilisearch") + meiliClient := &Client{ + log: log, + meiliUrl: conf.DocStore.MeiliConfig.MeiliUrl, + masterKey: conf.DocStore.MeiliConfig.MasterKey, + adminApiKey: conf.DocStore.MeiliConfig.AdminApiKey, + searchApiKey: conf.DocStore.MeiliConfig.SearchApiKey, + docIndex: docIndex, + attrIndex: attrIndex, + client: client, + } + return meiliClient, meiliClient.init() +} + +func (c *Client) init() error { + attrs, err := c.docIndex.GetFilterableAttributes() + if err != nil { + return err + } + if !utils.Equal(DocFilterableAttrs, attrs) { + t, err := c.docIndex.UpdateFilterableAttributes(&DocFilterableAttrs) + if err != nil { + return err + } + if err = c.wait(context.TODO(), "document", t.TaskUID); err != nil { + return err + } + } + + sortAttrs := DocSortAttrs + crtSortAttrs, err := c.docIndex.GetSortableAttributes() + if err != nil { + return err + } + if !utils.Equal(sortAttrs, crtSortAttrs) { + t, err := c.docIndex.UpdateSortableAttributes(&sortAttrs) + if err != nil { + return err + } + if err = c.wait(context.TODO(), "document", t.TaskUID); err != nil { + return err + } + } + + // attr index + attrAttrs, err := c.attrIndex.GetFilterableAttributes() + if err != nil { + return err + } + if !utils.Equal(DocAttrFilterableAttrs, attrAttrs) { + t, err := c.docIndex.UpdateFilterableAttributes(&DocAttrFilterableAttrs) + if err != nil { + return err + } + if err = c.wait(context.TODO(), "attr", t.TaskUID); err != nil { + return err + } + } + attrSortAttrs := DocAttrSortAttrs + crtAttrSortAttrs, err := c.docIndex.GetSortableAttributes() + if err != nil { + return err + } + if !utils.Equal(attrSortAttrs, crtAttrSortAttrs) { + t, err := c.docIndex.UpdateSortableAttributes(&attrSortAttrs) + if err != nil { + return err + } + if err = c.wait(context.TODO(), "attr", t.TaskUID); err != nil { + return err + } + } + return nil +} + +func (c *Client) index(kind string) meilisearch.IndexManager { + if kind == "attr" { + return c.attrIndex + } + return c.docIndex +} + +func (c *Client) CreateDocument(ctx context.Context, doc *doc.Document) error { + newDoc := (&Document{}).FromModel(doc) + c.log.Debugf("store entryId %s", newDoc.EntryId) + task, err := c.index(newDoc.Kind).AddDocuments(newDoc, "id") + if err != nil { + c.log.Error(err) + return err + } + if err := c.wait(ctx, newDoc.Kind, task.TaskUID); err != nil { + c.log.Errorf("store document with entryId %s error: %s", newDoc.EntryId, err) + return err + } + + // store document attr + newAttrs := (&DocumentAttrList{}).FromModel(doc) + c.log.Debugf("store doc of entryId %d attrs: %s", doc.EntryId, newAttrs.String()) + t, err := c.index("attr").AddDocuments(newAttrs, "id") + if err != nil { + c.log.Error(err) + return err + } + if err := c.wait(ctx, "attr", t.TaskUID); err != nil { + c.log.Errorf("store document attr of entryId %d error: %s", doc.EntryId, err) + return err + } + return nil +} + +func (c *Client) UpdateDocument(ctx context.Context, doc *doc.Document) error { + // delete document attr + newAttrsQuery := (&DocumentAttrQuery{}).FromModel(doc) + c.log.Debugf("delete document attrs: %s", newAttrsQuery.String()) + + filter := []interface{}{} + for _, aq := range newAttrsQuery.AttrQueries { + filter = append(filter, aq.ToFilter()) + } + t, err := c.index("attr").DeleteDocumentsByFilter(filter) + if err != nil { + c.log.Error(err) + return err + } + if err = c.wait(ctx, "attr", t.TaskUID); err != nil { + c.log.Errorf("delete document by filter error: %s", err) + return err + } + // store document attr + newAttrs := (&DocumentAttrList{}).FromModel(doc) + c.log.Debugf("store doc of entryId %d attrs: %s", doc.EntryId, newAttrs.String()) + t, err = c.index("attr").AddDocuments(newAttrs, "id") + if err != nil { + c.log.Error(err) + return err + } + if err := c.wait(ctx, "attr", t.TaskUID); err != nil { + c.log.Errorf("store document attr of entryId %d error: %s", doc.EntryId, err) + return err + } + return nil +} + +func (c *Client) GetDocument(ctx context.Context, entryId int64) (*doc.Document, error) { + namespace := models.GetNamespace(ctx) + query := (&DocumentQuery{}).OfEntryId(namespace.String(), entryId) + c.log.Debugf("get document by entryId: %d", entryId) + rep, err := c.index("document").Search("", query.ToRequest()) + if err != nil { + return nil, err + } + if len(rep.Hits) == 0 { + return nil, nil + } + b, _ := json.Marshal(rep.Hits[0]) + document := &Document{} + err = json.Unmarshal(b, &document) + if err != nil { + return nil, err + } + + // get attrs + attrQuery := (&DocumentAttrQuery{}).OfEntryId(document.Namespace, document.EntryId) + c.log.Debugf("filter document attr: %s", attrQuery.String()) + attrRep, err := c.index("attr").Search("", attrQuery.ToRequest()) + if err != nil { + return nil, err + } + + attrs := make([]*DocumentAttr, 0) + for _, hit := range attrRep.Hits { + b, _ := json.Marshal(hit) + attr := &DocumentAttr{} + err = json.Unmarshal(b, &attr) + if err != nil { + c.log.Errorf("unmarshal document attr error: %s", err) + continue + } + attrs = append(attrs, attr) + } + return document.ToModel(attrs), nil +} + +func (c *Client) FilterDocuments(ctx context.Context, filter *doc.DocumentFilter) ([]*doc.Document, error) { + query := (&DocumentQuery{}).FromModel(filter) + if filter.ParentID != nil || filter.Unread != nil || filter.Marked != nil { + entryIds := make([]string, 0) + attrQuery := (&DocumentAttrQueries{}).FromFilter(filter) + for _, aq := range *attrQuery { + c.log.Debugf("filter document attr: %s", aq.String()) + attrRep, err := c.index("attr").Search("", aq.ToRequest()) + if err != nil { + return nil, err + } + + for _, hit := range attrRep.Hits { + b, _ := json.Marshal(hit) + attr := &DocumentAttr{} + err = json.Unmarshal(b, &attr) + if err != nil { + c.log.Errorf("unmarshal document attr error: %s", err) + continue + } + entryIds = append(entryIds, attr.EntryId) + } + } + if len(entryIds) != 0 { + query.AttrQueries = append(query.AttrQueries, &AttrQuery{ + Attr: "entryId", + Option: "IN", + Value: entryIds, + }) + } + } + + c.log.Debugf("search document: [%s] query: [%s]", query.Search, query.String()) + rep, err := c.index("document").Search(query.Search, query.ToRequest()) + if err != nil { + return nil, err + } + c.log.Debugf("query document attr : [%s]", query.String()) + + documents := make([]*doc.Document, 0) + for _, hit := range rep.Hits { + b, _ := json.Marshal(hit) + document := &Document{} + err = json.Unmarshal(b, &document) + if err != nil { + c.log.Errorf("unmarshal document error: %s", err) + continue + } + + // get attrs + attrQuery := (&DocumentAttrQuery{}).OfEntryId(document.Namespace, document.EntryId) + c.log.Debugf("filter document attr: %s", attrQuery.String()) + attrRep, err := c.index("attr").Search("", attrQuery.ToRequest()) + if err != nil { + return nil, err + } + + attrs := make([]*DocumentAttr, 0) + for _, hit := range attrRep.Hits { + b, _ := json.Marshal(hit) + attr := &DocumentAttr{} + err = json.Unmarshal(b, &attr) + if err != nil { + c.log.Errorf("unmarshal document attr error: %s", err) + continue + } + attrs = append(attrs, attr) + } + documents = append(documents, document.ToModel(attrs)) + } + return documents, nil +} + +func (c *Client) DeleteDocument(ctx context.Context, entryId int64) error { + c.log.Debugf("delete document by entryId: %d", entryId) + aq := &AttrQuery{ + Attr: "entryId", + Option: "=", + Value: fmt.Sprintf("%d", entryId), + } + t, err := c.index("attr").DeleteDocumentsByFilter(aq.ToFilter()) + if err != nil { + c.log.Error(err) + return err + } + if err := c.wait(ctx, "attr", t.TaskUID); err != nil { + c.log.Errorf("delete document by filter error: %s", err) + } + return nil +} + +func (c *Client) wait(ctx context.Context, kind string, taskUID int64) error { + t := time.NewTicker(100 * time.Millisecond) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return fmt.Errorf("context timeout") + case <-t.C: + t, err := c.index(kind).GetTask(taskUID) + if err != nil { + c.log.Error(err) + return err + } + if t.Status == meilisearch.TaskStatusFailed { + err := fmt.Errorf("task %d failed: %s", taskUID, t.Error) + return err + } + if t.Status == meilisearch.TaskStatusCanceled { + err := fmt.Errorf("task %d canceled: %s", taskUID, t.Error) + return err + } + if t.Status == meilisearch.TaskStatusSucceeded { + return nil + } + } + } +} diff --git a/pkg/store/meili/mock.go b/pkg/store/meili/mock.go new file mode 100644 index 0000000..dd5dce4 --- /dev/null +++ b/pkg/store/meili/mock.go @@ -0,0 +1,236 @@ +/* + Copyright 2024 Friday Author. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package meili + +import ( + "context" + "fmt" + "reflect" + "strings" + + "github.com/basenana/friday/pkg/models/doc" + "github.com/basenana/friday/pkg/store" +) + +type MockClient struct { + docs []*Document + attrs []*DocumentAttr +} + +var _ store.DocStoreInterface = &MockClient{} + +func (m *MockClient) CreateDocument(ctx context.Context, doc *doc.Document) error { + m.docs = append(m.docs, (&Document{}).FromModel(doc)) + newAttrs := (&DocumentAttrList{}).FromModel(doc) + m.attrs = append(m.attrs, *newAttrs...) + return nil +} + +func (m *MockClient) UpdateDocument(ctx context.Context, doc *doc.Document) error { + aq := (&DocumentAttrQuery{}).FromModel(doc) + result := []*DocumentAttr{} + for _, attr := range m.attrs { + matched := true + all := len(aq.AttrQueries) + + for _, q := range aq.AttrQueries { + if q.Attr == "namespace" { + all -= 1 + if !match(q, attr.Namespace) { + matched = false + continue + } + } + if q.Attr == "entryId" { + all -= 1 + if !match(q, attr.EntryId) { + matched = false + continue + } + } + if attr.Key == q.Attr { + all -= 1 + if !match(q, attr.Value) { + matched = false + } + } + if q.Attr == "kind" { + all -= 1 + if !match(q, attr.Kind) { + matched = false + continue + } + } + } + if matched && all == 0 { + result = append(result, attr) + } + } + return nil +} + +func (m *MockClient) GetDocument(ctx context.Context, entryId int64) (*doc.Document, error) { + var res *Document + for _, d := range m.docs { + if d.EntryId == fmt.Sprintf("%d", entryId) { + res = d + break + } + } + attrs := make([]*DocumentAttr, 0) + for _, attr := range m.attrs { + if attr.EntryId == fmt.Sprintf("%d", entryId) { + attrs = append(attrs, attr) + } + } + if res != nil { + return res.ToModel(attrs), nil + } + return nil, nil +} + +func (m *MockClient) FilterDocuments(ctx context.Context, filter *doc.DocumentFilter) ([]*doc.Document, error) { + query := (&DocumentQuery{}).FromModel(filter) + if filter.ParentID != nil || filter.Unread != nil || filter.Marked != nil { + attrQuery := (&DocumentAttrQueries{}).FromFilter(filter) + entryId := make([]string, 0) + for _, aq := range *attrQuery { + for _, attr := range m.attrs { + all := len(aq.AttrQueries) + matched := true + for _, q := range aq.AttrQueries { + if q.Attr == "entryId" { + all -= 1 + if !match(q, attr.EntryId) { + matched = false + continue + } + } + if q.Attr == "namespace" { + all -= 1 + if !match(q, attr.Namespace) { + matched = false + continue + } + } + if q.Attr == "key" { + all -= 1 + if !match(q, attr.Key) { + matched = false + continue + } + } + if q.Attr == "value" { + all -= 1 + if !match(q, attr.Value) { + matched = false + continue + } + } + if q.Attr == "kind" { + all -= 1 + if !match(q, attr.Kind) { + matched = false + continue + } + } + } + if matched && all == 0 { + entryId = append(entryId, attr.EntryId) + } + } + } + if len(entryId) != 0 { + query.AttrQueries = append(query.AttrQueries, &AttrQuery{ + Attr: "entryId", + Option: "IN", + Value: entryId, + }) + } + } + + result := []*doc.Document{} + for _, d := range m.docs { + matched := true + all := len(query.AttrQueries) + for _, q := range query.AttrQueries { + if q.Attr == "entryId" { + all -= 1 + if !match(q, d.EntryId) { + matched = false + continue + } + } + if q.Attr == "namespace" { + all -= 1 + if !match(q, d.Namespace) { + matched = false + continue + } + } + if q.Attr == "id" { + all -= 1 + if !match(q, d.Id) { + matched = false + continue + } + } + if q.Attr == "kind" { + all -= 1 + if !match(q, d.Kind) { + matched = false + continue + } + } + } + if matched && all == 0 && strings.Contains(d.Content, query.Search) { + result = append(result, d.ToModel(nil)) + } + } + return result, nil +} + +func (m *MockClient) DeleteDocument(ctx context.Context, docId int64) error { + for i, d := range m.docs { + if d.EntryId == fmt.Sprintf("%d", docId) { + m.docs = append(m.docs[:i], m.docs[i+1:]...) + break + } + } + for i, attr := range m.attrs { + if attr.EntryId == fmt.Sprintf("%d", docId) { + m.attrs = append(m.attrs[:i], m.attrs[i+1:]...) + break + } + } + return nil +} + +func match[T string | interface{}](aq *AttrQuery, t T) bool { + if aq.Option == "=" && reflect.DeepEqual(t, aq.Value.(T)) { + return true + } + if aq.Option == "IN" { + value := aq.Value.([]T) + for _, v := range value { + if reflect.DeepEqual(t, v) { + return true + } + } + } + return false +} diff --git a/pkg/store/meili/model.go b/pkg/store/meili/model.go new file mode 100644 index 0000000..606120a --- /dev/null +++ b/pkg/store/meili/model.go @@ -0,0 +1,552 @@ +/* + Copyright 2024 Friday Author. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package meili + +import ( + "encoding/json" + "fmt" + "strconv" + "time" + + "github.com/google/uuid" + "github.com/meilisearch/meilisearch-go" + + "github.com/basenana/friday/pkg/models/doc" +) + +var ( + DocFilterableAttrs = []string{"namespace", "id", "entryId", "kind", "name", "source", "webUrl", "createdAt", "updatedAt"} + DocSortAttrs = []string{"createdAt", "updatedAt", "name"} + DocAttrFilterableAttrs = []string{"namespace", "entryId", "key", "id", "kind", "value"} + DocAttrSortAttrs = []string{"createdAt", "updatedAt"} +) + +type DocPtrInterface interface { + ID() string + EntryID() string + Type() string + String() string +} + +type Document struct { + Id string `json:"id"` + Kind string `json:"kind"` + Namespace string `json:"namespace"` + EntryId string `json:"entryId"` + Name string `json:"name"` + Source string `json:"source,omitempty"` + WebUrl string `json:"webUrl,omitempty"` + + Content string `json:"content"` + Summary string `json:"summary,omitempty"` + HeaderImage string `json:"headerImage,omitempty"` + SubContent string `json:"subContent,omitempty"` + + CreatedAt int64 `json:"createdAt,omitempty"` + UpdatedAt int64 `json:"updatedAt,omitempty"` +} + +func (d *Document) ID() string { + return d.Id +} + +func (d *Document) EntryID() string { + return d.EntryId +} + +func (d *Document) Type() string { + return "document" +} + +func (d *Document) String() string { + return fmt.Sprintf("EntryId(%s) %s", d.EntryId, d.Name) +} + +func (d *Document) FromModel(doc *doc.Document) *Document { + d.Id = uuid.New().String() + d.Kind = "document" + d.Namespace = doc.Namespace + d.EntryId = fmt.Sprintf("%d", doc.EntryId) + d.Name = doc.Name + d.Source = doc.Source + d.WebUrl = doc.WebUrl + d.Content = doc.Content + d.Summary = doc.Summary + d.HeaderImage = doc.HeaderImage + d.SubContent = doc.SubContent + d.CreatedAt = doc.CreatedAt.Unix() + d.UpdatedAt = doc.ChangedAt.Unix() + return d +} + +func (d *Document) ToModel(attrs []*DocumentAttr) *doc.Document { + entryId, _ := strconv.Atoi(d.EntryId) + m := &doc.Document{ + EntryId: int64(entryId), + Name: d.Name, + Namespace: d.Namespace, + ParentEntryID: nil, + Source: d.Source, + Content: d.Content, + Summary: d.Summary, + WebUrl: d.WebUrl, + HeaderImage: d.HeaderImage, + SubContent: d.SubContent, + Marked: nil, + Unread: nil, + CreatedAt: time.Unix(d.CreatedAt, 0), + ChangedAt: time.Unix(d.UpdatedAt, 0), + } + + for _, attr := range attrs { + switch attr.Key { + case "parentId": + parentID, _ := strconv.Atoi(attr.Value.(string)) + pId := int64(parentID) + m.ParentEntryID = &pId + case "mark": + m.Marked = attr.Value.(*bool) + case "unRead": + m.Unread = attr.Value.(*bool) + } + } + return m +} + +type DocumentList []*Document + +func (d DocumentList) String() string { + result := "" + for _, doc := range d { + result += fmt.Sprintf("EntryId(%s) %s\n", doc.EntryId, doc.Name) + } + return result +} + +var _ DocPtrInterface = &Document{} + +type DocumentAttr struct { + Id string `json:"id"` + Kind string `json:"kind"` + Namespace string `json:"namespace"` + EntryId string `json:"entryId"` + Key string `json:"key"` + Value interface{} `json:"value"` +} + +var _ DocPtrInterface = &DocumentAttr{} + +func (d *DocumentAttr) ID() string { + return d.Id +} + +func (d *DocumentAttr) EntryID() string { + return d.EntryId +} + +func (d *DocumentAttr) Type() string { + return "attr" +} + +func (d *DocumentAttr) String() string { + return fmt.Sprintf("EntryId(%s) %s: %v", d.EntryId, d.Key, d.Value) +} + +type DocumentAttrList []*DocumentAttr + +func (d *DocumentAttrList) String() string { + result := "" + for _, attr := range *d { + result += fmt.Sprintf("EntryId(%s) %s: %v\n", attr.EntryId, attr.Key, attr.Value) + } + return result +} + +func (d *DocumentAttrList) FromModel(doc *doc.Document) *DocumentAttrList { + attrs := make([]*DocumentAttr, 0) + if doc.ParentEntryID != nil { + attrs = append(attrs, &DocumentAttr{ + Id: uuid.New().String(), + Kind: "attr", + Namespace: doc.Namespace, + EntryId: fmt.Sprintf("%d", doc.EntryId), + Key: "parentId", + Value: doc.ParentEntryID, + }) + } + if doc.Marked != nil { + attrs = append(attrs, &DocumentAttr{ + Id: uuid.New().String(), + Kind: "attr", + Namespace: doc.Namespace, + EntryId: fmt.Sprintf("%d", doc.EntryId), + Key: "mark", + Value: doc.Marked, + }) + } + if doc.Unread != nil { + attrs = append(attrs, &DocumentAttr{ + Id: uuid.New().String(), + Kind: "attr", + Namespace: doc.Namespace, + EntryId: fmt.Sprintf("%d", doc.EntryId), + Key: "unRead", + Value: doc.Unread, + }) + } + return (*DocumentAttrList)(&attrs) +} + +type DocumentQuery struct { + AttrQueries []*AttrQuery + + Search string + HitsPerPage int64 + Page int64 + Offset int64 + Limit int64 + Sort []Sort +} + +func (q *DocumentQuery) OfEntryId(namespace string, entryId int64) *DocumentQuery { + return &DocumentQuery{ + AttrQueries: []*AttrQuery{ + { + Attr: "namespace", + Option: "=", + Value: namespace, + }, + { + Attr: "entryId", + Option: "=", + Value: fmt.Sprintf("%d", entryId), + }, + { + Attr: "kind", + Option: "=", + Value: "document", + }, + }, + Search: "", + HitsPerPage: 1, + Page: 1, + } +} + +func (q *DocumentQuery) FromModel(query *doc.DocumentFilter) *DocumentQuery { + q.Search = query.Search + q.HitsPerPage = query.PageSize + q.Page = query.Page + q.Sort = []Sort{} + if query.Order.Order == doc.Name { + q.Sort = append(q.Sort, Sort{ + Attr: "name", + Asc: !query.Order.Desc, + }) + } + if query.Order.Order == doc.CreatedAt { + q.Sort = append(q.Sort, Sort{ + Attr: "createdAt", + Asc: !query.Order.Desc, + }) + } + q.AttrQueries = []*AttrQuery{{ + Attr: "kind", + Option: "=", + Value: "document", + }} + if query.Namespace != "" { + q.AttrQueries = append(q.AttrQueries, &AttrQuery{ + Attr: "namespace", + Option: "=", + Value: query.Namespace, + }) + } + if query.FuzzyName != "" { + q.AttrQueries = append(q.AttrQueries, &AttrQuery{ + Attr: "name", + Option: "CONTAINS", + Value: query.FuzzyName, + }) + } + if query.Source != "" { + q.AttrQueries = append(q.AttrQueries, &AttrQuery{ + Attr: "source", + Option: "=", + Value: query.Source, + }) + } + if query.CreatedAtStart != nil { + q.AttrQueries = append(q.AttrQueries, &AttrQuery{ + Attr: "createdAt", + Option: ">=", + Value: query.CreatedAtStart.Unix(), + }) + } + if query.CreatedAtEnd != nil { + q.AttrQueries = append(q.AttrQueries, &AttrQuery{ + Attr: "createdAt", + Option: "<=", + Value: query.CreatedAtEnd.Unix(), + }) + } + if query.ChangedAtStart != nil { + q.AttrQueries = append(q.AttrQueries, &AttrQuery{ + Attr: "updatedAt", + Option: ">=", + Value: query.ChangedAtStart.Unix(), + }) + } + if query.ChangedAtEnd != nil { + q.AttrQueries = append(q.AttrQueries, &AttrQuery{ + Attr: "updatedAt", + Option: "<=", + Value: query.ChangedAtEnd.Unix(), + }) + } + return q +} + +type Sort struct { + Attr string + Asc bool +} + +func (s *Sort) String() string { + if s.Asc { + return fmt.Sprintf("%s:asc", s.Attr) + } + return fmt.Sprintf("%s:desc", s.Attr) +} + +type DocumentAttrQuery struct { + AttrQueries []*AttrQuery +} + +func (q *DocumentAttrQuery) String() string { + result := "" + for _, aq := range q.AttrQueries { + result += aq.String() + " " + } + return result +} + +func (q *DocumentAttrQuery) FromModel(doc *doc.Document) *DocumentAttrQuery { + q.AttrQueries = []*AttrQuery{{ + Attr: "kind", + Option: "=", + Value: "attr", + }, { + Attr: "namespace", + Option: "=", + Value: doc.Namespace, + }} + if doc.ParentEntryID != nil { + q.AttrQueries = append(q.AttrQueries, &AttrQuery{ + Attr: "key", + Option: "=", + Value: "parentId", + }) + } + if doc.Marked != nil { + q.AttrQueries = append(q.AttrQueries, &AttrQuery{ + Attr: "key", + Option: "=", + Value: "mark", + }) + } + if doc.Unread != nil { + q.AttrQueries = append(q.AttrQueries, &AttrQuery{ + Attr: "key", + Option: "=", + Value: "unRead", + }) + } + return q +} + +func (q *DocumentAttrQuery) OfEntryId(namespace, entryId string) *DocumentAttrQuery { + q.AttrQueries = []*AttrQuery{ + { + Attr: "namespace", + Option: "=", + Value: namespace, + }, + { + Attr: "entryId", + Option: "=", + Value: entryId, + }, + { + Attr: "kind", + Option: "=", + Value: "attr", + }, + } + return q +} + +func (q *DocumentAttrQuery) ToRequest() *meilisearch.SearchRequest { + filter := []interface{}{} + for _, aq := range q.AttrQueries { + filter = append(filter, aq.ToFilter()) + } + return &meilisearch.SearchRequest{ + Filter: filter, + Limit: 10000, + HitsPerPage: 10000, + Query: "", + } +} + +type DocumentAttrQueries []*DocumentAttrQuery + +func (q *DocumentAttrQueries) FromFilter(query *doc.DocumentFilter) *DocumentAttrQueries { + attrQueries := make([]*DocumentAttrQuery, 0) + if query.ParentID != nil { + attrQueries = append(attrQueries, &DocumentAttrQuery{ + AttrQueries: []*AttrQuery{ + { + Attr: "namespace", + Option: "=", + Value: query.Namespace, + }, + { + Attr: "kind", + Option: "=", + Value: "attr", + }, + { + Attr: "key", + Option: "=", + Value: "parentId", + }, + { + Attr: "value", + Option: "=", + Value: query.ParentID, + }, + }, + }) + } + if query.Marked != nil { + attrQueries = append(attrQueries, &DocumentAttrQuery{ + AttrQueries: []*AttrQuery{ + { + Attr: "namespace", + Option: "=", + Value: query.Namespace, + }, + { + Attr: "kind", + Option: "=", + Value: "attr", + }, + { + Attr: "key", + Option: "=", + Value: "mark", + }, + { + Attr: "value", + Option: "=", + Value: query.Marked, + }, + }, + }) + } + if query.Unread != nil { + attrQueries = append(attrQueries, &DocumentAttrQuery{ + AttrQueries: []*AttrQuery{ + { + Attr: "namespace", + Option: "=", + Value: query.Namespace, + }, + { + Attr: "kind", + Option: "=", + Value: "attr", + }, + { + Attr: "key", + Option: "=", + Value: "unRead", + }, + { + Attr: "value", + Option: "=", + Value: query.Unread, + }, + }, + }) + } + return (*DocumentAttrQueries)(&attrQueries) +} + +func (q *DocumentAttrQueries) String() string { + result := "" + for _, attrQuery := range *q { + result += attrQuery.String() + " " + } + return result +} + +type AttrQuery struct { + Attr string + Option string + Value interface{} +} + +func (aq *AttrQuery) ToFilter() interface{} { + vs, _ := json.Marshal(aq.Value) + return fmt.Sprintf("%s %s %s", aq.Attr, aq.Option, vs) +} + +func (aq *AttrQuery) String() string { + return aq.ToFilter().(string) +} + +func (q *DocumentQuery) String() string { + filters := "" + for _, aq := range q.AttrQueries { + filters += aq.String() + " " + } + return fmt.Sprintf("search: [%s], attr query: [%s]", q.Search, filters) +} + +func (q *DocumentQuery) ToRequest() *meilisearch.SearchRequest { + // build filter + filter := []interface{}{} + for _, aq := range q.AttrQueries { + filter = append(filter, aq.ToFilter()) + } + sorts := []string{} + for _, s := range q.Sort { + sorts = append(sorts, s.String()) + } + + return &meilisearch.SearchRequest{ + Offset: q.Offset, + Limit: q.Limit, + Sort: sorts, + HitsPerPage: q.HitsPerPage, + Page: q.Page, + Query: q.Search, + Filter: filter, + } +} diff --git a/pkg/store/vectorstore/pgvector/migrate.go b/pkg/store/pgvector/migrate.go similarity index 100% rename from pkg/store/vectorstore/pgvector/migrate.go rename to pkg/store/pgvector/migrate.go diff --git a/pkg/store/vectorstore/pgvector/model.go b/pkg/store/pgvector/model.go similarity index 100% rename from pkg/store/vectorstore/pgvector/model.go rename to pkg/store/pgvector/model.go diff --git a/pkg/store/vectorstore/pgvector/pgvector.go b/pkg/store/pgvector/pgvector.go similarity index 95% rename from pkg/store/vectorstore/pgvector/pgvector.go rename to pkg/store/pgvector/pgvector.go index 420d35c..ef84a12 100644 --- a/pkg/store/vectorstore/pgvector/pgvector.go +++ b/pkg/store/pgvector/pgvector.go @@ -27,8 +27,9 @@ import ( "gorm.io/gorm" "github.com/basenana/friday/pkg/models/vector" - "github.com/basenana/friday/pkg/store/vectorstore" - "github.com/basenana/friday/pkg/store/vectorstore/db" + "github.com/basenana/friday/pkg/store" + "github.com/basenana/friday/pkg/store/db" + "github.com/basenana/friday/pkg/store/utils" "github.com/basenana/friday/pkg/utils/logger" ) @@ -37,13 +38,13 @@ type PgVectorClient struct { dEntity *db.Entity } -var _ vectorstore.VectorStore = &PgVectorClient{} +var _ store.VectorStore = &PgVectorClient{} func NewPgVectorClient(log logger.Logger, postgresUrl string) (*PgVectorClient, error) { if log == nil { log = logger.NewLogger("database") } - dbObj, err := gorm.Open(postgres.Open(postgresUrl), &gorm.Config{Logger: logger.NewDbLogger(log)}) + dbObj, err := gorm.Open(postgres.Open(postgresUrl), &gorm.Config{Logger: utils.NewDbLogger()}) if err != nil { panic(err) } diff --git a/pkg/store/vectorstore/postgres/bleve.go b/pkg/store/postgres/bleve.go similarity index 97% rename from pkg/store/vectorstore/postgres/bleve.go rename to pkg/store/postgres/bleve.go index cf4f059..2e6b576 100644 --- a/pkg/store/vectorstore/postgres/bleve.go +++ b/pkg/store/postgres/bleve.go @@ -9,8 +9,6 @@ import ( "github.com/blevesearch/bleve/v2/registry" "github.com/blevesearch/upsidedown_store_api" "gorm.io/gorm" - - "github.com/basenana/friday/pkg/utils/logger" ) const ( @@ -26,7 +24,7 @@ func pgKVStoreConstructor(mo store.MergeOperator, config map[string]interface{}) if !ok { return nil, fmt.Errorf("dsn not found") } - pgCli, err := NewPostgresClient(logger.NewLogger("bleve"), dsnStr.(string)) + pgCli, err := NewPostgresClient(dsnStr.(string)) if err != nil { return nil, err } diff --git a/pkg/store/postgres/document.go b/pkg/store/postgres/document.go new file mode 100644 index 0000000..92d1b76 --- /dev/null +++ b/pkg/store/postgres/document.go @@ -0,0 +1,179 @@ +/* + Copyright 2024 Friday Author. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package postgres + +import ( + "context" + "errors" + "runtime/trace" + + "gorm.io/gorm" + + "github.com/basenana/friday/pkg/models" + "github.com/basenana/friday/pkg/models/doc" + "github.com/basenana/friday/pkg/store" + "github.com/basenana/friday/pkg/store/utils" +) + +var _ store.DocStoreInterface = &PostgresClient{} + +func (p *PostgresClient) CreateDocument(ctx context.Context, doc *doc.Document) error { + defer trace.StartRegion(ctx, "metastore.sql.SaveDocument").End() + err := p.dEntity.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + docMod := &Document{} + res := tx.Where("oid = ?", doc.EntryId).First(docMod) + if res.Error != nil { + if errors.Is(res.Error, gorm.ErrRecordNotFound) { + docMod = docMod.From(doc) + res = tx.Create(docMod) + return res.Error + } + return res.Error + } + docMod = docMod.From(doc) + res = tx.Save(docMod) + return res.Error + }) + if err != nil { + return utils.SqlError2Error(err) + } + return nil +} + +func (p *PostgresClient) UpdateDocument(ctx context.Context, doc *doc.Document) error { + defer trace.StartRegion(ctx, "metastore.sql.SaveDocument").End() + err := p.dEntity.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + docMod := &Document{} + res := tx.Where("oid = ?", doc.EntryId).First(docMod) + if res.Error != nil { + if errors.Is(res.Error, gorm.ErrRecordNotFound) { + docMod = docMod.From(doc) + res = tx.Create(docMod) + return res.Error + } + return res.Error + } + docMod = docMod.UpdateFrom(doc) + res = tx.Save(docMod) + return res.Error + }) + if err != nil { + return utils.SqlError2Error(err) + } + return nil +} + +func (p *PostgresClient) GetDocument(ctx context.Context, entryId int64) (*doc.Document, error) { + defer trace.StartRegion(ctx, "metastore.sql.GetDocument").End() + doc := &Document{} + res := p.dEntity.WithNamespace(ctx).Where("oid = ?", entryId).First(doc) + if res.Error != nil { + return nil, utils.SqlError2Error(res.Error) + } + return doc.To(), nil +} + +func (p *PostgresClient) FilterDocuments(ctx context.Context, filter *doc.DocumentFilter) ([]*doc.Document, error) { + defer trace.StartRegion(ctx, "metastore.sql.ListDocument").End() + docList := make([]Document, 0) + q := p.WithNamespace(ctx) + if page := models.GetPagination(ctx); page != nil { + q = q.Offset(page.Offset()).Limit(page.Limit()) + } + res := docOrder(docQueryFilter(q, filter), &filter.Order).Find(&docList) + if res.Error != nil { + return nil, utils.SqlError2Error(res.Error) + } + + result := make([]*doc.Document, len(docList)) + for i, doc := range docList { + result[i] = doc.To() + } + return result, nil +} + +func (p *PostgresClient) DeleteDocument(ctx context.Context, entryId int64) error { + defer trace.StartRegion(ctx, "metastore.sql.DeleteDocument").End() + err := p.dEntity.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + res := namespaceQuery(ctx, tx).Where("oid = ?", entryId).Delete(&Document{}) + return res.Error + }) + return utils.SqlError2Error(err) +} + +func docQueryFilter(tx *gorm.DB, filter *doc.DocumentFilter) *gorm.DB { + if filter.ParentID != nil { + tx = tx.Where("document.parent_entry_id = ?", filter.ParentID) + } + if filter.Marked != nil { + tx = tx.Where("document.marked = ?", *filter.Marked) + } + if filter.Unread != nil { + tx = tx.Where("document.unread = ?", *filter.Unread) + } + if filter.Source != "" { + tx = tx.Where("document.source = ?", filter.Source) + } + if filter.CreatedAtStart != nil { + tx = tx.Where("document.created_at >= ?", *filter.CreatedAtStart) + } + if filter.CreatedAtEnd != nil { + tx = tx.Where("document.created_at < ?", *filter.CreatedAtEnd) + } + if filter.ChangedAtStart != nil { + tx = tx.Where("document.changed_at >= ?", *filter.ChangedAtStart) + } + if filter.ChangedAtEnd != nil { + tx = tx.Where("document.changed_at < ?", *filter.ChangedAtEnd) + } + if filter.FuzzyName != "" { + tx = tx.Where("document.name LIKE ?", "%"+filter.FuzzyName+"%") + } + if filter.Search != "" { + tx = tx.Where("document.content LIKE ?", "%"+filter.Search+"%") + } + return tx +} + +func docOrder(tx *gorm.DB, order *doc.DocumentOrder) *gorm.DB { + if order != nil { + orderStr := order.Order.String() + if order.Desc { + orderStr += " DESC" + } + tx = tx.Order(orderStr) + } else { + tx = tx.Order("created_at DESC") + } + return tx +} + +func namespaceQuery(ctx context.Context, tx *gorm.DB) *gorm.DB { + ns := models.GetNamespace(ctx) + if ns.String() == models.DefaultNamespaceValue { + return tx + } + return tx.Where("namespace = ?", ns.String()) +} + +func (p *PostgresClient) WithNamespace(ctx context.Context) *gorm.DB { + ns := models.GetNamespace(ctx) + if ns.String() == models.DefaultNamespaceValue { + return p.dEntity.WithContext(ctx) + } + return p.dEntity.WithContext(ctx).Where("namespace = ?", ns.String()) +} diff --git a/pkg/store/vectorstore/postgres/migrate.go b/pkg/store/postgres/migrate.go similarity index 78% rename from pkg/store/vectorstore/postgres/migrate.go rename to pkg/store/postgres/migrate.go index 152c788..98bbe18 100644 --- a/pkg/store/vectorstore/postgres/migrate.go +++ b/pkg/store/postgres/migrate.go @@ -50,6 +50,23 @@ func buildMigrations() []*gormigrate.Migration { }, Rollback: func(db *gorm.DB) error { return nil }, }, + { + ID: "2024121501", + Migrate: func(db *gorm.DB) error { + err := db.AutoMigrate( + &Document{}, + ) + if err != nil { + return err + } + //_ = db.Exec(`CREATE INDEX name ON table USING gist(content);`) + //_ = db.Exec("CREATE INDEX idx_doc_content ON document USING GIN (content gin_trgm_ops);") + return nil + }, + Rollback: func(db *gorm.DB) error { + return nil + }, + }, } } diff --git a/pkg/store/vectorstore/postgres/model.go b/pkg/store/postgres/model.go similarity index 54% rename from pkg/store/vectorstore/postgres/model.go rename to pkg/store/postgres/model.go index 5cdd517..fa48e00 100644 --- a/pkg/store/vectorstore/postgres/model.go +++ b/pkg/store/postgres/model.go @@ -20,6 +20,7 @@ import ( "encoding/json" "time" + "github.com/basenana/friday/pkg/models/doc" "github.com/basenana/friday/pkg/models/vector" ) @@ -112,3 +113,79 @@ type BleveKV struct { func (v *BleveKV) TableName() string { return "friday_blevekv" } + +type Document struct { + ID int64 `gorm:"column:id;primaryKey"` + OID int64 `gorm:"column:oid;index:doc_oid"` + Name string `gorm:"column:name;index:doc_name"` + Namespace string `gorm:"column:namespace;index:doc_ns"` + Source string `gorm:"column:source;index:doc_source"` + ParentEntryID *int64 `gorm:"column:parent_entry_id;index:doc_parent_entry_id"` + Keywords string `gorm:"column:keywords"` + Content string `gorm:"column:content"` + Summary string `gorm:"column:summary"` + HeaderImage string `gorm:"column:header_image"` + SubContent string `gorm:"column:sub_content"` + Marked bool `gorm:"column:marked;index:doc_is_marked"` + Unread bool `gorm:"column:unread;index:doc_is_unread"` + CreatedAt time.Time `gorm:"column:created_at"` + ChangedAt time.Time `gorm:"column:changed_at"` +} + +func (d *Document) TableName() string { + return "document" +} + +func (d *Document) From(document *doc.Document) *Document { + d.ID = document.EntryId + d.OID = document.EntryId + d.Name = document.Name + d.Namespace = document.Namespace + d.ParentEntryID = document.ParentEntryID + d.Source = document.Source + d.Content = document.Content + d.Summary = document.Summary + d.CreatedAt = document.CreatedAt + d.ChangedAt = document.ChangedAt + if document.Marked != nil { + d.Marked = *document.Marked + } + if document.Unread != nil { + d.Unread = *document.Unread + } + d.HeaderImage = document.HeaderImage + d.SubContent = document.SubContent + return d +} + +func (d *Document) UpdateFrom(document *doc.Document) *Document { + if document.Unread != nil { + d.Unread = *document.Unread + } + if document.Marked != nil { + d.Marked = *document.Marked + } + if document.ParentEntryID != nil { + d.ParentEntryID = document.ParentEntryID + } + return d +} + +func (d *Document) To() *doc.Document { + result := &doc.Document{ + EntryId: d.OID, + Name: d.Name, + Namespace: d.Namespace, + ParentEntryID: d.ParentEntryID, + Source: d.Source, + Content: d.Content, + Summary: d.Summary, + SubContent: d.SubContent, + HeaderImage: d.HeaderImage, + Marked: &d.Marked, + Unread: &d.Unread, + CreatedAt: d.CreatedAt, + ChangedAt: d.ChangedAt, + } + return result +} diff --git a/pkg/store/postgres/postgres.go b/pkg/store/postgres/postgres.go new file mode 100644 index 0000000..661ca0c --- /dev/null +++ b/pkg/store/postgres/postgres.go @@ -0,0 +1,67 @@ +/* + Copyright 2023 Friday Author. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package postgres + +import ( + "time" + + "go.uber.org/zap" + "gorm.io/driver/postgres" + "gorm.io/gorm" + + "github.com/basenana/friday/pkg/store/db" + "github.com/basenana/friday/pkg/store/utils" + "github.com/basenana/friday/pkg/utils/logger" +) + +const defaultNamespace = "global" + +type PostgresClient struct { + log *zap.SugaredLogger + dEntity *db.Entity +} + +func NewPostgresClient(postgresUrl string) (*PostgresClient, error) { + log := logger.NewLog("postgres") + dbObj, err := gorm.Open(postgres.Open(postgresUrl), &gorm.Config{Logger: utils.NewDbLogger()}) + if err != nil { + panic(err) + } + + dbConn, err := dbObj.DB() + if err != nil { + return nil, err + } + + dbConn.SetMaxIdleConns(5) + dbConn.SetMaxOpenConns(50) + dbConn.SetConnMaxLifetime(time.Hour) + + if err = dbConn.Ping(); err != nil { + return nil, err + } + + dbEnt, err := db.NewDbEntity(dbObj, Migrate) + if err != nil { + return nil, err + } + + return &PostgresClient{ + log: log, + dEntity: dbEnt, + }, nil +} diff --git a/pkg/store/vectorstore/postgres/postgres.go b/pkg/store/postgres/vector.go similarity index 68% rename from pkg/store/vectorstore/postgres/postgres.go rename to pkg/store/postgres/vector.go index a5ef56a..9d5f367 100644 --- a/pkg/store/vectorstore/postgres/postgres.go +++ b/pkg/store/postgres/vector.go @@ -1,5 +1,5 @@ /* - Copyright 2023 Friday Author. + Copyright 2024 Friday Author. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,56 +24,13 @@ import ( "time" "github.com/cdipaolo/goml/base" - "gorm.io/driver/postgres" "gorm.io/gorm" - "github.com/basenana/friday/pkg/models" "github.com/basenana/friday/pkg/models/vector" - "github.com/basenana/friday/pkg/store/vectorstore" - "github.com/basenana/friday/pkg/store/vectorstore/db" - "github.com/basenana/friday/pkg/utils/logger" + "github.com/basenana/friday/pkg/store" + "github.com/basenana/friday/pkg/utils" ) -const defaultNamespace = "global" - -type PostgresClient struct { - log logger.Logger - dEntity *db.Entity -} - -func NewPostgresClient(log logger.Logger, postgresUrl string) (*PostgresClient, error) { - if log == nil { - log = logger.NewLogger("database") - } - dbObj, err := gorm.Open(postgres.Open(postgresUrl), &gorm.Config{Logger: logger.NewDbLogger(log)}) - if err != nil { - panic(err) - } - - dbConn, err := dbObj.DB() - if err != nil { - return nil, err - } - - dbConn.SetMaxIdleConns(5) - dbConn.SetMaxOpenConns(50) - dbConn.SetConnMaxLifetime(time.Hour) - - if err = dbConn.Ping(); err != nil { - return nil, err - } - - dbEnt, err := db.NewDbEntity(dbObj, Migrate) - if err != nil { - return nil, err - } - - return &PostgresClient{ - log: log, - dEntity: dbEnt, - }, nil -} - func (p *PostgresClient) Store(ctx context.Context, element *vector.Element, extra map[string]any) error { namespace := ctx.Value("namespace") if namespace == nil { @@ -149,7 +106,7 @@ func (p *PostgresClient) Search(ctx context.Context, query vector.VectorDocQuery } // knn search - dists := distances{} + dists := utils.Distances{} for _, index := range existIndexes { var vector []float64 err := json.Unmarshal([]byte(index.Vector), &vector) @@ -157,9 +114,9 @@ func (p *PostgresClient) Search(ctx context.Context, query vector.VectorDocQuery return nil, err } - dists = append(dists, distance{ - Index: index, - dist: base.EuclideanDistance(vector, vectors64), + dists = append(dists, utils.Distance{ + Object: index, + Dist: base.EuclideanDistance(vector, vectors64), }) } @@ -171,7 +128,8 @@ func (p *PostgresClient) Search(ctx context.Context, query vector.VectorDocQuery } results := make([]*vector.Doc, 0) for _, index := range minKIndexes { - results = append(results, index.ToDoc()) + i := index.Object.(Index) + results = append(results, i.ToDoc()) } return results, nil @@ -193,7 +151,7 @@ func (p *PostgresClient) Get(ctx context.Context, oid int64, name string, group return vModel.To() } -var _ vectorstore.VectorStore = &PostgresClient{} +var _ store.VectorStore = &PostgresClient{} func (p *PostgresClient) Inited(ctx context.Context) (bool, error) { var count int64 @@ -204,30 +162,3 @@ func (p *PostgresClient) Inited(ctx context.Context) (bool, error) { return count > 0, nil } - -type distance struct { - Index - dist float64 -} - -type distances []distance - -func (d distances) Len() int { - return len(d) -} - -func (d distances) Less(i, j int) bool { - return d[i].dist < d[j].dist -} - -func (d distances) Swap(i, j int) { - d[i], d[j] = d[j], d[i] -} - -func namespaceQuery(ctx context.Context, tx *gorm.DB) *gorm.DB { - ns := models.GetNamespace(ctx) - if ns.String() == models.DefaultNamespaceValue { - return tx - } - return tx.Where("namespace = ?", ns.String()) -} diff --git a/pkg/store/vectorstore/redis/redis.go b/pkg/store/redis/redis.go similarity index 94% rename from pkg/store/vectorstore/redis/redis.go rename to pkg/store/redis/redis.go index 24179a6..c6cf1f7 100644 --- a/pkg/store/vectorstore/redis/redis.go +++ b/pkg/store/redis/redis.go @@ -25,7 +25,7 @@ import ( "github.com/redis/rueidis" "github.com/basenana/friday/pkg/models/vector" - "github.com/basenana/friday/pkg/store/vectorstore" + "github.com/basenana/friday/pkg/store" "github.com/basenana/friday/pkg/utils/files" "github.com/basenana/friday/pkg/utils/logger" ) @@ -43,17 +43,17 @@ type RedisClient struct { dim int } -var _ vectorstore.VectorStore = &RedisClient{} +var _ store.VectorStore = &RedisClient{} -func NewRedisClientWithDim(redisUrl string, dim int) (vectorstore.VectorStore, error) { +func NewRedisClientWithDim(redisUrl string, dim int) (store.VectorStore, error) { return newRedisClient(redisUrl, EmbeddingPrefix, EmbeddingIndex, dim) } -func NewRedisClient(redisUrl string) (vectorstore.VectorStore, error) { +func NewRedisClient(redisUrl string) (store.VectorStore, error) { return newRedisClient(redisUrl, EmbeddingPrefix, EmbeddingIndex, 1536) } -func newRedisClient(redisUrl string, prefix, index string, embeddingDim int) (vectorstore.VectorStore, error) { +func newRedisClient(redisUrl string, prefix, index string, embeddingDim int) (store.VectorStore, error) { client, err := rueidis.NewClient(rueidis.ClientOption{InitAddress: []string{redisUrl}}) if err != nil { return nil, err diff --git a/pkg/store/utils/utils.go b/pkg/store/utils/utils.go new file mode 100644 index 0000000..8067cd0 --- /dev/null +++ b/pkg/store/utils/utils.go @@ -0,0 +1,74 @@ +/* + Copyright 2024 Friday Author. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package utils + +import ( + "context" + "time" + + "go.uber.org/zap" + "gorm.io/gorm" + + glogger "gorm.io/gorm/logger" + + "github.com/basenana/friday/pkg/models" + "github.com/basenana/friday/pkg/utils/logger" +) + +func SqlError2Error(err error) error { + switch err { + case gorm.ErrRecordNotFound: + return models.ErrNotFound + default: + return err + } +} + +type Logger struct { + *zap.SugaredLogger +} + +func (l *Logger) LogMode(level glogger.LogLevel) glogger.Interface { + return l +} + +func (l *Logger) Info(ctx context.Context, s string, i ...interface{}) { + l.Infof(s, i...) +} + +func (l *Logger) Warn(ctx context.Context, s string, i ...interface{}) { + l.Warnf(s, i...) +} + +func (l *Logger) Error(ctx context.Context, s string, i ...interface{}) { + l.Errorf(s, i...) +} + +func (l *Logger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + sqlContent, rows := fc() + l.Debugw("trace sql", "sql", sqlContent, "rows", rows, "err", err) + switch { + case err != nil && err != gorm.ErrRecordNotFound && err != context.Canceled: + l.Warnw("trace error", "sql", sqlContent, "rows", rows, "err", err) + case time.Since(begin) > time.Second: + l.Infow("slow sql", "sql", sqlContent, "rows", rows, "cost", time.Since(begin).Seconds()) + } +} + +func NewDbLogger() *Logger { + return &Logger{SugaredLogger: logger.NewLog("database")} +} diff --git a/pkg/store/vectorstore/interface.go b/pkg/store/vectorstore/interface.go deleted file mode 100644 index f38a19b..0000000 --- a/pkg/store/vectorstore/interface.go +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright 2023 friday - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package vectorstore - -import ( - "context" - - "github.com/basenana/friday/pkg/models/vector" -) - -type VectorStore interface { - Store(ctx context.Context, element *vector.Element, extra map[string]any) error - Search(ctx context.Context, query vector.VectorDocQuery, vectors []float32, k int) ([]*vector.Doc, error) - Get(ctx context.Context, oid int64, name string, group int) (*vector.Element, error) -} diff --git a/pkg/store/docstore/interface.go b/pkg/utils/knn.go similarity index 57% rename from pkg/store/docstore/interface.go rename to pkg/utils/knn.go index 282cc07..b03e53f 100644 --- a/pkg/store/docstore/interface.go +++ b/pkg/utils/knn.go @@ -14,17 +14,23 @@ limitations under the License. */ -package docstore +package utils -import ( - "context" +type Distance struct { + Object interface{} + Dist float64 +} + +type Distances []Distance - "github.com/basenana/friday/pkg/models/doc" -) +func (d Distances) Len() int { + return len(d) +} + +func (d Distances) Less(i, j int) bool { + return d[i].Dist < d[j].Dist +} -type DocStoreInterface interface { - Store(ctx context.Context, docPtr doc.DocPtrInterface) error - FilterAttr(ctx context.Context, query *doc.DocumentAttrQuery) (doc.DocumentAttrList, error) - Search(ctx context.Context, query *doc.DocumentQuery) (doc.DocumentList, error) - DeleteByFilter(ctx context.Context, aqs doc.DocumentAttrQuery) error +func (d Distances) Swap(i, j int) { + d[i], d[j] = d[j], d[i] } diff --git a/pkg/utils/logger/dblogger.go b/pkg/utils/logger/dblogger.go deleted file mode 100644 index 0afe201..0000000 --- a/pkg/utils/logger/dblogger.go +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright 2023 Friday Author. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package logger - -import ( - "context" - "time" - - "gorm.io/gorm" - glogger "gorm.io/gorm/logger" -) - -type DBLogger struct { - Logger -} - -func (l *DBLogger) LogMode(level glogger.LogLevel) glogger.Interface { - return l -} - -func (l *DBLogger) Info(ctx context.Context, s string, i ...interface{}) { - l.Infof(s, i...) -} - -func (l *DBLogger) Warn(ctx context.Context, s string, i ...interface{}) { - l.Warnf(s, i...) -} - -func (l *DBLogger) Error(ctx context.Context, s string, i ...interface{}) { - l.Errorf(s, i...) -} - -func (l *DBLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { - sqlContent, rows := fc() - l.Debugf("trace sql: %s\nrows: %d, err: %v", sqlContent, rows, err) - switch { - case err != nil && err != gorm.ErrRecordNotFound && err != context.Canceled: - l.Debugf("trace error, sql: %s\nrows: %s, err: %v", sqlContent, rows, err) - case time.Since(begin) > time.Second: - l.Infof("slow sql, sql: %s\nrows: %s, err: %v", sqlContent, rows, err) - } -} - -func NewDbLogger(log Logger) *DBLogger { - return &DBLogger{log} -}