Skip to content

Commit

Permalink
Merge pull request kubeagi#412 from Abirdcfly/chatmore
Browse files Browse the repository at this point in the history
feat: add more chat api
  • Loading branch information
bjwswang authored Dec 21, 2023
2 parents 4a222ae + bc1df20 commit cffe21f
Show file tree
Hide file tree
Showing 12 changed files with 320 additions and 86 deletions.
61 changes: 37 additions & 24 deletions apiserver/pkg/application/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,37 +211,49 @@ func GetApplication(ctx context.Context, c dynamic.Interface, name, namespace st
if err != nil {
return nil, err
}
app := &v1alpha1.Application{}
if err := getResource(ctx, c, common.SchemaOf(&common.ArcadiaAPIGroup, "Application"), namespace, name, app); err != nil {
return nil, err
}

prompt := &apiprompt.Prompt{}
if err := getResource(ctx, c, common.SchemaOf(&common.ArcadiaAPIGroup, "Prompt"), namespace, name, prompt); err != nil {
return nil, err
}
var (
chainConfig *apichain.CommonChainConfig
llmChainInput *apichain.LLMChainInput
retriever *apiretriever.KnowledgeBaseRetriever
)
qachain := &apichain.RetrievalQAChain{}
if err := getResource(ctx, c, common.SchemaOf(&common.ArcadiaAPIGroup, "RetrievalQAChain"), namespace, name, qachain); err != nil {
return nil, err
}
if qachain.UID != "" {
chainConfig = &qachain.Spec.CommonChainConfig
llmChainInput = &qachain.Spec.Input.LLMChainInput
}
llmchain := &apichain.LLMChain{}
if err := getResource(ctx, c, common.SchemaOf(&common.ArcadiaAPIGroup, "LLMChain"), namespace, name, llmchain); err != nil {
return nil, err
}
if llmchain.UID != "" {
chainConfig = &llmchain.Spec.CommonChainConfig
llmChainInput = &llmchain.Spec.Input
}
retriever := &apiretriever.KnowledgeBaseRetriever{}
if err := getResource(ctx, c, common.SchemaOf(&common.ArcadiaAPIGroup, "KnowledgeBaseRetriever"), namespace, name, retriever); err != nil {
return nil, err
hasKnowledgeBaseRetriever := false
for _, node := range app.Spec.Nodes {
if node.Ref != nil && node.Ref.APIGroup != nil && *node.Ref.APIGroup == apiretriever.Group {
hasKnowledgeBaseRetriever = true
break
}
}
app := &v1alpha1.Application{}
if err := getResource(ctx, c, common.SchemaOf(&common.ArcadiaAPIGroup, "Application"), namespace, name, app); err != nil {
return nil, err
if hasKnowledgeBaseRetriever {
qachain := &apichain.RetrievalQAChain{}
if err := getResource(ctx, c, common.SchemaOf(&common.ArcadiaAPIGroup, "RetrievalQAChain"), namespace, name, qachain); err != nil {
return nil, err
}
if qachain.UID != "" {
chainConfig = &qachain.Spec.CommonChainConfig
llmChainInput = &qachain.Spec.Input.LLMChainInput
}
retriever = &apiretriever.KnowledgeBaseRetriever{}
if err := getResource(ctx, c, common.SchemaOf(&common.ArcadiaAPIGroup, "KnowledgeBaseRetriever"), namespace, name, retriever); err != nil {
return nil, err
}
} else {
llmchain := &apichain.LLMChain{}
if err := getResource(ctx, c, common.SchemaOf(&common.ArcadiaAPIGroup, "LLMChain"), namespace, name, llmchain); err != nil {
return nil, err
}
if llmchain.UID != "" {
chainConfig = &llmchain.Spec.CommonChainConfig
llmChainInput = &llmchain.Spec.Input
}
}

return cr2app(prompt, chainConfig, llmChainInput, retriever, app)
Expand All @@ -264,8 +276,6 @@ func ListApplicationMeatadatas(ctx context.Context, c dynamic.Interface, input g
return res.Items[i].GetCreationTimestamp().After(res.Items[j].GetCreationTimestamp().Time)
})

totalCount := len(res.Items)

filterd := make([]generated.PageNode, 0)
for _, u := range res.Items {
if keyword != "" {
Expand All @@ -280,6 +290,8 @@ func ListApplicationMeatadatas(ctx context.Context, c dynamic.Interface, input g
}
filterd = append(filterd, m)
}
totalCount := len(filterd)

end := page * pageSize
if end > totalCount {
end = totalCount
Expand Down Expand Up @@ -517,6 +529,7 @@ func UpdateApplicationConfig(ctx context.Context, c dynamic.Interface, input gen
retriever.Spec.ScoreThreshold = float32(pointer.Float64Deref(input.ScoreThreshold, float64(retriever.Spec.ScoreThreshold)))
retriever.Spec.NumDocuments = pointer.IntDeref(input.NumDocuments, retriever.Spec.NumDocuments)
retriever.Spec.DocNullReturn = pointer.StringDeref(input.DocNullReturn, retriever.Spec.DocNullReturn)
retriever.Spec.Input.KnowledgeBaseRef.Name = *input.Knowledgebase
}, retriever); err != nil {
return nil, err
}
Expand Down
25 changes: 15 additions & 10 deletions apiserver/pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@ import (
client1 "github.com/kubeagi/arcadia/apiserver/pkg/client"
)

type idtokenKey struct{}
type contextKey string

const (
idTokenContextKey contextKey = "idToken"
UserNameContextKey contextKey = "userName"
)

type User struct {
Name string `json:"name"`
Expand All @@ -61,11 +66,11 @@ func isBearerToken(token string) (bool, string) {
return head == "bearer" && len(payload) > 0, payload
}

func cani(c dynamic.Interface, oidcToken *oidc.IDToken, resource, verb, namespace string) (bool, error) {
func cani(c dynamic.Interface, oidcToken *oidc.IDToken, resource, verb, namespace string) (bool, string, error) {
u := &User{}
if err := oidcToken.Claims(u); err != nil {
klog.Errorf("parse user info from idToken, error %v", err)
return false, fmt.Errorf("can't parse user info")
return false, "", fmt.Errorf("can't parse user info")
}

av := av1.SubjectAccessReview{
Expand All @@ -87,15 +92,15 @@ func cani(c dynamic.Interface, oidcToken *oidc.IDToken, resource, verb, namespac
if err != nil {
err = fmt.Errorf("auth can-i failed, error %w", err)
klog.Error(err)
return false, err
return false, "", err
}

ok, found, err := unstructured.NestedBool(u1.Object, "status", "allowed")
if err != nil || !found {
klog.Warning("not found allowed filed or some errors occurred.")
return false, err
return false, "", err
}
return ok, nil
return ok, u.Name, nil
}

func AuthInterceptor(needAuth bool, oidcVerifier *oidc.IDTokenVerifier, verb, resources string) gin.HandlerFunc {
Expand Down Expand Up @@ -133,7 +138,7 @@ func AuthInterceptor(needAuth bool, oidcVerifier *oidc.IDTokenVerifier, verb, re
return
}
if verb != "" {
allowed, err := cani(client, oidcIDtoken, resources, verb, namespace)
allowed, userName, err := cani(client, oidcIDtoken, resources, verb, namespace)
if err != nil {
klog.Errorf("auth error: failed to checkout permission. error %s", err)
ctx.AbortWithStatusJSON(http.StatusForbidden, gin.H{
Expand All @@ -148,17 +153,17 @@ func AuthInterceptor(needAuth bool, oidcVerifier *oidc.IDTokenVerifier, verb, re
})
return
}
ctx.Request = ctx.Request.WithContext(context.WithValue(ctx.Request.Context(), UserNameContextKey, userName))
}

// for graphql query
ctx1 := context.WithValue(ctx.Request.Context(), idtokenKey{}, rawToken)
ctx.Request = ctx.Request.WithContext(ctx1)
ctx.Request = ctx.Request.WithContext(context.WithValue(ctx.Request.Context(), idTokenContextKey, rawToken))
ctx.Next()
}
}

func ForOIDCToken(ctx context.Context) *string {
v, _ := ctx.Value(idtokenKey{}).(string)
v, _ := ctx.Value(idTokenContextKey).(string)
if v == "" {
return nil
}
Expand Down
78 changes: 76 additions & 2 deletions apiserver/pkg/chat/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package chat
import (
"context"
"errors"
"sync"
"time"

"github.com/tmc/langchaingo/memory"
Expand All @@ -33,9 +34,13 @@ import (
"github.com/kubeagi/arcadia/apiserver/pkg/client"
"github.com/kubeagi/arcadia/pkg/application"
"github.com/kubeagi/arcadia/pkg/application/base"
"github.com/kubeagi/arcadia/pkg/application/retriever"
)

var Conversions = map[string]Conversion{}
var (
mu sync.Mutex
Conversions = map[string]Conversion{}
)

func AppRun(ctx context.Context, req ChatReqBody, respStream chan string) (*ChatRespBody, error) {
token := auth.ForOIDCToken(ctx)
Expand All @@ -57,15 +62,22 @@ func AppRun(ctx context.Context, req ChatReqBody, respStream chan string) (*Chat
return nil, errors.New("application is not ready")
}
var conversion Conversion
currentUser, _ := ctx.Value(auth.UserNameContextKey).(string)
if req.ConversionID != "" {
var ok bool
conversion, ok = Conversions[req.ConversionID]
if !ok {
return nil, errors.New("conversion is not found")
}
if currentUser != "" && currentUser != conversion.User {
return nil, errors.New("conversion id not match with user")
}
if conversion.AppName != req.APPName || conversion.AppNamespce != req.AppNamespace {
return nil, errors.New("conversion id not match with app info")
}
if conversion.Debug != req.Debug {
return nil, errors.New("conversion id not match with debug")
}
} else {
conversion = Conversion{
ID: string(uuid.NewUUID()),
Expand All @@ -75,10 +87,13 @@ func AppRun(ctx context.Context, req ChatReqBody, respStream chan string) (*Chat
UpdatedAt: time.Now(),
Messages: make([]Message, 0),
History: memory.NewChatMessageHistory(),
User: currentUser,
Debug: req.Debug,
}
}
messageID := string(uuid.NewUUID())
conversion.Messages = append(conversion.Messages, Message{
ID: string(uuid.NewUUID()),
ID: messageID,
Query: req.Query,
Answer: "",
})
Expand All @@ -95,12 +110,71 @@ func AppRun(ctx context.Context, req ChatReqBody, respStream chan string) (*Chat

conversion.UpdatedAt = time.Now()
conversion.Messages[len(conversion.Messages)-1].Answer = out.Answer
conversion.Messages[len(conversion.Messages)-1].References = out.References
mu.Lock()
Conversions[conversion.ID] = conversion
mu.Unlock()
return &ChatRespBody{
ConversionID: conversion.ID,
MessageID: messageID,
Message: out.Answer,
CreatedAt: time.Now(),
References: out.References,
}, nil
}

func ListConversations(ctx context.Context, req APPMetadata) ([]Conversion, error) {
conversations := make([]Conversion, 0)
currentUser, _ := ctx.Value(auth.UserNameContextKey).(string)
mu.Lock()
for _, c := range Conversions {
if !c.Debug && c.AppName == req.APPName && c.AppNamespce == req.AppNamespace && (currentUser == "" || currentUser == c.User) {
conversations = append(conversations, c)
}
}
mu.Unlock()
return conversations, nil
}

func DeleteConversation(ctx context.Context, conversionID string) error {
currentUser, _ := ctx.Value(auth.UserNameContextKey).(string)
mu.Lock()
defer mu.Unlock()
c, ok := Conversions[conversionID]
if ok && (currentUser == "" || currentUser == c.User) {
delete(Conversions, c.ID)
return nil
} else {
return errors.New("conversion is not found")
}
}

func ListMessages(ctx context.Context, req ConversionReqBody) (Conversion, error) {
currentUser, _ := ctx.Value(auth.UserNameContextKey).(string)
mu.Lock()
defer mu.Unlock()
for _, c := range Conversions {
if c.AppName == req.APPName && c.AppNamespce == req.AppNamespace && req.ConversionID == c.ID && (currentUser == "" || currentUser == c.User) {
return c, nil
}
}
return Conversion{}, errors.New("conversion is not found")
}

func GetMessageReferences(ctx context.Context, req MessageReqBody) ([]retriever.Reference, error) {
currentUser, _ := ctx.Value(auth.UserNameContextKey).(string)
mu.Lock()
defer mu.Unlock()
for _, c := range Conversions {
if c.AppName == req.APPName && c.AppNamespce == req.AppNamespace && c.ID == req.ConversionID && (currentUser == "" || currentUser == c.User) {
for _, m := range c.Messages {
if m.ID == req.MessageID {
return m.References, nil
}
}
}
}
return nil, errors.New("conversion or message is not found")
}

// todo Reuse the flow without having to rebuild req same, not finish, Flow doesn't start with/contain nodes that depend on incomingInput.question
58 changes: 39 additions & 19 deletions apiserver/pkg/chat/chat_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"time"

"github.com/tmc/langchaingo/memory"

"github.com/kubeagi/arcadia/pkg/application/retriever"
)

type ResponseMode string
Expand All @@ -30,33 +32,51 @@ const (
// todo isFlowValidForStream only some node(llm chain) support streaming
)

type APPMetadata struct {
APPName string `json:"app_name" binding:"required"`
AppNamespace string `json:"app_namespace" binding:"required"`
}

type ConversionReqBody struct {
APPMetadata `json:",inline"`
ConversionID string `json:"conversion_id"`
}

type MessageReqBody struct {
ConversionReqBody `json:",inline"`
MessageID string `json:"message_id"`
}

type ChatReqBody struct {
Query string `json:"query" binding:"required"`
ResponseMode ResponseMode `json:"response_mode" binding:"required"`
ConversionID string `json:"conversion_id"`
APPName string `json:"app_name" binding:"required"`
AppNamespace string `json:"app_namespace" binding:"required"`
Query string `json:"query" binding:"required"`
ResponseMode ResponseMode `json:"response_mode" binding:"required"`
ConversionReqBody `json:",inline"`
Debug bool `json:"-"`
}

type ChatRespBody struct {
ConversionID string `json:"conversion_id"`
MessageID string `json:"message_id"`
Message string `json:"message"`
CreatedAt time.Time `json:"created_at"`
ConversionID string `json:"conversion_id"`
MessageID string `json:"message_id"`
Message string `json:"message"`
CreatedAt time.Time `json:"created_at"`
References []retriever.Reference `json:"references,omitempty"`
}

type Conversion struct {
ID string `json:"id"`
AppName string `json:"app_name"`
AppNamespce string `json:"app_namespace"`
StartedAt time.Time `json:"started_at"`
UpdatedAt time.Time `json:"updated_at"`
Messages []Message `json:"messages"`
History *memory.ChatMessageHistory
ID string `json:"id"`
AppName string `json:"app_name"`
AppNamespce string `json:"app_namespace"`
StartedAt time.Time `json:"started_at"`
UpdatedAt time.Time `json:"updated_at"`
Messages []Message `json:"messages"`
History *memory.ChatMessageHistory `json:"-"`
User string `json:"-"`
Debug bool `json:"-"`
}

type Message struct {
ID string `json:"id"`
Query string `json:"query"`
Answer string `json:"answer"`
ID string `json:"id"`
Query string `json:"query"`
Answer string `json:"answer"`
References []retriever.Reference `json:"references,omitempty"`
}
Loading

0 comments on commit cffe21f

Please sign in to comment.