diff --git a/README.md b/README.md index 440c40968..c618cd7fa 100644 --- a/README.md +++ b/README.md @@ -483,6 +483,62 @@ func main() { ``` + +Embedding Semantic Similarity + +```go +package main + +import ( + "context" + "log" + openai "github.com/sashabaranov/go-openai" + +) + +func main() { + client := openai.NewClient("your-token") + + // Create an EmbeddingRequest for the user query + queryReq := openai.EmbeddingRequest{ + Input: []string{"How many chucks would a woodchuck chuck"}, + Model: openai.AdaEmbeddingv2, + } + + // Create an embedding for the user query + queryResponse, err := client.CreateEmbeddings(context.Background(), queryReq) + if err != nil { + log.Fatal("Error creating query embedding:", err) + } + + // Create an EmbeddingRequest for the target text + targetReq := openai.EmbeddingRequest{ + Input: []string{"How many chucks would a woodchuck chuck if the woodchuck could chuck wood"}, + Model: openai.AdaEmbeddingv2, + } + + // Create an embedding for the target text + targetResponse, err := client.CreateEmbeddings(context.Background(), targetReq) + if err != nil { + log.Fatal("Error creating target embedding:", err) + } + + // Now that we have the embeddings for the user query and the target text, we + // can calculate their similarity. + queryEmbedding := queryResponse.Data[0] + targetEmbedding := targetResponse.Data[0] + + similarity, err := queryEmbedding.DotProduct(&targetEmbedding) + if err != nil { + log.Fatal("Error calculating dot product:", err) + } + + log.Printf("The similarity score between the query and the target is %f", similarity) +} + +``` + +
Azure OpenAI Embeddings diff --git a/embeddings.go b/embeddings.go index 5ba91f235..660bc24c3 100644 --- a/embeddings.go +++ b/embeddings.go @@ -4,10 +4,13 @@ import ( "context" "encoding/base64" "encoding/binary" + "errors" "math" "net/http" ) +var ErrVectorLengthMismatch = errors.New("vector length mismatch") + // EmbeddingModel enumerates the models which can be used // to generate Embedding vectors. type EmbeddingModel int @@ -124,6 +127,23 @@ type Embedding struct { Index int `json:"index"` } +// DotProduct calculates the dot product of the embedding vector with another +// embedding vector. Both vectors must have the same length; otherwise, an +// ErrVectorLengthMismatch is returned. The method returns the calculated dot +// product as a float32 value. +func (e *Embedding) DotProduct(other *Embedding) (float32, error) { + if len(e.Embedding) != len(other.Embedding) { + return 0, ErrVectorLengthMismatch + } + + var dotProduct float32 + for i := range e.Embedding { + dotProduct += e.Embedding[i] * other.Embedding[i] + } + + return dotProduct, nil +} + // EmbeddingResponse is the response from a Create embeddings request. type EmbeddingResponse struct { Object string `json:"object"` diff --git a/embeddings_test.go b/embeddings_test.go index 9c48c5b8f..72e8c245f 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -4,7 +4,9 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" + "math" "net/http" "reflect" "testing" @@ -233,3 +235,39 @@ func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { }) } } + +func TestDotProduct(t *testing.T) { + v1 := &Embedding{Embedding: []float32{1, 2, 3}} + v2 := &Embedding{Embedding: []float32{2, 4, 6}} + expected := float32(28.0) + + result, err := v1.DotProduct(v2) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if math.Abs(float64(result-expected)) > 1e-12 { + t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result) + } + + v1 = &Embedding{Embedding: []float32{1, 0, 0}} + v2 = &Embedding{Embedding: []float32{0, 1, 0}} + expected = float32(0.0) + + result, err = v1.DotProduct(v2) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if math.Abs(float64(result-expected)) > 1e-12 { + t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result) + } + + // Test for VectorLengthMismatchError + v1 = &Embedding{Embedding: []float32{1, 0, 0}} + v2 = &Embedding{Embedding: []float32{0, 1}} + _, err = v1.DotProduct(v2) + if !errors.Is(err, ErrVectorLengthMismatch) { + t.Errorf("Expected Vector Length Mismatch Error, but got: %v", err) + } +}