From e9f117ff72cbe83597c186b18fca72f337de5aba Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Sat, 21 Dec 2024 20:32:30 +0800 Subject: [PATCH] feat: add gemini-2.0-flash-exp and fix race condition in processChannelRelayError (#1983) Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com> --- controller/relay.go | 7 +++---- relay/adaptor/gemini/adaptor.go | 8 +++++++- relay/adaptor/gemini/constants.go | 5 ++++- relay/adaptor/vertexai/gemini/adapter.go | 5 ++++- relay/billing/ratio/model.go | 13 ++++++++----- 5 files changed, 26 insertions(+), 12 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index 49358e2597..038123b359 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -60,7 +60,7 @@ func Relay(c *gin.Context) { channelName := c.GetString(ctxkey.ChannelName) group := c.GetString(ctxkey.Group) originalModel := c.GetString(ctxkey.OriginalModel) - go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) + go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr) requestId := c.GetString(helper.RequestIdKey) retryTimes := config.RetryTimes if !shouldRetry(c, bizErr.StatusCode) { @@ -87,8 +87,7 @@ func Relay(c *gin.Context) { channelId := c.GetInt(ctxkey.ChannelId) lastFailedChannelId = channelId channelName := c.GetString(ctxkey.ChannelName) - // BUG: bizErr is in race condition - go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) + go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr) } if bizErr != nil { if bizErr.StatusCode == http.StatusTooManyRequests { @@ -122,7 +121,7 @@ func shouldRetry(c *gin.Context, statusCode int) bool { return true } -func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) { +func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err model.ErrorWithStatusCode) { logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message) // https://platform.openai.com/docs/guides/error-codes/api-errors if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) { diff --git a/relay/adaptor/gemini/adaptor.go b/relay/adaptor/gemini/adaptor.go index 12f48c715a..a86fde40b8 100644 --- a/relay/adaptor/gemini/adaptor.go +++ b/relay/adaptor/gemini/adaptor.go @@ -24,7 +24,12 @@ func (a *Adaptor) Init(meta *meta.Meta) { } func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { - version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion) + defaultVersion := config.GeminiVersion + if meta.ActualModelName == "gemini-2.0-flash-exp" { + defaultVersion = "v1beta" + } + + version := helper.AssignOrDefault(meta.Config.APIVersion, defaultVersion) action := "" switch meta.Mode { case relaymode.Embeddings: @@ -36,6 +41,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { if meta.IsStream { action = "streamGenerateContent?alt=sse" } + return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil } diff --git a/relay/adaptor/gemini/constants.go b/relay/adaptor/gemini/constants.go index b0f84dfc55..fa53d63baf 100644 --- a/relay/adaptor/gemini/constants.go +++ b/relay/adaptor/gemini/constants.go @@ -3,5 +3,8 @@ package gemini // https://ai.google.dev/models/gemini var ModelList = []string{ - "gemini-pro", "gemini-1.0-pro", "gemini-1.5-flash", "gemini-1.5-pro", "text-embedding-004", "aqa", + "gemini-pro", "gemini-1.0-pro", + "gemini-1.5-flash", "gemini-1.5-pro", + "text-embedding-004", "aqa", + "gemini-2.0-flash-exp", } diff --git a/relay/adaptor/vertexai/gemini/adapter.go b/relay/adaptor/vertexai/gemini/adapter.go index f86baee0e2..0557b075c6 100644 --- a/relay/adaptor/vertexai/gemini/adapter.go +++ b/relay/adaptor/vertexai/gemini/adapter.go @@ -15,7 +15,10 @@ import ( ) var ModelList = []string{ - "gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision", "gemini-1.5-pro-002", "gemini-1.5-flash-002", + "gemini-pro", "gemini-pro-vision", + "gemini-1.5-pro-001", "gemini-1.5-flash-001", + "gemini-1.5-pro-002", "gemini-1.5-flash-002", + "gemini-2.0-flash-exp", } type Adaptor struct { diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index 2b581ffcc1..613d2b3124 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -108,11 +108,14 @@ var ModelRatio = map[string]float64{ "bge-large-en": 0.002 * RMB, "tao-8k": 0.002 * RMB, // https://ai.google.dev/pricing - "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens - "gemini-1.0-pro": 1, - "gemini-1.5-flash": 1, - "gemini-1.5-pro": 1, - "aqa": 1, + "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "gemini-1.0-pro": 1, + "gemini-1.5-pro": 1, + "gemini-1.5-pro-001": 1, + "gemini-1.5-flash": 1, + "gemini-1.5-flash-001": 1, + "gemini-2.0-flash-exp": 1, + "aqa": 1, // https://open.bigmodel.cn/pricing "glm-4": 0.1 * RMB, "glm-4v": 0.1 * RMB,