Skip to content

Commit

Permalink
chore: tests have been refactored to match the encoding format passed…
Browse files Browse the repository at this point in the history
… by request
  • Loading branch information
henomis committed Sep 9, 2023
1 parent 250a433 commit 94550a8
Showing 1 changed file with 24 additions and 9 deletions.
33 changes: 24 additions & 9 deletions embeddings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,33 @@ func TestEmbeddingEndpoint(t *testing.T) {
{Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}},
}

sampleBase64Embeddings := []Base64Embedding{
{Embedding: "pHCdP4XrkUDhevxA"},
{Embedding: "/1jku0G/rLvA/EI8"},
}

server.RegisterHandler(
"/v1/embeddings",
func(w http.ResponseWriter, r *http.Request) {
resBytes, _ := json.Marshal(EmbeddingResponse{Data: sampleEmbeddings})
var req struct {
EncodingFormat EmbeddingEncodingFormat `json:"encoding_format"`
}
_ = json.NewDecoder(r.Body).Decode(&req)

var resBytes []byte
if req.EncodingFormat == EmbeddingEncodingFormatBase64 {
resBytes, _ = json.Marshal(EmbeddingResponseBase64{Data: sampleBase64Embeddings})
} else {
resBytes, _ = json.Marshal(EmbeddingResponse{Data: sampleEmbeddings})
}
fmt.Fprintln(w, string(resBytes))
},
)
// test create embeddings with strings (simple embedding request)
res, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{})
checks.NoError(t, err, "CreateEmbeddings error")
if len(res.Data) != len(sampleEmbeddings) {
t.Errorf("Expected %d embeddings, got %d", len(sampleEmbeddings), len(res.Data))
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
}

// test create embeddings with strings (simple embedding request)
Expand All @@ -126,22 +141,22 @@ func TestEmbeddingEndpoint(t *testing.T) {
},
)
checks.NoError(t, err, "CreateEmbeddings error")
if len(res.Data) != len(sampleEmbeddings) {
t.Errorf("Expected %d embeddings, got %d", len(sampleEmbeddings), len(res.Data))
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
}

// test create embeddings with strings
res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{})
checks.NoError(t, err, "CreateEmbeddings strings error")
if len(res.Data) != len(sampleEmbeddings) {
t.Errorf("Expected %d embeddings, got %d", len(sampleEmbeddings), len(res.Data))
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
}

// test create embeddings with tokens
res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{})
checks.NoError(t, err, "CreateEmbeddings tokens error")
if len(res.Data) != len(sampleEmbeddings) {
t.Errorf("Expected %d embeddings, got %d", len(sampleEmbeddings), len(res.Data))
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
}
}

Expand Down

0 comments on commit 94550a8

Please sign in to comment.