diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 3d5948fcec..4775ec88a8 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -9,44 +9,53 @@ import ( "net/http" "one-api/common" "strconv" + "strings" ) var stopFinishReason = "stop" +// tokenEncoderMap won't grow after initialization var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} +var defaultTokenEncoder *tiktoken.Tiktoken func InitTokenEncoders() { common.SysLog("initializing token encoders") - fallbackTokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") + gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") if err != nil { - common.FatalLog(fmt.Sprintf("failed to get fallback token encoder: %s", err.Error())) + common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) + } + defaultTokenEncoder = gpt35TokenEncoder + gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4") + if err != nil { + common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) } for model, _ := range common.ModelRatio { - tokenEncoder, err := tiktoken.EncodingForModel(model) - if err != nil { - common.SysError(fmt.Sprintf("using fallback encoder for model %s", model)) - tokenEncoderMap[model] = fallbackTokenEncoder - continue + if strings.HasPrefix(model, "gpt-3.5") { + tokenEncoderMap[model] = gpt35TokenEncoder + } else if strings.HasPrefix(model, "gpt-4") { + tokenEncoderMap[model] = gpt4TokenEncoder + } else { + tokenEncoderMap[model] = nil } - tokenEncoderMap[model] = tokenEncoder } common.SysLog("token encoders initialized") } func getTokenEncoder(model string) *tiktoken.Tiktoken { - if tokenEncoder, ok := tokenEncoderMap[model]; ok { + tokenEncoder, ok := tokenEncoderMap[model] + if ok && tokenEncoder != nil { return tokenEncoder } - tokenEncoder, err := tiktoken.EncodingForModel(model) - if err != nil { - common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) - tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo") + if ok { + tokenEncoder, err := tiktoken.EncodingForModel(model) if err != nil { - common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error())) + common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) + tokenEncoder = defaultTokenEncoder } + tokenEncoderMap[model] = tokenEncoder + return tokenEncoder } - tokenEncoderMap[model] = tokenEncoder - return tokenEncoder + return defaultTokenEncoder } func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {