Skip to content
This repository has been archived by the owner on Mar 20, 2024. It is now read-only.

Commit

Permalink
support "Continue generating" (#164)
Browse files Browse the repository at this point in the history
  • Loading branch information
linweiyuan committed Jun 17, 2023
1 parent 3b69e1d commit 1be0b64
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 47 deletions.
89 changes: 82 additions & 7 deletions api/chatgpt/api.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package chatgpt

import (
"bufio"
"bytes"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -38,9 +39,6 @@ func CreateConversation(c *gin.Context) {
if request.ConversationID == nil || *request.ConversationID == "" {
request.ConversationID = nil
}
if request.Messages[0].Author.Role == "" {
request.Messages[0].Author.Role = defaultRole
}

if request.Model == gpt4Model || request.Model == gpt4BrowsingModel || request.Model == gpt4PluginsModel {
formParams := fmt.Sprintf(
Expand All @@ -60,6 +58,16 @@ func CreateConversation(c *gin.Context) {
request.ArkoseToken = responseMap["token"]
}

resp, done := sendConversationRequest(c, request)
if done {
return
}

handleConversationResponse(c, resp, request)
}

//goland:noinspection GoUnhandledErrorResult
func sendConversationRequest(c *gin.Context, request CreateConversationRequest) (*http.Response, bool) {
jsonBytes, _ := json.Marshal(request)
req, _ := http.NewRequest(http.MethodPost, apiPrefix+"/conversation", bytes.NewBuffer(jsonBytes))
req.Header.Set("User-Agent", api.UserAgent)
Expand All @@ -68,18 +76,85 @@ func CreateConversation(c *gin.Context) {
resp, err := api.Client.Do(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, api.ReturnMessage(err.Error()))
return
return nil, true
}

defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
responseMap := make(map[string]interface{})
json.NewDecoder(resp.Body).Decode(&responseMap)
c.AbortWithStatusJSON(resp.StatusCode, responseMap)
return
return nil, true
}

return resp, false
}

//goland:noinspection GoUnhandledErrorResult
func handleConversationResponse(c *gin.Context, resp *http.Response, request CreateConversationRequest) {
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")

isMaxTokens := false
continueParentMessageID := ""
continueConversationID := ""

defer resp.Body.Close()
reader := bufio.NewReader(resp.Body)
for {
if c.Request.Context().Err() != nil {
break
}

line, err := reader.ReadString('\n')
if err != nil {
break
}

line = strings.TrimSpace(line)
if strings.HasPrefix(line, "event") ||
strings.HasPrefix(line, "data: 20") ||
line == "" {
continue
}

responseJson := line[6:]
if strings.HasPrefix(responseJson, "[DONE]") && isMaxTokens {
continue
}

// no need to unmarshal every time, but if response content has this "max_tokens", need to further check
if strings.TrimSpace(responseJson) != "" && strings.Contains(responseJson, responseTypeMaxTokens) {
var createConversationResponse CreateConversationResponse
json.Unmarshal([]byte(responseJson), &createConversationResponse)
message := createConversationResponse.Message
if message.Metadata.FinishDetails.Type == responseTypeMaxTokens && createConversationResponse.Message.Status == responseStatusFinishedSuccessfully {
isMaxTokens = true
continueParentMessageID = message.ID
continueConversationID = createConversationResponse.ConversationID
}
}

c.Writer.Write([]byte(line + "\n\n"))
c.Writer.Flush()
}

api.HandleConversationResponse(c, resp)
if isMaxTokens {
var continueConversationRequest = CreateConversationRequest{
ArkoseToken: request.ArkoseToken,
HistoryAndTrainingDisabled: request.HistoryAndTrainingDisabled,
Model: request.Model,
TimezoneOffsetMin: request.TimezoneOffsetMin,

Action: actionContinue,
ParentMessageID: continueParentMessageID,
ConversationID: &continueConversationID,
}
resp, done := sendConversationRequest(c, continueConversationRequest)
if done {
return
}

handleConversationResponse(c, resp, continueConversationRequest)
}
}

//goland:noinspection GoUnhandledErrorResult
Expand Down
5 changes: 4 additions & 1 deletion api/chatgpt/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package chatgpt

const (
apiPrefix = "https://chat.openai.com/backend-api"
defaultRole = "user"
getConversationsErrorMessage = "Failed to get conversations."
generateTitleErrorMessage = "Failed to generate title."
getContentErrorMessage = "Failed to get content."
Expand All @@ -25,4 +24,8 @@ const (
gpt4PluginsModel = "gpt-4-plugins"
gpt4PublicKey = "35536E1E-65B4-4D96-9D97-6ADB7EFF8147"
gpt4TokenUrl = "https://tcr9i.chat.openai.com/fc/gt2/public_key/" + gpt4PublicKey

actionContinue = "continue"
responseTypeMaxTokens = "max_tokens"
responseStatusFinishedSuccessfully = "finished_successfully"
)
48 changes: 40 additions & 8 deletions api/chatgpt/typings.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ type UserLogin struct {
}

type CreateConversationRequest struct {
Action string `json:"action"`
Messages []Message `json:"messages"`
Model string `json:"model"`
ParentMessageID string `json:"parent_message_id"`
ConversationID *string `json:"conversation_id"`
PluginIDs []string `json:"plugin_ids"`
TimezoneOffsetMin int `json:"timezone_offset_min"`
ArkoseToken string `json:"arkose_token"`
Action string `json:"action"`
Messages *[]Message `json:"messages"`
Model string `json:"model"`
ParentMessageID string `json:"parent_message_id"`
ConversationID *string `json:"conversation_id"`
PluginIDs []string `json:"plugin_ids"`
TimezoneOffsetMin int `json:"timezone_offset_min"`
ArkoseToken string `json:"arkose_token"`
HistoryAndTrainingDisabled bool `json:"history_and_training_disabled"`
}

type Message struct {
Expand All @@ -33,6 +34,37 @@ type Content struct {
Parts []string `json:"parts"`
}

type CreateConversationResponse struct {
Message struct {
ID string `json:"id"`
Author struct {
Role string `json:"role"`
Name interface{} `json:"name"`
Metadata struct {
} `json:"metadata"`
} `json:"author"`
CreateTime float64 `json:"create_time"`
UpdateTime interface{} `json:"update_time"`
Content struct {
ContentType string `json:"content_type"`
Parts []string `json:"parts"`
} `json:"content"`
Status string `json:"status"`
EndTurn bool `json:"end_turn"`
Weight float64 `json:"weight"`
Metadata struct {
MessageType string `json:"message_type"`
ModelSlug string `json:"model_slug"`
FinishDetails struct {
Type string `json:"type"`
} `json:"finish_details"`
} `json:"metadata"`
Recipient string `json:"recipient"`
} `json:"message"`
ConversationID string `json:"conversation_id"`
Error interface{} `json:"error"`
}

type FeedbackMessageRequest struct {
MessageID string `json:"message_id"`
ConversationID string `json:"conversation_id"`
Expand Down
29 changes: 0 additions & 29 deletions api/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@ package api

//goland:noinspection GoSnakeCaseUsage
import (
"bufio"
"os"
"strings"

"github.com/gin-gonic/gin"
_ "github.com/linweiyuan/go-chatgpt-api/env"

http "github.com/bogdanfinn/fhttp"
tls_client "github.com/bogdanfinn/tls-client"
)

Expand Down Expand Up @@ -67,33 +65,6 @@ func GetAccessToken(accessToken string) string {
return accessToken
}

//goland:noinspection GoUnhandledErrorResult
func HandleConversationResponse(c *gin.Context, resp *http.Response) {
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")

reader := bufio.NewReader(resp.Body)
for {
if c.Request.Context().Err() != nil {
break
}

line, err := reader.ReadString('\n')
if err != nil {
break
}

line = strings.TrimSpace(line)
if strings.HasPrefix(line, "event") ||
strings.HasPrefix(line, "data: 20") ||
line == "" {
continue
}

c.Writer.Write([]byte(line + "\n\n"))
c.Writer.Flush()
}
}

//goland:noinspection GoUnhandledErrorResult,SpellCheckingInspection
func NewHttpClient() tls_client.HttpClient {
client, _ := tls_client.NewHttpClient(tls_client.NewNoopLogger(), []tls_client.HttpClientOption{
Expand Down
33 changes: 31 additions & 2 deletions api/platform/api.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package platform

import (
"bufio"
"bytes"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -34,7 +35,7 @@ func CreateCompletions(c *gin.Context) {

defer resp.Body.Close()
if request.Stream {
api.HandleConversationResponse(c, resp)
handleCompletionsResponse(c, resp)
} else {
io.Copy(c.Writer, resp.Body)
}
Expand All @@ -52,12 +53,39 @@ func CreateChatCompletions(c *gin.Context) {

defer resp.Body.Close()
if request.Stream {
api.HandleConversationResponse(c, resp)
handleCompletionsResponse(c, resp)
} else {
io.Copy(c.Writer, resp.Body)
}
}

//goland:noinspection GoUnhandledErrorResult
func handleCompletionsResponse(c *gin.Context, resp *http.Response) {
c.Writer.Header().Set("Content-Type", "text/event-stream; charset=utf-8")

reader := bufio.NewReader(resp.Body)
for {
if c.Request.Context().Err() != nil {
break
}

line, err := reader.ReadString('\n')
if err != nil {
break
}

line = strings.TrimSpace(line)
if strings.HasPrefix(line, "event") ||
strings.HasPrefix(line, "data: 20") ||
line == "" {
continue
}

c.Writer.Write([]byte(line + "\n\n"))
c.Writer.Flush()
}
}

//goland:noinspection GoUnhandledErrorResult
func CreateEdit(c *gin.Context) {
var request CreateEditRequest
Expand Down Expand Up @@ -100,6 +128,7 @@ func CreateEmbeddings(c *gin.Context) {
io.Copy(c.Writer, resp.Body)
}

//goland:noinspection GoUnhandledErrorResult
func CreateModeration(c *gin.Context) {
var request CreateModerationRequest
c.ShouldBindJSON(&request)
Expand Down

0 comments on commit 1be0b64

Please sign in to comment.