Skip to content

Commit

Permalink
wrap user prompt by |start| & |end| to avoid from content-filter
Browse files Browse the repository at this point in the history
  • Loading branch information
sfwn committed Sep 26, 2023
1 parent f821eda commit 2abcfaf
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions internal/apps/ai-proxy/filters/message-context/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ import (
"github.com/sashabaranov/go-openai"
"gopkg.in/yaml.v3"

"github.com/erda-project/erda-infra/base/logs"
promptpb "github.com/erda-project/erda-proto-go/apps/aiproxy/prompt/pb"
sessionpb "github.com/erda-project/erda-proto-go/apps/aiproxy/session/pb"
"github.com/erda-project/erda/internal/apps/ai-proxy/common/ctxhelper"
"github.com/erda-project/erda/internal/apps/ai-proxy/models/message"
"github.com/erda-project/erda/internal/apps/ai-proxy/providers/dao"
"github.com/erda-project/erda/internal/apps/ai-proxy/vars"
Expand Down Expand Up @@ -70,7 +70,7 @@ func (c *SessionContext) Enable(_ context.Context, req *http.Request) bool {

func (c *SessionContext) OnRequest(ctx context.Context, _ http.ResponseWriter, infor reverseproxy.HttpInfor) (signal reverseproxy.Signal, err error) {
var (
l = ctx.Value(reverseproxy.LoggerCtxKey{}).(logs.Logger)
l = ctxhelper.GetLogger(ctx)
db = ctx.Value(vars.CtxKeyDAO{}).(dao.DAO)
)

Expand Down Expand Up @@ -99,6 +99,11 @@ func (c *SessionContext) OnRequest(ctx context.Context, _ http.ResponseWriter, i
return reverseproxy.Intercept, err
}
for _, msg := range chatCompletionRequest.Messages {
// handle user message, wrap by '|start| your question here |end|'
// to avoid from content-filter
if msg.Role == openai.ChatMessageRoleUser {
msg.Content = strutil.Concat("|start|", msg.Content, "|end|")
}
requestedMessages = append(requestedMessages, msg)
}

Expand Down Expand Up @@ -176,7 +181,11 @@ func (c *SessionContext) OnRequest(ctx context.Context, _ http.ResponseWriter, i

// set to request body
chatCompletionRequest.Messages = allMessages
b, _ := json.Marshal(&chatCompletionRequest)
b, err := json.Marshal(&chatCompletionRequest)
if err != nil {
l.Errorf("failed to marshal request body, err: %v", err)
return reverseproxy.Intercept, err
}
infor.SetBody(io.NopCloser(bytes.NewBuffer(b)), int64(len(b)))

return reverseproxy.Continue, nil
Expand Down

0 comments on commit 2abcfaf

Please sign in to comment.