Skip to content

Commit

Permalink
Merge pull request #55 from philippgille/simplify-unit-tests
Browse files Browse the repository at this point in the history
Simplify unit tests
  • Loading branch information
philippgille authored Mar 17, 2024
2 parents c286d3f + 51845b7 commit f76265c
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 123 deletions.
1 change: 1 addition & 0 deletions collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ func TestCollection_Add_Error(t *testing.T) {
embeddings := [][]float32{vectors, vectors}
metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
contents := []string{"hello world", "hallo welt"}

// Empty IDs
err = c.Add(ctx, []string{}, embeddings, metadatas, contents)
if err == nil {
Expand Down
137 changes: 31 additions & 106 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package chromem

import (
"context"
"reflect"
"slices"
"testing"
)
Expand Down Expand Up @@ -37,36 +38,18 @@ func TestDB_CreateCollection(t *testing.T) {
if !ok {
t.Fatal("expected collection", name, "not found")
}
if c2.Name != name {
t.Fatal("expected name", name, "got", c2.Name)
}
// The returned collection should also be the same
if c.Name != name {
t.Fatal("expected name", name, "got", c.Name)
}
// The collection's persistent dir should be empty
if c.persistDirectory != "" {
t.Fatal("expected empty persistent directory, got", c.persistDirectory)
}
// It's metadata should match
if len(c.metadata) != 1 || c.metadata["foo"] != "bar" {
t.Fatal("expected metadata", metadata, "got", c.metadata)
}
// Documents should be empty, but not nil
if c.documents == nil {
t.Fatal("expected non-nil documents, got nil")
}
if len(c.documents) != 0 {
t.Fatal("expected empty documents, got", len(c.documents))
}
// The embedding function should be the one we passed
// Check the embedding function first, then the rest with DeepEqual
gotVectors, err := c.embed(context.Background(), "test")
if err != nil {
t.Fatal("expected no error, got", err)
}
if !slices.Equal(gotVectors, vectors) {
t.Fatal("expected vectors", vectors, "got", gotVectors)
}
c.embed, c2.embed = nil, nil
if !reflect.DeepEqual(c, c2) {
t.Fatalf("expected collection %+v, got %+v", c, c2)
}
})

t.Run("NOK - Empty name", func(t *testing.T) {
Expand All @@ -88,8 +71,7 @@ func TestDB_ListCollections(t *testing.T) {

// Create initial collection
db := NewDB()
// We ignore the return value. CreateCollection is tested elsewhere.
_, err := db.CreateCollection(name, metadata, embeddingFunc)
orig, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand All @@ -108,32 +90,18 @@ func TestDB_ListCollections(t *testing.T) {
if !ok {
t.Fatal("expected collection", name, "not found")
}
if c.Name != name {
t.Fatal("expected name", name, "got", c.Name)
}
// The collection's persistent dir should be empty
if c.persistDirectory != "" {
t.Fatal("expected empty persistent directory, got", c.persistDirectory)
}
// It's metadata should match
if len(c.metadata) != 1 || c.metadata["foo"] != "bar" {
t.Fatal("expected metadata", metadata, "got", c.metadata)
}
// Documents should be empty, but not nil
if c.documents == nil {
t.Fatal("expected non-nil documents, got nil")
}
if len(c.documents) != 0 {
t.Fatal("expected empty documents, got", len(c.documents))
}
// The embedding function should be the one we passed
// Check the embedding function first, then the rest with DeepEqual
gotVectors, err := c.embed(context.Background(), "test")
if err != nil {
t.Fatal("expected no error, got", err)
}
if !slices.Equal(gotVectors, vectors) {
t.Fatal("expected vectors", vectors, "got", gotVectors)
}
orig.embed, c.embed = nil, nil
if !reflect.DeepEqual(orig, c) {
t.Fatalf("expected collection %+v, got %+v", orig, c)
}

// And it should be a copy. Adding a value here should not reflect on the DB's
// collection.
Expand All @@ -154,42 +122,26 @@ func TestDB_GetCollection(t *testing.T) {

// Create initial collection
db := NewDB()
// We ignore the return value. CreateCollection is tested elsewhere.
_, err := db.CreateCollection(name, metadata, embeddingFunc)
orig, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}

// Get collection
c := db.GetCollection(name, nil)

// Check expectations
if c.Name != name {
t.Fatal("expected name", name, "got", c.Name)
}
// The collection's persistent dir should be empty
if c.persistDirectory != "" {
t.Fatal("expected empty persistent directory, got", c.persistDirectory)
}
// It's metadata should match
if len(c.metadata) != 1 || c.metadata["foo"] != "bar" {
t.Fatal("expected metadata", metadata, "got", c.metadata)
}
// Documents should be empty, but not nil
if c.documents == nil {
t.Fatal("expected non-nil documents, got nil")
}
if len(c.documents) != 0 {
t.Fatal("expected empty documents, got", len(c.documents))
}
// The embedding function should be the one we passed
// Check the embedding function first, then the rest with DeepEqual
gotVectors, err := c.embed(context.Background(), "test")
if err != nil {
t.Fatal("expected no error, got", err)
}
if !slices.Equal(gotVectors, vectors) {
t.Fatal("expected vectors", vectors, "got", gotVectors)
}
orig.embed, c.embed = nil, nil
if !reflect.DeepEqual(orig, c) {
t.Fatalf("expected collection %+v, got %+v", orig, c)
}
}

func TestDB_GetOrCreateCollection(t *testing.T) {
Expand All @@ -206,8 +158,7 @@ func TestDB_GetOrCreateCollection(t *testing.T) {
db := NewDB()
// Create collection so that the GetOrCreateCollection() call below only
// gets it.
// We ignore the return value. CreateCollection is tested elsewhere.
_, err := db.CreateCollection(name, metadata, embeddingFunc)
orig, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand All @@ -223,33 +174,18 @@ func TestDB_GetOrCreateCollection(t *testing.T) {
t.Fatal("expected collection, got nil")
}

// Check expectations
if c.Name != name {
t.Fatal("expected name", name, "got", c.Name)
}
// The collection's persistent dir should be empty
if c.persistDirectory != "" {
t.Fatal("expected empty persistent directory, got", c.persistDirectory)
}
// It's metadata should match
if len(c.metadata) != 1 || c.metadata["foo"] != "bar" {
t.Fatal("expected metadata", metadata, "got", c.metadata)
}
// Documents should be empty, but not nil
if c.documents == nil {
t.Fatal("expected non-nil documents, got nil")
}
if len(c.documents) != 0 {
t.Fatal("expected empty documents, got", len(c.documents))
}
// The embedding function should be the one we passed
// Check the embedding function first, then the rest with DeepEqual
gotVectors, err := c.embed(context.Background(), "test")
if err != nil {
t.Fatal("expected no error, got", err)
}
if !slices.Equal(gotVectors, vectors) {
t.Fatal("expected vectors", vectors, "got", gotVectors)
}
orig.embed, c.embed = nil, nil
if !reflect.DeepEqual(orig, c) {
t.Fatalf("expected collection %+v, got %+v", orig, c)
}
})

t.Run("Create", func(t *testing.T) {
Expand All @@ -266,32 +202,21 @@ func TestDB_GetOrCreateCollection(t *testing.T) {
}

// Check like we check CreateCollection()
if c.Name != name {
t.Fatal("expected name", name, "got", c.Name)
}
// The collection's persistent dir should be empty
if c.persistDirectory != "" {
t.Fatal("expected empty persistent directory, got", c.persistDirectory)
}
// It's metadata should match
if len(c.metadata) != 1 || c.metadata["foo"] != "bar" {
t.Fatal("expected metadata", metadata, "got", c.metadata)
}
// Documents should be empty, but not nil
if c.documents == nil {
t.Fatal("expected non-nil documents, got nil")
}
if len(c.documents) != 0 {
t.Fatal("expected empty documents, got", len(c.documents))
c2, ok := db.collections[name]
if !ok {
t.Fatal("expected collection", name, "not found")
}
// The embedding function should be the one we passed
gotVectors, err := c.embed(context.Background(), "test")
if err != nil {
t.Fatal("expected no error, got", err)
}
if !slices.Equal(gotVectors, vectors) {
t.Fatal("expected vectors", vectors, "got", gotVectors)
}
c.embed, c2.embed = nil, nil
if !reflect.DeepEqual(c, c2) {
t.Fatalf("expected collection %+v, got %+v", c, c2)
}
})
}

Expand Down
20 changes: 9 additions & 11 deletions document_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package chromem

import (
"context"
"slices"
"reflect"
"testing"
)

Expand Down Expand Up @@ -49,17 +49,15 @@ func TestDocument_New(t *testing.T) {
if err != nil {
t.Fatal("expected no error, got", err)
}
if d.ID != id {
t.Fatal("expected id", id, "got", d.ID)
// We can compare with DeepEqual after removing the embedding function
d.Embedding = nil
exp := Document{
ID: id,
Metadata: metadata,
Content: content,
}
if d.Metadata["foo"] != metadata["foo"] {
t.Fatal("expected metadata", metadata, "got", d.Metadata)
}
if !slices.Equal(d.Embedding, vectors) {
t.Fatal("expected vectors", vectors, "got", d.Embedding)
}
if d.Content != content {
t.Fatal("expected content", content, "got", d.Content)
if !reflect.DeepEqual(exp, d) {
t.Fatalf("expected %+v, got %+v", exp, d)
}
})
}
Expand Down
9 changes: 3 additions & 6 deletions persistence_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"bytes"
"encoding/gob"
"os"
"slices"
"reflect"
"testing"
)

Expand Down Expand Up @@ -44,10 +44,7 @@ func TestPersistence(t *testing.T) {
if err != nil {
t.Fatal("expected nil, got", err)
}
if res.Foo != obj.Foo {
t.Fatal("expected", obj.Foo, "got", res.Foo)
}
if slices.Compare[[]float32](res.Bar, obj.Bar) != 0 {
t.Fatal("expected", obj.Bar, "got", res.Bar)
if !reflect.DeepEqual(obj, res) {
t.Fatalf("expected %+v, got %+v", obj, res)
}
}

0 comments on commit f76265c

Please sign in to comment.