diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 39c328b..1368738 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -31,7 +31,7 @@ jobs: mongodb-version: ${{ matrix.mongodb-version }} - name: Run tests - run: go test -v -covermode=count -coverprofile=coverage.out + run: go test `go list ./... | grep -v ./cmd` -v -covermode=count -coverprofile=coverage.out - name: Convert coverage.out to coverage.lcov uses: jandelgado/gcov2lcov-action@v1 diff --git a/collection.go b/collection.go index 81c6f6e..fd8222c 100644 --- a/collection.go +++ b/collection.go @@ -8,6 +8,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" ) + type Collection[T Document] struct { collection *mongo.Collection } @@ -95,6 +96,24 @@ func (repo *Collection[T]) CountDocuments(filter interface{}) (int64, error) { return count, err } +func (repo *Collection[T]) Aggregate(pipeline mongo.Pipeline, opts ...*options.AggregateOptions) ([]bson.M, error) { + csr, err := repo.collection.Aggregate(DefaultContext(), pipeline, opts...) + + var result = []bson.M{} + if err = csr.All(DefaultContext(), &result); err != nil { + return nil, err + } + + return result, nil +} + +func (repo *Collection[T]) Drop() error { + err := repo.collection.Drop(DefaultContext()) + return err +} + + func (repo *Collection[T]) NewId() primitive.ObjectID { return primitive.NewObjectID() } + diff --git a/collection_test.go b/collection_test.go index 1a0d742..bdb57c6 100644 --- a/collection_test.go +++ b/collection_test.go @@ -5,6 +5,7 @@ import ( "github.com/stretchr/testify/assert" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" "math/rand" "strings" "testing" @@ -231,8 +232,27 @@ func TestCollection_UpdateMany(t *testing.T) { // TODO func TestCollection_DeleteById(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + mockDb.Connect("mongodb://localhost:27017/colt?readPreference=primary&directConnection=true&ssl=false", "colt") + + collection := GetCollection[*testdoc](&mockDb, "testdocs") + + title := fmt.Sprint(rand.Int()) + doc := testdoc{Title: title} + + collection.Insert(&doc) + err := collection.DeleteById(doc.ID) + assert.Nil(t, err) + result, err := collection.FindById(doc.ID) + assert.Nil(t, result) + assert.NotNil(t, err) + + collection.Drop() + mockDb.Disconnect() } + + func TestCollection_CountDocuments(t *testing.T) { rand.Seed(time.Now().UnixNano()) mockDb.Connect("mongodb://localhost:27017/colt?readPreference=primary&directConnection=true&ssl=false", "colt") @@ -260,3 +280,34 @@ func TestCollection_CountDocuments(t *testing.T) { mockDb.Disconnect() } + + + +func TestCollection_Aggregate(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + mockDb.Connect("mongodb://localhost:27017/colt?readPreference=primary&directConnection=true&ssl=false", "colt") + + collection := GetCollection[*testdoc](&mockDb, "aggregatetest") + + title := fmt.Sprint(rand.Int()) + doc := testdoc{Title: title} + doc2 := testdoc{Title: title} + + collection.Insert(&doc) + collection.Insert(&doc2) + + result, err := collection.Aggregate(mongo.Pipeline{ + bson.D{ + {"$group", bson.D{ + {"_id", "$title"}, + {"count", bson.D{{"$sum", 1}}}, + }}}}) + + assert.Nil(t, err) + assert.NotNil(t, result) + assert.Equal(t, result[0]["_id"], title) + assert.Equal(t, result[0]["count"], int32(2)) + + collection.Drop() + mockDb.Disconnect() +}