From dc10c0039eeff485858f05b0abf21baa5cda656a Mon Sep 17 00:00:00 2001 From: 0xff-dev Date: Thu, 1 Feb 2024 13:54:23 +0800 Subject: [PATCH] feat: add a new api to get rag evaluation report --- apiserver/config/config.go | 14 ++ apiserver/docs/docs.go | 264 ++++++++++++++++++++++++++++ apiserver/docs/swagger.json | 264 ++++++++++++++++++++++++++++ apiserver/docs/swagger.yaml | 175 +++++++++++++++++++ apiserver/pkg/rag/report.go | 274 ++++++++++++++++++++++++++++++ apiserver/service/minio_server.go | 3 + apiserver/service/rag_server.go | 136 +++++++++++++++ apiserver/service/router.go | 3 + 8 files changed, 1133 insertions(+) create mode 100644 apiserver/pkg/rag/report.go create mode 100644 apiserver/service/rag_server.go diff --git a/apiserver/config/config.go b/apiserver/config/config.go index 1c446347f..3d1f2260a 100644 --- a/apiserver/config/config.go +++ b/apiserver/config/config.go @@ -19,8 +19,14 @@ import ( "flag" "os" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/runtime" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" "k8s.io/klog/v2" + "github.com/kubeagi/arcadia/api/base/v1alpha1" + evaluationarcadiav1alpha1 "github.com/kubeagi/arcadia/api/evaluation/v1alpha1" "github.com/kubeagi/arcadia/apiserver/pkg/dataprocessing" ) @@ -39,6 +45,8 @@ type ServerConfig struct { IssuerURL, MasterURL, ClientID, ClientSecret string DataProcessURL string + + Scheme *runtime.Scheme } func NewServerFlags() ServerConfig { @@ -58,6 +66,12 @@ func NewServerFlags() ServerConfig { klog.InitFlags(nil) flag.Parse() + s.Scheme = runtime.NewScheme() + utilruntime.Must(clientgoscheme.AddToScheme(s.Scheme)) + utilruntime.Must(v1.AddToScheme(s.Scheme)) + utilruntime.Must(evaluationarcadiav1alpha1.AddToScheme(s.Scheme)) + utilruntime.Must(v1alpha1.AddToScheme(s.Scheme)) + dataprocessing.Init(s.DataProcessURL) return *s } diff --git a/apiserver/docs/docs.go b/apiserver/docs/docs.go index 6aa44aa32..3c28a9fb4 100644 --- a/apiserver/docs/docs.go +++ b/apiserver/docs/docs.go @@ -1128,6 +1128,158 @@ const docTemplate = `{ } } } + }, + "/rags/detail": { + "get": { + "description": "Get detail data of a rag", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "RAG" + ], + "summary": "Get detail data of a rag", + "parameters": [ + { + "type": "string", + "description": "rag name", + "name": "ragName", + "in": "query", + "required": true + }, + { + "type": "string", + "description": "Name of the bucket", + "name": "namespace", + "in": "header", + "required": true + }, + { + "type": "string", + "description": "application name", + "name": "appName", + "in": "query", + "required": true + }, + { + "type": "integer", + "description": "default is 1", + "name": "page", + "in": "query" + }, + { + "type": "string", + "description": "default is 10", + "name": "size", + "in": "query" + }, + { + "type": "string", + "description": "rag metrcis", + "name": "sortBy", + "in": "query" + }, + { + "type": "string", + "description": "desc or asc", + "name": "order", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/rag.ReportDetail" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + } + } + } + }, + "/rags/report": { + "get": { + "description": "Get a summary of rag", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "RAG" + ], + "summary": "Get a summary of rag", + "parameters": [ + { + "type": "string", + "description": "rag name", + "name": "ragName", + "in": "query", + "required": true + }, + { + "type": "string", + "description": "Name of the bucket", + "name": "namespace", + "in": "header", + "required": true + }, + { + "type": "string", + "description": "application name", + "name": "appName", + "in": "query", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/rag.Report" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + } + } + } } }, "definitions": { @@ -1392,6 +1544,118 @@ const docTemplate = `{ } } }, + "rag.RadarData": { + "type": "object", + "properties": { + "color": { + "type": "string" + }, + "type": { + "type": "string" + }, + "value": { + "type": "number" + } + } + }, + "rag.Report": { + "type": "object", + "properties": { + "radarChart": { + "type": "array", + "items": { + "$ref": "#/definitions/rag.RadarData" + } + }, + "scatterChart": { + "type": "array", + "items": { + "$ref": "#/definitions/rag.ScatterData" + } + }, + "summary": { + "description": "TODO", + "type": "string" + }, + "totalScore": { + "$ref": "#/definitions/rag.TotalScoreData" + } + } + }, + "rag.ReportDetail": { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/definitions/rag.ReportLine" + } + }, + "total": { + "type": "integer" + } + } + }, + "rag.ReportLine": { + "type": "object", + "properties": { + "answer": { + "type": "string" + }, + "contexts": { + "type": "array", + "items": { + "type": "string" + } + }, + "costTime": { + "type": "number" + }, + "data": { + "type": "object", + "additionalProperties": { + "type": "number" + } + }, + "groundTruths": { + "type": "array", + "items": { + "type": "string" + } + }, + "question": { + "type": "string" + }, + "totalScore": { + "type": "number" + } + } + }, + "rag.ScatterData": { + "type": "object", + "properties": { + "color": { + "type": "string" + }, + "score": { + "type": "number" + }, + "type": { + "type": "string" + } + } + }, + "rag.TotalScoreData": { + "type": "object", + "properties": { + "color": { + "type": "string" + }, + "score": { + "type": "number" + } + } + }, "retriever.Reference": { "type": "object", "properties": { diff --git a/apiserver/docs/swagger.json b/apiserver/docs/swagger.json index d5aeade74..a0ceab04c 100644 --- a/apiserver/docs/swagger.json +++ b/apiserver/docs/swagger.json @@ -1122,6 +1122,158 @@ } } } + }, + "/rags/detail": { + "get": { + "description": "Get detail data of a rag", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "RAG" + ], + "summary": "Get detail data of a rag", + "parameters": [ + { + "type": "string", + "description": "rag name", + "name": "ragName", + "in": "query", + "required": true + }, + { + "type": "string", + "description": "Name of the bucket", + "name": "namespace", + "in": "header", + "required": true + }, + { + "type": "string", + "description": "application name", + "name": "appName", + "in": "query", + "required": true + }, + { + "type": "integer", + "description": "default is 1", + "name": "page", + "in": "query" + }, + { + "type": "string", + "description": "default is 10", + "name": "size", + "in": "query" + }, + { + "type": "string", + "description": "rag metrcis", + "name": "sortBy", + "in": "query" + }, + { + "type": "string", + "description": "desc or asc", + "name": "order", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/rag.ReportDetail" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + } + } + } + }, + "/rags/report": { + "get": { + "description": "Get a summary of rag", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "RAG" + ], + "summary": "Get a summary of rag", + "parameters": [ + { + "type": "string", + "description": "rag name", + "name": "ragName", + "in": "query", + "required": true + }, + { + "type": "string", + "description": "Name of the bucket", + "name": "namespace", + "in": "header", + "required": true + }, + { + "type": "string", + "description": "application name", + "name": "appName", + "in": "query", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/rag.Report" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + } + } + } } }, "definitions": { @@ -1386,6 +1538,118 @@ } } }, + "rag.RadarData": { + "type": "object", + "properties": { + "color": { + "type": "string" + }, + "type": { + "type": "string" + }, + "value": { + "type": "number" + } + } + }, + "rag.Report": { + "type": "object", + "properties": { + "radarChart": { + "type": "array", + "items": { + "$ref": "#/definitions/rag.RadarData" + } + }, + "scatterChart": { + "type": "array", + "items": { + "$ref": "#/definitions/rag.ScatterData" + } + }, + "summary": { + "description": "TODO", + "type": "string" + }, + "totalScore": { + "$ref": "#/definitions/rag.TotalScoreData" + } + } + }, + "rag.ReportDetail": { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/definitions/rag.ReportLine" + } + }, + "total": { + "type": "integer" + } + } + }, + "rag.ReportLine": { + "type": "object", + "properties": { + "answer": { + "type": "string" + }, + "contexts": { + "type": "array", + "items": { + "type": "string" + } + }, + "costTime": { + "type": "number" + }, + "data": { + "type": "object", + "additionalProperties": { + "type": "number" + } + }, + "groundTruths": { + "type": "array", + "items": { + "type": "string" + } + }, + "question": { + "type": "string" + }, + "totalScore": { + "type": "number" + } + } + }, + "rag.ScatterData": { + "type": "object", + "properties": { + "color": { + "type": "string" + }, + "score": { + "type": "number" + }, + "type": { + "type": "string" + } + } + }, + "rag.TotalScoreData": { + "type": "object", + "properties": { + "color": { + "type": "string" + }, + "score": { + "type": "number" + } + } + }, "retriever.Reference": { "type": "object", "properties": { diff --git a/apiserver/docs/swagger.yaml b/apiserver/docs/swagger.yaml index e023fa08a..aae1ff929 100644 --- a/apiserver/docs/swagger.yaml +++ b/apiserver/docs/swagger.yaml @@ -192,6 +192,79 @@ definitions: total: type: integer type: object + rag.RadarData: + properties: + color: + type: string + type: + type: string + value: + type: number + type: object + rag.Report: + properties: + radarChart: + items: + $ref: '#/definitions/rag.RadarData' + type: array + scatterChart: + items: + $ref: '#/definitions/rag.ScatterData' + type: array + summary: + description: TODO + type: string + totalScore: + $ref: '#/definitions/rag.TotalScoreData' + type: object + rag.ReportDetail: + properties: + data: + items: + $ref: '#/definitions/rag.ReportLine' + type: array + total: + type: integer + type: object + rag.ReportLine: + properties: + answer: + type: string + contexts: + items: + type: string + type: array + costTime: + type: number + data: + additionalProperties: + type: number + type: object + groundTruths: + items: + type: string + type: array + question: + type: string + totalScore: + type: number + type: object + rag.ScatterData: + properties: + color: + type: string + score: + type: number + type: + type: string + type: object + rag.TotalScoreData: + properties: + color: + type: string + score: + type: number + type: object retriever.Reference: properties: answer: @@ -1173,6 +1246,108 @@ paths: summary: get app's prompt starters tags: - application + /rags/detail: + get: + consumes: + - application/json + description: Get detail data of a rag + parameters: + - description: rag name + in: query + name: ragName + required: true + type: string + - description: Name of the bucket + in: header + name: namespace + required: true + type: string + - description: application name + in: query + name: appName + required: true + type: string + - description: default is 1 + in: query + name: page + type: integer + - description: default is 10 + in: query + name: size + type: string + - description: rag metrcis + in: query + name: sortBy + type: string + - description: desc or asc + in: query + name: order + type: string + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/rag.ReportDetail' + "400": + description: Bad Request + schema: + additionalProperties: + type: string + type: object + "500": + description: Internal Server Error + schema: + additionalProperties: + type: string + type: object + summary: Get detail data of a rag + tags: + - RAG + /rags/report: + get: + consumes: + - application/json + description: Get a summary of rag + parameters: + - description: rag name + in: query + name: ragName + required: true + type: string + - description: Name of the bucket + in: header + name: namespace + required: true + type: string + - description: application name + in: query + name: appName + required: true + type: string + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/rag.Report' + "400": + description: Bad Request + schema: + additionalProperties: + type: string + type: object + "500": + description: Internal Server Error + schema: + additionalProperties: + type: string + type: object + summary: Get a summary of rag + tags: + - RAG securityDefinitions: ApiKeyAuth: description: API token for authorization diff --git a/apiserver/pkg/rag/report.go b/apiserver/pkg/rag/report.go new file mode 100644 index 000000000..cf35a7006 --- /dev/null +++ b/apiserver/pkg/rag/report.go @@ -0,0 +1,274 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package rag + +import ( + "context" + "encoding/csv" + "fmt" + "io" + "sort" + "strconv" + "strings" + + "github.com/minio/minio-go/v7" + "k8s.io/klog/v2" + "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/kubeagi/arcadia/api/evaluation/v1alpha1" + "github.com/kubeagi/arcadia/apiserver/pkg/common" +) + +const ( + totalScore = "total_score" + + // TODO: support for color change via env + blueColorEnv = "BLUE_ENV" + blue = "blue" // 散点图的颜色 + + orangeEnv = "ORANGE_RNV" + orange = "orange" // 差 + + greenEnv = "GREEN_ENV" + green = "green" // 好 + + summarySuggestionTemplate = `通过此次评估,您的智能体得分偏低,主要体现在 %s 这 %d 项指标得分偏低。 +
+建议您对特定场景应用的模型进行模型精调;%s。` + noSuggestionTempalte = `通过此次评估,您的 RAG 方案得分 %.2f` +) + +var ( + metricChinese = map[string]string{ + string(v1alpha1.AnswerRelevancy): "答案相关度", + string(v1alpha1.AnswerSimilarity): "答案相似度", + string(v1alpha1.AnswerCorrectness): "答案正确性", + string(v1alpha1.Faithfulness): "忠实度", + string(v1alpha1.ContextPrecision): "知识库精度", + string(v1alpha1.ContextRelevancy): "知识库相关度", + string(v1alpha1.ContextRecall): "知识库召回率", + string(v1alpha1.AspectCritique): "暂时没用到", + } + + suggestionChinese = map[string]string{ + string(v1alpha1.AnswerRelevancy): "调整 Embedding 模型", + string(v1alpha1.AnswerSimilarity): "调整 Embedding 模型", + string(v1alpha1.AnswerCorrectness): "调整模型配置或更换模型", + string(v1alpha1.Faithfulness): "调整模型配置或更换模型", + string(v1alpha1.ContextPrecision): "调整 Embedding 模型", + string(v1alpha1.ContextRelevancy): "调整 Embedding 模型", + string(v1alpha1.ContextRecall): "调整 QA 数据", // 知识库相似度? + string(v1alpha1.AspectCritique): "暂时没用到", + } +) + +type ( + RadarData struct { + Type string `json:"type"` + Value float64 `json:"value"` + Color string `json:"color"` + } + + TotalScoreData struct { + Score float64 `json:"score"` + Color string `json:"color"` + } + + ScatterData struct { + Score float64 `json:"score"` + Type string `json:"type"` + Color string `json:"color"` + } + + Report struct { + RadarChart []RadarData `json:"radarChart"` + TotalScore TotalScoreData `json:"totalScore"` + ScatterChart []ScatterData `json:"scatterChart"` + + // TODO + Summary string `json:"summary"` + } + + // 忠实度、答案相关度、答案语义相似度、答案正确性、知识库相关度、知识库精度、知识库相似度 + // question,ground_truths,answer,contexts + ReportLine struct { + Question string `json:"question"` + GroundTruths []string `json:"groundTruths"` + Answer string `json:"answer"` + Contexts []string `json:"contexts"` + Data map[string]float64 `json:"data"` + TotalScore float64 `json:"totalScore"` + CostTime float64 `json:"costTime"` + } + ReportDetail struct { + Data []ReportLine `json:"data"` + Total int `json:"total"` + } +) + +func ParseSummary( + ctx context.Context, c client.Client, + appName, ragName, namespace string, + metricThresholds map[string]float64) (Report, error) { + source, err := common.SystemDatasourceOSS(ctx, c, nil) + if err != nil { + klog.Errorf("failed to get system datasource error %s", err) + return Report{}, err + } + + filePath := fmt.Sprintf("evals/%s/%s/summary.csv", appName, ragName) + object, err := source.Client.GetObject(ctx, namespace, filePath, minio.GetObjectOptions{}) + if err != nil { + klog.Errorf("failed to get summary.csv file error %s", err) + return Report{}, err + } + csvReader := csv.NewReader(object) + report := Report{TotalScore: TotalScoreData{}, RadarChart: []RadarData{}, ScatterChart: []ScatterData{}} + radarChecker := make(map[string]int) + scatterChecker := make(map[string]int) + + changeTotalScoreColor := false + + metrics := make([]string, 0) + metricSuggesstion := make([]string, 0) + + // skip the first line + firstLine := true + for { + line, err := csvReader.Read() + if err != nil { + if err != io.EOF { + return Report{}, err + } + break + } + if firstLine { + firstLine = false + continue + } + if len(line) != 2 { + return Report{}, fmt.Errorf("the summary file should only have two columns") + } + score, err := strconv.ParseFloat(line[1], 64) + if err != nil { + klog.Errorf("failed to parse thresholds for indicator %s, source value %s", line[0], line[1]) + return Report{}, err + } + if line[0] == totalScore { + report.TotalScore = TotalScoreData{Score: score, Color: green} + continue + } + nextRadarIndex := len(report.RadarChart) + idx, ok := radarChecker[line[0]] + if !ok { + radarChecker[line[0]] = nextRadarIndex + idx = nextRadarIndex + report.RadarChart = append(report.RadarChart, RadarData{Type: line[0]}) + } + report.RadarChart[idx].Value = score + report.RadarChart[idx].Color = green + if threshold, ok := metricThresholds[line[0]]; ok && score < threshold { + report.RadarChart[idx].Color = orange + metrics = append(metrics, metricChinese[line[0]]) + metricSuggesstion = append(metricSuggesstion, suggestionChinese[line[0]]) + changeTotalScoreColor = true + } + + nextScatterIndex := len(report.ScatterChart) + idx, ok = scatterChecker[line[0]] + if !ok { + scatterChecker[line[0]] = nextScatterIndex + report.ScatterChart = append(report.ScatterChart, ScatterData{Type: line[0], Color: blue}) + idx = nextScatterIndex + } + report.ScatterChart[idx].Score = score + } + + if changeTotalScoreColor { + report.TotalScore.Color = orange + report.Summary = fmt.Sprintf(summarySuggestionTemplate, strings.Join(metrics, "、"), len(metrics), strings.Join(metricSuggesstion, "、")) + } else { + report.Summary = fmt.Sprintf(noSuggestionTempalte, report.TotalScore.Score) + } + return report, nil +} + +func ParseResult( + ctx context.Context, c client.Client, + page, pageSize int, + appName, ragName, namespace, sortBy, order string) (ReportDetail, error) { + source, err := common.SystemDatasourceOSS(ctx, c, nil) + if err != nil { + klog.Errorf("failed to get system datasource error %s", err) + return ReportDetail{}, err + } + + filePath := fmt.Sprintf("evals/%s/%s/result.csv", appName, ragName) + object, err := source.Client.GetObject(ctx, namespace, filePath, minio.GetObjectOptions{}) + if err != nil { + klog.Errorf("failed to get result.csv file error %s", err) + return ReportDetail{}, err + } + csvReader := csv.NewReader(object) + + data, err := csvReader.ReadAll() + if err != nil { + klog.Error("failed to read csv error %s", err) + return ReportDetail{}, err + } + if len(data) == 0 { + klog.Error("this may not be a normal csv file with one line of data in it: %s", filePath) + return ReportDetail{}, nil + } + if len(data) == 1 { + klog.Error("there's only one header row. %s", filePath) + return ReportDetail{}, nil + } + + result := make([]ReportLine, len(data)-1) + header := data[0] + for i, line := range data[1:] { + item := ReportLine{ + Question: line[1], + GroundTruths: []string{line[2]}, + Answer: line[3], + Contexts: []string{line[4]}, + Data: make(map[string]float64), + } + sum := float64(0) + for i := 5; i < len(line); i++ { + f, _ := strconv.ParseFloat(line[i], 64) + item.Data[header[i]] = f + sum += f + } + item.TotalScore = sum / float64(len(line)-5) + result[i] = item + } + + start, end := common.PagePosition(page, pageSize, len(data)-1) + if sortBy != "" { + if _, ok := result[0].Data[sortBy]; ok { + sort.Slice(result, func(i, j int) bool { + if order == "desc" { + return result[i].Data[sortBy] > result[j].Data[sortBy] + } + return result[i].Data[sortBy] < result[j].Data[sortBy] + }) + } + } + result = result[start:end] + return ReportDetail{Data: result, Total: len(data) - 1}, nil +} diff --git a/apiserver/service/minio_server.go b/apiserver/service/minio_server.go index 4532e466c..c0bceafeb 100644 --- a/apiserver/service/minio_server.go +++ b/apiserver/service/minio_server.go @@ -32,6 +32,7 @@ import ( "k8s.io/klog/v2" "github.com/kubeagi/arcadia/api/base/v1alpha1" + evaluationarcadiav1alpha1 "github.com/kubeagi/arcadia/api/evaluation/v1alpha1" gqlconfig "github.com/kubeagi/arcadia/apiserver/config" "github.com/kubeagi/arcadia/apiserver/pkg/auth" "github.com/kubeagi/arcadia/apiserver/pkg/client" @@ -917,4 +918,6 @@ func registerMinIOAPI(group *gin.RouterGroup, conf gqlconfig.ServerConfig) { // create a webcrawler file for versioneddataset group.POST("/versioneddataset/files/webcrawler", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, v1alpha1.GroupVersion, "create", "versioneddatasets"), api.CreateWebCrawlerFile) } + + group.GET("/rags/files/downloadlink", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, evaluationarcadiav1alpha1.GroupVersion, "get", "rags"), api.GetDownloadLink) } diff --git a/apiserver/service/rag_server.go b/apiserver/service/rag_server.go new file mode 100644 index 000000000..5ce1b4ef9 --- /dev/null +++ b/apiserver/service/rag_server.go @@ -0,0 +1,136 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package service + +import ( + "fmt" + "net/http" + "strconv" + + "github.com/gin-gonic/gin" + "k8s.io/apimachinery/pkg/types" + "k8s.io/klog/v2" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/kubeagi/arcadia/api/evaluation/v1alpha1" + gqlconfig "github.com/kubeagi/arcadia/apiserver/config" + "github.com/kubeagi/arcadia/apiserver/pkg/auth" + "github.com/kubeagi/arcadia/apiserver/pkg/oidc" + "github.com/kubeagi/arcadia/apiserver/pkg/rag" +) + +type RagAPI struct { + c client.Client +} + +const ( + ragNameQuery = "ragName" + appNameQuery = "appName" + namespaceHeadr = "namespace" +) + +// @Summary Get a summary of rag +// @Schemes +// @Description Get a summary of rag +// @Tags RAG +// @Accept json +// @Produce json +// @Param ragName query string true "rag name" +// @Param namespace header string true "Name of the bucket" +// @Param appName query string true "application name" +// @Success 200 {object} rag.Report +// @Failure 400 {object} map[string]string +// @Failure 500 {object} map[string]string +// @Router /rags/report [get] +func (r *RagAPI) Summary(ctx *gin.Context) { + ragName := ctx.Query(ragNameQuery) + appName := ctx.Query(appNameQuery) + namespace := ctx.GetHeader(namespaceHeadr) + + rr := v1alpha1.RAG{} + if err := r.c.Get(ctx, types.NamespacedName{ + Namespace: namespace, Name: ragName, + }, &rr); err != nil { + klog.Error(fmt.Sprintf("can't get rag by name %s", ragName)) + ctx.AbortWithStatusJSON(http.StatusBadRequest, gin.H{ + "message": fmt.Sprintf("can't get rag by name %s", ragName), + }) + return + } + thresholds := make(map[string]float64) + for _, param := range rr.Spec.Metrics { + thresholds[string(param.Kind)] = float64(param.ToleranceThreshbold) / 100.0 + } + + report, err := rag.ParseSummary(ctx.Request.Context(), r.c, appName, ragName, namespace, thresholds) + if err != nil { + klog.Errorf("an error occurred generating the report, error %s", err) + ctx.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{ + "message": err.Error(), + }) + return + } + ctx.JSON(http.StatusOK, report) +} + +// @Summary Get detail data of a rag +// @Schemes +// @Description Get detail data of a rag +// @Tags RAG +// @Accept json +// @Produce json +// @Param ragName query string true "rag name" +// @Param namespace header string true "Name of the bucket" +// @Param appName query string true "application name" +// @Param page query int false "default is 1" +// @Param size query string false "default is 10" +// @Param sortBy query string false "rag metrcis" +// @Param order query string false "desc or asc" +// @Success 200 {object} rag.ReportDetail +// @Failure 400 {object} map[string]string +// @Failure 500 {object} map[string]string +// @Router /rags/detail [get] +func (r *RagAPI) ReportDetail(ctx *gin.Context) { + page, _ := strconv.Atoi(ctx.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(ctx.DefaultQuery("size", "10")) + sortBy := ctx.Query("sortBy") + order := ctx.DefaultQuery("order", "desc") + ragName := ctx.Query(ragNameQuery) + appName := ctx.Query(appNameQuery) + namespace := ctx.GetHeader(namespaceHeadr) + + result, err := rag.ParseResult(ctx.Request.Context(), r.c, page, pageSize, appName, ragName, namespace, sortBy, order) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{ + "message": err.Error(), + }) + return + } + ctx.JSON(http.StatusOK, result) +} + +func registerRAG(g *gin.RouterGroup, conf gqlconfig.ServerConfig) { + cfg := ctrl.GetConfigOrDie() + c, err := client.New(cfg, client.Options{Scheme: conf.Scheme}) + if err != nil { + panic(err) + } + api := RagAPI{c: c} + + g.GET("/report", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, v1alpha1.GroupVersion, "get", "rags"), api.Summary) + g.GET("/detail", auth.AuthInterceptor(conf.EnableOIDC, oidc.Verifier, v1alpha1.GroupVersion, "get", "rags"), api.ReportDetail) +} diff --git a/apiserver/service/router.go b/apiserver/service/router.go index 8101ffa0c..04cf9e815 100644 --- a/apiserver/service/router.go +++ b/apiserver/service/router.go @@ -61,6 +61,9 @@ func NewServerAndRun(conf config.ServerConfig) { // for ops apis with graphql registerGraphQL(r, bffGroup, conf) + ragGroup := r.Group("/rags") + registerRAG(ragGroup, conf) + // for chat server with Restful apis chatGroup := r.Group("/chat") registerChat(chatGroup, conf)