From d8f6efb8f215c4acf79167d49b3777ae7390633e Mon Sep 17 00:00:00 2001 From: zwwhdls Date: Wed, 29 May 2024 11:45:49 +0800 Subject: [PATCH] fix: namespace in sql Signed-off-by: zwwhdls --- pkg/friday/friday.go | 4 +-- pkg/friday/question.go | 4 +-- pkg/models/namespace.go | 51 ++++++++++++++++++++++++++++ pkg/vectorstore/db/entity.go | 12 +++++++ pkg/vectorstore/postgres/postgres.go | 29 +++++++--------- 5 files changed, 80 insertions(+), 20 deletions(-) create mode 100644 pkg/models/namespace.go diff --git a/pkg/friday/friday.go b/pkg/friday/friday.go index c10db34..2cafa9f 100644 --- a/pkg/friday/friday.go +++ b/pkg/friday/friday.go @@ -117,7 +117,7 @@ func (f *Friday) WithContext(ctx context.Context) *Friday { return t } -func (f *Friday) Namespace(namespace string) *Friday { - f.statement.context = context.WithValue(f.statement.context, "namespace", namespace) +func (f *Friday) Namespace(namespace *models.Namespace) *Friday { + f.statement.context = models.WithNamespace(f.statement.context, namespace) return f } diff --git a/pkg/friday/question.go b/pkg/friday/question.go index 643c5be..8f7e1c1 100644 --- a/pkg/friday/question.go +++ b/pkg/friday/question.go @@ -123,12 +123,12 @@ func (f *Friday) Chat(res *ChatState) *Friday { } func (f *Friday) generateSystemInfo() string { - systemTemplate := "基于以下内容,简洁和专业的来回答用户的问题。答案请使用中文。\n" + systemTemplate := "你是一位知识渊博的文字工作者,负责帮用户阅读文章,基于以下内容,简洁和专业的来回答用户的问题。答案请使用中文。\n" if f.statement.Summary != "" { systemTemplate += "\n这是文章简介: {{ .Summary }}\n" } if f.statement.Info != "" { - systemTemplate += "\n这是已知内容: {{ .Info }}\n" + systemTemplate += "\n这是相关的已知内容: {{ .Info }}\n" } if f.statement.HistorySummary != "" { systemTemplate += "\n这是历史聊天总结作为前情提要: {{ .HistorySummary }}\n" diff --git a/pkg/models/namespace.go b/pkg/models/namespace.go new file mode 100644 index 0000000..61115bf --- /dev/null +++ b/pkg/models/namespace.go @@ -0,0 +1,51 @@ +/* + Copyright 2023 Friday Author. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package models + +import "context" + +const ( + NamespaceKey = "namespace" + DefaultNamespaceValue = "global" // TODO: using 'public' + GlobalNamespaceValue = "global" +) + +type Namespace struct { + name string +} + +func NewNamespace(name string) *Namespace { + return &Namespace{name: name} +} + +func (n *Namespace) String() string { + return n.name +} + +func GetNamespace(ctx context.Context) (ns *Namespace) { + ns = &Namespace{ + name: DefaultNamespaceValue, + } + if ctx.Value(NamespaceKey) != nil { + ns.name = ctx.Value(NamespaceKey).(string) + } + return +} + +func WithNamespace(ctx context.Context, ns *Namespace) context.Context { + return context.WithValue(ctx, NamespaceKey, ns.String()) +} diff --git a/pkg/vectorstore/db/entity.go b/pkg/vectorstore/db/entity.go index cecddf5..2401560 100644 --- a/pkg/vectorstore/db/entity.go +++ b/pkg/vectorstore/db/entity.go @@ -17,7 +17,11 @@ package db import ( + "context" + "gorm.io/gorm" + + "github.com/basenana/friday/pkg/models" ) type Entity struct { @@ -31,3 +35,11 @@ func NewDbEntity(db *gorm.DB, migrate func(db *gorm.DB) error) (*Entity, error) } return ent, nil } + +func (e *Entity) WithNamespace(ctx context.Context) *gorm.DB { + ns := models.GetNamespace(ctx) + if ns.String() == models.DefaultNamespaceValue { + return e.WithContext(ctx) + } + return e.WithContext(ctx).Where("namespace = ?", ns.String()) +} diff --git a/pkg/vectorstore/postgres/postgres.go b/pkg/vectorstore/postgres/postgres.go index 2e9dc6f..86ec2c0 100644 --- a/pkg/vectorstore/postgres/postgres.go +++ b/pkg/vectorstore/postgres/postgres.go @@ -128,20 +128,14 @@ func (p *PostgresClient) Store(ctx context.Context, element *models.Element, ext } func (p *PostgresClient) Search(ctx context.Context, query models.DocQuery, vectors []float32, k int) ([]*models.Doc, error) { - namespace := ctx.Value("namespace") - if namespace == nil { - namespace = defaultNamespace - } vectors64 := make([]float64, 0) for _, v := range vectors { vectors64 = append(vectors64, float64(v)) } // query from db existIndexes := make([]Index, 0) - var res *gorm.DB - res = p.dEntity.WithContext(ctx) - res = res.Where("namespace = ?", namespace) + res := p.dEntity.WithNamespace(ctx) if query.ParentId != 0 { res = res.Where("parent_entry_id = ?", query.ParentId) } @@ -183,17 +177,12 @@ func (p *PostgresClient) Search(ctx context.Context, query models.DocQuery, vect } func (p *PostgresClient) Get(ctx context.Context, oid int64, name string, group int) (*models.Element, error) { - namespace := ctx.Value("namespace") - if namespace == nil { - namespace = defaultNamespace - } vModel := &Index{} - var res *gorm.DB - if oid == 0 { - res = p.dEntity.WithContext(ctx).Where("namespace = ? AND name = ? AND idx_group = ?", namespace, name, group).First(vModel) - } else { - res = p.dEntity.WithContext(ctx).Where("namespace = ? AND name = ? AND oid = ? AND idx_group = ?", namespace, name, oid, group).First(vModel) + tx := p.dEntity.WithNamespace(ctx).Where("name = ? AND idx_group = ?", name, group) + if oid != 0 { + tx = tx.Where("oid = ?", oid) } + res := tx.First(vModel) if res.Error != nil { if res.Error == gorm.ErrRecordNotFound { return nil, nil @@ -233,3 +222,11 @@ func (d distances) Less(i, j int) bool { func (d distances) Swap(i, j int) { d[i], d[j] = d[j], d[i] } + +func namespaceQuery(ctx context.Context, tx *gorm.DB) *gorm.DB { + ns := models.GetNamespace(ctx) + if ns.String() == models.DefaultNamespaceValue { + return tx + } + return tx.Where("namespace = ?", ns.String()) +}