Skip to content

Commit

Permalink
fix: namespace in sql
Browse files Browse the repository at this point in the history
Signed-off-by: zwwhdls <zww@hdls.me>
  • Loading branch information
zwwhdls committed May 29, 2024
1 parent 962b40b commit d8f6efb
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 20 deletions.
4 changes: 2 additions & 2 deletions pkg/friday/friday.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func (f *Friday) WithContext(ctx context.Context) *Friday {
return t
}

func (f *Friday) Namespace(namespace string) *Friday {
f.statement.context = context.WithValue(f.statement.context, "namespace", namespace)
func (f *Friday) Namespace(namespace *models.Namespace) *Friday {
f.statement.context = models.WithNamespace(f.statement.context, namespace)
return f
}
4 changes: 2 additions & 2 deletions pkg/friday/question.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ func (f *Friday) Chat(res *ChatState) *Friday {
}

func (f *Friday) generateSystemInfo() string {
systemTemplate := "基于以下内容,简洁和专业的来回答用户的问题。答案请使用中文。\n"
systemTemplate := "你是一位知识渊博的文字工作者,负责帮用户阅读文章,基于以下内容,简洁和专业的来回答用户的问题。答案请使用中文。\n"
if f.statement.Summary != "" {
systemTemplate += "\n这是文章简介: {{ .Summary }}\n"
}
if f.statement.Info != "" {
systemTemplate += "\n这是已知内容: {{ .Info }}\n"
systemTemplate += "\n这是相关的已知内容: {{ .Info }}\n"
}
if f.statement.HistorySummary != "" {
systemTemplate += "\n这是历史聊天总结作为前情提要: {{ .HistorySummary }}\n"
Expand Down
51 changes: 51 additions & 0 deletions pkg/models/namespace.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
Copyright 2023 Friday Author.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package models

import "context"

const (
NamespaceKey = "namespace"
DefaultNamespaceValue = "global" // TODO: using 'public'
GlobalNamespaceValue = "global"
)

type Namespace struct {
name string
}

func NewNamespace(name string) *Namespace {
return &Namespace{name: name}
}

func (n *Namespace) String() string {
return n.name
}

func GetNamespace(ctx context.Context) (ns *Namespace) {
ns = &Namespace{
name: DefaultNamespaceValue,
}
if ctx.Value(NamespaceKey) != nil {
ns.name = ctx.Value(NamespaceKey).(string)
}
return
}

func WithNamespace(ctx context.Context, ns *Namespace) context.Context {
return context.WithValue(ctx, NamespaceKey, ns.String())
}
12 changes: 12 additions & 0 deletions pkg/vectorstore/db/entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
package db

import (
"context"

"gorm.io/gorm"

"github.com/basenana/friday/pkg/models"
)

type Entity struct {
Expand All @@ -31,3 +35,11 @@ func NewDbEntity(db *gorm.DB, migrate func(db *gorm.DB) error) (*Entity, error)
}
return ent, nil
}

func (e *Entity) WithNamespace(ctx context.Context) *gorm.DB {
ns := models.GetNamespace(ctx)
if ns.String() == models.DefaultNamespaceValue {
return e.WithContext(ctx)
}
return e.WithContext(ctx).Where("namespace = ?", ns.String())
}
29 changes: 13 additions & 16 deletions pkg/vectorstore/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,20 +128,14 @@ func (p *PostgresClient) Store(ctx context.Context, element *models.Element, ext
}

func (p *PostgresClient) Search(ctx context.Context, query models.DocQuery, vectors []float32, k int) ([]*models.Doc, error) {
namespace := ctx.Value("namespace")
if namespace == nil {
namespace = defaultNamespace
}
vectors64 := make([]float64, 0)
for _, v := range vectors {
vectors64 = append(vectors64, float64(v))
}
// query from db
existIndexes := make([]Index, 0)
var res *gorm.DB

res = p.dEntity.WithContext(ctx)
res = res.Where("namespace = ?", namespace)
res := p.dEntity.WithNamespace(ctx)
if query.ParentId != 0 {
res = res.Where("parent_entry_id = ?", query.ParentId)
}
Expand Down Expand Up @@ -183,17 +177,12 @@ func (p *PostgresClient) Search(ctx context.Context, query models.DocQuery, vect
}

func (p *PostgresClient) Get(ctx context.Context, oid int64, name string, group int) (*models.Element, error) {
namespace := ctx.Value("namespace")
if namespace == nil {
namespace = defaultNamespace
}
vModel := &Index{}
var res *gorm.DB
if oid == 0 {
res = p.dEntity.WithContext(ctx).Where("namespace = ? AND name = ? AND idx_group = ?", namespace, name, group).First(vModel)
} else {
res = p.dEntity.WithContext(ctx).Where("namespace = ? AND name = ? AND oid = ? AND idx_group = ?", namespace, name, oid, group).First(vModel)
tx := p.dEntity.WithNamespace(ctx).Where("name = ? AND idx_group = ?", name, group)
if oid != 0 {
tx = tx.Where("oid = ?", oid)
}
res := tx.First(vModel)
if res.Error != nil {
if res.Error == gorm.ErrRecordNotFound {
return nil, nil
Expand Down Expand Up @@ -233,3 +222,11 @@ func (d distances) Less(i, j int) bool {
func (d distances) Swap(i, j int) {
d[i], d[j] = d[j], d[i]
}

func namespaceQuery(ctx context.Context, tx *gorm.DB) *gorm.DB {
ns := models.GetNamespace(ctx)
if ns.String() == models.DefaultNamespaceValue {
return tx
}
return tx.Where("namespace = ?", ns.String())
}

0 comments on commit d8f6efb

Please sign in to comment.