Skip to content

Commit

Permalink
feat: support llm chat on replicate
Browse files Browse the repository at this point in the history
  • Loading branch information
Laisky committed Dec 19, 2024
1 parent 4dd2b9d commit 48e8b6b
Show file tree
Hide file tree
Showing 12 changed files with 380 additions and 33 deletions.
3 changes: 2 additions & 1 deletion common/render/render.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ package render
import (
"encoding/json"
"fmt"
"strings"

"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"strings"
)

func StringData(c *gin.Context, str string) {
Expand Down
2 changes: 1 addition & 1 deletion monitor/manage.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func ShouldDisableChannel(err *model.Error, statusCode int) bool {
strings.Contains(lowerMessage, "credit") ||
strings.Contains(lowerMessage, "balance") ||
strings.Contains(lowerMessage, "permission denied") ||
strings.Contains(lowerMessage, "organization has been restricted") || // groq
strings.Contains(lowerMessage, "organization has been restricted") || // groq

Check warning on line 37 in monitor/manage.go

View check run for this annotation

Codecov / codecov/patch

monitor/manage.go#L37

Added line #L37 was not covered by tests
strings.Contains(lowerMessage, "已欠费") {
return true
}
Expand Down
6 changes: 3 additions & 3 deletions relay/adaptor/ollama/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
TopP: request.TopP,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
NumPredict: request.MaxTokens,
NumCtx: request.NumCtx,
NumPredict: request.MaxTokens,
NumCtx: request.NumCtx,

Check warning on line 35 in relay/adaptor/ollama/main.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/ollama/main.go#L34-L35

Added lines #L34 - L35 were not covered by tests
},
Stream: request.Stream,
}
Expand Down Expand Up @@ -122,7 +122,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
for scanner.Scan() {
data := scanner.Text()
if strings.HasPrefix(data, "}") {
data = strings.TrimPrefix(data, "}") + "}"
data = strings.TrimPrefix(data, "}") + "}"

Check warning on line 125 in relay/adaptor/ollama/main.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/ollama/main.go#L125

Added line #L125 was not covered by tests
}

var ollamaResponse ChatResponse
Expand Down
7 changes: 4 additions & 3 deletions relay/adaptor/openai/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@ package openai

import (
"fmt"
"strings"

"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/model"
"strings"
)

func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
func ResponseText2Usage(responseText string, modelName string, promptTokens int) *model.Usage {

Check warning on line 11 in relay/adaptor/openai/helper.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/openai/helper.go#L11

Added line #L11 was not covered by tests
usage := &model.Usage{}
usage.PromptTokens = promptTokens
usage.CompletionTokens = CountTokenText(responseText, modeName)
usage.CompletionTokens = CountTokenText(responseText, modelName)

Check warning on line 14 in relay/adaptor/openai/helper.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/openai/helper.go#L14

Added line #L14 was not covered by tests
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage
}
Expand Down
55 changes: 53 additions & 2 deletions relay/adaptor/replicate/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io"
"net/http"
"slices"
"strings"
"time"

"github.com/gin-gonic/gin"
Expand Down Expand Up @@ -39,7 +40,55 @@ func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
}

func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
return nil, errors.New("not implemented")
if !request.Stream {
// TODO: support non-stream mode
return nil, errors.Errorf("replicate models only support stream mode now, please set stream=true")
}

Check warning on line 46 in relay/adaptor/replicate/adaptor.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/adaptor.go#L42-L46

Added lines #L42 - L46 were not covered by tests

// Build the prompt from OpenAI messages
var promptBuilder strings.Builder
for _, message := range request.Messages {
switch msgCnt := message.Content.(type) {
case string:
promptBuilder.WriteString(message.Role)
promptBuilder.WriteString(": ")
promptBuilder.WriteString(msgCnt)
promptBuilder.WriteString("\n")
default:

Check warning on line 57 in relay/adaptor/replicate/adaptor.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/adaptor.go#L49-L57

Added lines #L49 - L57 were not covered by tests
}
}

replicateRequest := ReplicateChatRequest{
Input: ChatInput{
Prompt: promptBuilder.String(),
MaxTokens: request.MaxTokens,
Temperature: 1.0,
TopP: 1.0,
PresencePenalty: 0.0,
FrequencyPenalty: 0.0,
},
}

// Map optional fields
if request.Temperature != nil {
replicateRequest.Input.Temperature = *request.Temperature
}
if request.TopP != nil {
replicateRequest.Input.TopP = *request.TopP
}
if request.PresencePenalty != nil {
replicateRequest.Input.PresencePenalty = *request.PresencePenalty
}
if request.FrequencyPenalty != nil {
replicateRequest.Input.FrequencyPenalty = *request.FrequencyPenalty
}
if request.MaxTokens > 0 {
replicateRequest.Input.MaxTokens = request.MaxTokens
} else if request.MaxTokens == 0 {
replicateRequest.Input.MaxTokens = 500
}

Check warning on line 89 in relay/adaptor/replicate/adaptor.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/adaptor.go#L61-L89

Added lines #L61 - L89 were not covered by tests

return replicateRequest, nil

Check warning on line 91 in relay/adaptor/replicate/adaptor.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/adaptor.go#L91

Added line #L91 was not covered by tests
}

func (a *Adaptor) Init(meta *meta.Meta) {
Expand All @@ -61,14 +110,16 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
}

func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
logger.Info(c, "send image request to replicate")
logger.Info(c, "send request to replicate")
return adaptor.DoRequestHelper(a, c, meta, requestBody)

Check warning on line 114 in relay/adaptor/replicate/adaptor.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/adaptor.go#L112-L114

Added lines #L112 - L114 were not covered by tests
}

func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
switch meta.Mode {
case relaymode.ImagesGenerations:
err, usage = ImageHandler(c, resp)
case relaymode.ChatCompletions:
err, usage = ChatHandler(c, resp)
default:
err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError)

Check warning on line 124 in relay/adaptor/replicate/adaptor.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/adaptor.go#L117-L124

Added lines #L117 - L124 were not covered by tests
}
Expand Down
191 changes: 191 additions & 0 deletions relay/adaptor/replicate/chat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
package replicate

import (
"bufio"
"encoding/json"
"io"
"net/http"
"strings"
"time"

"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)

func ChatHandler(c *gin.Context, resp *http.Response) (
srvErr *model.ErrorWithStatusCode, usage *model.Usage) {
if resp.StatusCode != http.StatusCreated {
payload, _ := io.ReadAll(resp.Body)
return openai.ErrorWrapper(
errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
"bad_status_code", http.StatusInternalServerError),
nil
}

Check warning on line 28 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L21-L28

Added lines #L21 - L28 were not covered by tests

respBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}

Check warning on line 33 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L30-L33

Added lines #L30 - L33 were not covered by tests

respData := new(ChatResponse)
if err = json.Unmarshal(respBody, respData); err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}

Check warning on line 38 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L35-L38

Added lines #L35 - L38 were not covered by tests

for {
err = func() error {
// get task
taskReq, err := http.NewRequestWithContext(c.Request.Context(),
http.MethodGet, respData.URLs.Get, nil)
if err != nil {
return errors.Wrap(err, "new request")
}

Check warning on line 47 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L40-L47

Added lines #L40 - L47 were not covered by tests

taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
taskResp, err := http.DefaultClient.Do(taskReq)
if err != nil {
return errors.Wrap(err, "get task")
}
defer taskResp.Body.Close()

if taskResp.StatusCode != http.StatusOK {
payload, _ := io.ReadAll(taskResp.Body)
return errors.Errorf("bad status code [%d]%s",
taskResp.StatusCode, string(payload))
}

Check warning on line 60 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L49-L60

Added lines #L49 - L60 were not covered by tests

taskBody, err := io.ReadAll(taskResp.Body)
if err != nil {
return errors.Wrap(err, "read task response")
}

Check warning on line 65 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L62-L65

Added lines #L62 - L65 were not covered by tests

taskData := new(ChatResponse)
if err = json.Unmarshal(taskBody, taskData); err != nil {
return errors.Wrap(err, "decode task response")
}

Check warning on line 70 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L67-L70

Added lines #L67 - L70 were not covered by tests

switch taskData.Status {
case "succeeded":
case "failed", "canceled":
return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error)
default:
time.Sleep(time.Second * 3)
return errNextLoop

Check warning on line 78 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L72-L78

Added lines #L72 - L78 were not covered by tests
}

if taskData.URLs.Stream == "" {
return errors.New("stream url is empty")
}

Check warning on line 83 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L81-L83

Added lines #L81 - L83 were not covered by tests

// request stream url
responseText, err := chatStreamHandler(c, taskData.URLs.Stream)
if err != nil {
return errors.Wrap(err, "chat stream handler")
}

Check warning on line 89 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L86-L89

Added lines #L86 - L89 were not covered by tests

ctxMeta := meta.GetByContext(c)
usage = openai.ResponseText2Usage(responseText,
ctxMeta.ActualModelName, ctxMeta.PromptTokens)
return nil

Check warning on line 94 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L91-L94

Added lines #L91 - L94 were not covered by tests
}()
if err != nil {
if errors.Is(err, errNextLoop) {
continue

Check warning on line 98 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L96-L98

Added lines #L96 - L98 were not covered by tests
}

return openai.ErrorWrapper(err, "chat_task_failed", http.StatusInternalServerError), nil

Check warning on line 101 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L101

Added line #L101 was not covered by tests
}

break

Check warning on line 104 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L104

Added line #L104 was not covered by tests
}

return nil, usage

Check warning on line 107 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L107

Added line #L107 was not covered by tests
}

const (
eventPrefix = "event: "
dataPrefix = "data: "
done = "[DONE]"
)

func chatStreamHandler(c *gin.Context, streamUrl string) (responseText string, err error) {
// request stream endpoint
streamReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, streamUrl, nil)
if err != nil {
return "", errors.Wrap(err, "new request to stream")
}

Check warning on line 121 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L116-L121

Added lines #L116 - L121 were not covered by tests

streamReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
streamReq.Header.Set("Accept", "text/event-stream")
streamReq.Header.Set("Cache-Control", "no-store")

resp, err := http.DefaultClient.Do(streamReq)
if err != nil {
return "", errors.Wrap(err, "do request to stream")
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
payload, _ := io.ReadAll(resp.Body)
return "", errors.Errorf("bad status code [%d]%s", resp.StatusCode, string(payload))
}

Check warning on line 136 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L123-L136

Added lines #L123 - L136 were not covered by tests

scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)

common.SetEventStreamHeaders(c)
doneRendered := false
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue

Check warning on line 146 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L138-L146

Added lines #L138 - L146 were not covered by tests
}

// Handle comments starting with ':'
if strings.HasPrefix(line, ":") {
continue

Check warning on line 151 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L150-L151

Added lines #L150 - L151 were not covered by tests
}

// Parse SSE fields
if strings.HasPrefix(line, eventPrefix) {
event := strings.TrimSpace(line[len(eventPrefix):])
var data string
// Read the following lines to get data and id
for scanner.Scan() {
nextLine := scanner.Text()
if nextLine == "" {
break

Check warning on line 162 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L155-L162

Added lines #L155 - L162 were not covered by tests
}
if strings.HasPrefix(nextLine, dataPrefix) {
data = nextLine[len(dataPrefix):]
} else if strings.HasPrefix(nextLine, "id:") {
// id = strings.TrimSpace(nextLine[len("id:"):])
}

Check warning on line 168 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L164-L168

Added lines #L164 - L168 were not covered by tests
}

if event == "output" {
render.StringData(c, data)
responseText += data
} else if event == "done" {
render.Done(c)
doneRendered = true
break

Check warning on line 177 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L171-L177

Added lines #L171 - L177 were not covered by tests
}
}
}

if err := scanner.Err(); err != nil {
return "", errors.Wrap(err, "scan stream")
}

Check warning on line 184 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L182-L184

Added lines #L182 - L184 were not covered by tests

if !doneRendered {
render.Done(c)
}

Check warning on line 188 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L186-L188

Added lines #L186 - L188 were not covered by tests

return responseText, nil

Check warning on line 190 in relay/adaptor/replicate/chat.go

View check run for this annotation

Codecov / codecov/patch

relay/adaptor/replicate/chat.go#L190

Added line #L190 was not covered by tests
}
36 changes: 18 additions & 18 deletions relay/adaptor/replicate/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,24 @@ var ModelList = []string{
// -------------------------------------
// language model
// -------------------------------------
// "ibm-granite/granite-20b-code-instruct-8k", // TODO: implement the adaptor
// "ibm-granite/granite-3.0-2b-instruct", // TODO: implement the adaptor
// "ibm-granite/granite-3.0-8b-instruct", // TODO: implement the adaptor
// "ibm-granite/granite-8b-code-instruct-128k", // TODO: implement the adaptor
// "meta/llama-2-13b", // TODO: implement the adaptor
// "meta/llama-2-13b-chat", // TODO: implement the adaptor
// "meta/llama-2-70b", // TODO: implement the adaptor
// "meta/llama-2-70b-chat", // TODO: implement the adaptor
// "meta/llama-2-7b", // TODO: implement the adaptor
// "meta/llama-2-7b-chat", // TODO: implement the adaptor
// "meta/meta-llama-3.1-405b-instruct", // TODO: implement the adaptor
// "meta/meta-llama-3-70b", // TODO: implement the adaptor
// "meta/meta-llama-3-70b-instruct", // TODO: implement the adaptor
// "meta/meta-llama-3-8b", // TODO: implement the adaptor
// "meta/meta-llama-3-8b-instruct", // TODO: implement the adaptor
// "mistralai/mistral-7b-instruct-v0.2", // TODO: implement the adaptor
// "mistralai/mistral-7b-v0.1", // TODO: implement the adaptor
// "mistralai/mixtral-8x7b-instruct-v0.1", // TODO: implement the adaptor
"ibm-granite/granite-20b-code-instruct-8k",
"ibm-granite/granite-3.0-2b-instruct",
"ibm-granite/granite-3.0-8b-instruct",
"ibm-granite/granite-8b-code-instruct-128k",
"meta/llama-2-13b",
"meta/llama-2-13b-chat",
"meta/llama-2-70b",
"meta/llama-2-70b-chat",
"meta/llama-2-7b",
"meta/llama-2-7b-chat",
"meta/meta-llama-3.1-405b-instruct",
"meta/meta-llama-3-70b",
"meta/meta-llama-3-70b-instruct",
"meta/meta-llama-3-8b",
"meta/meta-llama-3-8b-instruct",
"mistralai/mistral-7b-instruct-v0.2",
"mistralai/mistral-7b-v0.1",
"mistralai/mixtral-8x7b-instruct-v0.1",
// -------------------------------------
// video model
// -------------------------------------
Expand Down
Loading

0 comments on commit 48e8b6b

Please sign in to comment.