Skip to content

Commit

Permalink
feat(reactor): reactor ai function sdk use openai (#6040)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzhuzhen authored Sep 13, 2023
1 parent b0ad278 commit 8cad1c9
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 161 deletions.
2 changes: 1 addition & 1 deletion internal/pkg/ai-functions/functions/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type Function interface {
Description() string
SystemMessage() string
UserMessage() string
Schema() json.RawMessage
Schema() (json.RawMessage, error)
RequestOptions() []sdk.RequestOption
CompletionOptions() []sdk.PatchOption
Callback(ctx context.Context, arguments json.RawMessage, input interface{}, needAdjust bool) (any, error)
Expand Down
8 changes: 5 additions & 3 deletions internal/pkg/ai-functions/functions/test-case/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/erda-project/erda/bundle"
"github.com/erda-project/erda/internal/pkg/ai-functions/functions"
"github.com/erda-project/erda/internal/pkg/ai-functions/sdk"
"github.com/erda-project/erda/pkg/strutil"
)

const Name = "create-test-case"
Expand Down Expand Up @@ -94,8 +95,9 @@ func (f *Function) UserMessage() string {
return "Not really implemented."
}

func (f *Function) Schema() json.RawMessage {
return Schema
func (f *Function) Schema() (json.RawMessage, error) {
schema, err := strutil.YamlOrJsonToJson(Schema)
return schema, err
}

func (f *Function) RequestOptions() []sdk.RequestOption {
Expand All @@ -107,7 +109,7 @@ func (f *Function) RequestOptions() []sdk.RequestOption {
func (f *Function) CompletionOptions() []sdk.PatchOption {
return []sdk.PatchOption{
sdk.PathOptionWithModel("gpt-35-turbo-16k"),
sdk.PathOptionWithTemperature("1"),
sdk.PathOptionWithTemperature(1),
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (h *AIFunction) createTestCaseForRequirementIDAndTestID(ctx context.Context
logrus.Debugf("parse createTestCase functionParams=%+v", functionParams)

if err := validateParamsForCreateTestcase(functionParams); err != nil {
return nil, errors.Wrapf(err, "process single testCase create faild")
return nil, errors.Wrapf(err, "validateParamsForCreateTestcase faild")
}

for _, tp := range functionParams.Requirements {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

"bou.ke/monkey"
"github.com/golang/protobuf/jsonpb"
"github.com/sashabaranov/go-openai"
"google.golang.org/protobuf/types/known/structpb"

"github.com/erda-project/erda-infra/base/logs"
Expand Down Expand Up @@ -179,18 +180,18 @@ func TestAIFunction_createTestCaseForRequirementIDAndTestID(t *testing.T) {
}

monkey.PatchInstanceMethod(reflect.TypeOf(&sdk.Client{}), "CreateCompletion", func(_ *sdk.Client,
ctx context.Context, req *sdk.CreateCompletionOptions) (*sdk.ChatCompletions, error) {
choices := make([]*sdk.ChatChoice, 0)
choices = append(choices, &sdk.ChatChoice{
ctx context.Context, req *openai.ChatCompletionRequest) (*openai.ChatCompletionResponse, error) {
choices := make([]openai.ChatCompletionChoice, 0)
choices = append(choices, openai.ChatCompletionChoice{
Index: 0,
Message: &sdk.ChatMessage{
FunctionCall: &sdk.FunctionCall{
Message: openai.ChatCompletionMessage{
FunctionCall: &openai.FunctionCall{
Arguments: arguments,
},
},
})

return &sdk.ChatCompletions{
return &openai.ChatCompletionResponse{
Choices: choices,
}, nil
})
Expand Down
57 changes: 34 additions & 23 deletions internal/pkg/ai-functions/handler/utils/handler_function_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
"net/url"

"github.com/pkg/errors"
"github.com/sashabaranov/go-openai"
"github.com/sirupsen/logrus"

"github.com/erda-project/erda-proto-go/apps/aifunction/pb"
"github.com/erda-project/erda/internal/pkg/ai-functions/functions"
Expand All @@ -35,39 +37,44 @@ func GetChatMessageFunctionCallArguments(ctx context.Context, factory functions.
var (
err error
f = factory(ctx, "", req.GetBackground())
systemMsg = &sdk.ChatMessage{
systemMsg = openai.ChatCompletionMessage{
Role: "system",
Content: f.SystemMessage(),
Name: "system",
}
userMsg = &sdk.ChatMessage{
userMsg = openai.ChatCompletionMessage{
Role: "user",
Content: prompt,
Name: "erda",
}
fd = &sdk.FunctionDefinition{
Name: f.Name(),
Description: f.Description(),
Parameters: f.Schema(),
}
options = &sdk.CreateCompletionOptions{
Messages: []*sdk.ChatMessage{systemMsg, userMsg}, // todo: history messages
Functions: []*sdk.FunctionDefinition{fd},
FunctionCall: sdk.FunctionCall{Name: fd.Name},
Temperature: "1", // default 1, can be modified by f.CompletionOptions()
Stream: false,
Model: "gpt-35-turbo-16k", // default the newest model, can be modified by f.CompletionOptions()
}
)

schema, err := f.Schema()
if err != nil {
return nil, err
}

fd := openai.FunctionDefinition{
Name: f.Name(),
Description: f.Description(),
Parameters: schema,
}
logrus.Debugf("openai.FunctionDefinition fd.Parameters string: %s\n", fd.Parameters)

options := &openai.ChatCompletionRequest{
Messages: []openai.ChatCompletionMessage{systemMsg, userMsg}, // todo: history messages
Functions: []openai.FunctionDefinition{fd},
FunctionCall: openai.FunctionCall{Name: fd.Name},
Temperature: 1, // default 1, can be modified by f.CompletionOptions()
Stream: false,
Model: "gpt-35-turbo-16k", // default the newest model, can be modified by f.CompletionOptions()
}

cos := f.CompletionOptions()
for _, o := range cos {
o(options)
}
if valid := json.Valid(fd.Parameters); !valid {
if fd.Parameters, err = strutil.YamlOrJsonToJson(f.Schema()); err != nil {
return nil, err
}
}

// 在 request option 中添加认证信息: 以某组织下某用户身份调用 ai-proxy,
// ai-proxy 中的 filter erda-auth 会回调 erda.cloud 的 openai, 检查该企业和用户是否有权使用 AI 能力
ros := append(f.RequestOptions(), func(r *http.Request) {
Expand All @@ -83,12 +90,16 @@ func GetChatMessageFunctionCallArguments(ctx context.Context, factory functions.
if err != nil {
return nil, errors.Wrap(err, "failed to CreateCompletion")
}
if len(completion.Choices) == 0 || completion.Choices[0].Message == nil || completion.Choices[0].Message.FunctionCall == nil {
if len(completion.Choices) == 0 || completion.Choices[0].Message.FunctionCall == nil {
return nil, errors.New("no idea") // todo: do not return error, response friendly
}
// todo: check index out of range and invalid memory reference
arguments := completion.Choices[0].Message.FunctionCall.JSONMessageArguments()
if err = fd.VerifyArguments(arguments); err != nil {
arguments, err := strutil.YamlOrJsonToJson([]byte(completion.Choices[0].Message.FunctionCall.Arguments))
if err != nil {
arguments = json.RawMessage(completion.Choices[0].Message.FunctionCall.Arguments)
}

if err = sdk.VerifyArguments(fd.Parameters.(json.RawMessage), arguments); err != nil {
return nil, errors.Wrap(err, "invalid arguments from FunctionCall")
}

Expand Down
120 changes: 3 additions & 117 deletions internal/pkg/ai-functions/sdk/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,60 +25,10 @@ import (
"github.com/erda-project/erda/pkg/strutil"
)

type CreateCompletionOptions struct {
Messages []*ChatMessage `json:"messages"`
Functions []*FunctionDefinition `json:"functions,omitempty" yaml:"functions,omitempty"`
FunctionCall FunctionCall `json:"function_call,omitempty" yaml:"function_call,omitempty"`
MaxTokens int32 `json:"max_tokens,omitempty" yaml:"maxTokens,omitempty"`
Temperature json.Number `json:"temperature,omitempty" yaml:"temperature,omitempty"`
TopP json.Number `json:"top_p,omitempty" yaml:"top_p,omitempty"`
LogitBias map[string]int32 `json:"logit_bias,omitempty" yaml:"logit_bias,omitempty"`
User string `json:"user,omitempty" yaml:"user,omitempty"`
N int32 `json:"n,omitempty" yaml:"n,omitempty"`
Stop string `json:"stop,omitempty" yaml:"stop,omitempty"`
PresencePenalty json.Number `json:"presence_penalty,omitempty" yaml:"presence_penalty,omitempty"`
FrequencyPenalty json.Number `json:"frequency_penalty,omitempty" yaml:"frequency_penalty,omitempty"`
Stream bool `json:"stream,omitempty" yaml:"stream,omitempty"`
Model string `json:"model,omitempty" yaml:"model,omitempty"`
}

func (c *CreateCompletionOptions) Validate() error {
if len(c.Messages) == 0 {
return errors.New("messages is required")
}
return nil
}

type ChatCompletions struct {
ID string `json:"id" yaml:"id"`
Object string `json:"object,omitempty" yaml:"object,omitempty"`
Created uint64 `json:"created" yaml:"created"`
Model string `json:"model,omitempty" yaml:"model,omitempty"`
Choices []*ChatChoice `json:"choices" yaml:"choices"`
Usage *CompletionsUsage `json:"usage,omitempty" yaml:"usage,omitempty"`
PromptAnnotations []*PromptFilterResult `json:"prompt_annotations,omitempty" yaml:"prompt_annotations,omitempty"`
}

type ChatMessage struct {
Role string `json:"role" yaml:"role"`
Content string `json:"content" yaml:"content"`
Name string `json:"name,omitempty" yaml:"name,omitempty"`
// FunctionCall The name and arguments of a function that should be called, as generated by the model.
FunctionCall *FunctionCall `json:"function_call,omitempty" yaml:"function_call,omitempty"`
}

// FunctionDefinition The definition of a caller-specified function that chat completions may invoke in response to matching user input.
type FunctionDefinition struct {
Name string `json:"name" yaml:"name"`
Description string `json:"description" yaml:"description"`
// Parameters: The parameters the functions accepts, described as a JSON Schema object.
Parameters json.RawMessage `json:"parameters" yaml:"parameters"`
}

// VerifyArguments verifies that the given JSON conforms to the JSON Schema FunctionDefinition.Parameters
func (fd *FunctionDefinition) VerifyArguments(data json.RawMessage) (err error) {
func VerifyArguments(parameters, data json.RawMessage) (err error) {
// fd.Parameters and data may be either JSON or Yaml structured, convert to JSON structured uniformly.
if fd.Parameters, err = strutil.YamlOrJsonToJson(fd.Parameters); err != nil {
if parameters, err = strutil.YamlOrJsonToJson(parameters); err != nil {
return errors.Wrap(err, "failed to unmarshal Parameters to JSON")
}
if data, err = strutil.YamlOrJsonToJson(data); err != nil {
Expand All @@ -87,7 +37,7 @@ func (fd *FunctionDefinition) VerifyArguments(data json.RawMessage) (err error)
if valid := json.Valid(data); !valid {
return errors.New("data is invalid JSON")
}
ls := gojsonschema.NewBytesLoader(fd.Parameters)
ls := gojsonschema.NewBytesLoader(parameters)
ld := gojsonschema.NewBytesLoader(data)
result, err := gojsonschema.Validate(ls, ld)
if err != nil {
Expand All @@ -102,67 +52,3 @@ func (fd *FunctionDefinition) VerifyArguments(data json.RawMessage) (err error)
}
return errors.New(strings.Join(ss, "; "))
}

type FunctionCall struct {
Name string `json:"name" yaml:"name"`
Arguments string `json:"arguments,omitempty" yaml:"arguments,omitempty"`
}

func (fc *FunctionCall) JSONMessageArguments() json.RawMessage {
data, err := strutil.YamlOrJsonToJson([]byte(fc.Arguments))
if err != nil {
return json.RawMessage(fc.Arguments)
}
return data
}

type ChatChoice struct {
Index int32 `json:"index" yaml:"index"`
Message *ChatMessage `json:"message" yaml:"message"`
FinishReason string `json:"finish_reason,omitempty" yaml:"finish_reason,omitempty"`
Delta *ChatMessage `json:"delta,omitempty" yaml:"delta,omitempty"`
ContentFilterResults *ContentFilterResults `json:"content_filter_results,omitempty" yaml:"content_filter_results,omitempty"`
}

type PromptFilterResult struct {
PromptIndex int32 `json:"prompt_index"`
ContentFilterResults ContentFilterResults `json:"content_filter_results" yaml:"content_filter_results"`
}

type CompletionsUsage struct {
// todo:
}

type ContentFilterResults struct {
Sexual ContentFilterResult `json:"sexual" yaml:"sexual"`
Violence ContentFilterResult `json:"violence" yaml:"violence"`
Hate ContentFilterResult `json:"hate" yaml:"hate"`
SelfHarm ContentFilterResult `json:"self_harm" yaml:"self_harm"`
Error ErrorBase `json:"error" yaml:"error"`
}

type ContentFilterResult struct {
Severity string `json:"severity" yaml:"severity"`
Filtered bool `json:"filtered" yaml:"filtered"`
}

type ErrorBase struct {
Code string `json:"code" yaml:"code"`
Message string `json:"error" yaml:"error"`
}

type Error struct {
*ErrorBase
Code string
Message string
Param string
Type string
InnerError InnerError
}

type InnerError struct {
Code InnerErrorCode `json:"code" yaml:"code"`
ContentFilterResults ContentFilterResults `json:"contentFilterResults" yaml:"contentFilterResults"`
}

type InnerErrorCode string
20 changes: 12 additions & 8 deletions internal/pkg/ai-functions/sdk/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"os"

"github.com/pkg/errors"
"github.com/sashabaranov/go-openai"
"github.com/sirupsen/logrus"

"github.com/erda-project/erda/pkg/http/httputil"
Expand Down Expand Up @@ -65,16 +66,19 @@ func (c *Client) HttpClient() *http.Client {
return c.client
}

func (c *Client) CreateCompletion(ctx context.Context, req *CreateCompletionOptions) (*ChatCompletions, error) {
func (c *Client) CreateCompletion(ctx context.Context, req *openai.ChatCompletionRequest) (*openai.ChatCompletionResponse, error) {
req.Stream = false
logrus.Infof("CreateCompletion with req: %+v\n", *req)

var buf bytes.Buffer

if err := json.NewEncoder(&buf).Encode(req); err != nil {
return nil, errors.Wrap(err, "failed to Encode CreateCompletionOptions")
}
aiProxyClientAK := os.Getenv(AIProxyClientAK)
if aiProxyClientAK == "" {
err := errors.Errorf("env %s not set", AIProxyClientAK)
return nil, errors.Wrap(err, "failed to Encode CreateCompletionOptions")
return nil, errors.Wrap(err, "failed to get ai proxy client ak")
}
request, err := http.NewRequest(http.MethodPost, c.URLV1ChatCompletion(), &buf)
if err != nil {
Expand All @@ -86,7 +90,7 @@ func (c *Client) CreateCompletion(ctx context.Context, req *CreateCompletionOpti
o(request)
}

logrus.Debugf("WWWZZZ AI-Proxy Request: %+v", request)
logrus.Debugf("Post AI-Proxy Request: %+v", request)
response, err := c.HttpClient().Do(request)
if err != nil {
return nil, errors.Wrapf(err, "failed to Do http request to %s", c.URLV1ChatCompletion())
Expand All @@ -100,24 +104,24 @@ func (c *Client) CreateCompletion(ctx context.Context, req *CreateCompletionOpti
}
return nil, errors.Errorf("response not ok, status: %s, message: %s", response.Status, string(data))
}
var chatCompletion ChatCompletions
var chatCompletion openai.ChatCompletionResponse
if err = json.NewDecoder(response.Body).Decode(&chatCompletion); err != nil {
return nil, errors.Wrap(err, "failed to Decode response to ChatCompletion")
}

return &chatCompletion, nil
}

type PatchOption func(option *CreateCompletionOptions)
type PatchOption func(option *openai.ChatCompletionRequest)

func PathOptionWithModel(model string) PatchOption {
return func(cco *CreateCompletionOptions) {
return func(cco *openai.ChatCompletionRequest) {
cco.Model = model
}
}

func PathOptionWithTemperature(temperature json.Number) PatchOption {
return func(cco *CreateCompletionOptions) {
func PathOptionWithTemperature(temperature float32) PatchOption {
return func(cco *openai.ChatCompletionRequest) {
cco.Temperature = temperature
}
}
Expand Down
Loading

0 comments on commit 8cad1c9

Please sign in to comment.