diff --git a/README.md b/README.md index 1f708af70..07c86cb81 100644 --- a/README.md +++ b/README.md @@ -527,6 +527,202 @@ func main() { ``` + +
+Generate Embeddings + +```go +package main + +import ( + "context" + "encoding/gob" + "fmt" + "os" + + "github.com/sashabaranov/go-openai" +) + +func getEmbedding(ctx context.Context, client *openai.Client, input []string) ([]float32, error) { + + resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{ + Input: input, + Model: openai.AdaEmbeddingV2, + }) + + if err != nil { + return nil, err + } + + return resp.Data[0].Embedding, nil +} + +func main() { + + ctx := context.Background() + client := openai.NewClient("your token") + + // example selections + selections := []string{ + "Welcome to the go-openai interface, which will be the gateway for golang software engineers to enter the OpenAI development world.", + "It was tasty and fresh. The other one I bought was old and tasted moldy. But this one was good.", + "Great coffee at a good price. I'm a subscription buyer and I buy this month after month. What more can I say?", + "This chocolate is amazing..I love the taste and smell, this is the only chocolate for me...I found a new love!", + "I love this coffee! And such a great price. Will buy more when I am running out which will be soon.", + "The Raspberry Tea Syrup is great. I can use it for hot and cold drinks as well in certain recipes.", + "Everyone that dips with this loves it! So easy to use! Olive oil and tasty bread is all you need.", + "This is a favorite of mine for using over ice. Even bought it to give out as Christmas gifts last year.", + "If you like a great , hot, sauce then buy this. If spicy with heat isn't to your liking then don't buy it.", + "My name is Aceld, and I am a Golang software development engineer. I like young and beautiful girls.", + "The competition was held over two days,24 July and 2 August. The qualifying round was the first day with the apparatus final on the second day.", + "There are 4 types of gymnastics apparatus: floor, vault, pommel horse, and rings. The apparatus final is a competition between the top 8 gymnasts in each apparatus.", + } + + // Generate embeddings + var selectionsEmbeddings [][]float32 + for _, selection := range selections { + embedding, err := getEmbedding(ctx, client, []string{selection}) + if err != nil { + fmt.Printf("GetEmedding error: %v\n", err) + return + } + selectionsEmbeddings = append(selectionsEmbeddings, embedding) + } + + // Write embeddings binary data to file + file, err := os.Create("embeddings.bin") + if err != nil { + fmt.Printf("Create file error: %v\n", err) + return + } + defer file.Close() + + encoder := gob.NewEncoder(file) + err = encoder.Encode(selectionsEmbeddings) + if err != nil { + fmt.Printf("Encode error: %v\n", err) + return + } + + return +} +``` +
+ +
+Embedding Similarity Search + +```go +package main + +import ( + "context" + "encoding/gob" + "fmt" + "io/ioutil" + "os" + "sort" + "strings" + + "github.com/sashabaranov/go-openai" +) + +func getEmbedding(ctx context.Context, client *openai.Client, input []string) ([]float32, error) { + resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{ + Input: input, + Model: openai.AdaEmbeddingV2, + }) + + if err != nil { + return nil, err + } + + return resp.Data[0].Embedding, nil +} + +// Sort the index in descending order of similarity +func sortIndexes(scores []float32) []int { + indexes := make([]int, len(scores)) + for i := range indexes { + indexes[i] = i + } + sort.SliceStable(indexes, func(i, j int) bool { + return scores[indexes[i]] > scores[indexes[j]] + }) + return indexes +} + +func main() { + ctx := context.Background() + client := openai.NewClient("your token") + + // "embeddings.bin" from exp: + file, err := os.Open("embeddings.bin") + if err != nil { + panic(err) + } + defer file.Close() + + // load all embeddings from local binary file + var allEmbeddings [][]float32 + decoder := gob.NewDecoder(file) + if err := decoder.Decode(&allEmbeddings); err != nil { + fmt.Printf("Decode error: %v\n", err) + return + } + + // make some input you like + input := "I am a Golang Software Engineer, I like Go and OpenAI." + + // get embedding of input + inputEmbd, err := getEmbedding(ctx, client, []string{input}) + if err != nil { + fmt.Printf("GetEmedding error: %v\n", err) + return + } + + // Calculate similarity through cosine matching algorithm + var questionScores []float32 + for _, embed := range allEmbeddings { + // OpenAI embeddings are normalized to length 1, which means that: + // Cosine similarity can be computed slightly faster using just a dot product + score := openai.DotProduct(embed, inputEmbd) + questionScores = append(questionScores, score) + } + + // Take the subscripts of the top few selections with the highest similarity + sortedIndexes := sortIndexes(questionScores) + sortedIndexes = sortedIndexes[:3] // Top 3 + + fmt.Println("input:", input) + fmt.Println("----------------------") + fmt.Println("similarity section:") + selectionsFile, err := os.Open("selections.txt") + if err != nil { + fmt.Printf("Open file error: %v\n", err) + return + } + defer selectionsFile.Close() + + fileData, err := ioutil.ReadAll(selectionsFile) + if err != nil { + fmt.Printf("ReadAll file error: %v\n", err) + return + } + + // Split by line + selections := strings.Split(string(fileData), "\n") + + for _, index := range sortedIndexes { + selection := selections[index] + fmt.Printf("%.4f %s\n", questionScores[index], selection) + } + + return +} +``` +
+
JSON Schema for function calling @@ -593,19 +789,20 @@ The `Parameters` field of a `FunctionDefinition` can accept either of the above Open-AI maintains clear documentation on how to [handle API errors](https://platform.openai.com/docs/guides/error-codes/api-errors) example: -``` +```go e := &openai.APIError{} + if errors.As(err, &e) { - switch e.HTTPStatusCode { - case 401: - // invalid auth or key (do not retry) + switch e.HTTPStatusCode { + case 401: + // invalid auth or key (do not retry) case 429: - // rate limiting or engine overload (wait and retry) + // rate limiting or engine overload (wait and retry) case 500: - // openai server error (retry) + // openai server error (retry) default: - // unhandled - } + // unhandled + } } ``` diff --git a/embeddings_test.go b/embeddings_test.go index 47c4f5108..df26859c9 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -1,15 +1,16 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "bytes" "context" "encoding/json" "fmt" + "math" "net/http" "testing" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestEmbedding(t *testing.T) { @@ -116,3 +117,21 @@ func TestEmbeddingEndpoint(t *testing.T) { _, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{}) checks.NoError(t, err, "CreateEmbeddings tokens error") } + +func TestDotProduct(t *testing.T) { + v1 := []float32{1, 2, 3} + v2 := []float32{2, 4, 6} + expected := float32(28.0) + result := DotProduct(v1, v2) + if math.Abs(float64(result-expected)) > 1e-12 { + t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result) + } + + v1 = []float32{1, 0, 0} + v2 = []float32{0, 1, 0} + expected = float32(0.0) + result = DotProduct(v1, v2) + if math.Abs(float64(result-expected)) > 1e-12 { + t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result) + } +} diff --git a/embeddings_utils.go b/embeddings_utils.go new file mode 100644 index 000000000..76efb66b2 --- /dev/null +++ b/embeddings_utils.go @@ -0,0 +1,11 @@ +package openai + +// DotProduct Calculate dot product of two vectors. +func DotProduct(v1, v2 []float32) float32 { + var result float32 + // Iterate over vectors and calculate dot product. + for i := 0; i < len(v1); i++ { + result += v1[i] * v2[i] + } + return result +}