diff --git a/README.md b/README.md
index 440c40968..c618cd7fa 100644
--- a/README.md
+++ b/README.md
@@ -483,6 +483,62 @@ func main() {
```
+
+Embedding Semantic Similarity
+
+```go
+package main
+
+import (
+ "context"
+ "log"
+ openai "github.com/sashabaranov/go-openai"
+
+)
+
+func main() {
+ client := openai.NewClient("your-token")
+
+ // Create an EmbeddingRequest for the user query
+ queryReq := openai.EmbeddingRequest{
+ Input: []string{"How many chucks would a woodchuck chuck"},
+ Model: openai.AdaEmbeddingv2,
+ }
+
+ // Create an embedding for the user query
+ queryResponse, err := client.CreateEmbeddings(context.Background(), queryReq)
+ if err != nil {
+ log.Fatal("Error creating query embedding:", err)
+ }
+
+ // Create an EmbeddingRequest for the target text
+ targetReq := openai.EmbeddingRequest{
+ Input: []string{"How many chucks would a woodchuck chuck if the woodchuck could chuck wood"},
+ Model: openai.AdaEmbeddingv2,
+ }
+
+ // Create an embedding for the target text
+ targetResponse, err := client.CreateEmbeddings(context.Background(), targetReq)
+ if err != nil {
+ log.Fatal("Error creating target embedding:", err)
+ }
+
+ // Now that we have the embeddings for the user query and the target text, we
+ // can calculate their similarity.
+ queryEmbedding := queryResponse.Data[0]
+ targetEmbedding := targetResponse.Data[0]
+
+ similarity, err := queryEmbedding.DotProduct(&targetEmbedding)
+ if err != nil {
+ log.Fatal("Error calculating dot product:", err)
+ }
+
+ log.Printf("The similarity score between the query and the target is %f", similarity)
+}
+
+```
+
+
Azure OpenAI Embeddings
diff --git a/embeddings.go b/embeddings.go
index 5ba91f235..660bc24c3 100644
--- a/embeddings.go
+++ b/embeddings.go
@@ -4,10 +4,13 @@ import (
"context"
"encoding/base64"
"encoding/binary"
+ "errors"
"math"
"net/http"
)
+var ErrVectorLengthMismatch = errors.New("vector length mismatch")
+
// EmbeddingModel enumerates the models which can be used
// to generate Embedding vectors.
type EmbeddingModel int
@@ -124,6 +127,23 @@ type Embedding struct {
Index int `json:"index"`
}
+// DotProduct calculates the dot product of the embedding vector with another
+// embedding vector. Both vectors must have the same length; otherwise, an
+// ErrVectorLengthMismatch is returned. The method returns the calculated dot
+// product as a float32 value.
+func (e *Embedding) DotProduct(other *Embedding) (float32, error) {
+ if len(e.Embedding) != len(other.Embedding) {
+ return 0, ErrVectorLengthMismatch
+ }
+
+ var dotProduct float32
+ for i := range e.Embedding {
+ dotProduct += e.Embedding[i] * other.Embedding[i]
+ }
+
+ return dotProduct, nil
+}
+
// EmbeddingResponse is the response from a Create embeddings request.
type EmbeddingResponse struct {
Object string `json:"object"`
diff --git a/embeddings_test.go b/embeddings_test.go
index 9c48c5b8f..72e8c245f 100644
--- a/embeddings_test.go
+++ b/embeddings_test.go
@@ -4,7 +4,9 @@ import (
"bytes"
"context"
"encoding/json"
+ "errors"
"fmt"
+ "math"
"net/http"
"reflect"
"testing"
@@ -233,3 +235,39 @@ func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) {
})
}
}
+
+func TestDotProduct(t *testing.T) {
+ v1 := &Embedding{Embedding: []float32{1, 2, 3}}
+ v2 := &Embedding{Embedding: []float32{2, 4, 6}}
+ expected := float32(28.0)
+
+ result, err := v1.DotProduct(v2)
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+
+ if math.Abs(float64(result-expected)) > 1e-12 {
+ t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result)
+ }
+
+ v1 = &Embedding{Embedding: []float32{1, 0, 0}}
+ v2 = &Embedding{Embedding: []float32{0, 1, 0}}
+ expected = float32(0.0)
+
+ result, err = v1.DotProduct(v2)
+ if err != nil {
+ t.Errorf("Unexpected error: %v", err)
+ }
+
+ if math.Abs(float64(result-expected)) > 1e-12 {
+ t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result)
+ }
+
+ // Test for VectorLengthMismatchError
+ v1 = &Embedding{Embedding: []float32{1, 0, 0}}
+ v2 = &Embedding{Embedding: []float32{0, 1}}
+ _, err = v1.DotProduct(v2)
+ if !errors.Is(err, ErrVectorLengthMismatch) {
+ t.Errorf("Expected Vector Length Mismatch Error, but got: %v", err)
+ }
+}