diff --git a/apiserver/config/config.go b/apiserver/config/config.go
index 1c446347f..3d1f2260a 100644
--- a/apiserver/config/config.go
+++ b/apiserver/config/config.go
@@ -19,8 +19,14 @@ import (
"flag"
"os"
+ v1 "k8s.io/api/core/v1"
+ "k8s.io/apimachinery/pkg/runtime"
+ utilruntime "k8s.io/apimachinery/pkg/util/runtime"
+ clientgoscheme "k8s.io/client-go/kubernetes/scheme"
"k8s.io/klog/v2"
+ "github.com/kubeagi/arcadia/api/base/v1alpha1"
+ evaluationarcadiav1alpha1 "github.com/kubeagi/arcadia/api/evaluation/v1alpha1"
"github.com/kubeagi/arcadia/apiserver/pkg/dataprocessing"
)
@@ -39,6 +45,8 @@ type ServerConfig struct {
IssuerURL, MasterURL, ClientID, ClientSecret string
DataProcessURL string
+
+ Scheme *runtime.Scheme
}
func NewServerFlags() ServerConfig {
@@ -58,6 +66,12 @@ func NewServerFlags() ServerConfig {
klog.InitFlags(nil)
flag.Parse()
+ s.Scheme = runtime.NewScheme()
+ utilruntime.Must(clientgoscheme.AddToScheme(s.Scheme))
+ utilruntime.Must(v1.AddToScheme(s.Scheme))
+ utilruntime.Must(evaluationarcadiav1alpha1.AddToScheme(s.Scheme))
+ utilruntime.Must(v1alpha1.AddToScheme(s.Scheme))
+
dataprocessing.Init(s.DataProcessURL)
return *s
}
diff --git a/apiserver/docs/docs.go b/apiserver/docs/docs.go
index 6aa44aa32..3c28a9fb4 100644
--- a/apiserver/docs/docs.go
+++ b/apiserver/docs/docs.go
@@ -1128,6 +1128,158 @@ const docTemplate = `{
}
}
}
+ },
+ "/rags/detail": {
+ "get": {
+ "description": "Get detail data of a rag",
+ "consumes": [
+ "application/json"
+ ],
+ "produces": [
+ "application/json"
+ ],
+ "tags": [
+ "RAG"
+ ],
+ "summary": "Get detail data of a rag",
+ "parameters": [
+ {
+ "type": "string",
+ "description": "rag name",
+ "name": "ragName",
+ "in": "query",
+ "required": true
+ },
+ {
+ "type": "string",
+ "description": "Name of the bucket",
+ "name": "namespace",
+ "in": "header",
+ "required": true
+ },
+ {
+ "type": "string",
+ "description": "application name",
+ "name": "appName",
+ "in": "query",
+ "required": true
+ },
+ {
+ "type": "integer",
+ "description": "default is 1",
+ "name": "page",
+ "in": "query"
+ },
+ {
+ "type": "string",
+ "description": "default is 10",
+ "name": "size",
+ "in": "query"
+ },
+ {
+ "type": "string",
+ "description": "rag metrcis",
+ "name": "sortBy",
+ "in": "query"
+ },
+ {
+ "type": "string",
+ "description": "desc or asc",
+ "name": "order",
+ "in": "query"
+ }
+ ],
+ "responses": {
+ "200": {
+ "description": "OK",
+ "schema": {
+ "$ref": "#/definitions/rag.ReportDetail"
+ }
+ },
+ "400": {
+ "description": "Bad Request",
+ "schema": {
+ "type": "object",
+ "additionalProperties": {
+ "type": "string"
+ }
+ }
+ },
+ "500": {
+ "description": "Internal Server Error",
+ "schema": {
+ "type": "object",
+ "additionalProperties": {
+ "type": "string"
+ }
+ }
+ }
+ }
+ }
+ },
+ "/rags/report": {
+ "get": {
+ "description": "Get a summary of rag",
+ "consumes": [
+ "application/json"
+ ],
+ "produces": [
+ "application/json"
+ ],
+ "tags": [
+ "RAG"
+ ],
+ "summary": "Get a summary of rag",
+ "parameters": [
+ {
+ "type": "string",
+ "description": "rag name",
+ "name": "ragName",
+ "in": "query",
+ "required": true
+ },
+ {
+ "type": "string",
+ "description": "Name of the bucket",
+ "name": "namespace",
+ "in": "header",
+ "required": true
+ },
+ {
+ "type": "string",
+ "description": "application name",
+ "name": "appName",
+ "in": "query",
+ "required": true
+ }
+ ],
+ "responses": {
+ "200": {
+ "description": "OK",
+ "schema": {
+ "$ref": "#/definitions/rag.Report"
+ }
+ },
+ "400": {
+ "description": "Bad Request",
+ "schema": {
+ "type": "object",
+ "additionalProperties": {
+ "type": "string"
+ }
+ }
+ },
+ "500": {
+ "description": "Internal Server Error",
+ "schema": {
+ "type": "object",
+ "additionalProperties": {
+ "type": "string"
+ }
+ }
+ }
+ }
+ }
}
},
"definitions": {
@@ -1392,6 +1544,118 @@ const docTemplate = `{
}
}
},
+ "rag.RadarData": {
+ "type": "object",
+ "properties": {
+ "color": {
+ "type": "string"
+ },
+ "type": {
+ "type": "string"
+ },
+ "value": {
+ "type": "number"
+ }
+ }
+ },
+ "rag.Report": {
+ "type": "object",
+ "properties": {
+ "radarChart": {
+ "type": "array",
+ "items": {
+ "$ref": "#/definitions/rag.RadarData"
+ }
+ },
+ "scatterChart": {
+ "type": "array",
+ "items": {
+ "$ref": "#/definitions/rag.ScatterData"
+ }
+ },
+ "summary": {
+ "description": "TODO",
+ "type": "string"
+ },
+ "totalScore": {
+ "$ref": "#/definitions/rag.TotalScoreData"
+ }
+ }
+ },
+ "rag.ReportDetail": {
+ "type": "object",
+ "properties": {
+ "data": {
+ "type": "array",
+ "items": {
+ "$ref": "#/definitions/rag.ReportLine"
+ }
+ },
+ "total": {
+ "type": "integer"
+ }
+ }
+ },
+ "rag.ReportLine": {
+ "type": "object",
+ "properties": {
+ "answer": {
+ "type": "string"
+ },
+ "contexts": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
+ "costTime": {
+ "type": "number"
+ },
+ "data": {
+ "type": "object",
+ "additionalProperties": {
+ "type": "number"
+ }
+ },
+ "groundTruths": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
+ "question": {
+ "type": "string"
+ },
+ "totalScore": {
+ "type": "number"
+ }
+ }
+ },
+ "rag.ScatterData": {
+ "type": "object",
+ "properties": {
+ "color": {
+ "type": "string"
+ },
+ "score": {
+ "type": "number"
+ },
+ "type": {
+ "type": "string"
+ }
+ }
+ },
+ "rag.TotalScoreData": {
+ "type": "object",
+ "properties": {
+ "color": {
+ "type": "string"
+ },
+ "score": {
+ "type": "number"
+ }
+ }
+ },
"retriever.Reference": {
"type": "object",
"properties": {
diff --git a/apiserver/docs/swagger.json b/apiserver/docs/swagger.json
index d5aeade74..a0ceab04c 100644
--- a/apiserver/docs/swagger.json
+++ b/apiserver/docs/swagger.json
@@ -1122,6 +1122,158 @@
}
}
}
+ },
+ "/rags/detail": {
+ "get": {
+ "description": "Get detail data of a rag",
+ "consumes": [
+ "application/json"
+ ],
+ "produces": [
+ "application/json"
+ ],
+ "tags": [
+ "RAG"
+ ],
+ "summary": "Get detail data of a rag",
+ "parameters": [
+ {
+ "type": "string",
+ "description": "rag name",
+ "name": "ragName",
+ "in": "query",
+ "required": true
+ },
+ {
+ "type": "string",
+ "description": "Name of the bucket",
+ "name": "namespace",
+ "in": "header",
+ "required": true
+ },
+ {
+ "type": "string",
+ "description": "application name",
+ "name": "appName",
+ "in": "query",
+ "required": true
+ },
+ {
+ "type": "integer",
+ "description": "default is 1",
+ "name": "page",
+ "in": "query"
+ },
+ {
+ "type": "string",
+ "description": "default is 10",
+ "name": "size",
+ "in": "query"
+ },
+ {
+ "type": "string",
+ "description": "rag metrcis",
+ "name": "sortBy",
+ "in": "query"
+ },
+ {
+ "type": "string",
+ "description": "desc or asc",
+ "name": "order",
+ "in": "query"
+ }
+ ],
+ "responses": {
+ "200": {
+ "description": "OK",
+ "schema": {
+ "$ref": "#/definitions/rag.ReportDetail"
+ }
+ },
+ "400": {
+ "description": "Bad Request",
+ "schema": {
+ "type": "object",
+ "additionalProperties": {
+ "type": "string"
+ }
+ }
+ },
+ "500": {
+ "description": "Internal Server Error",
+ "schema": {
+ "type": "object",
+ "additionalProperties": {
+ "type": "string"
+ }
+ }
+ }
+ }
+ }
+ },
+ "/rags/report": {
+ "get": {
+ "description": "Get a summary of rag",
+ "consumes": [
+ "application/json"
+ ],
+ "produces": [
+ "application/json"
+ ],
+ "tags": [
+ "RAG"
+ ],
+ "summary": "Get a summary of rag",
+ "parameters": [
+ {
+ "type": "string",
+ "description": "rag name",
+ "name": "ragName",
+ "in": "query",
+ "required": true
+ },
+ {
+ "type": "string",
+ "description": "Name of the bucket",
+ "name": "namespace",
+ "in": "header",
+ "required": true
+ },
+ {
+ "type": "string",
+ "description": "application name",
+ "name": "appName",
+ "in": "query",
+ "required": true
+ }
+ ],
+ "responses": {
+ "200": {
+ "description": "OK",
+ "schema": {
+ "$ref": "#/definitions/rag.Report"
+ }
+ },
+ "400": {
+ "description": "Bad Request",
+ "schema": {
+ "type": "object",
+ "additionalProperties": {
+ "type": "string"
+ }
+ }
+ },
+ "500": {
+ "description": "Internal Server Error",
+ "schema": {
+ "type": "object",
+ "additionalProperties": {
+ "type": "string"
+ }
+ }
+ }
+ }
+ }
}
},
"definitions": {
@@ -1386,6 +1538,118 @@
}
}
},
+ "rag.RadarData": {
+ "type": "object",
+ "properties": {
+ "color": {
+ "type": "string"
+ },
+ "type": {
+ "type": "string"
+ },
+ "value": {
+ "type": "number"
+ }
+ }
+ },
+ "rag.Report": {
+ "type": "object",
+ "properties": {
+ "radarChart": {
+ "type": "array",
+ "items": {
+ "$ref": "#/definitions/rag.RadarData"
+ }
+ },
+ "scatterChart": {
+ "type": "array",
+ "items": {
+ "$ref": "#/definitions/rag.ScatterData"
+ }
+ },
+ "summary": {
+ "description": "TODO",
+ "type": "string"
+ },
+ "totalScore": {
+ "$ref": "#/definitions/rag.TotalScoreData"
+ }
+ }
+ },
+ "rag.ReportDetail": {
+ "type": "object",
+ "properties": {
+ "data": {
+ "type": "array",
+ "items": {
+ "$ref": "#/definitions/rag.ReportLine"
+ }
+ },
+ "total": {
+ "type": "integer"
+ }
+ }
+ },
+ "rag.ReportLine": {
+ "type": "object",
+ "properties": {
+ "answer": {
+ "type": "string"
+ },
+ "contexts": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
+ "costTime": {
+ "type": "number"
+ },
+ "data": {
+ "type": "object",
+ "additionalProperties": {
+ "type": "number"
+ }
+ },
+ "groundTruths": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
+ "question": {
+ "type": "string"
+ },
+ "totalScore": {
+ "type": "number"
+ }
+ }
+ },
+ "rag.ScatterData": {
+ "type": "object",
+ "properties": {
+ "color": {
+ "type": "string"
+ },
+ "score": {
+ "type": "number"
+ },
+ "type": {
+ "type": "string"
+ }
+ }
+ },
+ "rag.TotalScoreData": {
+ "type": "object",
+ "properties": {
+ "color": {
+ "type": "string"
+ },
+ "score": {
+ "type": "number"
+ }
+ }
+ },
"retriever.Reference": {
"type": "object",
"properties": {
diff --git a/apiserver/docs/swagger.yaml b/apiserver/docs/swagger.yaml
index e023fa08a..aae1ff929 100644
--- a/apiserver/docs/swagger.yaml
+++ b/apiserver/docs/swagger.yaml
@@ -192,6 +192,79 @@ definitions:
total:
type: integer
type: object
+ rag.RadarData:
+ properties:
+ color:
+ type: string
+ type:
+ type: string
+ value:
+ type: number
+ type: object
+ rag.Report:
+ properties:
+ radarChart:
+ items:
+ $ref: '#/definitions/rag.RadarData'
+ type: array
+ scatterChart:
+ items:
+ $ref: '#/definitions/rag.ScatterData'
+ type: array
+ summary:
+ description: TODO
+ type: string
+ totalScore:
+ $ref: '#/definitions/rag.TotalScoreData'
+ type: object
+ rag.ReportDetail:
+ properties:
+ data:
+ items:
+ $ref: '#/definitions/rag.ReportLine'
+ type: array
+ total:
+ type: integer
+ type: object
+ rag.ReportLine:
+ properties:
+ answer:
+ type: string
+ contexts:
+ items:
+ type: string
+ type: array
+ costTime:
+ type: number
+ data:
+ additionalProperties:
+ type: number
+ type: object
+ groundTruths:
+ items:
+ type: string
+ type: array
+ question:
+ type: string
+ totalScore:
+ type: number
+ type: object
+ rag.ScatterData:
+ properties:
+ color:
+ type: string
+ score:
+ type: number
+ type:
+ type: string
+ type: object
+ rag.TotalScoreData:
+ properties:
+ color:
+ type: string
+ score:
+ type: number
+ type: object
retriever.Reference:
properties:
answer:
@@ -1173,6 +1246,108 @@ paths:
summary: get app's prompt starters
tags:
- application
+ /rags/detail:
+ get:
+ consumes:
+ - application/json
+ description: Get detail data of a rag
+ parameters:
+ - description: rag name
+ in: query
+ name: ragName
+ required: true
+ type: string
+ - description: Name of the bucket
+ in: header
+ name: namespace
+ required: true
+ type: string
+ - description: application name
+ in: query
+ name: appName
+ required: true
+ type: string
+ - description: default is 1
+ in: query
+ name: page
+ type: integer
+ - description: default is 10
+ in: query
+ name: size
+ type: string
+ - description: rag metrcis
+ in: query
+ name: sortBy
+ type: string
+ - description: desc or asc
+ in: query
+ name: order
+ type: string
+ produces:
+ - application/json
+ responses:
+ "200":
+ description: OK
+ schema:
+ $ref: '#/definitions/rag.ReportDetail'
+ "400":
+ description: Bad Request
+ schema:
+ additionalProperties:
+ type: string
+ type: object
+ "500":
+ description: Internal Server Error
+ schema:
+ additionalProperties:
+ type: string
+ type: object
+ summary: Get detail data of a rag
+ tags:
+ - RAG
+ /rags/report:
+ get:
+ consumes:
+ - application/json
+ description: Get a summary of rag
+ parameters:
+ - description: rag name
+ in: query
+ name: ragName
+ required: true
+ type: string
+ - description: Name of the bucket
+ in: header
+ name: namespace
+ required: true
+ type: string
+ - description: application name
+ in: query
+ name: appName
+ required: true
+ type: string
+ produces:
+ - application/json
+ responses:
+ "200":
+ description: OK
+ schema:
+ $ref: '#/definitions/rag.Report'
+ "400":
+ description: Bad Request
+ schema:
+ additionalProperties:
+ type: string
+ type: object
+ "500":
+ description: Internal Server Error
+ schema:
+ additionalProperties:
+ type: string
+ type: object
+ summary: Get a summary of rag
+ tags:
+ - RAG
securityDefinitions:
ApiKeyAuth:
description: API token for authorization
diff --git a/apiserver/pkg/rag/report.go b/apiserver/pkg/rag/report.go
new file mode 100644
index 000000000..cf35a7006
--- /dev/null
+++ b/apiserver/pkg/rag/report.go
@@ -0,0 +1,274 @@
+/*
+Copyright 2024 KubeAGI.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+package rag
+
+import (
+ "context"
+ "encoding/csv"
+ "fmt"
+ "io"
+ "sort"
+ "strconv"
+ "strings"
+
+ "github.com/minio/minio-go/v7"
+ "k8s.io/klog/v2"
+ "sigs.k8s.io/controller-runtime/pkg/client"
+
+ "github.com/kubeagi/arcadia/api/evaluation/v1alpha1"
+ "github.com/kubeagi/arcadia/apiserver/pkg/common"
+)
+
+const (
+ totalScore = "total_score"
+
+ // TODO: support for color change via env
+ blueColorEnv = "BLUE_ENV"
+ blue = "blue" // 散点图的颜色
+
+ orangeEnv = "ORANGE_RNV"
+ orange = "orange" // 差
+
+ greenEnv = "GREEN_ENV"
+ green = "green" // 好
+
+ summarySuggestionTemplate = `通过此次评估,您的智能体得分偏低,主要体现在 %s 这 %d 项指标得分偏低。
+
+建议您对特定场景应用的模型进行模型精调;%s。`
+ noSuggestionTempalte = `通过此次评估,您的 RAG 方案得分 %.2f`
+)
+
+var (
+ metricChinese = map[string]string{
+ string(v1alpha1.AnswerRelevancy): "答案相关度",
+ string(v1alpha1.AnswerSimilarity): "答案相似度",
+ string(v1alpha1.AnswerCorrectness): "答案正确性",
+ string(v1alpha1.Faithfulness): "忠实度",
+ string(v1alpha1.ContextPrecision): "知识库精度",
+ string(v1alpha1.ContextRelevancy): "知识库相关度",
+ string(v1alpha1.ContextRecall): "知识库召回率",
+ string(v1alpha1.AspectCritique): "暂时没用到",
+ }
+
+ suggestionChinese = map[string]string{
+ string(v1alpha1.AnswerRelevancy): "调整 Embedding 模型",
+ string(v1alpha1.AnswerSimilarity): "调整 Embedding 模型",
+ string(v1alpha1.AnswerCorrectness): "调整模型配置或更换模型",
+ string(v1alpha1.Faithfulness): "调整模型配置或更换模型",
+ string(v1alpha1.ContextPrecision): "调整 Embedding 模型",
+ string(v1alpha1.ContextRelevancy): "调整 Embedding 模型",
+ string(v1alpha1.ContextRecall): "调整 QA 数据", // 知识库相似度?
+ string(v1alpha1.AspectCritique): "暂时没用到",
+ }
+)
+
+type (
+ RadarData struct {
+ Type string `json:"type"`
+ Value float64 `json:"value"`
+ Color string `json:"color"`
+ }
+
+ TotalScoreData struct {
+ Score float64 `json:"score"`
+ Color string `json:"color"`
+ }
+
+ ScatterData struct {
+ Score float64 `json:"score"`
+ Type string `json:"type"`
+ Color string `json:"color"`
+ }
+
+ Report struct {
+ RadarChart []RadarData `json:"radarChart"`
+ TotalScore TotalScoreData `json:"totalScore"`
+ ScatterChart []ScatterData `json:"scatterChart"`
+
+ // TODO
+ Summary string `json:"summary"`
+ }
+
+ // 忠实度、答案相关度、答案语义相似度、答案正确性、知识库相关度、知识库精度、知识库相似度
+ // question,ground_truths,answer,contexts
+ ReportLine struct {
+ Question string `json:"question"`
+ GroundTruths []string `json:"groundTruths"`
+ Answer string `json:"answer"`
+ Contexts []string `json:"contexts"`
+ Data map[string]float64 `json:"data"`
+ TotalScore float64 `json:"totalScore"`
+ CostTime float64 `json:"costTime"`
+ }
+ ReportDetail struct {
+ Data []ReportLine `json:"data"`
+ Total int `json:"total"`
+ }
+)
+
+func ParseSummary(
+ ctx context.Context, c client.Client,
+ appName, ragName, namespace string,
+ metricThresholds map[string]float64) (Report, error) {
+ source, err := common.SystemDatasourceOSS(ctx, c, nil)
+ if err != nil {
+ klog.Errorf("failed to get system datasource error %s", err)
+ return Report{}, err
+ }
+
+ filePath := fmt.Sprintf("evals/%s/%s/summary.csv", appName, ragName)
+ object, err := source.Client.GetObject(ctx, namespace, filePath, minio.GetObjectOptions{})
+ if err != nil {
+ klog.Errorf("failed to get summary.csv file error %s", err)
+ return Report{}, err
+ }
+ csvReader := csv.NewReader(object)
+ report := Report{TotalScore: TotalScoreData{}, RadarChart: []RadarData{}, ScatterChart: []ScatterData{}}
+ radarChecker := make(map[string]int)
+ scatterChecker := make(map[string]int)
+
+ changeTotalScoreColor := false
+
+ metrics := make([]string, 0)
+ metricSuggesstion := make([]string, 0)
+
+ // skip the first line
+ firstLine := true
+ for {
+ line, err := csvReader.Read()
+ if err != nil {
+ if err != io.EOF {
+ return Report{}, err
+ }
+ break
+ }
+ if firstLine {
+ firstLine = false
+ continue
+ }
+ if len(line) != 2 {
+ return Report{}, fmt.Errorf("the summary file should only have two columns")
+ }
+ score, err := strconv.ParseFloat(line[1], 64)
+ if err != nil {
+ klog.Errorf("failed to parse thresholds for indicator %s, source value %s", line[0], line[1])
+ return Report{}, err
+ }
+ if line[0] == totalScore {
+ report.TotalScore = TotalScoreData{Score: score, Color: green}
+ continue
+ }
+ nextRadarIndex := len(report.RadarChart)
+ idx, ok := radarChecker[line[0]]
+ if !ok {
+ radarChecker[line[0]] = nextRadarIndex
+ idx = nextRadarIndex
+ report.RadarChart = append(report.RadarChart, RadarData{Type: line[0]})
+ }
+ report.RadarChart[idx].Value = score
+ report.RadarChart[idx].Color = green
+ if threshold, ok := metricThresholds[line[0]]; ok && score < threshold {
+ report.RadarChart[idx].Color = orange
+ metrics = append(metrics, metricChinese[line[0]])
+ metricSuggesstion = append(metricSuggesstion, suggestionChinese[line[0]])
+ changeTotalScoreColor = true
+ }
+
+ nextScatterIndex := len(report.ScatterChart)
+ idx, ok = scatterChecker[line[0]]
+ if !ok {
+ scatterChecker[line[0]] = nextScatterIndex
+ report.ScatterChart = append(report.ScatterChart, ScatterData{Type: line[0], Color: blue})
+ idx = nextScatterIndex
+ }
+ report.ScatterChart[idx].Score = score
+ }
+
+ if changeTotalScoreColor {
+ report.TotalScore.Color = orange
+ report.Summary = fmt.Sprintf(summarySuggestionTemplate, strings.Join(metrics, "、"), len(metrics), strings.Join(metricSuggesstion, "、"))
+ } else {
+ report.Summary = fmt.Sprintf(noSuggestionTempalte, report.TotalScore.Score)
+ }
+ return report, nil
+}
+
+func ParseResult(
+ ctx context.Context, c client.Client,
+ page, pageSize int,
+ appName, ragName, namespace, sortBy, order string) (ReportDetail, error) {
+ source, err := common.SystemDatasourceOSS(ctx, c, nil)
+ if err != nil {
+ klog.Errorf("failed to get system datasource error %s", err)
+ return ReportDetail{}, err
+ }
+
+ filePath := fmt.Sprintf("evals/%s/%s/result.csv", appName, ragName)
+ object, err := source.Client.GetObject(ctx, namespace, filePath, minio.GetObjectOptions{})
+ if err != nil {
+ klog.Errorf("failed to get result.csv file error %s", err)
+ return ReportDetail{}, err
+ }
+ csvReader := csv.NewReader(object)
+
+ data, err := csvReader.ReadAll()
+ if err != nil {
+ klog.Error("failed to read csv error %s", err)
+ return ReportDetail{}, err
+ }
+ if len(data) == 0 {
+ klog.Error("this may not be a normal csv file with one line of data in it: %s", filePath)
+ return ReportDetail{}, nil
+ }
+ if len(data) == 1 {
+ klog.Error("there's only one header row. %s", filePath)
+ return ReportDetail{}, nil
+ }
+
+ result := make([]ReportLine, len(data)-1)
+ header := data[0]
+ for i, line := range data[1:] {
+ item := ReportLine{
+ Question: line[1],
+ GroundTruths: []string{line[2]},
+ Answer: line[3],
+ Contexts: []string{line[4]},
+ Data: make(map[string]float64),
+ }
+ sum := float64(0)
+ for i := 5; i < len(line); i++ {
+ f, _ := strconv.ParseFloat(line[i], 64)
+ item.Data[header[i]] = f
+ sum += f
+ }
+ item.TotalScore = sum / float64(len(line)-5)
+ result[i] = item
+ }
+
+ start, end := common.PagePosition(page, pageSize, len(data)-1)
+ if sortBy != "" {
+ if _, ok := result[0].Data[sortBy]; ok {
+ sort.Slice(result, func(i, j int) bool {
+ if order == "desc" {
+ return result[i].Data[sortBy] > result[j].Data[sortBy]
+ }
+ return result[i].Data[sortBy] < result[j].Data[sortBy]
+ })
+ }
+ }
+ result = result[start:end]
+ return ReportDetail{Data: result, Total: len(data) - 1}, nil
+}
diff --git a/apiserver/service/minio_server.go b/apiserver/service/minio_server.go
index 4532e466c..c0bceafeb 100644
--- a/apiserver/service/minio_server.go
+++ b/apiserver/service/minio_server.go
@@ -32,6 +32,7 @@ import (
"k8s.io/klog/v2"
"github.com/kubeagi/arcadia/api/base/v1alpha1"
+ evaluationarcadiav1alpha1 "github.com/kubeagi/arcadia/api/evaluation/v1alpha1"
gqlconfig "github.com/kubeagi/arcadia/apiserver/config"
"github.com/kubeagi/arcadia/apiserver/pkg/auth"
"github.com/kubeagi/arcadia/apiserver/pkg/client"
@@ -917,4 +918,6 @@ func registerMinIOAPI(group *gin.RouterGroup, conf gqlconfig.ServerConfig) {
// create a webcrawler file for versioneddataset
group.POST("/versioneddataset/files/webcrawler", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, v1alpha1.GroupVersion, "create", "versioneddatasets"), api.CreateWebCrawlerFile)
}
+
+ group.GET("/rags/files/downloadlink", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, evaluationarcadiav1alpha1.GroupVersion, "get", "rags"), api.GetDownloadLink)
}
diff --git a/apiserver/service/rag_server.go b/apiserver/service/rag_server.go
new file mode 100644
index 000000000..5ce1b4ef9
--- /dev/null
+++ b/apiserver/service/rag_server.go
@@ -0,0 +1,136 @@
+/*
+Copyright 2024 KubeAGI.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+package service
+
+import (
+ "fmt"
+ "net/http"
+ "strconv"
+
+ "github.com/gin-gonic/gin"
+ "k8s.io/apimachinery/pkg/types"
+ "k8s.io/klog/v2"
+ ctrl "sigs.k8s.io/controller-runtime"
+ "sigs.k8s.io/controller-runtime/pkg/client"
+
+ "github.com/kubeagi/arcadia/api/evaluation/v1alpha1"
+ gqlconfig "github.com/kubeagi/arcadia/apiserver/config"
+ "github.com/kubeagi/arcadia/apiserver/pkg/auth"
+ "github.com/kubeagi/arcadia/apiserver/pkg/oidc"
+ "github.com/kubeagi/arcadia/apiserver/pkg/rag"
+)
+
+type RagAPI struct {
+ c client.Client
+}
+
+const (
+ ragNameQuery = "ragName"
+ appNameQuery = "appName"
+ namespaceHeadr = "namespace"
+)
+
+// @Summary Get a summary of rag
+// @Schemes
+// @Description Get a summary of rag
+// @Tags RAG
+// @Accept json
+// @Produce json
+// @Param ragName query string true "rag name"
+// @Param namespace header string true "Name of the bucket"
+// @Param appName query string true "application name"
+// @Success 200 {object} rag.Report
+// @Failure 400 {object} map[string]string
+// @Failure 500 {object} map[string]string
+// @Router /rags/report [get]
+func (r *RagAPI) Summary(ctx *gin.Context) {
+ ragName := ctx.Query(ragNameQuery)
+ appName := ctx.Query(appNameQuery)
+ namespace := ctx.GetHeader(namespaceHeadr)
+
+ rr := v1alpha1.RAG{}
+ if err := r.c.Get(ctx, types.NamespacedName{
+ Namespace: namespace, Name: ragName,
+ }, &rr); err != nil {
+ klog.Error(fmt.Sprintf("can't get rag by name %s", ragName))
+ ctx.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
+ "message": fmt.Sprintf("can't get rag by name %s", ragName),
+ })
+ return
+ }
+ thresholds := make(map[string]float64)
+ for _, param := range rr.Spec.Metrics {
+ thresholds[string(param.Kind)] = float64(param.ToleranceThreshbold) / 100.0
+ }
+
+ report, err := rag.ParseSummary(ctx.Request.Context(), r.c, appName, ragName, namespace, thresholds)
+ if err != nil {
+ klog.Errorf("an error occurred generating the report, error %s", err)
+ ctx.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
+ "message": err.Error(),
+ })
+ return
+ }
+ ctx.JSON(http.StatusOK, report)
+}
+
+// @Summary Get detail data of a rag
+// @Schemes
+// @Description Get detail data of a rag
+// @Tags RAG
+// @Accept json
+// @Produce json
+// @Param ragName query string true "rag name"
+// @Param namespace header string true "Name of the bucket"
+// @Param appName query string true "application name"
+// @Param page query int false "default is 1"
+// @Param size query string false "default is 10"
+// @Param sortBy query string false "rag metrcis"
+// @Param order query string false "desc or asc"
+// @Success 200 {object} rag.ReportDetail
+// @Failure 400 {object} map[string]string
+// @Failure 500 {object} map[string]string
+// @Router /rags/detail [get]
+func (r *RagAPI) ReportDetail(ctx *gin.Context) {
+ page, _ := strconv.Atoi(ctx.DefaultQuery("page", "1"))
+ pageSize, _ := strconv.Atoi(ctx.DefaultQuery("size", "10"))
+ sortBy := ctx.Query("sortBy")
+ order := ctx.DefaultQuery("order", "desc")
+ ragName := ctx.Query(ragNameQuery)
+ appName := ctx.Query(appNameQuery)
+ namespace := ctx.GetHeader(namespaceHeadr)
+
+ result, err := rag.ParseResult(ctx.Request.Context(), r.c, page, pageSize, appName, ragName, namespace, sortBy, order)
+ if err != nil {
+ ctx.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
+ "message": err.Error(),
+ })
+ return
+ }
+ ctx.JSON(http.StatusOK, result)
+}
+
+func registerRAG(g *gin.RouterGroup, conf gqlconfig.ServerConfig) {
+ cfg := ctrl.GetConfigOrDie()
+ c, err := client.New(cfg, client.Options{Scheme: conf.Scheme})
+ if err != nil {
+ panic(err)
+ }
+ api := RagAPI{c: c}
+
+ g.GET("/report", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, v1alpha1.GroupVersion, "get", "rags"), api.Summary)
+ g.GET("/detail", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, v1alpha1.GroupVersion, "get", "rags"), api.ReportDetail)
+}
diff --git a/apiserver/service/router.go b/apiserver/service/router.go
index 8101ffa0c..04cf9e815 100644
--- a/apiserver/service/router.go
+++ b/apiserver/service/router.go
@@ -61,6 +61,9 @@ func NewServerAndRun(conf config.ServerConfig) {
// for ops apis with graphql
registerGraphQL(r, bffGroup, conf)
+ ragGroup := r.Group("/rags")
+ registerRAG(ragGroup, conf)
+
// for chat server with Restful apis
chatGroup := r.Group("/chat")
registerChat(chatGroup, conf)