diff --git a/internal/apps/ai-proxy/filters/message-context/filter.go b/internal/apps/ai-proxy/filters/message-context/filter.go index 77ea406678b..7fb971feb8a 100644 --- a/internal/apps/ai-proxy/filters/message-context/filter.go +++ b/internal/apps/ai-proxy/filters/message-context/filter.go @@ -26,9 +26,9 @@ import ( "github.com/sashabaranov/go-openai" "gopkg.in/yaml.v3" - "github.com/erda-project/erda-infra/base/logs" promptpb "github.com/erda-project/erda-proto-go/apps/aiproxy/prompt/pb" sessionpb "github.com/erda-project/erda-proto-go/apps/aiproxy/session/pb" + "github.com/erda-project/erda/internal/apps/ai-proxy/common/ctxhelper" "github.com/erda-project/erda/internal/apps/ai-proxy/models/message" "github.com/erda-project/erda/internal/apps/ai-proxy/providers/dao" "github.com/erda-project/erda/internal/apps/ai-proxy/vars" @@ -70,7 +70,7 @@ func (c *SessionContext) Enable(_ context.Context, req *http.Request) bool { func (c *SessionContext) OnRequest(ctx context.Context, _ http.ResponseWriter, infor reverseproxy.HttpInfor) (signal reverseproxy.Signal, err error) { var ( - l = ctx.Value(reverseproxy.LoggerCtxKey{}).(logs.Logger) + l = ctxhelper.GetLogger(ctx) db = ctx.Value(vars.CtxKeyDAO{}).(dao.DAO) ) @@ -99,6 +99,11 @@ func (c *SessionContext) OnRequest(ctx context.Context, _ http.ResponseWriter, i return reverseproxy.Intercept, err } for _, msg := range chatCompletionRequest.Messages { + // handle user message, wrap by '|start| your question here |end|' + // to avoid from content-filter + if msg.Role == openai.ChatMessageRoleUser { + msg.Content = strutil.Concat("|start|", msg.Content, "|end|") + } requestedMessages = append(requestedMessages, msg) } @@ -176,7 +181,11 @@ func (c *SessionContext) OnRequest(ctx context.Context, _ http.ResponseWriter, i // set to request body chatCompletionRequest.Messages = allMessages - b, _ := json.Marshal(&chatCompletionRequest) + b, err := json.Marshal(&chatCompletionRequest) + if err != nil { + l.Errorf("failed to marshal request body, err: %v", err) + return reverseproxy.Intercept, err + } infor.SetBody(io.NopCloser(bytes.NewBuffer(b)), int64(len(b))) return reverseproxy.Continue, nil