From 90bca8c7de8f1b1f7cf199d38e2fed88658212d4 Mon Sep 17 00:00:00 2001 From: Abirdcfly Date: Thu, 29 Feb 2024 09:24:22 +0800 Subject: [PATCH] feat: add rerank Signed-off-by: Abirdcfly --- Dockerfile.rerank-mock | 23 ++ .../v1alpha1/knowledgebaseretriever_types.go | 2 +- .../v1alpha1/rerankretriever_types.go | 80 ++++++ .../v1alpha1/zz_generated.deepcopy.go | 92 +++++++ apiserver/docs/docs.go | 5 + apiserver/docs/swagger.json | 5 + apiserver/docs/swagger.yaml | 4 + apiserver/pkg/chat/chat_docs.go | 10 +- ...gi.k8s.com.cn_knowledgebaseretrievers.yaml | 2 +- ...a.kubeagi.k8s.com.cn_rerankretrievers.yaml | 120 +++++++++ config/rbac/role.yaml | 26 ++ ...qachain_knowledgebase_pgvector_rerank.yaml | 98 +++++++ .../retriever/rerank_retriever_controller.go | 163 ++++++++++++ deploy/charts/arcadia/Chart.yaml | 2 +- ...gi.k8s.com.cn_knowledgebaseretrievers.yaml | 2 +- ...a.kubeagi.k8s.com.cn_rerankretrievers.yaml | 120 +++++++++ deploy/charts/arcadia/templates/rbac.yaml | 26 ++ main.go | 7 + pkg/appruntime/agent/executor.go | 6 +- pkg/appruntime/agent/streamhandler.go | 10 +- pkg/appruntime/app_runtime.go | 23 +- pkg/appruntime/base/keyword.go | 28 ++ pkg/appruntime/base/node.go | 6 + pkg/appruntime/chain/apichain.go | 8 +- pkg/appruntime/chain/common.go | 9 +- pkg/appruntime/chain/llmchain.go | 14 +- pkg/appruntime/chain/mpchain.go | 6 +- pkg/appruntime/chain/retrievalqachain.go | 22 +- pkg/appruntime/llm/llm.go | 2 +- pkg/appruntime/retriever/common.go | 159 +++++++++++ pkg/appruntime/retriever/fakeretiever.go | 33 +++ .../retriever/knowledgebaseretriever.go | 251 +++--------------- pkg/appruntime/retriever/rerankretriever.go | 160 +++++++++++ pkg/config/config.go | 13 + pkg/config/config_type.go | 3 + tests/example-test.sh | 28 +- tests/rerank-mock/deploy-svc.yaml | 36 +++ tests/rerank-mock/main.go | 71 +++++ 38 files changed, 1403 insertions(+), 272 deletions(-) create mode 100644 Dockerfile.rerank-mock create mode 100644 api/app-node/retriever/v1alpha1/rerankretriever_types.go create mode 100644 config/crd/bases/retriever.arcadia.kubeagi.k8s.com.cn_rerankretrievers.yaml create mode 100644 config/samples/app_retrievalqachain_knowledgebase_pgvector_rerank.yaml create mode 100644 controllers/app-node/retriever/rerank_retriever_controller.go create mode 100644 deploy/charts/arcadia/crds/retriever.arcadia.kubeagi.k8s.com.cn_rerankretrievers.yaml create mode 100644 pkg/appruntime/base/keyword.go create mode 100644 pkg/appruntime/retriever/common.go create mode 100644 pkg/appruntime/retriever/fakeretiever.go create mode 100644 pkg/appruntime/retriever/rerankretriever.go create mode 100644 tests/rerank-mock/deploy-svc.yaml create mode 100644 tests/rerank-mock/main.go diff --git a/Dockerfile.rerank-mock b/Dockerfile.rerank-mock new file mode 100644 index 000000000..c6aabd733 --- /dev/null +++ b/Dockerfile.rerank-mock @@ -0,0 +1,23 @@ +# Build the manager binary +ARG GO_VER=1.21 +FROM golang:${GO_VER} as builder +ARG GOPROXY=https://goproxy.cn,direct +WORKDIR /workspace +# Copy the Go Modules manifests +COPY . . +# cache deps before building and copying source so that we don't need to re-download as much +# and so that source changes don't invalidate our downloaded layer +RUN go env -w GOPROXY=${GOPROXY} +RUN go mod download + + +# Build +RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -a -o rerank-mock tests/rerank-mock/main.go +# Use distroless as minimal base image to package the manager binary +# Refer to https://github.com/GoogleContainerTools/distroless for more details +FROM gcr.io/distroless/static:nonroot +WORKDIR / +COPY --from=builder /workspace/rerank-mock . +USER 65532:65532 + +ENTRYPOINT ["/rerank-mock"] diff --git a/api/app-node/retriever/v1alpha1/knowledgebaseretriever_types.go b/api/app-node/retriever/v1alpha1/knowledgebaseretriever_types.go index 07ab06595..5453c09e8 100644 --- a/api/app-node/retriever/v1alpha1/knowledgebaseretriever_types.go +++ b/api/app-node/retriever/v1alpha1/knowledgebaseretriever_types.go @@ -38,7 +38,7 @@ type CommonRetrieverConfig struct { // NumDocuments is the max number of documents to return. // +kubebuilder:default=5 // +kubebuilder:validation:Minimum=1 - // +kubebuilder:validation:Maximum=10 + // +kubebuilder:validation:Maximum=50 NumDocuments int `json:"numDocuments,omitempty"` // DocNullReturn is the return statement when the query result is empty from the retriever. // +kubebuilder:default="未找到您询问的内容,请详细描述您的问题" diff --git a/api/app-node/retriever/v1alpha1/rerankretriever_types.go b/api/app-node/retriever/v1alpha1/rerankretriever_types.go new file mode 100644 index 000000000..a6f98faf6 --- /dev/null +++ b/api/app-node/retriever/v1alpha1/rerankretriever_types.go @@ -0,0 +1,80 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package v1alpha1 + +import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + node "github.com/kubeagi/arcadia/api/app-node" + "github.com/kubeagi/arcadia/api/base/v1alpha1" +) + +// RerankRetrieverSpec defines the desired state of RerankRetriever +type RerankRetrieverSpec struct { + v1alpha1.CommonSpec `json:",inline"` + CommonRetrieverConfig `json:",inline"` + // the endpoint of the rerank + // TODO: should change to model or worker + Endpoint string `json:"endpoint,omitempty"` +} + +// RerankRetrieverStatus defines the observed state of RerankRetriever +type RerankRetrieverStatus struct { + // ObservedGeneration is the last observed generation. + // +optional + ObservedGeneration int64 `json:"observedGeneration,omitempty"` + + // ConditionedStatus is the current status + v1alpha1.ConditionedStatus `json:",inline"` +} + +//+kubebuilder:object:root=true +//+kubebuilder:subresource:status + +// RerankRetriever is the Schema for the RerankRetriever API +type RerankRetriever struct { + metav1.TypeMeta `json:",inline"` + metav1.ObjectMeta `json:"metadata,omitempty"` + + Spec RerankRetrieverSpec `json:"spec,omitempty"` + Status RerankRetrieverStatus `json:"status,omitempty"` +} + +//+kubebuilder:object:root=true + +// RerankRetrieverList contains a list of RerankRetriever +type RerankRetrieverList struct { + metav1.TypeMeta `json:",inline"` + metav1.ListMeta `json:"metadata,omitempty"` + Items []RerankRetriever `json:"items"` +} + +func init() { + SchemeBuilder.Register(&RerankRetriever{}, &RerankRetrieverList{}) +} + +var _ node.Node = (*RerankRetriever)(nil) + +func (c *RerankRetriever) SetRef() { + annotations := node.SetRefAnnotations(c.GetAnnotations(), []node.Ref{node.RetrieverRef.Len(1)}, []node.Ref{node.RetrievalQAChainRef.Len(1)}) + if c.GetAnnotations() == nil { + c.SetAnnotations(annotations) + } + for k, v := range annotations { + c.Annotations[k] = v + } +} diff --git a/api/app-node/retriever/v1alpha1/zz_generated.deepcopy.go b/api/app-node/retriever/v1alpha1/zz_generated.deepcopy.go index caed444bf..2884912ae 100644 --- a/api/app-node/retriever/v1alpha1/zz_generated.deepcopy.go +++ b/api/app-node/retriever/v1alpha1/zz_generated.deepcopy.go @@ -131,3 +131,95 @@ func (in *KnowledgeBaseRetrieverStatus) DeepCopy() *KnowledgeBaseRetrieverStatus in.DeepCopyInto(out) return out } + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *RerankRetriever) DeepCopyInto(out *RerankRetriever) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) + out.Spec = in.Spec + in.Status.DeepCopyInto(&out.Status) +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new RerankRetriever. +func (in *RerankRetriever) DeepCopy() *RerankRetriever { + if in == nil { + return nil + } + out := new(RerankRetriever) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *RerankRetriever) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *RerankRetrieverList) DeepCopyInto(out *RerankRetrieverList) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ListMeta.DeepCopyInto(&out.ListMeta) + if in.Items != nil { + in, out := &in.Items, &out.Items + *out = make([]RerankRetriever, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new RerankRetrieverList. +func (in *RerankRetrieverList) DeepCopy() *RerankRetrieverList { + if in == nil { + return nil + } + out := new(RerankRetrieverList) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *RerankRetrieverList) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *RerankRetrieverSpec) DeepCopyInto(out *RerankRetrieverSpec) { + *out = *in + out.CommonSpec = in.CommonSpec + out.CommonRetrieverConfig = in.CommonRetrieverConfig +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new RerankRetrieverSpec. +func (in *RerankRetrieverSpec) DeepCopy() *RerankRetrieverSpec { + if in == nil { + return nil + } + out := new(RerankRetrieverSpec) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *RerankRetrieverStatus) DeepCopyInto(out *RerankRetrieverStatus) { + *out = *in + in.ConditionedStatus.DeepCopyInto(&out.ConditionedStatus) +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new RerankRetrieverStatus. +func (in *RerankRetrieverStatus) DeepCopy() *RerankRetrieverStatus { + if in == nil { + return nil + } + out := new(RerankRetrieverStatus) + in.DeepCopyInto(out) + return out +} diff --git a/apiserver/docs/docs.go b/apiserver/docs/docs.go index 820199333..d6406564a 100644 --- a/apiserver/docs/docs.go +++ b/apiserver/docs/docs.go @@ -1728,6 +1728,11 @@ const docTemplate = `{ "type": "string", "example": "q: 旷工最小计算单位为多少天?" }, + "rerank_score": { + "description": "RerankScore", + "type": "number", + "example": 0.58124 + }, "score": { "description": "vector search score", "type": "number", diff --git a/apiserver/docs/swagger.json b/apiserver/docs/swagger.json index 739a0054c..faceb12f8 100644 --- a/apiserver/docs/swagger.json +++ b/apiserver/docs/swagger.json @@ -1722,6 +1722,11 @@ "type": "string", "example": "q: 旷工最小计算单位为多少天?" }, + "rerank_score": { + "description": "RerankScore", + "type": "number", + "example": 0.58124 + }, "score": { "description": "vector search score", "type": "number", diff --git a/apiserver/docs/swagger.yaml b/apiserver/docs/swagger.yaml index 64d844c13..cc06df0d5 100644 --- a/apiserver/docs/swagger.yaml +++ b/apiserver/docs/swagger.yaml @@ -290,6 +290,10 @@ definitions: description: Question row example: 'q: 旷工最小计算单位为多少天?' type: string + rerank_score: + description: RerankScore + example: 0.58124 + type: number score: description: vector search score example: 0.34 diff --git a/apiserver/pkg/chat/chat_docs.go b/apiserver/pkg/chat/chat_docs.go index bbb7909ac..fe3558e45 100644 --- a/apiserver/pkg/chat/chat_docs.go +++ b/apiserver/pkg/chat/chat_docs.go @@ -306,10 +306,10 @@ func (cs *ChatServer) GenerateSingleDocSummary(ctx context.Context, req Conversa return "", ErrNoLLMProvidedInApplication } out := map[string]any{ - "question": req.Query, - "_answer_stream": respStream, - "llm": llm, - "documents": documents, + "question": req.Query, + runtimebase.OutputAnserStreamChanKeyInArg: respStream, + "llm": llm, + "documents": documents, } if req.ResponseMode == "streaming" { out["_need_stream"] = true @@ -323,7 +323,7 @@ func (cs *ChatServer) GenerateSingleDocSummary(ctx context.Context, req Conversa if err != nil { return "", fmt.Errorf("failed to generate summary due to %s", err.Error()) } - a, ok := out["_answer"] + a, ok := out[runtimebase.OutputAnserKeyInArg] if !ok { return "", errors.New("empty answer") } diff --git a/config/crd/bases/retriever.arcadia.kubeagi.k8s.com.cn_knowledgebaseretrievers.yaml b/config/crd/bases/retriever.arcadia.kubeagi.k8s.com.cn_knowledgebaseretrievers.yaml index 881b26f73..9610b89e2 100644 --- a/config/crd/bases/retriever.arcadia.kubeagi.k8s.com.cn_knowledgebaseretrievers.yaml +++ b/config/crd/bases/retriever.arcadia.kubeagi.k8s.com.cn_knowledgebaseretrievers.yaml @@ -53,7 +53,7 @@ spec: numDocuments: default: 5 description: NumDocuments is the max number of documents to return. - maximum: 10 + maximum: 50 minimum: 1 type: integer scoreThreshold: diff --git a/config/crd/bases/retriever.arcadia.kubeagi.k8s.com.cn_rerankretrievers.yaml b/config/crd/bases/retriever.arcadia.kubeagi.k8s.com.cn_rerankretrievers.yaml new file mode 100644 index 000000000..3a476efe3 --- /dev/null +++ b/config/crd/bases/retriever.arcadia.kubeagi.k8s.com.cn_rerankretrievers.yaml @@ -0,0 +1,120 @@ +--- +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + annotations: + controller-gen.kubebuilder.io/version: v0.9.2 + creationTimestamp: null + name: rerankretrievers.retriever.arcadia.kubeagi.k8s.com.cn +spec: + group: retriever.arcadia.kubeagi.k8s.com.cn + names: + kind: RerankRetriever + listKind: RerankRetrieverList + plural: rerankretrievers + singular: rerankretriever + scope: Namespaced + versions: + - name: v1alpha1 + schema: + openAPIV3Schema: + description: RerankRetriever is the Schema for the RerankRetriever API + properties: + apiVersion: + description: 'APIVersion defines the versioned schema of this representation + of an object. Servers should convert recognized schemas to the latest + internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources' + type: string + kind: + description: 'Kind is a string value representing the REST resource this + object represents. Servers may infer this from the endpoint the client + submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds' + type: string + metadata: + type: object + spec: + description: RerankRetrieverSpec defines the desired state of RerankRetriever + properties: + creator: + description: Creator defines datasource creator (AUTO-FILLED by webhook) + type: string + description: + description: Description defines datasource description + type: string + displayName: + description: DisplayName defines datasource display name + type: string + docNullReturn: + default: 未找到您询问的内容,请详细描述您的问题 + description: DocNullReturn is the return statement when the query + result is empty from the retriever. + type: string + endpoint: + description: 'the endpoint of the rerank TODO: should change to model + or worker' + type: string + numDocuments: + default: 5 + description: NumDocuments is the max number of documents to return. + maximum: 50 + minimum: 1 + type: integer + scoreThreshold: + default: 0.3 + description: ScoreThreshold is the cosine distance float score threshold. + Lower score represents more similarity. + maximum: 1 + minimum: 0 + type: number + type: object + status: + description: RerankRetrieverStatus defines the observed state of RerankRetriever + properties: + conditions: + description: Conditions of the resource. + items: + description: A Condition that may apply to a resource. + properties: + lastSuccessfulTime: + description: LastSuccessfulTime is repository Last Successful + Update Time + format: date-time + type: string + lastTransitionTime: + description: LastTransitionTime is the last time this condition + transitioned from one status to another. + format: date-time + type: string + message: + description: A Message containing details about this condition's + last transition from one status to another, if any. + type: string + reason: + description: A Reason for this condition's last transition from + one status to another. + type: string + status: + description: Status of this condition; is it currently True, + False, or Unknown + type: string + type: + description: Type of this condition. At most one of each condition + type may apply to a resource at any point in time. + type: string + required: + - lastTransitionTime + - reason + - status + - type + type: object + type: array + observedGeneration: + description: ObservedGeneration is the last observed generation. + format: int64 + type: integer + type: object + type: object + served: true + storage: true + subresources: + status: {} diff --git a/config/rbac/role.yaml b/config/rbac/role.yaml index 332e93275..5d4fb7163 100644 --- a/config/rbac/role.yaml +++ b/config/rbac/role.yaml @@ -653,6 +653,32 @@ rules: - get - patch - update +- apiGroups: + - retriever.arcadia.kubeagi.k8s.com.cn + resources: + - rerankretrievers + verbs: + - create + - delete + - get + - list + - patch + - update + - watch +- apiGroups: + - retriever.arcadia.kubeagi.k8s.com.cn + resources: + - rerankretrievers/finalizers + verbs: + - update +- apiGroups: + - retriever.arcadia.kubeagi.k8s.com.cn + resources: + - rerankretrievers/status + verbs: + - get + - patch + - update - apiGroups: - storage.k8s.io resources: diff --git a/config/samples/app_retrievalqachain_knowledgebase_pgvector_rerank.yaml b/config/samples/app_retrievalqachain_knowledgebase_pgvector_rerank.yaml new file mode 100644 index 000000000..bf1a6fa11 --- /dev/null +++ b/config/samples/app_retrievalqachain_knowledgebase_pgvector_rerank.yaml @@ -0,0 +1,98 @@ +apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: Application +metadata: + name: base-chat-with-knowledgebase-pgvector-rerank + namespace: arcadia +spec: + displayName: "知识库应用" + description: "最简单的和知识库对话的应用" + prologue: "Welcome to talk to the KnowledgeBase!🤖" + nodes: + - name: Input + displayName: "用户输入" + description: "用户输入节点,必须" + ref: + kind: Input + name: Input + nextNodeName: ["prompt-node"] + - name: prompt-node + displayName: "prompt" + description: "设定prompt,template中可以使用{{xx}}来替换变量" + ref: + apiGroup: prompt.arcadia.kubeagi.k8s.com.cn + kind: Prompt + name: base-chat-with-knowledgebase + nextNodeName: ["chain-node"] + - name: llm-node + displayName: "zhipu大模型服务" + description: "设定大模型的访问信息" + ref: + apiGroup: arcadia.kubeagi.k8s.com.cn + kind: LLM + name: app-shared-llm-service + nextNodeName: ["chain-node"] + - name: knowledgebase-node + displayName: "使用的知识库" + description: "要用哪个知识库" + ref: + apiGroup: arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBase + name: knowledgebase-sample-pgvector + nextNodeName: ["retriever-node"] + - name: retriever-node + displayName: "从知识库提取信息的retriever" + description: "连接应用和知识库" + ref: + apiGroup: retriever.arcadia.kubeagi.k8s.com.cn + kind: KnowledgeBaseRetriever + name: base-chat-with-knowledgebase-pgvector-1 + nextNodeName: ["rerank-retriever-node"] + - name: rerank-retriever-node + displayName: "rerank retriever" + description: "重排" + ref: + apiGroup: retriever.arcadia.kubeagi.k8s.com.cn + kind: RerankRetriever + name: base-chat-with-knowledgebase-pgvector-rerank + nextNodeName: ["chain-node"] + - name: chain-node + displayName: "RetrievalQA chain" + description: "chain是langchain的核心概念,RetrievalQAChain用于从 retriever 中提取信息,供llm调用" + ref: + apiGroup: chain.arcadia.kubeagi.k8s.com.cn + kind: RetrievalQAChain + name: base-chat-with-knowledgebase + nextNodeName: ["Output"] + - name: Output + displayName: "最终输出" + description: "最终输出节点,必须" + ref: + kind: Output + name: Output +--- +apiVersion: retriever.arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: KnowledgeBaseRetriever +metadata: + name: base-chat-with-knowledgebase-pgvector-1 + namespace: arcadia + annotations: + arcadia.kubeagi.k8s.com.cn/input-rules: '[{"kind":"KnowledgeBase","group":"arcadia.kubeagi.k8s.com.cn","length":1}]' + arcadia.kubeagi.k8s.com.cn/output-rules: '[{"kind":"RetrievalQAChain","group":"chain.arcadia.kubeagi.k8s.com.cn","length":1}]' +spec: + displayName: "从知识库获取信息的Retriever" + scoreThreshold: 0.3 + numDocuments: 50 +--- +apiVersion: retriever.arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: RerankRetriever +metadata: + name: base-chat-with-knowledgebase-pgvector-rerank + namespace: arcadia + annotations: + arcadia.kubeagi.k8s.com.cn/input-rules: '[{"kind":"KnowledgeBase","group":"arcadia.kubeagi.k8s.com.cn","length":1}]' + arcadia.kubeagi.k8s.com.cn/output-rules: '[{"kind":"RetrievalQAChain","group":"chain.arcadia.kubeagi.k8s.com.cn","length":1}]' +spec: + displayName: "rerank Retriever" + scoreThreshold: 0.1 + numDocuments: 3 + endpoint: "http://rerank-mock.default.svc:8123/rerank" diff --git a/controllers/app-node/retriever/rerank_retriever_controller.go b/controllers/app-node/retriever/rerank_retriever_controller.go new file mode 100644 index 000000000..0e6b2a39f --- /dev/null +++ b/controllers/app-node/retriever/rerank_retriever_controller.go @@ -0,0 +1,163 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package chain + +import ( + "context" + "reflect" + + "github.com/go-logr/logr" + "k8s.io/apimachinery/pkg/runtime" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + + api "github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1" + arcadiav1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1" + appnode "github.com/kubeagi/arcadia/controllers/app-node" + "github.com/kubeagi/arcadia/pkg/config" +) + +// RerankRetrieverReconciler reconciles a RerankRetriever object +type RerankRetrieverReconciler struct { + client.Client + Scheme *runtime.Scheme +} + +//+kubebuilder:rbac:groups=retriever.arcadia.kubeagi.k8s.com.cn,resources=rerankretrievers,verbs=get;list;watch;create;update;patch;delete +//+kubebuilder:rbac:groups=retriever.arcadia.kubeagi.k8s.com.cn,resources=rerankretrievers/status,verbs=get;update;patch +//+kubebuilder:rbac:groups=retriever.arcadia.kubeagi.k8s.com.cn,resources=rerankretrievers/finalizers,verbs=update + +// Reconcile is part of the main kubernetes reconciliation loop which aims to +// move the current state of the cluster closer to the desired state. +// For more details, check Reconcile and its Result here: +// - https://pkg.go.dev/sigs.k8s.io/controller-runtime@v0.12.2/pkg/reconcile +func (r *RerankRetrieverReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + log := ctrl.LoggerFrom(ctx) + log.V(5).Info("Start RerankRetriever Reconcile") + instance := &api.RerankRetriever{} + if err := r.Get(ctx, req.NamespacedName, instance); err != nil { + // There's no need to requeue if the resource no longer exists. + // Otherwise, we'll be requeued implicitly because we return an error. + log.V(1).Info("Failed to get RerankRetriever") + return ctrl.Result{}, client.IgnoreNotFound(err) + } + log = log.WithValues("Generation", instance.GetGeneration(), "ObservedGeneration", instance.Status.ObservedGeneration, "creator", instance.Spec.Creator) + log.V(5).Info("Get RerankRetriever instance") + + // Add a finalizer.Then, we can define some operations which should + // occur before the RerankRetriever to be deleted. + // More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/finalizers + if newAdded := controllerutil.AddFinalizer(instance, arcadiav1alpha1.Finalizer); newAdded { + log.Info("Try to add Finalizer for RerankRetriever") + if err := r.Update(ctx, instance); err != nil { + log.Error(err, "Failed to update RerankRetriever to add finalizer, will try again later") + return ctrl.Result{}, err + } + log.Info("Adding Finalizer for RerankRetriever done") + return ctrl.Result{}, nil + } + + // Check if the RerankRetriever instance is marked to be deleted, which is + // indicated by the deletion timestamp being set. + if instance.GetDeletionTimestamp() != nil && controllerutil.ContainsFinalizer(instance, arcadiav1alpha1.Finalizer) { + log.Info("Performing Finalizer Operations for RerankRetriever before delete CR") + // TODO perform the finalizer operations here, for example: remove vectorstore data? + log.Info("Removing Finalizer for RerankRetriever after successfully performing the operations") + controllerutil.RemoveFinalizer(instance, arcadiav1alpha1.Finalizer) + if err := r.Update(ctx, instance); err != nil { + log.Error(err, "Failed to remove the finalizer for RerankRetriever") + return ctrl.Result{}, err + } + log.Info("Remove RerankRetriever done") + return ctrl.Result{}, nil + } + + instance, result, err := r.reconcile(ctx, log, instance) + + // Update status after reconciliation. + if updateStatusErr := r.patchStatus(ctx, instance); updateStatusErr != nil { + log.Error(updateStatusErr, "unable to update status after reconciliation") + return ctrl.Result{Requeue: true}, updateStatusErr + } + + return result, err +} + +func (r *RerankRetrieverReconciler) reconcile(ctx context.Context, log logr.Logger, instance *api.RerankRetriever) (*api.RerankRetriever, ctrl.Result, error) { + // Observe generation change + if instance.Status.ObservedGeneration != instance.Generation { + instance.Status.ObservedGeneration = instance.Generation + r.setCondition(instance, instance.Status.WaitingCompleteCondition()...) + if updateStatusErr := r.patchStatus(ctx, instance); updateStatusErr != nil { + log.Error(updateStatusErr, "unable to update status after generation update") + return instance, ctrl.Result{Requeue: true}, updateStatusErr + } + } + if instance.Spec.Endpoint == "" { + endpoint, err := config.GetDefaultRerankEndpoint(ctx, r.Client) + if err == nil && endpoint != "" { + instanceNew := instance.DeepCopy() + instanceNew.Spec.Endpoint = endpoint + err := r.Patch(ctx, instanceNew, client.MergeFrom(instance)) + if err != nil { + return instance, ctrl.Result{Requeue: true}, err + } + instance = instanceNew + } + } + + if instance.Status.IsReady() { + return instance, ctrl.Result{}, nil + } + // Note: should change here + // TODO: we should do more checks later.For example: + // LLM status + // Prompt status + if err := appnode.CheckAndUpdateAnnotation(ctx, log, r.Client, instance); err != nil { + instance.Status.SetConditions(instance.Status.ErrorCondition(err.Error())...) + } else { + instance.Status.SetConditions(instance.Status.ReadyCondition()...) + } + + return instance, ctrl.Result{}, nil +} + +func (r *RerankRetrieverReconciler) patchStatus(ctx context.Context, instance *api.RerankRetriever) error { + latest := &api.RerankRetriever{} + if err := r.Client.Get(ctx, client.ObjectKeyFromObject(instance), latest); err != nil { + return err + } + if reflect.DeepEqual(instance.Status, latest.Status) { + return nil + } + patch := client.MergeFrom(latest.DeepCopy()) + latest.Status = instance.Status + return r.Client.Status().Patch(ctx, latest, patch, client.FieldOwner("RerankRetriever-controller")) +} + +// SetupWithManager sets up the controller with the Manager. +func (r *RerankRetrieverReconciler) SetupWithManager(mgr ctrl.Manager) error { + return ctrl.NewControllerManagedBy(mgr). + For(&api.RerankRetriever{}). + Complete(r) +} + +func (r *RerankRetrieverReconciler) setCondition(instance *api.RerankRetriever, condition ...arcadiav1alpha1.Condition) *api.RerankRetriever { + instance.Status.SetConditions(condition...) + return instance +} diff --git a/deploy/charts/arcadia/Chart.yaml b/deploy/charts/arcadia/Chart.yaml index b14b51838..c23acdf12 100644 --- a/deploy/charts/arcadia/Chart.yaml +++ b/deploy/charts/arcadia/Chart.yaml @@ -2,7 +2,7 @@ apiVersion: v2 name: arcadia description: A Helm chart(Also a KubeBB Component) for KubeAGI Arcadia type: application -version: 0.3.3 +version: 0.3.4 appVersion: "0.2.0" keywords: diff --git a/deploy/charts/arcadia/crds/retriever.arcadia.kubeagi.k8s.com.cn_knowledgebaseretrievers.yaml b/deploy/charts/arcadia/crds/retriever.arcadia.kubeagi.k8s.com.cn_knowledgebaseretrievers.yaml index 881b26f73..9610b89e2 100644 --- a/deploy/charts/arcadia/crds/retriever.arcadia.kubeagi.k8s.com.cn_knowledgebaseretrievers.yaml +++ b/deploy/charts/arcadia/crds/retriever.arcadia.kubeagi.k8s.com.cn_knowledgebaseretrievers.yaml @@ -53,7 +53,7 @@ spec: numDocuments: default: 5 description: NumDocuments is the max number of documents to return. - maximum: 10 + maximum: 50 minimum: 1 type: integer scoreThreshold: diff --git a/deploy/charts/arcadia/crds/retriever.arcadia.kubeagi.k8s.com.cn_rerankretrievers.yaml b/deploy/charts/arcadia/crds/retriever.arcadia.kubeagi.k8s.com.cn_rerankretrievers.yaml new file mode 100644 index 000000000..3a476efe3 --- /dev/null +++ b/deploy/charts/arcadia/crds/retriever.arcadia.kubeagi.k8s.com.cn_rerankretrievers.yaml @@ -0,0 +1,120 @@ +--- +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + annotations: + controller-gen.kubebuilder.io/version: v0.9.2 + creationTimestamp: null + name: rerankretrievers.retriever.arcadia.kubeagi.k8s.com.cn +spec: + group: retriever.arcadia.kubeagi.k8s.com.cn + names: + kind: RerankRetriever + listKind: RerankRetrieverList + plural: rerankretrievers + singular: rerankretriever + scope: Namespaced + versions: + - name: v1alpha1 + schema: + openAPIV3Schema: + description: RerankRetriever is the Schema for the RerankRetriever API + properties: + apiVersion: + description: 'APIVersion defines the versioned schema of this representation + of an object. Servers should convert recognized schemas to the latest + internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources' + type: string + kind: + description: 'Kind is a string value representing the REST resource this + object represents. Servers may infer this from the endpoint the client + submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds' + type: string + metadata: + type: object + spec: + description: RerankRetrieverSpec defines the desired state of RerankRetriever + properties: + creator: + description: Creator defines datasource creator (AUTO-FILLED by webhook) + type: string + description: + description: Description defines datasource description + type: string + displayName: + description: DisplayName defines datasource display name + type: string + docNullReturn: + default: 未找到您询问的内容,请详细描述您的问题 + description: DocNullReturn is the return statement when the query + result is empty from the retriever. + type: string + endpoint: + description: 'the endpoint of the rerank TODO: should change to model + or worker' + type: string + numDocuments: + default: 5 + description: NumDocuments is the max number of documents to return. + maximum: 50 + minimum: 1 + type: integer + scoreThreshold: + default: 0.3 + description: ScoreThreshold is the cosine distance float score threshold. + Lower score represents more similarity. + maximum: 1 + minimum: 0 + type: number + type: object + status: + description: RerankRetrieverStatus defines the observed state of RerankRetriever + properties: + conditions: + description: Conditions of the resource. + items: + description: A Condition that may apply to a resource. + properties: + lastSuccessfulTime: + description: LastSuccessfulTime is repository Last Successful + Update Time + format: date-time + type: string + lastTransitionTime: + description: LastTransitionTime is the last time this condition + transitioned from one status to another. + format: date-time + type: string + message: + description: A Message containing details about this condition's + last transition from one status to another, if any. + type: string + reason: + description: A Reason for this condition's last transition from + one status to another. + type: string + status: + description: Status of this condition; is it currently True, + False, or Unknown + type: string + type: + description: Type of this condition. At most one of each condition + type may apply to a resource at any point in time. + type: string + required: + - lastTransitionTime + - reason + - status + - type + type: object + type: array + observedGeneration: + description: ObservedGeneration is the last observed generation. + format: int64 + type: integer + type: object + type: object + served: true + storage: true + subresources: + status: {} diff --git a/deploy/charts/arcadia/templates/rbac.yaml b/deploy/charts/arcadia/templates/rbac.yaml index 500657ece..e55fd19ed 100644 --- a/deploy/charts/arcadia/templates/rbac.yaml +++ b/deploy/charts/arcadia/templates/rbac.yaml @@ -670,6 +670,32 @@ rules: - get - patch - update +- apiGroups: + - retriever.arcadia.kubeagi.k8s.com.cn + resources: + - rerankretrievers + verbs: + - create + - delete + - get + - list + - patch + - update + - watch +- apiGroups: + - retriever.arcadia.kubeagi.k8s.com.cn + resources: + - rerankretrievers/finalizers + verbs: + - update +- apiGroups: + - retriever.arcadia.kubeagi.k8s.com.cn + resources: + - rerankretrievers/status + verbs: + - get + - patch + - update - apiGroups: - storage.k8s.io resources: diff --git a/main.go b/main.go index 5ac1b0133..c916099a3 100644 --- a/main.go +++ b/main.go @@ -264,6 +264,13 @@ func main() { setupLog.Error(err, "unable to create controller", "controller", "KnowledgeBaseRetriever") os.Exit(1) } + if err = (&retrievertrollers.RerankRetrieverReconciler{ + Client: mgr.GetClient(), + Scheme: mgr.GetScheme(), + }).SetupWithManager(mgr); err != nil { + setupLog.Error(err, "unable to create controller", "controller", "RerankRetriever") + os.Exit(1) + } if err = (&promptcontrollers.PromptReconciler{ Client: mgr.GetClient(), Scheme: mgr.GetScheme(), diff --git a/pkg/appruntime/agent/executor.go b/pkg/appruntime/agent/executor.go index f086ed140..779b68ecb 100644 --- a/pkg/appruntime/agent/executor.go +++ b/pkg/appruntime/agent/executor.go @@ -48,7 +48,7 @@ func NewExecutor(baseNode base.BaseNode) *Executor { } func (p *Executor) Run(ctx context.Context, cli client.Client, args map[string]any) (map[string]any, error) { - v1, ok := args["llm"] + v1, ok := args[base.LangchaingoLLMKeyInArg] if !ok { return args, errors.New("no llm") } @@ -122,7 +122,7 @@ func (p *Executor) Run(ctx context.Context, cli client.Client, args map[string]a agents.WithMaxIterations(instance.Spec.Options.MaxIterations)(o) // Only show tool action in the streaming output if configured if instance.Spec.Options.ShowToolAction { - if needStream, ok := args["_need_stream"].(bool); ok && needStream { + if needStream, ok := args[base.InputIsNeedStreamKeyInArg].(bool); ok && needStream { streamHandler := StreamHandler{callbacks.SimpleHandler{}, args} agents.WithCallbacksHandler(streamHandler)(o) } @@ -139,6 +139,6 @@ func (p *Executor) Run(ctx context.Context, cli client.Client, args map[string]a return args, fmt.Errorf("error when call agent: %w", err) } klog.FromContext(ctx).V(5).Info("use agent, blocking out:", response["output"]) - args["_answer"] = response["output"] + args[base.OutputAnserKeyInArg] = response["output"] return args, nil } diff --git a/pkg/appruntime/agent/streamhandler.go b/pkg/appruntime/agent/streamhandler.go index 5913c13e3..c8f704740 100644 --- a/pkg/appruntime/agent/streamhandler.go +++ b/pkg/appruntime/agent/streamhandler.go @@ -22,6 +22,8 @@ import ( "github.com/tmc/langchaingo/callbacks" "k8s.io/klog/v2" + + "github.com/kubeagi/arcadia/pkg/appruntime/base" ) // StreamHandler is a callback handler that prints to the standard output streaming. @@ -34,13 +36,13 @@ var _ callbacks.Handler = StreamHandler{} func (handler StreamHandler) HandleStreamingFunc(ctx context.Context, chunk []byte) { logger := klog.FromContext(ctx) - if _, ok := handler.args["_answer_stream"]; !ok { + if _, ok := handler.args[base.OutputAnserStreamChanKeyInArg]; !ok { logger.Info("no _answer_stream found, create a new one") - handler.args["_answer_stream"] = make(chan string) + handler.args[base.OutputAnserStreamChanKeyInArg] = make(chan string) } - streamChan, ok := handler.args["_answer_stream"].(chan string) + streamChan, ok := handler.args[base.OutputAnserStreamChanKeyInArg].(chan string) if !ok { - err := fmt.Errorf("answer_stream is not chan string, but %T", handler.args["_answer_stream"]) + err := fmt.Errorf("answer_stream is not chan string, but %T", handler.args[base.OutputAnserStreamChanKeyInArg]) logger.Error(err, "answer_stream is not chan string") return } diff --git a/pkg/appruntime/app_runtime.go b/pkg/appruntime/app_runtime.go index 6bcd396f2..6bdf6e1cb 100644 --- a/pkg/appruntime/app_runtime.go +++ b/pkg/appruntime/app_runtime.go @@ -139,16 +139,14 @@ func (a *Application) Init(ctx context.Context, cli client.Client) (err error) { func (a *Application) Run(ctx context.Context, cli client.Client, respStream chan string, input Input) (output Output, err error) { out := map[string]any{ - "question": input.Question, - "files": input.Files, - "_answer_stream": respStream, - "_history": input.History, + base.InputQuestionKeyInArg: input.Question, + "files": input.Files, + base.OutputAnserStreamChanKeyInArg: respStream, + base.InputIsNeedStreamKeyInArg: input.NeedStream, + base.LangchaingoChatMessageHistoryKeyInArg: input.History, // Use an empty context before run "context": "", } - if input.NeedStream { - out["_need_stream"] = true - } visited := make(map[string]bool) waitRunningNodes := list.New() for _, v := range a.StartingNodes { @@ -170,6 +168,10 @@ func (a *Application) Run(ctx context.Context, cli client.Client, respStream cha } klog.FromContext(ctx).V(3).Info(fmt.Sprintf("try to run node:%s", e.Name())) if out, err = e.Run(ctx, cli, out); err != nil { + var er *base.AppStopEarlyError + if errors.As(err, &er) { + return Output{Answer: er.Msg}, nil + } return Output{}, fmt.Errorf("run node %s: %w", e.Name(), err) } defer e.Cleanup() @@ -179,12 +181,12 @@ func (a *Application) Run(ctx context.Context, cli client.Client, respStream cha waitRunningNodes.PushBack(n) } } - if a, ok := out["_answer"]; ok { + if a, ok := out[base.OutputAnserKeyInArg]; ok { if answer, ok := a.(string); ok && len(answer) > 0 { output = Output{Answer: answer} } } - if a, ok := out["_references"]; ok { + if a, ok := out[base.RuntimeRetrieverReferencesKeyInArg]; ok { if references, ok := a.([]retriever.Reference); ok && len(references) > 0 { output.References = references } @@ -224,6 +226,9 @@ func InitNode(ctx context.Context, appNamespace, name string, ref arcadiav1alpha case "knowledgebaseretriever": logger.V(3).Info("initnode knowledgebaseretriever") return retriever.NewKnowledgeBaseRetriever(baseNode), nil + case "rerankretriever": + logger.V(3).Info("initnode rerankretriever") + return retriever.NewRerankRetriever(baseNode), nil default: return nil, err } diff --git a/pkg/appruntime/base/keyword.go b/pkg/appruntime/base/keyword.go new file mode 100644 index 000000000..672bac86c --- /dev/null +++ b/pkg/appruntime/base/keyword.go @@ -0,0 +1,28 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package base + +const ( + InputQuestionKeyInArg = "question" + InputIsNeedStreamKeyInArg = "_need_stream" + LangchaingoChatMessageHistoryKeyInArg = "_history" + OutputAnserKeyInArg = "_answer" + OutputAnserStreamChanKeyInArg = "_answer_stream" + RuntimeRetrieverReferencesKeyInArg = "_references" + LangchaingoRetrieverKeyInArg = "retriever" + LangchaingoLLMKeyInArg = "llm" +) diff --git a/pkg/appruntime/base/node.go b/pkg/appruntime/base/node.go index 68f2dba01..694cf8ec7 100644 --- a/pkg/appruntime/base/node.go +++ b/pkg/appruntime/base/node.go @@ -121,3 +121,9 @@ func (c *BaseNode) Ready() (bool, string) { func (c *BaseNode) Cleanup() { } + +type AppStopEarlyError struct { + Msg string +} + +func (e *AppStopEarlyError) Error() string { return e.Msg } diff --git a/pkg/appruntime/chain/apichain.go b/pkg/appruntime/chain/apichain.go index f457e8b4e..cb961bce9 100644 --- a/pkg/appruntime/chain/apichain.go +++ b/pkg/appruntime/chain/apichain.go @@ -54,7 +54,7 @@ func (l *APIChain) Init(ctx context.Context, cli client.Client, _ map[string]any } func (l *APIChain) Run(ctx context.Context, _ client.Client, args map[string]any) (map[string]any, error) { - v1, ok := args["llm"] + v1, ok := args[base.LangchaingoLLMKeyInArg] if !ok { return args, errors.New("no llm") } @@ -75,7 +75,7 @@ func (l *APIChain) Run(ctx context.Context, _ client.Client, args map[string]any return args, fmt.Errorf("can't format prompt: %w", err) } args["input"] = p.String() - v3, ok := args["_history"] + v3, ok := args[base.LangchaingoChatMessageHistoryKeyInArg] if !ok { return args, errors.New("no history") } @@ -98,7 +98,7 @@ func (l *APIChain) Run(ctx context.Context, _ client.Client, args map[string]any args["api_docs"] = apiDoc var out string needStream := false - needStream, ok = args["_need_stream"].(bool) + needStream, ok = args[base.InputIsNeedStreamKeyInArg].(bool) if ok && needStream { options = append(options, chains.WithStreamingFunc(stream(args))) out, err = chains.Predict(ctx, l.APIChain, args, options...) @@ -112,7 +112,7 @@ func (l *APIChain) Run(ctx context.Context, _ client.Client, args map[string]any out, err = handleNoErrNoOut(ctx, needStream, out, err, l.APIChain, args, options) klog.FromContext(ctx).V(5).Info("use apichain, blocking out:" + out) if err == nil { - args["_answer"] = out + args[base.OutputAnserKeyInArg] = out return args, nil } return args, fmt.Errorf("apichain run error: %w", err) diff --git a/pkg/appruntime/chain/common.go b/pkg/appruntime/chain/common.go index bfd46bae2..13e42ac6a 100644 --- a/pkg/appruntime/chain/common.go +++ b/pkg/appruntime/chain/common.go @@ -30,6 +30,7 @@ import ( agent "github.com/kubeagi/arcadia/api/app-node/agent/v1alpha1" "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" + "github.com/kubeagi/arcadia/pkg/appruntime/base" "github.com/kubeagi/arcadia/pkg/appruntime/retriever" "github.com/kubeagi/arcadia/pkg/tools/bingsearch" ) @@ -37,13 +38,13 @@ import ( func stream(res map[string]any) func(ctx context.Context, chunk []byte) error { return func(ctx context.Context, chunk []byte) error { logger := klog.FromContext(ctx) - if _, ok := res["_answer_stream"]; !ok { + if _, ok := res[base.OutputAnserStreamChanKeyInArg]; !ok { logger.Info("no _answer_stream found, create a new one") - res["_answer_stream"] = make(chan string) + res[base.OutputAnserStreamChanKeyInArg] = make(chan string) } - streamChan, ok := res["_answer_stream"].(chan string) + streamChan, ok := res[base.OutputAnserStreamChanKeyInArg].(chan string) if !ok { - err := fmt.Errorf("answer_stream is not chan string, but %T", res["_answer_stream"]) + err := fmt.Errorf("answer_stream is not chan string, but %T", res[base.OutputAnserStreamChanKeyInArg]) logger.Error(err, "answer_stream is not chan string") return err } diff --git a/pkg/appruntime/chain/llmchain.go b/pkg/appruntime/chain/llmchain.go index ad82d99dd..47bebabb8 100644 --- a/pkg/appruntime/chain/llmchain.go +++ b/pkg/appruntime/chain/llmchain.go @@ -56,7 +56,7 @@ func (l *LLMChain) Init(ctx context.Context, cli client.Client, _ map[string]any } func (l *LLMChain) Run(ctx context.Context, _ client.Client, args map[string]any) (outArgs map[string]any, err error) { - v1, ok := args["llm"] + v1, ok := args[base.LangchaingoLLMKeyInArg] if !ok { return args, errors.New("no llm") } @@ -75,7 +75,7 @@ func (l *LLMChain) Run(ctx context.Context, _ client.Client, args map[string]any // _history is optional // if set ,only ChatMessageHistory allowed var history langchaingoschema.ChatMessageHistory - if v3, ok := args["_history"]; ok && v3 != nil { + if v3, ok := args[base.LangchaingoChatMessageHistoryKeyInArg]; ok && v3 != nil { history, ok = v3.(langchaingoschema.ChatMessageHistory) if !ok { return args, errors.New("history not memory.ChatMessageHistory") @@ -84,9 +84,9 @@ func (l *LLMChain) Run(ctx context.Context, _ client.Client, args map[string]any instance := l.Instance options := GetChainOptions(instance.Spec.CommonChainConfig) // Add the answer to the context if it's not empty - if args["_answer"] != nil { - klog.Infoln("get answer from upstream:", args["_answer"]) - args["context"] = fmt.Sprintf("%s\n%s", args["context"], args["_answer"]) + if args[base.OutputAnserKeyInArg] != nil { + klog.Infoln("get answer from upstream:", args[base.OutputAnserKeyInArg]) + args["context"] = fmt.Sprintf("%s\n%s", args["context"], args[base.OutputAnserKeyInArg]) } args = runTools(ctx, args, instance.Spec.Tools) chain := chains.NewLLMChain(llm, prompt) @@ -97,7 +97,7 @@ func (l *LLMChain) Run(ctx context.Context, _ client.Client, args map[string]any var out string needStream := false - needStream, ok = args["_need_stream"].(bool) + needStream, ok = args[base.InputIsNeedStreamKeyInArg].(bool) if ok && needStream { options = append(options, chains.WithStreamingFunc(stream(args))) out, err = chains.Predict(ctx, l.LLMChain, args, options...) @@ -111,7 +111,7 @@ func (l *LLMChain) Run(ctx context.Context, _ client.Client, args map[string]any out, err = handleNoErrNoOut(ctx, needStream, out, err, l.LLMChain, args, options) klog.FromContext(ctx).V(5).Info("use llmchain, blocking out:" + out) if err == nil { - args["_answer"] = out + args[base.OutputAnserKeyInArg] = out return args, nil } return args, fmt.Errorf("llmchain run error: %w", err) diff --git a/pkg/appruntime/chain/mpchain.go b/pkg/appruntime/chain/mpchain.go index 05a1644f8..5731391b4 100644 --- a/pkg/appruntime/chain/mpchain.go +++ b/pkg/appruntime/chain/mpchain.go @@ -89,7 +89,7 @@ func (l *MapReduceChain) Init(ctx context.Context, cli client.Client, args map[s return errors.New("no arguments provided for MapReduceChain") } // initialize the LLM - v1, ok := args["llm"] + v1, ok := args[base.LangchaingoLLMKeyInArg] if !ok { return errors.New("no llm") } @@ -165,7 +165,7 @@ func (l *MapReduceChain) Run(ctx context.Context, cli client.Client, args map[st // run LLMChain needStream := false - needStream, ok = args["_need_stream"].(bool) + needStream, ok = args[base.InputIsNeedStreamKeyInArg].(bool) if ok && needStream { l.chainCallOptions = append(l.chainCallOptions, chains.WithStreamingFunc(stream(args))) } @@ -175,7 +175,7 @@ func (l *MapReduceChain) Run(ctx context.Context, cli client.Client, args map[st out, err = handleNoErrNoOut(ctx, needStream, out, err, l.LLMChain, args, l.chainCallOptions) klog.FromContext(ctx).V(5).Info("use MapReduceChain, blocking out:" + out) if err == nil { - args["_answer"] = out + args[base.OutputAnserKeyInArg] = out return args, nil } return args, fmt.Errorf("mapreaducechain run error: %w", err) diff --git a/pkg/appruntime/chain/retrievalqachain.go b/pkg/appruntime/chain/retrievalqachain.go index 0653de904..47e4e26ea 100644 --- a/pkg/appruntime/chain/retrievalqachain.go +++ b/pkg/appruntime/chain/retrievalqachain.go @@ -31,7 +31,6 @@ import ( "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" "github.com/kubeagi/arcadia/pkg/appruntime/base" - appretriever "github.com/kubeagi/arcadia/pkg/appruntime/retriever" ) type RetrievalQAChain struct { @@ -57,7 +56,7 @@ func (l *RetrievalQAChain) Init(ctx context.Context, cli client.Client, _ map[st } func (l *RetrievalQAChain) Run(ctx context.Context, _ client.Client, args map[string]any) (outArgs map[string]any, err error) { - v1, ok := args["llm"] + v1, ok := args[base.LangchaingoLLMKeyInArg] if !ok { return args, errors.New("no llm") } @@ -81,7 +80,7 @@ func (l *RetrievalQAChain) Run(ctx context.Context, _ client.Client, args map[st if !ok { return args, errors.New("retriever not schema.Retriever") } - v4, ok := args["_history"] + v4, ok := args[base.LangchaingoChatMessageHistoryKeyInArg] if !ok { return args, errors.New("no history") } @@ -98,20 +97,12 @@ func (l *RetrievalQAChain) Run(ctx context.Context, _ client.Client, args map[st if history != nil { llmChain.Memory = getMemory(llm, instance.Spec.Memory, history, "", "") } - var baseChain chains.Chain - var stuffDocuments *appretriever.KnowledgeBaseStuffDocuments - if knowledgeBaseRetriever, ok := v3.(*appretriever.KnowledgeBaseRetriever); ok { - stuffDocuments = appretriever.NewStuffDocuments(llmChain, knowledgeBaseRetriever.DocNullReturn) - baseChain = stuffDocuments - } else { - baseChain = chains.NewStuffDocuments(llmChain) - } - chain := chains.NewConversationalRetrievalQA(baseChain, chains.LoadCondenseQuestionGenerator(llm), retriever, getMemory(llm, instance.Spec.Memory, history, "", "")) + chain := chains.NewConversationalRetrievalQA(chains.NewStuffDocuments(llmChain), chains.LoadCondenseQuestionGenerator(llm), retriever, getMemory(llm, instance.Spec.Memory, history, "", "")) l.ConversationalRetrievalQA = chain args["query"] = args["question"] var out string needStream := false - needStream, ok = args["_need_stream"].(bool) + needStream, ok = args[base.InputIsNeedStreamKeyInArg].(bool) if ok && needStream { options = append(options, chains.WithStreamingFunc(stream(args))) out, err = chains.Predict(ctx, l.ConversationalRetrievalQA, args, options...) @@ -122,13 +113,10 @@ func (l *RetrievalQAChain) Run(ctx context.Context, _ client.Client, args map[st out, err = chains.Predict(ctx, l.ConversationalRetrievalQA, args) } } - if stuffDocuments != nil && len(stuffDocuments.References) > 0 { - args = appretriever.AddReferencesToArgs(args, stuffDocuments.References) - } out, err = handleNoErrNoOut(ctx, needStream, out, err, l.ConversationalRetrievalQA, args, options) klog.FromContext(ctx).V(5).Info("use retrievalqachain, blocking out:" + out) if err == nil { - args["_answer"] = out + args[base.OutputAnserKeyInArg] = out return args, nil } return args, fmt.Errorf("retrievalqachain run error: %w", err) diff --git a/pkg/appruntime/llm/llm.go b/pkg/appruntime/llm/llm.go index 8db6c8ad8..da8569fe0 100644 --- a/pkg/appruntime/llm/llm.go +++ b/pkg/appruntime/llm/llm.go @@ -57,7 +57,7 @@ func (z *LLM) Init(ctx context.Context, cli client.Client, _ map[string]any) err } func (z *LLM) Run(ctx context.Context, _ client.Client, args map[string]any) (map[string]any, error) { - args["llm"] = z + args[base.LangchaingoLLMKeyInArg] = z logger := klog.FromContext(ctx) logger.Info("use llm", "name", z.Ref.Name, "namespace", z.RefNamespace()) return args, nil diff --git a/pkg/appruntime/retriever/common.go b/pkg/appruntime/retriever/common.go new file mode 100644 index 000000000..d8c2b1cb0 --- /dev/null +++ b/pkg/appruntime/retriever/common.go @@ -0,0 +1,159 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package retriever + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "strings" + + langchaingoschema "github.com/tmc/langchaingo/schema" + "k8s.io/klog/v2" + + "github.com/kubeagi/arcadia/pkg/appruntime/base" + "github.com/kubeagi/arcadia/pkg/documentloaders" +) + +type Reference struct { + // Question row + Question string `json:"question" example:"q: 旷工最小计算单位为多少天?"` + // Answer row + Answer string `json:"answer" example:"旷工最小计算单位为 0.5 天。"` + // vector search score + Score float32 `json:"score" example:"0.34"` + // the qa file fullpath + QAFilePath string `json:"qa_file_path" example:"dataset/dataset-playground/v1/qa.csv"` + // line number in the qa file + QALineNumber int `json:"qa_line_number" example:"7"` + // source file name, only file name, not full path + FileName string `json:"file_name" example:"员工考勤管理制度-2023.pdf"` + // page number in the source file + PageNumber int `json:"page_number" example:"1"` + // related content in the source file or in webpage + Content string `json:"content" example:"旷工最小计算单位为0.5天,不足0.5天以0.5天计算,超过0.5天不满1天以1天计算,以此类推。"` + // Title of the webpage + Title string `json:"title,omitempty" example:"开始使用 Microsoft 帐户 – Microsoft"` + // URL of the webpage + URL string `json:"url,omitempty" example:"https://www.microsoft.com/zh-cn/welcome"` + // RerankScore + RerankScore float32 `json:"rerank_score,omitempty" example:"0.58124"` + Metadata map[string]any `json:"-"` +} + +func (reference Reference) String() string { + bytes, err := json.Marshal(&reference) + if err != nil { + return "" + } + return string(bytes) +} + +func (reference Reference) SimpleString() string { + return fmt.Sprintf("%s %s", reference.Question, reference.Answer) +} + +func AddReferencesToArgs(args map[string]any, refs []Reference) map[string]any { + if len(refs) == 0 { + return args + } + old, exist := args[base.RuntimeRetrieverReferencesKeyInArg] + if exist { + oldRefs := old.([]Reference) + args[base.RuntimeRetrieverReferencesKeyInArg] = append(oldRefs, refs...) + return args + } + args[base.RuntimeRetrieverReferencesKeyInArg] = refs + return args +} + +func ConvertDocuments(ctx context.Context, docs []langchaingoschema.Document, retrieverName string) (newDocs []langchaingoschema.Document, refs []Reference) { + logger := klog.FromContext(ctx) + docLen := len(docs) + logger.V(3).Info(fmt.Sprintf("get data from retriever: %s, total numbers: %d\n", retrieverName, docLen)) + refs = make([]Reference, 0, docLen) + for k, doc := range docs { + logger.V(3).Info(fmt.Sprintf("related doc[%d] raw text: %s, raw score: %f\n", k, doc.PageContent, doc.Score)) + for key, v := range doc.Metadata { + if str, ok := v.([]byte); ok { + logger.V(3).Info(fmt.Sprintf("related doc[%d] metadata[%s]: %s\n", k, key, string(str))) + } else { + logger.V(3).Info(fmt.Sprintf("related doc[%d] metadata[%s]: %#v\n", k, key, v)) + } + } + // chroma will get []byte, pgvector will get string... + answer, ok := doc.Metadata[documentloaders.AnswerCol].(string) + if !ok { + if a, ok := doc.Metadata[documentloaders.AnswerCol].([]byte); ok { + answer = strings.TrimPrefix(strings.TrimSuffix(string(a), "\""), "\"") + } + } + pageContent := doc.PageContent + if retrieverName == "knowledgebase" { + if len(answer) != 0 { + doc.PageContent = doc.PageContent + "\na: " + answer + } + } + + qafilepath, ok := doc.Metadata[documentloaders.QAFileName].(string) + if !ok { + if a, ok := doc.Metadata[documentloaders.QAFileName].([]byte); ok { + qafilepath = strings.TrimPrefix(strings.TrimSuffix(string(a), "\""), "\"") + } + } + lineNumber, ok := doc.Metadata[documentloaders.LineNumber].(string) + if !ok { + if a, ok := doc.Metadata[documentloaders.LineNumber].([]byte); ok { + lineNumber = strings.TrimPrefix(strings.TrimSuffix(string(a), "\""), "\"") + } + } + line, _ := strconv.Atoi(lineNumber) + filename, ok := doc.Metadata[documentloaders.FileNameCol].(string) + if !ok { + if a, ok := doc.Metadata[documentloaders.FileNameCol].([]byte); ok { + filename = strings.TrimPrefix(strings.TrimSuffix(string(a), "\""), "\"") + } + } + pageNumber, ok := doc.Metadata[documentloaders.PageNumberCol].(string) + if !ok { + if a, ok := doc.Metadata[documentloaders.PageNumberCol].([]byte); ok { + pageNumber = strings.TrimPrefix(strings.TrimSuffix(string(a), "\""), "\"") + } + } + page, _ := strconv.Atoi(pageNumber) + content, ok := doc.Metadata[documentloaders.ChunkContentCol].(string) + if !ok { + if a, ok := doc.Metadata[documentloaders.ChunkContentCol].([]byte); ok { + content = strings.TrimPrefix(strings.TrimSuffix(string(a), "\""), "\"") + } + } + refs = append(refs, Reference{ + Question: pageContent, + Answer: answer, + Score: doc.Score, + QAFilePath: qafilepath, + QALineNumber: line, + FileName: filename, + PageNumber: page, + Content: content, + Metadata: doc.Metadata, + }) + docs[k] = doc + } + return docs, refs +} diff --git a/pkg/appruntime/retriever/fakeretiever.go b/pkg/appruntime/retriever/fakeretiever.go new file mode 100644 index 000000000..5cee71c67 --- /dev/null +++ b/pkg/appruntime/retriever/fakeretiever.go @@ -0,0 +1,33 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package retriever + +import ( + "context" + + langchaingoschema "github.com/tmc/langchaingo/schema" +) + +var _ langchaingoschema.Retriever = &Fakeretriever{} + +type Fakeretriever struct { + Docs []langchaingoschema.Document +} + +func (f *Fakeretriever) GetRelevantDocuments(context.Context, string) ([]langchaingoschema.Document, error) { + return f.Docs, nil +} diff --git a/pkg/appruntime/retriever/knowledgebaseretriever.go b/pkg/appruntime/retriever/knowledgebaseretriever.go index f3334c26e..a9f2ed809 100644 --- a/pkg/appruntime/retriever/knowledgebaseretriever.go +++ b/pkg/appruntime/retriever/knowledgebaseretriever.go @@ -18,14 +18,9 @@ package retriever import ( "context" - "encoding/json" + "errors" "fmt" - "strconv" - "strings" - "github.com/tmc/langchaingo/callbacks" - "github.com/tmc/langchaingo/chains" - langchaingoschema "github.com/tmc/langchaingo/schema" "github.com/tmc/langchaingo/vectorstores" "k8s.io/apimachinery/pkg/types" "k8s.io/klog/v2" @@ -35,73 +30,19 @@ import ( "github.com/kubeagi/arcadia/api/base/v1alpha1" "github.com/kubeagi/arcadia/pkg/appruntime/base" "github.com/kubeagi/arcadia/pkg/appruntime/log" - "github.com/kubeagi/arcadia/pkg/documentloaders" "github.com/kubeagi/arcadia/pkg/langchainwrap" pkgvectorstore "github.com/kubeagi/arcadia/pkg/vectorstore" ) -type Reference struct { - // Question row - Question string `json:"question" example:"q: 旷工最小计算单位为多少天?"` - // Answer row - Answer string `json:"answer" example:"旷工最小计算单位为 0.5 天。"` - // vector search score - Score float32 `json:"score" example:"0.34"` - // the qa file fullpath - QAFilePath string `json:"qa_file_path" example:"dataset/dataset-playground/v1/qa.csv"` - // line number in the qa file - QALineNumber int `json:"qa_line_number" example:"7"` - // source file name, only file name, not full path - FileName string `json:"file_name" example:"员工考勤管理制度-2023.pdf"` - // page number in the source file - PageNumber int `json:"page_number" example:"1"` - // related content in the source file or in webpage - Content string `json:"content" example:"旷工最小计算单位为0.5天,不足0.5天以0.5天计算,超过0.5天不满1天以1天计算,以此类推。"` - // Title of the webpage - Title string `json:"title,omitempty" example:"开始使用 Microsoft 帐户 – Microsoft"` - // URL of the webpage - URL string `json:"url,omitempty" example:"https://www.microsoft.com/zh-cn/welcome"` -} - -func (reference Reference) String() string { - bytes, err := json.Marshal(&reference) - if err != nil { - return "" - } - return string(bytes) -} - -func (reference Reference) SimpleString() string { - return fmt.Sprintf("%s %s", reference.Question, reference.Answer) -} - -func AddReferencesToArgs(args map[string]any, refs []Reference) map[string]any { - if len(refs) == 0 { - return args - } - old, exist := args["_references"] - if exist { - oldRefs := old.([]Reference) - args["_references"] = append(oldRefs, refs...) - return args - } - args["_references"] = refs - return args -} - type KnowledgeBaseRetriever struct { - langchaingoschema.Retriever base.BaseNode - DocNullReturn string - Instance *apiretriever.KnowledgeBaseRetriever - Finish func() + Instance *apiretriever.KnowledgeBaseRetriever + Finish func() } func NewKnowledgeBaseRetriever(baseNode base.BaseNode) *KnowledgeBaseRetriever { return &KnowledgeBaseRetriever{ - Retriever: nil, - BaseNode: baseNode, - DocNullReturn: "", + BaseNode: baseNode, } } @@ -116,7 +57,6 @@ func (l *KnowledgeBaseRetriever) Init(ctx context.Context, cli client.Client, _ func (l *KnowledgeBaseRetriever) Run(ctx context.Context, cli client.Client, args map[string]any) (map[string]any, error) { instance := l.Instance - l.DocNullReturn = instance.Spec.DocNullReturn var knowledgebaseName, knowledgebaseNamespace string for _, n := range l.BaseNode.GetPrevNode() { @@ -159,160 +99,57 @@ func (l *KnowledgeBaseRetriever) Run(ctx context.Context, cli client.Client, arg return nil, err } logger := klog.FromContext(ctx) - logger.V(3).Info(fmt.Sprintf("retriever created with scorethreshold: %f", instance.Spec.ScoreThreshold)) + logger.V(3).Info(fmt.Sprintf("retriever created[scorethreshold: %f][num: %d]", instance.Spec.ScoreThreshold, instance.Spec.NumDocuments)) retriever := vectorstores.ToRetriever(s, instance.Spec.NumDocuments, vectorstores.WithScoreThreshold(instance.Spec.ScoreThreshold)) retriever.CallbacksHandler = log.KLogHandler{LogLevel: 3} - l.Retriever = retriever - args["retriever"] = l - return args, nil -} - -func (l *KnowledgeBaseRetriever) Ready() (isReady bool, msg string) { - return l.Instance.Status.IsReadyOrGetReadyMessage() -} -func (l *KnowledgeBaseRetriever) Cleanup() { - if l.Finish != nil { - l.Finish() + question, ok := args["question"] + if !ok { + return nil, errors.New("no question in args") } -} - -// KnowledgeBaseStuffDocuments is similar to chains.StuffDocuments but with new joinDocuments method -type KnowledgeBaseStuffDocuments struct { - chains.StuffDocuments - isDocNullReturn bool - DocNullReturn string - callbacks.SimpleHandler - References []Reference -} - -func (c *KnowledgeBaseStuffDocuments) GetCallbackHandler() callbacks.Handler { - return c -} - -var ( - _ chains.Chain = &KnowledgeBaseStuffDocuments{} - _ callbacks.Handler = &KnowledgeBaseStuffDocuments{} - _ callbacks.HandlerHaver = &KnowledgeBaseStuffDocuments{} -) - -func (c *KnowledgeBaseStuffDocuments) joinDocuments(ctx context.Context, docs []langchaingoschema.Document) string { - logger := klog.FromContext(ctx) - var text string - docLen := len(docs) - for k, doc := range docs { - logger.V(3).Info(fmt.Sprintf("KnowledgeBaseRetriever: related doc[%d] raw text: %s, raw score: %f\n", k, doc.PageContent, doc.Score)) - for key, v := range doc.Metadata { - if str, ok := v.([]byte); ok { - logger.V(3).Info(fmt.Sprintf("KnowledgeBaseRetriever: related doc[%d] metadata[%s]: %s\n", k, key, string(str))) - } else { - logger.V(3).Info(fmt.Sprintf("KnowledgeBaseRetriever: related doc[%d] metadata[%s]: %#v\n", k, key, v)) - } - } - // chroma will get []byte, pgvector will get string... - answer, ok := doc.Metadata[documentloaders.AnswerCol].(string) - if !ok { - if a, ok := doc.Metadata[documentloaders.AnswerCol].([]byte); ok { - answer = strings.TrimPrefix(strings.TrimSuffix(string(a), "\""), "\"") - } - } - - text += doc.PageContent - if len(answer) != 0 { - text = text + "\na: " + answer - } - if k != docLen-1 { - text += c.Separator - } - qafilepath, ok := doc.Metadata[documentloaders.QAFileName].(string) - if !ok { - if a, ok := doc.Metadata[documentloaders.QAFileName].([]byte); ok { - qafilepath = strings.TrimPrefix(strings.TrimSuffix(string(a), "\""), "\"") - } - } - lineNumber, ok := doc.Metadata[documentloaders.LineNumber].(string) - if !ok { - if a, ok := doc.Metadata[documentloaders.LineNumber].([]byte); ok { - lineNumber = strings.TrimPrefix(strings.TrimSuffix(string(a), "\""), "\"") - } - } - line, _ := strconv.Atoi(lineNumber) - filename, ok := doc.Metadata[documentloaders.FileNameCol].(string) - if !ok { - if a, ok := doc.Metadata[documentloaders.FileNameCol].([]byte); ok { - filename = strings.TrimPrefix(strings.TrimSuffix(string(a), "\""), "\"") - } - } - pageNumber, ok := doc.Metadata[documentloaders.PageNumberCol].(string) - if !ok { - if a, ok := doc.Metadata[documentloaders.PageNumberCol].([]byte); ok { - pageNumber = strings.TrimPrefix(strings.TrimSuffix(string(a), "\""), "\"") + query, ok := question.(string) + if !ok { + return nil, errors.New("question not string") + } + docs, err := retriever.GetRelevantDocuments(ctx, query) + if err != nil { + return nil, fmt.Errorf("can't get relevant documents: %w", err) + } + if len(docs) == 0 && instance.Spec.DocNullReturn != "" { + val, ok := args[base.InputIsNeedStreamKeyInArg] + if ok { + needStream, _ := val.(bool) + if needStream { + streamChan, ok := args[base.OutputAnserStreamChanKeyInArg].(chan string) + if ok && streamChan != nil { + go func() { + streamChan <- instance.Spec.DocNullReturn + }() + } } } - page, _ := strconv.Atoi(pageNumber) - content, ok := doc.Metadata[documentloaders.ChunkContentCol].(string) - if !ok { - if a, ok := doc.Metadata[documentloaders.ChunkContentCol].([]byte); ok { - content = strings.TrimPrefix(strings.TrimSuffix(string(a), "\""), "\"") - } + return nil, &base.AppStopEarlyError{Msg: instance.Spec.DocNullReturn} + } + // pgvector get score means vector distance, similarity = 1 - vector distance + // chroma get score means similarity + // we want similarity finally. + if vectorStore.Spec.Type() == v1alpha1.VectorStoreTypePGVector { + for i := range docs { + docs[i].Score = 1 - docs[i].Score } - c.References = append(c.References, Reference{ - Question: doc.PageContent, - Answer: answer, - Score: 1 - doc.Score, // for pgvector - QAFilePath: qafilepath, - QALineNumber: line, - FileName: filename, - PageNumber: page, - Content: content, - }) - } - logger.V(3).Info(fmt.Sprintf("KnowledgeBaseRetriever: finally get related text: %s\n", text)) - if len(text) == 0 { - c.isDocNullReturn = true - } - return text -} - -func NewStuffDocuments(llmChain *chains.LLMChain, docNullReturn string) *KnowledgeBaseStuffDocuments { - return &KnowledgeBaseStuffDocuments{ - StuffDocuments: chains.NewStuffDocuments(llmChain), - DocNullReturn: docNullReturn, - References: make([]Reference, 0, 5), } + docs, refs := ConvertDocuments(ctx, docs, "knowledgebase") + args[base.LangchaingoRetrieverKeyInArg] = &Fakeretriever{Docs: docs} + args[base.RuntimeRetrieverReferencesKeyInArg] = refs + return args, nil } -func (c *KnowledgeBaseStuffDocuments) Call(ctx context.Context, values map[string]any, options ...chains.ChainCallOption) (map[string]any, error) { - docs, ok := values[c.InputKey].([]langchaingoschema.Document) - if !ok { - return nil, fmt.Errorf("%w: %w", chains.ErrInvalidInputValues, chains.ErrInputValuesWrongType) - } - - inputValues := make(map[string]any) - for key, value := range values { - inputValues[key] = value - } - - inputValues[c.DocumentVariableName] = c.joinDocuments(ctx, docs) - return chains.Call(ctx, c.LLMChain, inputValues, options...) -} - -func (c KnowledgeBaseStuffDocuments) GetMemory() langchaingoschema.Memory { - return c.StuffDocuments.GetMemory() -} - -func (c KnowledgeBaseStuffDocuments) GetInputKeys() []string { - return c.StuffDocuments.GetInputKeys() -} - -func (c KnowledgeBaseStuffDocuments) GetOutputKeys() []string { - return c.StuffDocuments.GetOutputKeys() +func (l *KnowledgeBaseRetriever) Ready() (isReady bool, msg string) { + return l.Instance.Status.IsReadyOrGetReadyMessage() } -func (c KnowledgeBaseStuffDocuments) HandleChainEnd(ctx context.Context, outputValues map[string]any) { - if !c.isDocNullReturn { - return +func (l *KnowledgeBaseRetriever) Cleanup() { + if l.Finish != nil { + l.Finish() } - klog.FromContext(ctx).Info(fmt.Sprintf("raw llmChain output: %s, but there is no doc return, so set output to %s\n", outputValues[c.LLMChain.OutputKey], c.DocNullReturn)) - outputValues[c.LLMChain.OutputKey] = c.DocNullReturn } diff --git a/pkg/appruntime/retriever/rerankretriever.go b/pkg/appruntime/retriever/rerankretriever.go new file mode 100644 index 000000000..16ca2104b --- /dev/null +++ b/pkg/appruntime/retriever/rerankretriever.go @@ -0,0 +1,160 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package retriever + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "reflect" + "sort" + + langchainschema "github.com/tmc/langchaingo/schema" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + + apiretriever "github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1" + "github.com/kubeagi/arcadia/pkg/appruntime/base" +) + +type RerankRetriever struct { + base.BaseNode + Instance *apiretriever.RerankRetriever +} + +func NewRerankRetriever(baseNode base.BaseNode) *RerankRetriever { + return &RerankRetriever{ + BaseNode: baseNode, + } +} + +func (l *RerankRetriever) Init(ctx context.Context, cli client.Client, _ map[string]any) error { + instance := &apiretriever.RerankRetriever{} + if err := cli.Get(ctx, types.NamespacedName{Namespace: l.RefNamespace(), Name: l.BaseNode.Ref.Name}, instance); err != nil { + return fmt.Errorf("can't find the rerank retriever in cluster: %w", err) + } + l.Instance = instance + return nil +} + +func (l *RerankRetriever) Run(ctx context.Context, cli client.Client, args map[string]any) (map[string]any, error) { + refs, ok := args[base.RuntimeRetrieverReferencesKeyInArg] + if !ok { + return args, errors.New("no refs in args") + } + references, ok := refs.([]Reference) + if !ok || len(references) == 0 { + return args, errors.New("empty references") + } + q, ok := args[base.InputQuestionKeyInArg] + if !ok { + return args, errors.New("no question in args") + } + query, ok := q.(string) + if !ok || len(query) == 0 { + return args, errors.New("empty question") + } + body := RerankRequestBody{ + Query: query, + Passages: make([]string, len(references)), + } + for i := range references { + // first, use the question (and answer, if it has) as the passage + if references[i].Question != "" { + body.Passages[i] = references[i].Question + if references[i].Answer != "" { + body.Passages[i] += "\n" + references[i].Answer + } + } else { + // second, use the raw content as the passage + body.Passages[i] = references[i].Content + } + } + reqBytes, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("request json marshal failed: %w", err) + } + request, err := http.NewRequestWithContext(ctx, http.MethodPost, l.Instance.Spec.Endpoint, bytes.NewBuffer(reqBytes)) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + + response, err := http.DefaultClient.Do(request) + if err != nil { + return nil, fmt.Errorf("get resp err: %w", err) + } + defer response.Body.Close() + + code := response.StatusCode + resp := make([]float32, 0) + if err := json.NewDecoder(response.Body).Decode(&resp); err != nil { + return nil, fmt.Errorf("parse json resp get err:%w, http status code:%d", err, code) + } + + for i := range references { + references[i].RerankScore = resp[i] + } + sort.Slice(references, func(i, j int) bool { + return references[i].RerankScore > references[j].RerankScore + }) + newRef := make([]Reference, 0, len(references)) + for i := range references { + if l.Instance.Spec.ScoreThreshold > 0 && references[i].RerankScore < l.Instance.Spec.ScoreThreshold { + break + } + if l.Instance.Spec.NumDocuments > 0 && len(newRef) >= l.Instance.Spec.NumDocuments { + break + } + newRef = append(newRef, references[i]) + } + args[base.RuntimeRetrieverReferencesKeyInArg] = newRef + + v, ok := args[base.LangchaingoRetrieverKeyInArg] + if !ok { + return args, errors.New("no retriever") + } + retriever, ok := v.(langchainschema.Retriever) + if !ok { + return args, errors.New("retriever not schema.Retriever") + } + docs, err := retriever.GetRelevantDocuments(ctx, query) + if err != nil { + return args, fmt.Errorf("get relevant documents failed: %w", err) + } + newDocs := make([]langchainschema.Document, 0, len(docs)) + for i := range newRef { + for j := range docs { + if newRef[i].Score == docs[j].Score && reflect.DeepEqual(newRef[i].Metadata, docs[i].Metadata) { + newDocs = append(newDocs, docs[j]) + } + } + } + args[base.LangchaingoRetrieverKeyInArg] = &Fakeretriever{Docs: newDocs} + return args, nil +} + +func (l *RerankRetriever) Ready() (isReady bool, msg string) { + return l.Instance.Status.IsReadyOrGetReadyMessage() +} + +type RerankRequestBody struct { + Query string `json:"question"` + Passages []string `json:"answers"` +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 10a245df7..773628146 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -43,6 +43,7 @@ var ( ErrNoConfigVectorstore = fmt.Errorf("config Vectorstore in comfigmap is not found") ErrNoConfigStreamlit = fmt.Errorf("config Streamlit in comfigmap is not found") ErrNoConfigRayClusters = fmt.Errorf("config RayClusters in comfigmap is not found") + ErrNoConfigRerank = fmt.Errorf("config rerankDefaultEndpoint in comfigmap is not found") ) func getDatasource(ctx context.Context, ref arcadiav1alpha1.TypedObjectReference, c client.Client) (ds *arcadiav1alpha1.Datasource, err error) { @@ -149,3 +150,15 @@ func GetRayClusters(ctx context.Context, c client.Client) ([]RayCluster, error) } return config.RayClusters, nil } + +func GetDefaultRerankEndpoint(ctx context.Context, c client.Client) (string, error) { + config, err := GetConfig(ctx, c) + if err != nil { + return "", err + } + if endpoint := config.RerankDefaultEndpoint; endpoint == "" { + return "", ErrNoConfigRerank + } else { + return endpoint, nil + } +} diff --git a/pkg/config/config_type.go b/pkg/config/config_type.go index f0b823ee4..f0141aa41 100644 --- a/pkg/config/config_type.go +++ b/pkg/config/config_type.go @@ -42,6 +42,9 @@ type Config struct { // Resource pool managed by Ray cluster RayClusters []RayCluster `json:"rayClusters,omitempty"` + + // the default endpoint of rerank + RerankDefaultEndpoint string `json:"rerankDefaultEndpoint,omitempty"` } // Gateway defines the way to access llm apis host by Arcadia diff --git a/tests/example-test.sh b/tests/example-test.sh index f2724e3cc..6b2ffe310 100755 --- a/tests/example-test.sh +++ b/tests/example-test.sh @@ -46,7 +46,7 @@ function debugInfo { fi if [[ $GITHUB_ACTIONS == "true" ]]; then warning "debugInfo start 🧐" - mkdir -p $LOG_DIR + mkdir -p $LOG_DIR || true warning "1. Try to get all resources " kubectl api-resources --verbs=list -o name | xargs -n 1 kubectl get -A --ignore-not-found=true --show-kind=true >$LOG_DIR/get-all-resources-list.log @@ -187,7 +187,7 @@ function getRespInAppChat() { START_TIME=$(date +%s) while true; do data=$(jq -n --arg appname "$appname" --arg query "$query" --arg namespace "$namespace" --arg conversationID "$conversationID" '{"query":$query,"response_mode":"blocking","conversation_id":$conversationID,"app_name":$appname, "app_namespace":$namespace}') - resp=$(curl -s -XPOST http://127.0.0.1:8081/chat --data "$data") + resp=$(curl --max-time $TimeoutSeconds -s -XPOST http://127.0.0.1:8081/chat --data "$data") ai_data=$(echo $resp | jq -r '.message') references=$(echo $resp | jq -r '.references') if [ -z "$ai_data" ] || [ "$ai_data" = "null" ]; then @@ -382,7 +382,7 @@ waitCRDStatusReady "Application" "arcadia" "base-chat-with-knowledgebase" sleep 3 getRespInAppChat "base-chat-with-knowledgebase" "arcadia" "公司的考勤管理制度适用于哪些人员?" "" "true" info "8.2.1.2 When no related doc is found, return retriever.spec.docNullReturn info" -getRespInAppChat "base-chat-with-knowledgebase" "arcadia" "飞天的主演是谁?" "" "false" +getRespInAppChat "base-chat-with-knowledgebase" "arcadia" "飞天的主演是谁?" "" "true" expected=$(kubectl get knowledgebaseretrievers -n arcadia base-chat-with-knowledgebase -o json | jq -r .spec.docNullReturn) if [[ $ai_data != $expected ]]; then echo "when no related doc is found, return retriever.spec.docNullReturn info should be:"$expected ", but resp:"$resp @@ -395,7 +395,27 @@ waitCRDStatusReady "Application" "arcadia" "base-chat-with-knowledgebase-pgvecto sleep 3 getRespInAppChat "base-chat-with-knowledgebase-pgvector" "arcadia" "公司的考勤管理制度适用于哪些人员?" "" "true" info "8.2.2.2 When no related doc is found, return retriever.spec.docNullReturn info" -getRespInAppChat "base-chat-with-knowledgebase-pgvector" "arcadia" "飞天的主演是谁?" "" "false" +getRespInAppChat "base-chat-with-knowledgebase-pgvector" "arcadia" "飞天的主演是谁?" "" "true" +expected=$(kubectl get knowledgebaseretrievers -n arcadia base-chat-with-knowledgebase -o json | jq -r .spec.docNullReturn) +if [[ $ai_data != $expected ]]; then + echo "when no related doc is found, return retriever.spec.docNullReturn info should be:"$expected ", but resp:"$resp + exit 1 +fi + +info "8.2.3 QA app using knowledgebase base on pgvector and rerank" +if [[ $GITHUB_ACTIONS == "true" ]]; then + docker build --build-arg="GOPROXY=https://proxy.golang.org/,direct" -t controller:rerank-mock -f Dockerfile.rerank-mock . +else + docker build -t controller:rerank-mock -f Dockerfile.rerank-mock . +fi +kind load docker-image controller:rerank-mock --name=$KindName +kubectl apply -f tests/rerank-mock/deploy-svc.yaml +kubectl apply -f config/samples/app_retrievalqachain_knowledgebase_pgvector_rerank.yaml +waitCRDStatusReady "Application" "arcadia" "base-chat-with-knowledgebase-pgvector-rerank" +sleep 3 +getRespInAppChat "base-chat-with-knowledgebase-pgvector-rerank" "arcadia" "公司的考勤管理制度适用于哪些人员?" "" "true" +info "8.2.3.2 When no related doc is found, return retriever.spec.docNullReturn info" +getRespInAppChat "base-chat-with-knowledgebase-pgvector-rerank" "arcadia" "飞天的主演是谁?" "" "true" expected=$(kubectl get knowledgebaseretrievers -n arcadia base-chat-with-knowledgebase -o json | jq -r .spec.docNullReturn) if [[ $ai_data != $expected ]]; then echo "when no related doc is found, return retriever.spec.docNullReturn info should be:"$expected ", but resp:"$resp diff --git a/tests/rerank-mock/deploy-svc.yaml b/tests/rerank-mock/deploy-svc.yaml new file mode 100644 index 000000000..c9f72d06a --- /dev/null +++ b/tests/rerank-mock/deploy-svc.yaml @@ -0,0 +1,36 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: rerank-mock +spec: + replicas: 1 + selector: + matchLabels: + app: rerank-mock + template: + metadata: + labels: + app: rerank-mock + spec: + containers: + - image: controller:rerank-mock + imagePullPolicy: IfNotPresent + name: rerank-mock + ports: + - containerPort: 8123 + protocol: TCP + resources: {} +--- +apiVersion: v1 +kind: Service +metadata: + name: rerank-mock +spec: + ports: + - name: api + port: 8123 + protocol: TCP + targetPort: 8123 + selector: + app: rerank-mock + type: ClusterIP diff --git a/tests/rerank-mock/main.go b/tests/rerank-mock/main.go new file mode 100644 index 000000000..b4463e7f6 --- /dev/null +++ b/tests/rerank-mock/main.go @@ -0,0 +1,71 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/kubeagi/arcadia/pkg/appruntime/retriever" +) + +func rerank(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/rerank" { + http.Error(w, "404 not found.", http.StatusNotFound) + return + } + + if r.Method != "POST" { + http.Error(w, "only POST methods are supported.", http.StatusBadRequest) + return + } + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + data := &retriever.RerankRequestBody{} + if err := json.Unmarshal(body, data); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + resp := make([]float32, len(data.Passages)) + for i := range resp { + resp[i] = 1 - float32(i)*0.01 + } + respData, err := json.Marshal(resp) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if _, err := w.Write(respData); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func main() { + http.HandleFunc("/rerank", rerank) + + fmt.Println("Starting server for testing rerank...") + if err := http.ListenAndServe(":8123", nil); err != nil { + panic(err) + } +}