Skip to content

Commit

Permalink
fix: get more accurate emberder segmentation
Browse files Browse the repository at this point in the history
Signed-off-by: Abirdcfly <fp544037857@gmail.com>
  • Loading branch information
Abirdcfly committed Dec 11, 2023
1 parent b175755 commit 5174be0
Show file tree
Hide file tree
Showing 19 changed files with 260 additions and 186 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ COPY go.mod go.mod
COPY go.sum go.sum
# cache deps before building and copying source so that we don't need to re-download as much
# and so that source changes don't invalidate our downloaded layer
RUN go env -w GOPROXY=https://goproxy.cn,direct
#RUN go env -w GOPROXY=https://goproxy.cn,direct
RUN go mod download

# Copy the go source
Expand Down
44 changes: 25 additions & 19 deletions api/base/v1alpha1/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,39 @@ import (
"k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/dynamic"
"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/kubeagi/arcadia/pkg/utils"
)

func (e Embedder) AuthAPIKey(ctx context.Context, c client.Client) (string, error) {
func (e Embedder) AuthAPIKey(ctx context.Context, c client.Client, cli dynamic.Interface) (string, error) {
if e.Spec.Enpoint == nil || e.Spec.Enpoint.AuthSecret == nil {
return "", nil
}
authSecret := &corev1.Secret{}
err := c.Get(ctx, types.NamespacedName{Name: e.Spec.Enpoint.AuthSecret.Name, Namespace: e.Namespace}, authSecret)
if err != nil {
if err := utils.ValidClient(c, cli); err != nil {
return "", err
}
return string(authSecret.Data["apiKey"]), nil
}

func (e Embedder) AuthAPIKeyByDynamicCli(ctx context.Context, cli dynamic.Interface) (string, error) {
if e.Spec.Enpoint == nil || e.Spec.Enpoint.AuthSecret == nil {
return "", nil
}
authSecret := &corev1.Secret{}
obj, err := cli.Resource(schema.GroupVersionResource{Group: "", Version: "v1", Resource: "secrets"}).
Namespace(e.GetNamespace()).Get(ctx, e.Spec.Enpoint.AuthSecret.Name, metav1.GetOptions{})
if err != nil {
return "", err
}
err = runtime.DefaultUnstructuredConverter.FromUnstructured(obj.UnstructuredContent(), authSecret)
if err != nil {
return "", err
if c != nil {
if err := c.Get(ctx, types.NamespacedName{Name: e.Spec.Enpoint.AuthSecret.Name, Namespace: e.Namespace}, authSecret); err != nil {
return "", err
}
} else {
obj, err := cli.Resource(schema.GroupVersionResource{Group: "", Version: "v1", Resource: "secrets"}).
Namespace(e.GetNamespace()).Get(ctx, e.Spec.Enpoint.AuthSecret.Name, metav1.GetOptions{})
if err != nil {
return "", err
}
err = runtime.DefaultUnstructuredConverter.FromUnstructured(obj.UnstructuredContent(), authSecret)
if err != nil {
return "", err
}
}
return string(authSecret.Data["apiKey"]), nil
}

type EmbeddingType string

const (
OpenAI EmbeddingType = "openai"
ZhiPuAI EmbeddingType = "zhipuai"
)
4 changes: 1 addition & 3 deletions api/base/v1alpha1/embedder_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ package v1alpha1

import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"github.com/kubeagi/arcadia/pkg/embeddings"
)

// EDIT THIS FILE! THIS IS SCAFFOLDING FOR YOU TO OWN!
Expand All @@ -30,7 +28,7 @@ type EmbedderSpec struct {
CommonSpec `json:",inline"`

// ServiceType indicates the source type of embedding service
Type embeddings.EmbeddingType `json:"type,omitempty"`
Type EmbeddingType `json:"type,omitempty"`

// Provider defines the provider info which provide this embedder service
Provider `json:"provider,omitempty"`
Expand Down
3 changes: 1 addition & 2 deletions api/base/v1alpha1/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"github.com/kubeagi/arcadia/pkg/embeddings"
"github.com/kubeagi/arcadia/pkg/llms"
)

Expand Down Expand Up @@ -108,7 +107,7 @@ func (worker Worker) BuildEmbedder() *Embedder {
Creator: worker.Spec.Creator,
Description: "Embedder created by Worker(OpenAI compatible)",
},
Type: embeddings.OpenAI,
Type: OpenAI,
Provider: Provider{
Worker: &TypedObjectReference{
Kind: "Worker",
Expand Down
7 changes: 3 additions & 4 deletions controllers/embedder_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import (
"sigs.k8s.io/controller-runtime/pkg/predicate"

arcadiav1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1"
"github.com/kubeagi/arcadia/pkg/embeddings"
"github.com/kubeagi/arcadia/pkg/llms/openai"
"github.com/kubeagi/arcadia/pkg/llms/zhipuai"
)
Expand Down Expand Up @@ -156,20 +155,20 @@ func (r *EmbedderReconciler) check3rdPartyEmbedder(ctx context.Context, logger l
var msg string

// Check Auth availability
apiKey, err := instance.AuthAPIKey(ctx, r.Client)
apiKey, err := instance.AuthAPIKey(ctx, r.Client, nil)
if err != nil {
return r.UpdateStatus(ctx, instance, nil, err)
}

switch instance.Spec.Type {
case embeddings.ZhiPuAI:
case arcadiav1alpha1.ZhiPuAI:
embedClient := zhipuai.NewZhiPuAI(apiKey)
res, err := embedClient.Validate()
if err != nil {
return r.UpdateStatus(ctx, instance, nil, err)
}
msg = res.String()
case embeddings.OpenAI:
case arcadiav1alpha1.OpenAI:
embedClient := openai.NewOpenAI(apiKey, instance.Spec.Enpoint.URL)
res, err := embedClient.Validate()
if err != nil {
Expand Down
79 changes: 22 additions & 57 deletions controllers/knowledgebase_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ import (
"github.com/go-logr/logr"
"github.com/minio/minio-go/v7"
"github.com/tmc/langchaingo/documentloaders"
langchainembeddings "github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/llms/openai"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/textsplitter"
"github.com/tmc/langchaingo/vectorstores/chroma"
apierrors "k8s.io/apimachinery/pkg/api/errors"
Expand All @@ -46,8 +45,6 @@ import (
"github.com/kubeagi/arcadia/pkg/datasource"
pkgdocumentloaders "github.com/kubeagi/arcadia/pkg/documentloaders"
"github.com/kubeagi/arcadia/pkg/embeddings"
zhipuaiembeddings "github.com/kubeagi/arcadia/pkg/embeddings/zhipuai"
"github.com/kubeagi/arcadia/pkg/llms/zhipuai"
"github.com/kubeagi/arcadia/pkg/utils"
)

Expand Down Expand Up @@ -253,7 +250,7 @@ func (r *KnowledgeBaseReconciler) reconcileFileGroup(ctx context.Context, log lo
return errDataSourceNotReady
}

system, err := config.GetSystemDatasource(ctx, r.Client)
system, err := config.GetSystemDatasource(ctx, r.Client, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -383,69 +380,32 @@ func (r *KnowledgeBaseReconciler) handleFile(ctx context.Context, log logr.Logge
if !store.Status.IsReady() {
return errVectorStoreNotReady
}
var em langchainembeddings.Embedder
switch embedder.Spec.Provider.GetType() {
case arcadiav1alpha1.ProviderType3rdParty:
switch embedder.Spec.Type { // nolint: gocritic
case embeddings.ZhiPuAI:
apiKey, err := embedder.AuthAPIKey(ctx, r.Client)
if err != nil {
return err
}
em, err = zhipuaiembeddings.NewZhiPuAI(
zhipuaiembeddings.WithClient(*zhipuai.NewZhiPuAI(apiKey)),
)
if err != nil {
return err
}
}
case arcadiav1alpha1.ProviderTypeWorker:
gatway, err := config.GetGateway(ctx, r.Client)
if err != nil {
return err
}
if gatway == nil {
return fmt.Errorf("global config gateway not found")
}
refWorker := embedder.Spec.Worker
if refWorker == nil {
return fmt.Errorf("embedder.spec.worker not defined")
}
worker := &arcadiav1alpha1.Worker{}
if err := r.Client.Get(ctx, types.NamespacedName{Namespace: refWorker.GetNamespace(), Name: refWorker.Name}, worker); err != nil {
return err
}
refModel := worker.Spec.Model
if refModel == nil {
return fmt.Errorf("worker.spec.model not defined")
}
modelName := worker.MakeRegistrationModelName()
llm, err := openai.New(openai.WithModel(modelName), openai.WithBaseURL(gatway.APIServer), openai.WithToken("fake"))
if err != nil {
return err
}
em, err = langchainembeddings.NewEmbedder(llm)
if err != nil {
return err
}
em, err := embeddings.GetLangchainEmbedder(ctx, embedder, r.Client, nil)
if err != nil {
return err
}
data, err := io.ReadAll(file) // TODO Load large files in pieces to save memory
// TODO Line or single line byte exceeds embedder limit
if err != nil {
return err
}
dataReader := bytes.NewReader(data)
var documents []schema.Document
var loader documentloaders.Loader
switch filepath.Ext(fileName) {
case "txt":
case ".txt":
loader = documentloaders.NewText(dataReader)
case "csv":
case ".csv":
if v == arcadiav1alpha1.ObjectTypeQA {
loader = pkgdocumentloaders.NewQACSV(dataReader, fileName, "q", "a")
documents, err = loader.Load(ctx)
if err != nil {
return err
}
} else {
loader = documentloaders.NewCSV(dataReader)
}
case "html", "htm":
case ".html", ".htm":
loader = documentloaders.NewHTML(dataReader)
default:
loader = documentloaders.NewText(dataReader)
Expand Down Expand Up @@ -475,11 +435,15 @@ func (r *KnowledgeBaseReconciler) handleFile(ctx context.Context, log logr.Logge
// )
//}

documents, err := loader.LoadAndSplit(ctx, split)
if err != nil {
return err
if len(documents) == 0 {
documents, err = loader.LoadAndSplit(ctx, split)
if err != nil {
return err
}
}
for i, doc := range documents {
log.Info(fmt.Sprintf("document[%d]: embedding:%s, metadata:%v", i, doc.PageContent, doc.Metadata))
}

switch store.Spec.Type() { // nolint: gocritic
case arcadiav1alpha1.VectorStoreTypeChroma:
s, err := chroma.New(
Expand Down Expand Up @@ -511,6 +475,7 @@ func (r *KnowledgeBaseReconciler) reconcileDelete(ctx context.Context, log logr.
chroma.WithChromaURL(vectorStore.Spec.Enpoint.URL),
chroma.WithDistanceFunction(vectorStore.Spec.Chroma.DistanceFunction),
chroma.WithNameSpace(kb.VectorStoreCollectionName()),
chroma.WithOpenAiAPIKey("fake"),
)
if err != nil {
log.Error(err, "reconcile delete: init vector store error, may leave garbage data")
Expand Down
4 changes: 2 additions & 2 deletions controllers/model_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func (r *ModelReconciler) CheckModel(ctx context.Context, logger logr.Logger, in
var ds datasource.Datasource
var info any

system, err := config.GetSystemDatasource(ctx, r.Client)
system, err := config.GetSystemDatasource(ctx, r.Client, nil)
if err != nil {
return r.UpdateStatus(ctx, instance, err)
}
Expand Down Expand Up @@ -213,7 +213,7 @@ func (r *ModelReconciler) RemoveModel(ctx context.Context, logger logr.Logger, i
var ds datasource.Datasource
var info any

system, err := config.GetSystemDatasource(ctx, r.Client)
system, err := config.GetSystemDatasource(ctx, r.Client, nil)
if err != nil {
return r.UpdateStatus(ctx, instance, err)
}
Expand Down
2 changes: 1 addition & 1 deletion controllers/namespace_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func (r *NamespaceReconciler) SetupWithManager(mgr ctrl.Manager) error {
}

func (r *NamespaceReconciler) ossClient(ctx context.Context) (*datasource.OSS, error) {
systemDatasource, err := config.GetSystemDatasource(ctx, r.Client)
systemDatasource, err := config.GetSystemDatasource(ctx, r.Client, nil)
if err != nil {
klog.Errorf("get system datasource error %s", err)
return nil, err
Expand Down
4 changes: 2 additions & 2 deletions controllers/versioneddataset_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func (r *VersionedDatasetReconciler) preUpdate(ctx context.Context, logger logr.
func (r *VersionedDatasetReconciler) checkStatus(ctx context.Context, logger logr.Logger, instance *v1alpha1.VersionedDataset) (bool, []v1alpha1.FileStatus, error) {
// TODO: Currently, we think there is only one default minio environment,
// so we get the minio client directly through the configuration.
systemDatasource, err := config.GetSystemDatasource(ctx, r.Client)
systemDatasource, err := config.GetSystemDatasource(ctx, r.Client, nil)
if err != nil {
logger.Error(err, "Failed to get system datasource")
return false, nil, err
Expand All @@ -232,7 +232,7 @@ func (r *VersionedDatasetReconciler) checkStatus(ctx context.Context, logger log
}

func (r *VersionedDatasetReconciler) removeBucketFiles(ctx context.Context, logger logr.Logger, instance *v1alpha1.VersionedDataset) error {
systemDatasource, err := config.GetSystemDatasource(ctx, r.Client)
systemDatasource, err := config.GetSystemDatasource(ctx, r.Client, nil)
if err != nil {
logger.Error(err, "Failed to get system datasource")
return err
Expand Down
2 changes: 1 addition & 1 deletion controllers/worker_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func (r *WorkerReconciler) Initialize(ctx context.Context, logger logr.Logger, i

func (r *WorkerReconciler) reconcile(ctx context.Context, logger logr.Logger, worker *arcadiav1alpha1.Worker) error {
// reconcile worker instance
system, err := config.GetSystemDatasource(ctx, r.Client)
system, err := config.GetSystemDatasource(ctx, r.Client, nil)
if err != nil {
return fmt.Errorf("failed to get system datasource with %w", err)
}
Expand Down
3 changes: 1 addition & 2 deletions graphql-server/go-server/pkg/embedder/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
"github.com/kubeagi/arcadia/graphql-server/go-server/graph/generated"
"github.com/kubeagi/arcadia/graphql-server/go-server/pkg/common"
graphqlutils "github.com/kubeagi/arcadia/graphql-server/go-server/pkg/utils"
"github.com/kubeagi/arcadia/pkg/embeddings"
"github.com/kubeagi/arcadia/pkg/utils"
)

Expand Down Expand Up @@ -104,7 +103,7 @@ func CreateEmbedder(ctx context.Context, c dynamic.Interface, input generated.Cr
URL: input.Endpointinput.URL,
},
},
Type: embeddings.EmbeddingType(servicetype),
Type: v1alpha1.EmbeddingType(servicetype),
},
}

Expand Down
12 changes: 11 additions & 1 deletion pkg/application/chain/retrievalqachain.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"k8s.io/klog/v2"

"github.com/kubeagi/arcadia/pkg/application/base"
appretriever "github.com/kubeagi/arcadia/pkg/application/retriever"
)

type RetrievalQAChain struct {
Expand Down Expand Up @@ -69,7 +70,16 @@ func (l *RetrievalQAChain) Run(ctx context.Context, _ dynamic.Interface, args ma
}

llmChain := chains.NewLLMChain(llm, prompt)
chain := chains.NewRetrievalQA(chains.NewStuffDocuments(llmChain), retriever)
var baseChain chains.Chain
if _, ok := v3.(*appretriever.KnowledgeBaseRetriever); ok {
baseChain = appretriever.NewStuffDocuments(llmChain)
klog.Infoln("!!!TODO")
} else {
baseChain = chains.NewStuffDocuments(llmChain)
klog.Infoln("???TODO")
klog.Infof("%#v", v3)
}
chain := chains.NewRetrievalQA(baseChain, retriever)
l.RetrievalQA = chain
args["query"] = args["question"]
var out string
Expand Down
Loading

0 comments on commit 5174be0

Please sign in to comment.