Skip to content

Commit

Permalink
Merge pull request #12 from sugarshop/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
mengmengmengqiang committed May 20, 2023
2 parents 930eed1 + 5164bef commit 645b84d
Show file tree
Hide file tree
Showing 11 changed files with 167 additions and 6 deletions.
22 changes: 22 additions & 0 deletions db/postgresql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package db

import (
postgrest "github.com/nedpals/postgrest-go/pkg"
supa "github.com/nedpals/supabase-go"
)

// Init 数据库初始化连接
func Init() {
supabaseUrl := "https://sfisgjpeqptcluzmtbup.supabase.co"
supabaseKey := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6InNmaXNnanBlcXB0Y2x1em10YnVwIiwicm9sZSI6ImFub24iLCJpYXQiOjE2ODM5MzAyMzAsImV4cCI6MTk5OTUwNjIzMH0.eHhleg3ev4YGA1yHosohWwzxOZxNEh4hP1PavfMF-X0"
supabase := supa.CreateClient(supabaseUrl, supabaseKey)
completionDB = supabase.DB
}

// completion DB
var completionDB *postgrest.Client

// CompletionDB completion DB
func CompletionDB() *postgrest.Client {
return completionDB
}
12 changes: 12 additions & 0 deletions db/postgresql_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package db

import (
"github.com/stretchr/testify/assert"
"testing"
)

func TestInit(t *testing.T) {
Init()
res := CompletionDB()
assert.NotNil(t, res)
}
4 changes: 2 additions & 2 deletions deployment/prod/deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ kind: Service
apiVersion: v1
metadata:
namespace: default
name: asgard-gateway
name: asgard-gateway-svc
labels:
app: asgard-gateway
spec:
Expand Down Expand Up @@ -54,6 +54,6 @@ spec:
pathType: Prefix
backend:
service:
name: asgard-gateway
name: asgard-gateway-svc
port:
number: 8080
6 changes: 6 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,33 @@ require (
github.com/apache/thrift v0.13.0
github.com/cloudwego/kitex v0.5.2
github.com/gin-gonic/gin v1.9.0
github.com/nedpals/postgrest-go v0.1.3
github.com/nedpals/supabase-go v0.2.0
github.com/sashabaranov/go-openai v1.9.3
github.com/stretchr/testify v1.8.1
)

require (
github.com/bytedance/sonic v1.8.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/choleraehyq/pid v0.0.16 // indirect
github.com/cloudwego/thriftgo v0.2.9 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.11.2 // indirect
github.com/goccy/go-json v0.10.0 // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.1 // indirect
github.com/mattn/go-isatty v0.0.17 // indirect
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.6 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.9 // indirect
golang.org/x/arch v0.2.0 // indirect
Expand Down
8 changes: 8 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,12 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o=
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20220608213341-c488b8fa1db3 h1:mpL/HvfIgIejhVwAfxBQkwEjlhP5o0O9RAeTAjpwzxc=
github.com/google/pprof v0.0.0-20220608213341-c488b8fa1db3/go.mod h1:gSuNB+gJaOiQKLEZ+q+PK9Mq3SOzhRcw2GsGS/FhYDk=
Expand Down Expand Up @@ -130,6 +133,11 @@ github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OH
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/nedpals/postgrest-go v0.1.2/go.mod h1:RGinB2OXsnGLcZMu5avS0U+b9npyZmk+ecK74UDi/xY=
github.com/nedpals/postgrest-go v0.1.3 h1:ZC3aPPx9rDTWQWzvnWI60lJWjAqgCCD/U6hcHp3NL0w=
github.com/nedpals/postgrest-go v0.1.3/go.mod h1:RGinB2OXsnGLcZMu5avS0U+b9npyZmk+ecK74UDi/xY=
github.com/nedpals/supabase-go v0.2.0 h1:ZuciOzOwfyKmsd/D/XAP4+pTx/2kpAhLbo8OS3WZQo8=
github.com/nedpals/supabase-go v0.2.0/go.mod h1:RSjFlnvLQ3nc9F4WhyartDlvT6smRY0tnJTLxum31d8=
github.com/nishanths/predeclared v0.0.0-20200524104333-86fad755b4d3/go.mod h1:nt3d53pc1VYcphSCIaYAJtnPYnr3Zyn8fMq2wvPGPso=
github.com/oleiade/lane v1.0.1/go.mod h1:IyTkraa4maLfjq/GmHR+Dxb4kCMtEGeb+qmhlrQ5Mk4=
github.com/pelletier/go-toml/v2 v2.0.6 h1:nrzqCb7j9cDFj2coyLNLaZuJTLjWjlaz6nvTvIwycIU=
Expand Down
63 changes: 63 additions & 0 deletions handler/openai_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"io"
"log"
"net/http"
"time"
)

type OpenAIHandler struct {
Expand All @@ -37,17 +38,31 @@ func (h *OpenAIHandler) Completions(c *gin.Context) error {
return fmt.Errorf("StatusCode: %d, Type:%s, Code:%v", apiErr.HTTPStatusCode, apiErr.Type, apiErr.Code)
}

var curCompletion model.Completion
ch := make(chan bool)

gone := c.Stream(func(w io.Writer) bool {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
fmt.Printf("[Completion]: Stream finished")
go func() {
if saveErr := curCompletion.Save(); saveErr != nil {
fmt.Printf("[Completions]: failed to save completion: %+v", err)
ch <- false
} else {
ch <- true
}
close(ch)
}()
return false
}

if err != nil {
fmt.Printf("[Completion]: Stream error: %v\n", err)
return false
}
// wrap completion.
h.constructCompletion(&curCompletion, response)

jsonBytes, err := json.Marshal(response)
if err != nil {
Expand All @@ -62,12 +77,60 @@ func (h *OpenAIHandler) Completions(c *gin.Context) error {
return true
})
if gone {
// client disconnected in middle of stream
// do something after client is gone
log.Println("client is gone")
}

select {
case succ := <- ch:
if succ {
fmt.Printf("save success")
} else {
fmt.Printf("save failed")
}
case <-time.After(10 * time.Second): // 超时时间为10秒
fmt.Println("save to db timeout!")
}
return nil
}

func (h *OpenAIHandler) constructCompletion(cur *model.Completion, input interface{}) {
streamResponse, streamOk := input.(openai.ChatCompletionStreamResponse)
completionResponse, completionOk := input.(openai.ChatCompletionResponse)
var id string
var chatModel string
var role string
var content string
if streamOk {
id = streamResponse.ID
chatModel = streamResponse.Model
if len(streamResponse.Choices) > 0 {
role = streamResponse.Choices[0].Delta.Role
}
content = streamResponse.Choices[0].Delta.Content
}
if completionOk {
id = completionResponse.ID
chatModel = completionResponse.Model
if len(completionResponse.Choices) > 0 {
role = completionResponse.Choices[0].Message.Role
}
content = completionResponse.Choices[0].Message.Content
}

if len(cur.ChatID) == 0 {
cur.ChatID = id
}
if len(cur.Model) == 0 {
cur.Model = chatModel
}
if len(cur.Role) == 0 {
cur.Role = role
}
cur.Content += content
}

// OpenAIStream return completion stream of the OpenAIChat
func (h *OpenAIHandler) OpenAIStream(c *gin.Context, param *model.CompletionsReqBody) (*openai.ChatCompletionStream, *openai.APIError) {
client := openai.NewClient(param.Key)
Expand Down
1 change: 0 additions & 1 deletion handler/openai_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ func TestOpenAIHandler_Completions(t *testing.T) {
// Create a new HTTP request
reqBody := model.CompletionsReqBody{
Model: model.OpenAIModel{ID: "gpt-3.5-turbo"},
SystemPrompt: model.NEXTPUBLICDEFAULTSYSTEMPROMPT,
Temperature: 1,
Key: "key",
Messages: []openai.ChatCompletionMessage{
Expand Down
4 changes: 4 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"github.com/gin-gonic/gin"
"github.com/sugarshop/asgard-gateway/db"
"github.com/sugarshop/asgard-gateway/handler"
"net/http"
"os"
Expand All @@ -17,6 +18,9 @@ func main() {
})
})

// init db
db.Init()

// register other api
handler.Register(engine)

Expand Down
28 changes: 28 additions & 0 deletions model/completion.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package model

import (
"github.com/sugarshop/asgard-gateway/db"
"log"
"time"
)

type Completion struct {
ID int64 `json:"id,omitempty"`
CreatedAt time.Time `json:"created_at"`
ChatID string `json:"chat_id"`
Model string `json:"model"`
Content string `json:"content"`
Role string `json:"role"`
}

func (s *Completion) Save() error {
// set Now Time()
s.CreatedAt = time.Now()
var results []Completion
err := db.CompletionDB().From("completion").Insert(s).Execute(&results)
if err != nil {
log.Printf("[Save]: err :%+v", err)
return err
}
return nil
}
20 changes: 20 additions & 0 deletions model/completion_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package model

import (
"github.com/stretchr/testify/assert"
"github.com/sugarshop/asgard-gateway/db"
"testing"
)

func TestCompletion_Save(t *testing.T) {
db.Init()
compl := Completion{
ChatID: "chat-001",
Model: "gpt-3.5-turbo-0301",
Content: "i love u",
Role: "test",
}
err := compl.Save()

assert.Nil(t, err)
}
5 changes: 2 additions & 3 deletions model/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,9 @@ func (e *OpenAIError) Error() string {

type CompletionsReqBody struct {
Model OpenAIModel `json:"model"`
SystemPrompt string `json:"systemPrompt"`
Temperature float64 `json:"temperature"`
Key string `json:"key"`
Messages []openai.ChatCompletionMessage `json:"messages"`
Key string `json:"key"`
Temperature float64 `json:"temperature"`
}

// LogprobResult represents logprob result of Choice.
Expand Down

0 comments on commit 645b84d

Please sign in to comment.