Skip to content

Commit

Permalink
perf: lazy initialization for token encoders (close #566)
Browse files Browse the repository at this point in the history
  • Loading branch information
songquanpeng committed Sep 29, 2023
1 parent 197d1d7 commit 594f06e
Showing 1 changed file with 25 additions and 16 deletions.
41 changes: 25 additions & 16 deletions controller/relay-utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,44 +9,53 @@ import (
"net/http"
"one-api/common"
"strconv"
"strings"
)

var stopFinishReason = "stop"

// tokenEncoderMap won't grow after initialization
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
var defaultTokenEncoder *tiktoken.Tiktoken

func InitTokenEncoders() {
common.SysLog("initializing token encoders")
fallbackTokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get fallback token encoder: %s", err.Error()))
common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
}
defaultTokenEncoder = gpt35TokenEncoder
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
}
for model, _ := range common.ModelRatio {
tokenEncoder, err := tiktoken.EncodingForModel(model)
if err != nil {
common.SysError(fmt.Sprintf("using fallback encoder for model %s", model))
tokenEncoderMap[model] = fallbackTokenEncoder
continue
if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder
} else if strings.HasPrefix(model, "gpt-4") {
tokenEncoderMap[model] = gpt4TokenEncoder
} else {
tokenEncoderMap[model] = nil
}
tokenEncoderMap[model] = tokenEncoder
}
common.SysLog("token encoders initialized")
}

func getTokenEncoder(model string) *tiktoken.Tiktoken {
if tokenEncoder, ok := tokenEncoderMap[model]; ok {
tokenEncoder, ok := tokenEncoderMap[model]
if ok && tokenEncoder != nil {
return tokenEncoder
}
tokenEncoder, err := tiktoken.EncodingForModel(model)
if err != nil {
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
if ok {
tokenEncoder, err := tiktoken.EncodingForModel(model)
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error()))
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
tokenEncoder = defaultTokenEncoder
}
tokenEncoderMap[model] = tokenEncoder
return tokenEncoder
}
tokenEncoderMap[model] = tokenEncoder
return tokenEncoder
return defaultTokenEncoder
}

func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
Expand Down

0 comments on commit 594f06e

Please sign in to comment.