diff --git a/README.md b/README.md
index 1f708af70..07c86cb81 100644
--- a/README.md
+++ b/README.md
@@ -527,6 +527,202 @@ func main() {
```
+
+
+Generate Embeddings
+
+```go
+package main
+
+import (
+ "context"
+ "encoding/gob"
+ "fmt"
+ "os"
+
+ "github.com/sashabaranov/go-openai"
+)
+
+func getEmbedding(ctx context.Context, client *openai.Client, input []string) ([]float32, error) {
+
+ resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{
+ Input: input,
+ Model: openai.AdaEmbeddingV2,
+ })
+
+ if err != nil {
+ return nil, err
+ }
+
+ return resp.Data[0].Embedding, nil
+}
+
+func main() {
+
+ ctx := context.Background()
+ client := openai.NewClient("your token")
+
+ // example selections
+ selections := []string{
+ "Welcome to the go-openai interface, which will be the gateway for golang software engineers to enter the OpenAI development world.",
+ "It was tasty and fresh. The other one I bought was old and tasted moldy. But this one was good.",
+ "Great coffee at a good price. I'm a subscription buyer and I buy this month after month. What more can I say?",
+ "This chocolate is amazing..I love the taste and smell, this is the only chocolate for me...I found a new love!",
+ "I love this coffee! And such a great price. Will buy more when I am running out which will be soon.",
+ "The Raspberry Tea Syrup is great. I can use it for hot and cold drinks as well in certain recipes.",
+ "Everyone that dips with this loves it! So easy to use! Olive oil and tasty bread is all you need.",
+ "This is a favorite of mine for using over ice. Even bought it to give out as Christmas gifts last year.",
+ "If you like a great , hot, sauce then buy this. If spicy with heat isn't to your liking then don't buy it.",
+ "My name is Aceld, and I am a Golang software development engineer. I like young and beautiful girls.",
+ "The competition was held over two days,24 July and 2 August. The qualifying round was the first day with the apparatus final on the second day.",
+ "There are 4 types of gymnastics apparatus: floor, vault, pommel horse, and rings. The apparatus final is a competition between the top 8 gymnasts in each apparatus.",
+ }
+
+ // Generate embeddings
+ var selectionsEmbeddings [][]float32
+ for _, selection := range selections {
+ embedding, err := getEmbedding(ctx, client, []string{selection})
+ if err != nil {
+ fmt.Printf("GetEmedding error: %v\n", err)
+ return
+ }
+ selectionsEmbeddings = append(selectionsEmbeddings, embedding)
+ }
+
+ // Write embeddings binary data to file
+ file, err := os.Create("embeddings.bin")
+ if err != nil {
+ fmt.Printf("Create file error: %v\n", err)
+ return
+ }
+ defer file.Close()
+
+ encoder := gob.NewEncoder(file)
+ err = encoder.Encode(selectionsEmbeddings)
+ if err != nil {
+ fmt.Printf("Encode error: %v\n", err)
+ return
+ }
+
+ return
+}
+```
+
+
+
+Embedding Similarity Search
+
+```go
+package main
+
+import (
+ "context"
+ "encoding/gob"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "sort"
+ "strings"
+
+ "github.com/sashabaranov/go-openai"
+)
+
+func getEmbedding(ctx context.Context, client *openai.Client, input []string) ([]float32, error) {
+ resp, err := client.CreateEmbeddings(ctx, openai.EmbeddingRequest{
+ Input: input,
+ Model: openai.AdaEmbeddingV2,
+ })
+
+ if err != nil {
+ return nil, err
+ }
+
+ return resp.Data[0].Embedding, nil
+}
+
+// Sort the index in descending order of similarity
+func sortIndexes(scores []float32) []int {
+ indexes := make([]int, len(scores))
+ for i := range indexes {
+ indexes[i] = i
+ }
+ sort.SliceStable(indexes, func(i, j int) bool {
+ return scores[indexes[i]] > scores[indexes[j]]
+ })
+ return indexes
+}
+
+func main() {
+ ctx := context.Background()
+ client := openai.NewClient("your token")
+
+ // "embeddings.bin" from exp:
+ file, err := os.Open("embeddings.bin")
+ if err != nil {
+ panic(err)
+ }
+ defer file.Close()
+
+ // load all embeddings from local binary file
+ var allEmbeddings [][]float32
+ decoder := gob.NewDecoder(file)
+ if err := decoder.Decode(&allEmbeddings); err != nil {
+ fmt.Printf("Decode error: %v\n", err)
+ return
+ }
+
+ // make some input you like
+ input := "I am a Golang Software Engineer, I like Go and OpenAI."
+
+ // get embedding of input
+ inputEmbd, err := getEmbedding(ctx, client, []string{input})
+ if err != nil {
+ fmt.Printf("GetEmedding error: %v\n", err)
+ return
+ }
+
+ // Calculate similarity through cosine matching algorithm
+ var questionScores []float32
+ for _, embed := range allEmbeddings {
+ // OpenAI embeddings are normalized to length 1, which means that:
+ // Cosine similarity can be computed slightly faster using just a dot product
+ score := openai.DotProduct(embed, inputEmbd)
+ questionScores = append(questionScores, score)
+ }
+
+ // Take the subscripts of the top few selections with the highest similarity
+ sortedIndexes := sortIndexes(questionScores)
+ sortedIndexes = sortedIndexes[:3] // Top 3
+
+ fmt.Println("input:", input)
+ fmt.Println("----------------------")
+ fmt.Println("similarity section:")
+ selectionsFile, err := os.Open("selections.txt")
+ if err != nil {
+ fmt.Printf("Open file error: %v\n", err)
+ return
+ }
+ defer selectionsFile.Close()
+
+ fileData, err := ioutil.ReadAll(selectionsFile)
+ if err != nil {
+ fmt.Printf("ReadAll file error: %v\n", err)
+ return
+ }
+
+ // Split by line
+ selections := strings.Split(string(fileData), "\n")
+
+ for _, index := range sortedIndexes {
+ selection := selections[index]
+ fmt.Printf("%.4f %s\n", questionScores[index], selection)
+ }
+
+ return
+}
+```
+
+
JSON Schema for function calling
@@ -593,19 +789,20 @@ The `Parameters` field of a `FunctionDefinition` can accept either of the above
Open-AI maintains clear documentation on how to [handle API errors](https://platform.openai.com/docs/guides/error-codes/api-errors)
example:
-```
+```go
e := &openai.APIError{}
+
if errors.As(err, &e) {
- switch e.HTTPStatusCode {
- case 401:
- // invalid auth or key (do not retry)
+ switch e.HTTPStatusCode {
+ case 401:
+ // invalid auth or key (do not retry)
case 429:
- // rate limiting or engine overload (wait and retry)
+ // rate limiting or engine overload (wait and retry)
case 500:
- // openai server error (retry)
+ // openai server error (retry)
default:
- // unhandled
- }
+ // unhandled
+ }
}
```
diff --git a/embeddings_test.go b/embeddings_test.go
index 47c4f5108..df26859c9 100644
--- a/embeddings_test.go
+++ b/embeddings_test.go
@@ -1,15 +1,16 @@
package openai_test
import (
- . "github.com/sashabaranov/go-openai"
- "github.com/sashabaranov/go-openai/internal/test/checks"
-
"bytes"
"context"
"encoding/json"
"fmt"
+ "math"
"net/http"
"testing"
+
+ . "github.com/sashabaranov/go-openai"
+ "github.com/sashabaranov/go-openai/internal/test/checks"
)
func TestEmbedding(t *testing.T) {
@@ -116,3 +117,21 @@ func TestEmbeddingEndpoint(t *testing.T) {
_, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{})
checks.NoError(t, err, "CreateEmbeddings tokens error")
}
+
+func TestDotProduct(t *testing.T) {
+ v1 := []float32{1, 2, 3}
+ v2 := []float32{2, 4, 6}
+ expected := float32(28.0)
+ result := DotProduct(v1, v2)
+ if math.Abs(float64(result-expected)) > 1e-12 {
+ t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result)
+ }
+
+ v1 = []float32{1, 0, 0}
+ v2 = []float32{0, 1, 0}
+ expected = float32(0.0)
+ result = DotProduct(v1, v2)
+ if math.Abs(float64(result-expected)) > 1e-12 {
+ t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result)
+ }
+}
diff --git a/embeddings_utils.go b/embeddings_utils.go
new file mode 100644
index 000000000..76efb66b2
--- /dev/null
+++ b/embeddings_utils.go
@@ -0,0 +1,11 @@
+package openai
+
+// DotProduct Calculate dot product of two vectors.
+func DotProduct(v1, v2 []float32) float32 {
+ var result float32
+ // Iterate over vectors and calculate dot product.
+ for i := 0; i < len(v1); i++ {
+ result += v1[i] * v2[i]
+ }
+ return result
+}