From bc66ce15920b9b9d551ed2a30e90bd7e5e25f2b4 Mon Sep 17 00:00:00 2001 From: Qiying Wang <781345688@qq.com> Date: Sat, 6 Jul 2024 13:19:41 +0800 Subject: [PATCH] feat: refactor AwsClaude to Aws to support both llama3 and claude (#1601) * feat: refactor AwsClaude to Aws to support both llama3 and claude * fix: aws llama3 ratio --- relay/adaptor/aws/{adapter.go => adaptor.go} | 78 +++--- relay/adaptor/aws/claude/adapter.go | 37 +++ relay/adaptor/aws/{ => claude}/main.go | 40 ++- relay/adaptor/aws/{ => claude}/model.go | 0 relay/adaptor/aws/llama3/adapter.go | 37 +++ relay/adaptor/aws/llama3/main.go | 231 ++++++++++++++++++ relay/adaptor/aws/llama3/main_test.go | 45 ++++ relay/adaptor/aws/llama3/model.go | 29 +++ relay/adaptor/aws/registry.go | 39 +++ relay/adaptor/aws/utils/adaptor.go | 51 ++++ relay/adaptor/aws/utils/utils.go | 16 ++ relay/billing/ratio/model.go | 44 +++- relay/controller/audio.go | 9 +- relay/controller/helper.go | 9 +- relay/controller/image.go | 7 +- relay/controller/text.go | 7 +- web/berry/src/constants/ChannelConstants.js | 2 +- .../src/constants/channel.constants.js | 2 +- 18 files changed, 595 insertions(+), 88 deletions(-) rename relay/adaptor/aws/{adapter.go => adaptor.go} (71%) create mode 100644 relay/adaptor/aws/claude/adapter.go rename relay/adaptor/aws/{ => claude}/main.go (86%) rename relay/adaptor/aws/{ => claude}/model.go (100%) create mode 100644 relay/adaptor/aws/llama3/adapter.go create mode 100644 relay/adaptor/aws/llama3/main.go create mode 100644 relay/adaptor/aws/llama3/main_test.go create mode 100644 relay/adaptor/aws/llama3/model.go create mode 100644 relay/adaptor/aws/registry.go create mode 100644 relay/adaptor/aws/utils/adaptor.go create mode 100644 relay/adaptor/aws/utils/utils.go diff --git a/relay/adaptor/aws/adapter.go b/relay/adaptor/aws/adaptor.go similarity index 71% rename from relay/adaptor/aws/adapter.go rename to relay/adaptor/aws/adaptor.go index 7245d3d9fe..62221346d8 100644 --- a/relay/adaptor/aws/adapter.go +++ b/relay/adaptor/aws/adaptor.go @@ -1,17 +1,16 @@ package aws import ( - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/credentials" - "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" - "github.com/songquanpeng/one-api/common/ctxkey" + "errors" "io" "net/http" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" "github.com/gin-gonic/gin" - "github.com/pkg/errors" "github.com/songquanpeng/one-api/relay/adaptor" - "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" ) @@ -19,60 +18,44 @@ import ( var _ adaptor.Adaptor = new(Adaptor) type Adaptor struct { - meta *meta.Meta - awsClient *bedrockruntime.Client + awsAdapter utils.AwsAdapter + + Meta *meta.Meta + AwsClient *bedrockruntime.Client } func (a *Adaptor) Init(meta *meta.Meta) { - a.meta = meta - a.awsClient = bedrockruntime.New(bedrockruntime.Options{ + a.Meta = meta + a.AwsClient = bedrockruntime.New(bedrockruntime.Options{ Region: meta.Config.Region, Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")), }) } -func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { - return "", nil -} - -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { - return nil -} - func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } - claudeReq := anthropic.ConvertRequest(*request) - c.Set(ctxkey.RequestModel, request.Model) - c.Set(ctxkey.ConvertedRequest, claudeReq) - return claudeReq, nil -} - -func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { - if request == nil { - return nil, errors.New("request is nil") + adaptor := GetAdaptor(request.Model) + if adaptor == nil { + return nil, errors.New("adaptor not found") } - return request, nil -} -func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { - return nil, nil + a.awsAdapter = adaptor + return adaptor.ConvertRequest(c, relayMode, request) } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { - if meta.IsStream { - err, usage = StreamHandler(c, a.awsClient) - } else { - err, usage = Handler(c, a.awsClient, meta.ActualModelName) + if a.awsAdapter == nil { + return nil, utils.WrapErr(errors.New("awsAdapter is nil")) } - return + return a.awsAdapter.DoResponse(c, a.AwsClient, meta) } func (a *Adaptor) GetModelList() (models []string) { - for n := range awsModelIDMap { - models = append(models, n) + for model := range adaptors { + models = append(models, model) } return } @@ -80,3 +63,22 @@ func (a *Adaptor) GetModelList() (models []string) { func (a *Adaptor) GetChannelName() string { return "aws" } + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + return "", nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + return nil +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return nil, nil +} diff --git a/relay/adaptor/aws/claude/adapter.go b/relay/adaptor/aws/claude/adapter.go new file mode 100644 index 0000000000..eb3c9fb85c --- /dev/null +++ b/relay/adaptor/aws/claude/adapter.go @@ -0,0 +1,37 @@ +package aws + +import ( + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +var _ utils.AwsAdapter = new(Adaptor) + +type Adaptor struct { +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + + claudeReq := anthropic.ConvertRequest(*request) + c.Set(ctxkey.RequestModel, request.Model) + c.Set(ctxkey.ConvertedRequest, claudeReq) + return claudeReq, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, awsCli) + } else { + err, usage = Handler(c, awsCli, meta.ActualModelName) + } + return +} diff --git a/relay/adaptor/aws/main.go b/relay/adaptor/aws/claude/main.go similarity index 86% rename from relay/adaptor/aws/main.go rename to relay/adaptor/aws/claude/main.go index 72f40ddcdf..7142e46f72 100644 --- a/relay/adaptor/aws/main.go +++ b/relay/adaptor/aws/claude/main.go @@ -5,8 +5,6 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/songquanpeng/one-api/common/ctxkey" - "github.com/songquanpeng/one-api/relay/adaptor/openai" "io" "net/http" @@ -17,23 +15,17 @@ import ( "github.com/jinzhu/copier" "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" + "github.com/songquanpeng/one-api/relay/adaptor/openai" relaymodel "github.com/songquanpeng/one-api/relay/model" ) -func wrapErr(err error) *relaymodel.ErrorWithStatusCode { - return &relaymodel.ErrorWithStatusCode{ - StatusCode: http.StatusInternalServerError, - Error: relaymodel.Error{ - Message: fmt.Sprintf("%s", err.Error()), - }, - } -} - // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html -var awsModelIDMap = map[string]string{ +var AwsModelIDMap = map[string]string{ "claude-instant-1.2": "anthropic.claude-instant-v1", "claude-2.0": "anthropic.claude-v2", "claude-2.1": "anthropic.claude-v2:1", @@ -44,7 +36,7 @@ var awsModelIDMap = map[string]string{ } func awsModelID(requestModel string) (string, error) { - if awsModelID, ok := awsModelIDMap[requestModel]; ok { + if awsModelID, ok := AwsModelIDMap[requestModel]; ok { return awsModelID, nil } @@ -54,7 +46,7 @@ func awsModelID(requestModel string) (string, error) { func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) if err != nil { - return wrapErr(errors.Wrap(err, "awsModelID")), nil + return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil } awsReq := &bedrockruntime.InvokeModelInput{ @@ -65,30 +57,30 @@ func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (* claudeReq_, ok := c.Get(ctxkey.ConvertedRequest) if !ok { - return wrapErr(errors.New("request not found")), nil + return utils.WrapErr(errors.New("request not found")), nil } claudeReq := claudeReq_.(*anthropic.Request) awsClaudeReq := &Request{ AnthropicVersion: "bedrock-2023-05-31", } if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { - return wrapErr(errors.Wrap(err, "copy request")), nil + return utils.WrapErr(errors.Wrap(err, "copy request")), nil } awsReq.Body, err = json.Marshal(awsClaudeReq) if err != nil { - return wrapErr(errors.Wrap(err, "marshal request")), nil + return utils.WrapErr(errors.Wrap(err, "marshal request")), nil } awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) if err != nil { - return wrapErr(errors.Wrap(err, "InvokeModel")), nil + return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil } claudeResponse := new(anthropic.Response) err = json.Unmarshal(awsResp.Body, claudeResponse) if err != nil { - return wrapErr(errors.Wrap(err, "unmarshal response")), nil + return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil } openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse) @@ -108,7 +100,7 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E createdTime := helper.GetTimestamp() awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) if err != nil { - return wrapErr(errors.Wrap(err, "awsModelID")), nil + return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil } awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ @@ -119,7 +111,7 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E claudeReq_, ok := c.Get(ctxkey.ConvertedRequest) if !ok { - return wrapErr(errors.New("request not found")), nil + return utils.WrapErr(errors.New("request not found")), nil } claudeReq := claudeReq_.(*anthropic.Request) @@ -127,16 +119,16 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E AnthropicVersion: "bedrock-2023-05-31", } if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { - return wrapErr(errors.Wrap(err, "copy request")), nil + return utils.WrapErr(errors.Wrap(err, "copy request")), nil } awsReq.Body, err = json.Marshal(awsClaudeReq) if err != nil { - return wrapErr(errors.Wrap(err, "marshal request")), nil + return utils.WrapErr(errors.Wrap(err, "marshal request")), nil } awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq) if err != nil { - return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil + return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil } stream := awsResp.GetStream() defer stream.Close() diff --git a/relay/adaptor/aws/model.go b/relay/adaptor/aws/claude/model.go similarity index 100% rename from relay/adaptor/aws/model.go rename to relay/adaptor/aws/claude/model.go diff --git a/relay/adaptor/aws/llama3/adapter.go b/relay/adaptor/aws/llama3/adapter.go new file mode 100644 index 0000000000..83edbc9d25 --- /dev/null +++ b/relay/adaptor/aws/llama3/adapter.go @@ -0,0 +1,37 @@ +package aws + +import ( + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/songquanpeng/one-api/common/ctxkey" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +var _ utils.AwsAdapter = new(Adaptor) + +type Adaptor struct { +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + + llamaReq := ConvertRequest(*request) + c.Set(ctxkey.RequestModel, request.Model) + c.Set(ctxkey.ConvertedRequest, llamaReq) + return llamaReq, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, awsCli) + } else { + err, usage = Handler(c, awsCli, meta.ActualModelName) + } + return +} diff --git a/relay/adaptor/aws/llama3/main.go b/relay/adaptor/aws/llama3/main.go new file mode 100644 index 0000000000..e5fcd89f12 --- /dev/null +++ b/relay/adaptor/aws/llama3/main.go @@ -0,0 +1,231 @@ +// Package aws provides the AWS adaptor for the relay service. +package aws + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "text/template" + + "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/random" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + relaymodel "github.com/songquanpeng/one-api/relay/model" +) + +// Only support llama-3-8b and llama-3-70b instruction models +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html +var AwsModelIDMap = map[string]string{ + "llama3-8b-8192": "meta.llama3-8b-instruct-v1:0", + "llama3-70b-8192": "meta.llama3-70b-instruct-v1:0", +} + +func awsModelID(requestModel string) (string, error) { + if awsModelID, ok := AwsModelIDMap[requestModel]; ok { + return awsModelID, nil + } + + return "", errors.Errorf("model %s not found", requestModel) +} + +// promptTemplate with range +const promptTemplate = `<|begin_of_text|>{{range .Messages}}<|start_header_id|>{{.Role}}<|end_header_id|>{{.StringContent}}<|eot_id|>{{end}}<|start_header_id|>assistant<|end_header_id|> +` + +var promptTpl = template.Must(template.New("llama3-chat").Parse(promptTemplate)) + +func RenderPrompt(messages []relaymodel.Message) string { + var buf bytes.Buffer + err := promptTpl.Execute(&buf, struct{ Messages []relaymodel.Message }{messages}) + if err != nil { + logger.SysError("error rendering prompt messages: " + err.Error()) + } + return buf.String() +} + +func ConvertRequest(textRequest relaymodel.GeneralOpenAIRequest) *Request { + llamaRequest := Request{ + MaxGenLen: textRequest.MaxTokens, + Temperature: textRequest.Temperature, + TopP: textRequest.TopP, + } + if llamaRequest.MaxGenLen == 0 { + llamaRequest.MaxGenLen = 2048 + } + prompt := RenderPrompt(textRequest.Messages) + llamaRequest.Prompt = prompt + return &llamaRequest +} + +func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { + awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil + } + + awsReq := &bedrockruntime.InvokeModelInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + + llamaReq, ok := c.Get(ctxkey.ConvertedRequest) + if !ok { + return utils.WrapErr(errors.New("request not found")), nil + } + + awsReq.Body, err = json.Marshal(llamaReq) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "marshal request")), nil + } + + awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "InvokeModel")), nil + } + + var llamaResponse Response + err = json.Unmarshal(awsResp.Body, &llamaResponse) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil + } + + openaiResp := ResponseLlama2OpenAI(&llamaResponse) + openaiResp.Model = modelName + usage := relaymodel.Usage{ + PromptTokens: llamaResponse.PromptTokenCount, + CompletionTokens: llamaResponse.GenerationTokenCount, + TotalTokens: llamaResponse.PromptTokenCount + llamaResponse.GenerationTokenCount, + } + openaiResp.Usage = usage + + c.JSON(http.StatusOK, openaiResp) + return nil, &usage +} + +func ResponseLlama2OpenAI(llamaResponse *Response) *openai.TextResponse { + var responseText string + if len(llamaResponse.Generation) > 0 { + responseText = llamaResponse.Generation + } + choice := openai.TextResponseChoice{ + Index: 0, + Message: relaymodel.Message{ + Role: "assistant", + Content: responseText, + Name: nil, + }, + FinishReason: llamaResponse.StopReason, + } + fullTextResponse := openai.TextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + } + return &fullTextResponse +} + +func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { + createdTime := helper.GetTimestamp() + awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil + } + + awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + + llamaReq, ok := c.Get(ctxkey.ConvertedRequest) + if !ok { + return utils.WrapErr(errors.New("request not found")), nil + } + + awsReq.Body, err = json.Marshal(llamaReq) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "marshal request")), nil + } + + awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq) + if err != nil { + return utils.WrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil + } + stream := awsResp.GetStream() + defer stream.Close() + + c.Writer.Header().Set("Content-Type", "text/event-stream") + var usage relaymodel.Usage + c.Stream(func(w io.Writer) bool { + event, ok := <-stream.Events() + if !ok { + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + + switch v := event.(type) { + case *types.ResponseStreamMemberChunk: + var llamaResp StreamResponse + err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(&llamaResp) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return false + } + + if llamaResp.PromptTokenCount > 0 { + usage.PromptTokens = llamaResp.PromptTokenCount + } + if llamaResp.StopReason == "stop" { + usage.CompletionTokens = llamaResp.GenerationTokenCount + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } + response := StreamResponseLlama2OpenAI(&llamaResp) + response.Id = fmt.Sprintf("chatcmpl-%s", random.GetUUID()) + response.Model = c.GetString(ctxkey.OriginalModel) + response.Created = createdTime + jsonStr, err := json.Marshal(response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) + return true + case *types.UnknownUnionMember: + fmt.Println("unknown tag:", v.Tag) + return false + default: + fmt.Println("union is nil or unknown type") + return false + } + }) + + return nil, &usage +} + +func StreamResponseLlama2OpenAI(llamaResponse *StreamResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Content = llamaResponse.Generation + choice.Delta.Role = "assistant" + finishReason := llamaResponse.StopReason + if finishReason != "null" { + choice.FinishReason = &finishReason + } + var openaiResponse openai.ChatCompletionsStreamResponse + openaiResponse.Object = "chat.completion.chunk" + openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} + return &openaiResponse +} diff --git a/relay/adaptor/aws/llama3/main_test.go b/relay/adaptor/aws/llama3/main_test.go new file mode 100644 index 0000000000..d539eee8a0 --- /dev/null +++ b/relay/adaptor/aws/llama3/main_test.go @@ -0,0 +1,45 @@ +package aws_test + +import ( + "testing" + + aws "github.com/songquanpeng/one-api/relay/adaptor/aws/llama3" + relaymodel "github.com/songquanpeng/one-api/relay/model" + "github.com/stretchr/testify/assert" +) + +func TestRenderPrompt(t *testing.T) { + messages := []relaymodel.Message{ + { + Role: "user", + Content: "What's your name?", + }, + } + prompt := aws.RenderPrompt(messages) + expected := `<|begin_of_text|><|start_header_id|>user<|end_header_id|>What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|> +` + assert.Equal(t, expected, prompt) + + messages = []relaymodel.Message{ + { + Role: "system", + Content: "Your name is Kat. You are a detective.", + }, + { + Role: "user", + Content: "What's your name?", + }, + { + Role: "assistant", + Content: "Kat", + }, + { + Role: "user", + Content: "What's your job?", + }, + } + prompt = aws.RenderPrompt(messages) + expected = `<|begin_of_text|><|start_header_id|>system<|end_header_id|>Your name is Kat. You are a detective.<|eot_id|><|start_header_id|>user<|end_header_id|>What's your name?<|eot_id|><|start_header_id|>assistant<|end_header_id|>Kat<|eot_id|><|start_header_id|>user<|end_header_id|>What's your job?<|eot_id|><|start_header_id|>assistant<|end_header_id|> +` + assert.Equal(t, expected, prompt) +} diff --git a/relay/adaptor/aws/llama3/model.go b/relay/adaptor/aws/llama3/model.go new file mode 100644 index 0000000000..7b86c3b8ff --- /dev/null +++ b/relay/adaptor/aws/llama3/model.go @@ -0,0 +1,29 @@ +package aws + +// Request is the request to AWS Llama3 +// +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html +type Request struct { + Prompt string `json:"prompt"` + MaxGenLen int `json:"max_gen_len,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` +} + +// Response is the response from AWS Llama3 +// +// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html +type Response struct { + Generation string `json:"generation"` + PromptTokenCount int `json:"prompt_token_count"` + GenerationTokenCount int `json:"generation_token_count"` + StopReason string `json:"stop_reason"` +} + +// {'generation': 'Hi', 'prompt_token_count': 15, 'generation_token_count': 1, 'stop_reason': None} +type StreamResponse struct { + Generation string `json:"generation"` + PromptTokenCount int `json:"prompt_token_count"` + GenerationTokenCount int `json:"generation_token_count"` + StopReason string `json:"stop_reason"` +} diff --git a/relay/adaptor/aws/registry.go b/relay/adaptor/aws/registry.go new file mode 100644 index 0000000000..5f6554808c --- /dev/null +++ b/relay/adaptor/aws/registry.go @@ -0,0 +1,39 @@ +package aws + +import ( + claude "github.com/songquanpeng/one-api/relay/adaptor/aws/claude" + llama3 "github.com/songquanpeng/one-api/relay/adaptor/aws/llama3" + "github.com/songquanpeng/one-api/relay/adaptor/aws/utils" +) + +type AwsModelType int + +const ( + AwsClaude AwsModelType = iota + 1 + AwsLlama3 +) + +var ( + adaptors = map[string]AwsModelType{} +) + +func init() { + for model := range claude.AwsModelIDMap { + adaptors[model] = AwsClaude + } + for model := range llama3.AwsModelIDMap { + adaptors[model] = AwsLlama3 + } +} + +func GetAdaptor(model string) utils.AwsAdapter { + adaptorType := adaptors[model] + switch adaptorType { + case AwsClaude: + return &claude.Adaptor{} + case AwsLlama3: + return &llama3.Adaptor{} + default: + return nil + } +} diff --git a/relay/adaptor/aws/utils/adaptor.go b/relay/adaptor/aws/utils/adaptor.go new file mode 100644 index 0000000000..4cb880f29d --- /dev/null +++ b/relay/adaptor/aws/utils/adaptor.go @@ -0,0 +1,51 @@ +package utils + +import ( + "errors" + "io" + "net/http" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" +) + +type AwsAdapter interface { + ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) + DoResponse(c *gin.Context, awsCli *bedrockruntime.Client, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) +} + +type Adaptor struct { + Meta *meta.Meta + AwsClient *bedrockruntime.Client +} + +func (a *Adaptor) Init(meta *meta.Meta) { + a.Meta = meta + a.AwsClient = bedrockruntime.New(bedrockruntime.Options{ + Region: meta.Config.Region, + Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")), + }) +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + return "", nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + return nil +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return nil, nil +} diff --git a/relay/adaptor/aws/utils/utils.go b/relay/adaptor/aws/utils/utils.go new file mode 100644 index 0000000000..669dc62844 --- /dev/null +++ b/relay/adaptor/aws/utils/utils.go @@ -0,0 +1,16 @@ +package utils + +import ( + "net/http" + + relaymodel "github.com/songquanpeng/one-api/relay/model" +) + +func WrapErr(err error) *relaymodel.ErrorWithStatusCode { + return &relaymodel.ErrorWithStatusCode{ + StatusCode: http.StatusInternalServerError, + Error: relaymodel.Error{ + Message: err.Error(), + }, + } +} diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index 56d31e1387..8a7d574366 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -2,6 +2,7 @@ package ratio import ( "encoding/json" + "fmt" "strings" "github.com/songquanpeng/one-api/common/logger" @@ -169,6 +170,9 @@ var ModelRatio = map[string]float64{ "step-1v-32k": 0.024 * RMB, "step-1-32k": 0.024 * RMB, "step-1-200k": 0.15 * RMB, + // aws llama3 https://aws.amazon.com/cn/bedrock/pricing/ + "llama3-8b-8192(33)": 0.0003 / 0.002, // $0.0003 / 1K tokens + "llama3-70b-8192(33)": 0.00265 / 0.002, // $0.00265 / 1K tokens // https://cohere.com/pricing "command": 0.5, "command-nightly": 0.5, @@ -185,7 +189,11 @@ var ModelRatio = map[string]float64{ "deepl-ja": 25.0 / 1000 * USD, } -var CompletionRatio = map[string]float64{} +var CompletionRatio = map[string]float64{ + // aws llama3 + "llama3-8b-8192(33)": 0.0006 / 0.0003, + "llama3-70b-8192(33)": 0.0035 / 0.00265, +} var DefaultModelRatio map[string]float64 var DefaultCompletionRatio map[string]float64 @@ -234,22 +242,28 @@ func UpdateModelRatioByJSONString(jsonStr string) error { return json.Unmarshal([]byte(jsonStr), &ModelRatio) } -func GetModelRatio(name string) float64 { +func GetModelRatio(name string, channelType int) float64 { if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { name = strings.TrimSuffix(name, "-internet") } if strings.HasPrefix(name, "command-") && strings.HasSuffix(name, "-internet") { name = strings.TrimSuffix(name, "-internet") } - ratio, ok := ModelRatio[name] - if !ok { - ratio, ok = DefaultModelRatio[name] + model := fmt.Sprintf("%s(%d)", name, channelType) + if ratio, ok := ModelRatio[model]; ok { + return ratio + } + if ratio, ok := DefaultModelRatio[model]; ok { + return ratio } - if !ok { - logger.SysError("model ratio not found: " + name) - return 30 + if ratio, ok := ModelRatio[name]; ok { + return ratio + } + if ratio, ok := DefaultModelRatio[name]; ok { + return ratio } - return ratio + logger.SysError("model ratio not found: " + name) + return 30 } func CompletionRatio2JSONString() string { @@ -265,7 +279,17 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error { return json.Unmarshal([]byte(jsonStr), &CompletionRatio) } -func GetCompletionRatio(name string) float64 { +func GetCompletionRatio(name string, channelType int) float64 { + if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { + name = strings.TrimSuffix(name, "-internet") + } + model := fmt.Sprintf("%s(%d)", name, channelType) + if ratio, ok := CompletionRatio[model]; ok { + return ratio + } + if ratio, ok := DefaultCompletionRatio[model]; ok { + return ratio + } if ratio, ok := CompletionRatio[name]; ok { return ratio } diff --git a/relay/controller/audio.go b/relay/controller/audio.go index 47da350ba1..0d537772a0 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -7,6 +7,10 @@ import ( "encoding/json" "errors" "fmt" + "io" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/client" @@ -21,9 +25,6 @@ import ( "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" - "io" - "net/http" - "strings" ) func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { @@ -53,7 +54,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } } - modelRatio := billingratio.GetModelRatio(audioModel) + modelRatio := billingratio.GetModelRatio(audioModel, channelType) groupRatio := billingratio.GetGroupRatio(group) ratio := modelRatio * groupRatio var quota int64 diff --git a/relay/controller/helper.go b/relay/controller/helper.go index 846366bc1c..1843bee702 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -4,6 +4,10 @@ import ( "context" "errors" "fmt" + "math" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" @@ -16,9 +20,6 @@ import ( "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" - "math" - "net/http" - "strings" ) func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) { @@ -96,7 +97,7 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M return } var quota int64 - completionRatio := billingratio.GetCompletionRatio(textRequest.Model) + completionRatio := billingratio.GetCompletionRatio(textRequest.Model, meta.ChannelType) promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) diff --git a/relay/controller/image.go b/relay/controller/image.go index ebe4e5b610..6dbf5c8f31 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -6,6 +6,9 @@ import ( "encoding/json" "errors" "fmt" + "io" + "net/http" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/ctxkey" @@ -17,8 +20,6 @@ import ( "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" ) func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { @@ -166,7 +167,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus requestBody = bytes.NewBuffer(jsonStr) } - modelRatio := billingratio.GetModelRatio(imageModel) + modelRatio := billingratio.GetModelRatio(imageModel, meta.ChannelType) groupRatio := billingratio.GetGroupRatio(meta.Group) ratio := modelRatio * groupRatio userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) diff --git a/relay/controller/text.go b/relay/controller/text.go index 6ed19b1de8..0d3c56b07d 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -4,6 +4,9 @@ import ( "bytes" "encoding/json" "fmt" + "io" + "net/http" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay" @@ -14,8 +17,6 @@ import ( "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" ) func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { @@ -35,7 +36,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping) meta.ActualModelName = textRequest.Model // get model ratio & group ratio - modelRatio := billingratio.GetModelRatio(textRequest.Model) + modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType) groupRatio := billingratio.GetGroupRatio(meta.Group) ratio := modelRatio * groupRatio // pre-consume quota diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js index aacc8d473c..881f66bd75 100644 --- a/web/berry/src/constants/ChannelConstants.js +++ b/web/berry/src/constants/ChannelConstants.js @@ -13,7 +13,7 @@ export const CHANNEL_OPTIONS = { }, 33: { key: 33, - text: 'AWS Claude', + text: 'AWS', value: 33, color: 'primary' }, diff --git a/web/default/src/constants/channel.constants.js b/web/default/src/constants/channel.constants.js index adf50a40f0..1b4c1910d5 100644 --- a/web/default/src/constants/channel.constants.js +++ b/web/default/src/constants/channel.constants.js @@ -1,7 +1,7 @@ export const CHANNEL_OPTIONS = [ {key: 1, text: 'OpenAI', value: 1, color: 'green'}, {key: 14, text: 'Anthropic Claude', value: 14, color: 'black'}, - {key: 33, text: 'AWS Claude', value: 33, color: 'black'}, + {key: 33, text: 'AWS', value: 33, color: 'black'}, {key: 3, text: 'Azure OpenAI', value: 3, color: 'olive'}, {key: 11, text: 'Google PaLM2', value: 11, color: 'orange'}, {key: 24, text: 'Google Gemini', value: 24, color: 'orange'},