Skip to content

Commit

Permalink
Merge pull request #24 from basenana/fix/token_limit
Browse files Browse the repository at this point in the history
fix: add limiter of openai
  • Loading branch information
zwwhdls authored Nov 19, 2023
2 parents 6badc71 + 4428c31 commit bf3e331
Show file tree
Hide file tree
Showing 14 changed files with 50 additions and 41 deletions.
4 changes: 3 additions & 1 deletion cmd/apps/ingest.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package apps

import (
"context"

"github.com/spf13/cobra"

"github.com/basenana/friday/pkg/friday"
Expand All @@ -38,7 +40,7 @@ var IngestCmd = &cobra.Command{
}

func ingest(ps string) error {
err := friday.Fri.IngestFromOriginFile(ps)
err := friday.Fri.IngestFromOriginFile(context.TODO(), ps)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion flow/operator/ingest.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,5 @@ func (i *ingestOperator) Do(ctx context.Context, param *flow.Parameter) error {
Source: source,
Content: knowledge,
}
return friday.Fri.IngestFromFile(doc)
return friday.Fri.IngestFromFile(context.TODO(), doc)
}
4 changes: 3 additions & 1 deletion pkg/build/withvector/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package withvector

import (
"context"

"github.com/basenana/friday/config"
"github.com/basenana/friday/pkg/embedding"
huggingfaceembedding "github.com/basenana/friday/pkg/embedding/huggingface"
Expand Down Expand Up @@ -55,7 +57,7 @@ func NewFridayWithVector(conf *config.Config, vectorClient vectorstore.VectorSto
}
if conf.EmbeddingType == config.EmbeddingHuggingFace {
embeddingModel = huggingfaceembedding.NewHuggingFace(conf.EmbeddingUrl, conf.EmbeddingModel)
testEmbed, _, err := embeddingModel.VectorQuery("test")
testEmbed, _, err := embeddingModel.VectorQuery(context.TODO(), "test")
if err != nil {
return nil, err
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/embedding/huggingface/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package huggingface

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -52,7 +53,7 @@ type vectorResults struct {
Result [][]float32
}

func (h HuggingFace) VectorQuery(doc string) ([]float32, map[string]interface{}, error) {
func (h HuggingFace) VectorQuery(ctx context.Context, doc string) ([]float32, map[string]interface{}, error) {
path := "/embeddings/query"

model := h.model
Expand All @@ -72,7 +73,7 @@ func (h HuggingFace) VectorQuery(doc string) ([]float32, map[string]interface{},
return res.Result, nil, err
}

func (h HuggingFace) VectorDocs(docs []string) ([][]float32, []map[string]interface{}, error) {
func (h HuggingFace) VectorDocs(ctx context.Context, docs []string) ([][]float32, []map[string]interface{}, error) {
path := "/embeddings/docs"

model := h.model
Expand Down
6 changes: 4 additions & 2 deletions pkg/embedding/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

package embedding

import "context"

type Embedding interface {
VectorQuery(doc string) ([]float32, map[string]interface{}, error)
VectorDocs(docs []string) ([][]float32, []map[string]interface{}, error)
VectorQuery(ctx context.Context, doc string) ([]float32, map[string]interface{}, error)
VectorDocs(ctx context.Context, docs []string) ([][]float32, []map[string]interface{}, error)
}
8 changes: 4 additions & 4 deletions pkg/embedding/openai/v1/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ func NewOpenAIEmbedding(baseUrl, key string, qpm, burst int) embedding.Embedding
}
}

func (o *OpenAIEmbedding) VectorQuery(doc string) ([]float32, map[string]interface{}, error) {
res, err := o.Embedding(context.TODO(), doc)
func (o *OpenAIEmbedding) VectorQuery(ctx context.Context, doc string) ([]float32, map[string]interface{}, error) {
res, err := o.Embedding(ctx, doc)
if err != nil {
return nil, nil, err
}
Expand All @@ -48,12 +48,12 @@ func (o *OpenAIEmbedding) VectorQuery(doc string) ([]float32, map[string]interfa
return res.Data[0].Embedding, metadata, nil
}

func (o *OpenAIEmbedding) VectorDocs(docs []string) ([][]float32, []map[string]interface{}, error) {
func (o *OpenAIEmbedding) VectorDocs(ctx context.Context, docs []string) ([][]float32, []map[string]interface{}, error) {
res := make([][]float32, len(docs))
metadata := make([]map[string]interface{}, len(docs))

for i, doc := range docs {
r, err := o.Embedding(context.TODO(), doc)
r, err := o.Embedding(ctx, doc)
if err != nil {
return nil, nil, err
}
Expand Down
17 changes: 9 additions & 8 deletions pkg/friday/ingest.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package friday

import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
Expand All @@ -31,7 +32,7 @@ import (
)

// IngestFromFile ingest a whole file providing models.File
func (f *Friday) IngestFromFile(file models.File) error {
func (f *Friday) IngestFromFile(ctx context.Context, file models.File) error {
elements := []models.Element{}
parentDir := filepath.Dir(file.Source)
// split doc
Expand All @@ -48,11 +49,11 @@ func (f *Friday) IngestFromFile(file models.File) error {
elements = append(elements, e)
}
// ingest
return f.Ingest(elements)
return f.Ingest(ctx, elements)
}

// Ingest ingest elements of a file
func (f *Friday) Ingest(elements []models.Element) error {
func (f *Friday) Ingest(ctx context.Context, elements []models.Element) error {
f.Log.Debugf("Ingesting %d ...", len(elements))
for i, element := range elements {
// id: sha256(source)-group
Expand All @@ -67,7 +68,7 @@ func (f *Friday) Ingest(elements []models.Element) error {
continue
}

vectors, m, err := f.Embedding.VectorQuery(element.Content)
vectors, m, err := f.Embedding.VectorQuery(ctx, element.Content)
if err != nil {
return err
}
Expand All @@ -83,7 +84,7 @@ func (f *Friday) Ingest(elements []models.Element) error {
}

// IngestFromElementFile ingest a whole file given an element-style origin file
func (f *Friday) IngestFromElementFile(ps string) error {
func (f *Friday) IngestFromElementFile(ctx context.Context, ps string) error {
doc, err := os.ReadFile(ps)
if err != nil {
return err
Expand All @@ -93,11 +94,11 @@ func (f *Friday) IngestFromElementFile(ps string) error {
return err
}
merged := f.Spliter.Merge(elements)
return f.Ingest(merged)
return f.Ingest(ctx, merged)
}

// IngestFromOriginFile ingest a whole file given an origin file
func (f *Friday) IngestFromOriginFile(ps string) error {
func (f *Friday) IngestFromOriginFile(ctx context.Context, ps string) error {
fs, err := files.ReadFiles(ps)
if err != nil {
return err
Expand All @@ -120,5 +121,5 @@ func (f *Friday) IngestFromOriginFile(ps string) error {
}
}

return f.Ingest(elements)
return f.Ingest(ctx, elements)
}
8 changes: 5 additions & 3 deletions pkg/friday/ingest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package friday

import (
"context"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"

Expand Down Expand Up @@ -51,7 +53,7 @@ var _ = Describe("TestIngest", func() {
},
},
}
err := loFriday.Ingest(elements)
err := loFriday.Ingest(context.TODO(), elements)
Expect(err).Should(BeNil())
})
})
Expand All @@ -77,10 +79,10 @@ type FakeEmbedding struct{}

var _ embedding.Embedding = FakeEmbedding{}

func (f FakeEmbedding) VectorQuery(doc string) ([]float32, map[string]interface{}, error) {
func (f FakeEmbedding) VectorQuery(ctx context.Context, doc string) ([]float32, map[string]interface{}, error) {
return []float32{}, map[string]interface{}{}, nil
}

func (f FakeEmbedding) VectorDocs(docs []string) ([][]float32, []map[string]interface{}, error) {
func (f FakeEmbedding) VectorDocs(ctx context.Context, docs []string) ([][]float32, []map[string]interface{}, error) {
return [][]float32{}, []map[string]interface{}{}, nil
}
6 changes: 3 additions & 3 deletions pkg/friday/question.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (

func (f *Friday) Question(ctx context.Context, q string) (string, error) {
prompt := prompts.NewQuestionPrompt()
c, err := f.searchDocs(q)
c, err := f.searchDocs(ctx, q)
if err != nil {
return "", err
}
Expand All @@ -44,9 +44,9 @@ func (f *Friday) Question(ctx context.Context, q string) (string, error) {
return c, nil
}

func (f *Friday) searchDocs(q string) (string, error) {
func (f *Friday) searchDocs(ctx context.Context, q string) (string, error) {
f.Log.Debugf("vector query for %s ...", q)
qv, _, err := f.Embedding.VectorQuery(q)
qv, _, err := f.Embedding.VectorQuery(ctx, q)
if err != nil {
return "", fmt.Errorf("vector embedding error: %w", err)
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/friday/question_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ var _ = Describe("TestQuestion", func() {
Expect(ans).Should(Equal("I am an answer"))
})
It("searchDocs should be succeed", func() {
ans, err := loFriday.searchDocs("I am a question")
ans, err := loFriday.searchDocs(context.TODO(), "I am a question")
Expect(err).Should(BeNil())
Expect(ans).Should(Equal("There are logs of questions"))
})
Expand Down Expand Up @@ -83,11 +83,11 @@ type FakeQuestionEmbedding struct{}

var _ embedding.Embedding = FakeQuestionEmbedding{}

func (f FakeQuestionEmbedding) VectorQuery(doc string) ([]float32, map[string]interface{}, error) {
func (f FakeQuestionEmbedding) VectorQuery(ctx context.Context, doc string) ([]float32, map[string]interface{}, error) {
return []float32{}, map[string]interface{}{}, nil
}

func (f FakeQuestionEmbedding) VectorDocs(docs []string) ([][]float32, []map[string]interface{}, error) {
func (f FakeQuestionEmbedding) VectorDocs(ctx context.Context, docs []string) ([][]float32, []map[string]interface{}, error) {
return [][]float32{}, []map[string]interface{}{}, nil
}

Expand Down
8 changes: 4 additions & 4 deletions pkg/llm/client/openai/v1/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ func (o *OpenAIV1) Chat(ctx context.Context, prompt prompts.PromptTemplate, para
answer, err := o.chat(ctx, prompt, parameters)
if err != nil {
errMsg := err.Error()
if strings.Contains(errMsg, "rate_limit_exceeded") {
o.log.Warnf("meets rate limit exceeded, sleep %d second and retry", o.rateLimit)
time.Sleep(time.Duration(o.rateLimit) * time.Second)
if strings.Contains(errMsg, "rate_limit_exceeded") || strings.Contains(errMsg, "Rate limit reached") {
o.log.Warn("meets rate limit exceeded, sleep 30 seconds and retry")
time.Sleep(time.Duration(30) * time.Second)
return o.chat(ctx, prompt, parameters)
}
return nil, err
Expand All @@ -68,7 +68,7 @@ func (o *OpenAIV1) chat(ctx context.Context, prompt prompts.PromptTemplate, para
data := map[string]interface{}{
"model": model,
"messages": []interface{}{map[string]string{"role": "user", "content": p}},
"max_tokens": 4096,
"max_tokens": 1024,
"temperature": 0.7,
"top_p": 1,
"frequency_penalty": 0,
Expand Down
9 changes: 4 additions & 5 deletions pkg/llm/client/openai/v1/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ const (
type OpenAIV1 struct {
log logger.Logger

baseUri string
key string
rateLimit int
baseUri string
key string

limiter *rate.Limiter
}
Expand All @@ -52,7 +51,7 @@ func NewOpenAIV1(baseUrl, key string, qpm, burst int) *OpenAIV1 {
burst = defaultBurst
}

limiter := rate.NewLimiter(rate.Limit(qpm/60), burst)
limiter := rate.NewLimiter(rate.Limit(qpm), burst*60)

return &OpenAIV1{
log: logger.NewLogger("openai"),
Expand All @@ -65,7 +64,7 @@ func NewOpenAIV1(baseUrl, key string, qpm, burst int) *OpenAIV1 {
var _ llm.LLM = &OpenAIV1{}

func (o *OpenAIV1) request(ctx context.Context, path string, method string, body io.Reader) ([]byte, error) {
err := o.limiter.Wait(ctx)
err := o.limiter.WaitN(ctx, 60)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/llm/client/openai/v1/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ func (o *OpenAIV1) Completion(ctx context.Context, prompt prompts.PromptTemplate
if err != nil {
errMsg := err.Error()
if strings.Contains(errMsg, "rate_limit_exceeded") {
o.log.Warnf("meets rate limit exceeded, sleep %d second and retry", o.rateLimit)
time.Sleep(time.Duration(o.rateLimit) * time.Second)
o.log.Warn("meets rate limit exceeded, sleep 30 seconds and retry")
time.Sleep(time.Duration(30) * time.Second)
return o.completion(ctx, prompt, parameters)
}
return nil, err
Expand Down
4 changes: 2 additions & 2 deletions pkg/llm/client/openai/v1/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ func (o *OpenAIV1) Embedding(ctx context.Context, doc string) (*EmbeddingResult,
if err != nil {
errMsg := err.Error()
if strings.Contains(errMsg, "rate_limit_exceeded") {
o.log.Warnf("meets rate limit exceeded, sleep %d second and retry", o.rateLimit)
time.Sleep(time.Duration(o.rateLimit) * time.Second)
o.log.Warn("meets rate limit exceeded, sleep 30 seconds and retry")
time.Sleep(time.Duration(30) * time.Second)
return o.embedding(ctx, doc)
}
return nil, err
Expand Down

0 comments on commit bf3e331

Please sign in to comment.