Skip to content

Commit

Permalink
Merge branch 'refactor/collection'
Browse files Browse the repository at this point in the history
  • Loading branch information
victorguarana committed May 1, 2024
2 parents f15e6bb + 39150df commit 8adc9dd
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 51 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ The code bellow sets up a connection to a MongoDB database, creates a collection
package main

import (
"context"
"fmt"
"github.com/victorguarana/gomongo"
"time"

"github.com/victorguarana/gomongo/gomongo"
)

type Movie struct {
Expand Down
49 changes: 26 additions & 23 deletions gomongo/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ import (
"go.mongodb.org/mongo-driver/mongo"
)

// Collection should always implement ICollection
var _ ICollection[any] = Collection[any]{}

var (
ErrEmptyID = errors.New("id can not be nil")
ErrConnectionNotInitialized = errors.New("connection was not initialized")
Expand All @@ -35,7 +38,7 @@ type Index struct {
Name string
}

type Collection[T any] interface {
type ICollection[T any] interface {
All(ctx context.Context) ([]T, error)
Create(ctx context.Context, doc T) (ID, error)
Count(ctx context.Context) (int, error)
Expand All @@ -59,40 +62,40 @@ type Collection[T any] interface {
Name() string
}

type collection[T any] struct {
type Collection[T any] struct {
mongoCollection *mongo.Collection
}

func NewCollection[T any](database *Database, collectionName string) (Collection[T], error) {
func NewCollection[T any](database Database, collectionName string) (Collection[T], error) {
if err := validateDatabase(database); err != nil {
return nil, ErrConnectionNotInitialized
return Collection[T]{}, ErrConnectionNotInitialized
}

return &collection[T]{
return Collection[T]{
mongoCollection: database.mongoDatabase.Collection(collectionName),
}, nil
}

// All returns all objects of a collection
func (c *collection[T]) All(ctx context.Context) ([]T, error) {
func (c Collection[T]) All(ctx context.Context) ([]T, error) {
emptyFilter := bson.M{}
emptyOrder := map[string]OrderBy{}
return where[T](ctx, c.mongoCollection, emptyFilter, emptyOrder)
}

// Count returns the number of objects of a collection
func (c *collection[T]) Count(ctx context.Context) (int, error) {
func (c Collection[T]) Count(ctx context.Context) (int, error) {
emptyFilter := bson.M{}
return count(ctx, c.mongoCollection, emptyFilter)
}

// Create inserts a new object into a collection and returns the id of the inserted document
func (c *collection[T]) Create(ctx context.Context, instance T) (ID, error) {
func (c Collection[T]) Create(ctx context.Context, instance T) (ID, error) {
return create(ctx, c.mongoCollection, instance)
}

// DeleteID deletes an object of a collection by id
func (c *collection[T]) DeleteID(ctx context.Context, id ID) error {
func (c Collection[T]) DeleteID(ctx context.Context, id ID) error {
if err := validateReceivedID(id); err != nil {
return err
}
Expand All @@ -102,7 +105,7 @@ func (c *collection[T]) DeleteID(ctx context.Context, id ID) error {
}

// FindID returns an object of a collection by id
func (c *collection[T]) FindID(ctx context.Context, id ID) (T, error) {
func (c Collection[T]) FindID(ctx context.Context, id ID) (T, error) {
if err := validateReceivedID(id); err != nil {
var t T
return t, err
Expand All @@ -114,42 +117,42 @@ func (c *collection[T]) FindID(ctx context.Context, id ID) (T, error) {
}

// FindOne returns an object of a collection by filter
func (c *collection[T]) FindOne(ctx context.Context, filter any) (T, error) {
func (c Collection[T]) FindOne(ctx context.Context, filter any) (T, error) {
filter = validateReceivedFilter(filter)
emptyOrder := map[string]OrderBy{}
return findOne[T](ctx, c.mongoCollection, filter, emptyOrder)
}

// First returns the first object of a collection in natural order
func (c *collection[T]) First(ctx context.Context) (T, error) {
func (c Collection[T]) First(ctx context.Context) (T, error) {
emptyFilter := bson.M{}
emptyOrder := map[string]OrderBy{}
return findOne[T](ctx, c.mongoCollection, emptyFilter, emptyOrder)
}

// FirstInserted returns the first object of a collection ordered by id
func (c *collection[T]) FirstInserted(ctx context.Context, filter any) (T, error) {
func (c Collection[T]) FirstInserted(ctx context.Context, filter any) (T, error) {
filter = validateReceivedFilter(filter)
order := map[string]OrderBy{"_id": OrderAsc}
return findOne[T](ctx, c.mongoCollection, filter, order)
}

// Last returns the last object of a collection in natural order
func (c *collection[T]) Last(ctx context.Context) (T, error) {
func (c Collection[T]) Last(ctx context.Context) (T, error) {
emptyFilter := bson.M{}
order := map[string]OrderBy{"$natural": OrderDesc}
return findOne[T](ctx, c.mongoCollection, emptyFilter, order)
}

// LastInserted returns the last object of a collection ordered by id
func (c *collection[T]) LastInserted(ctx context.Context, filter any) (T, error) {
func (c Collection[T]) LastInserted(ctx context.Context, filter any) (T, error) {
filter = validateReceivedFilter(filter)
order := map[string]OrderBy{"_id": OrderDesc}
return findOne[T](ctx, c.mongoCollection, filter, order)
}

// Update updates an object of a collection by id
func (c *collection[T]) UpdateID(ctx context.Context, id ID, instance T) error {
func (c Collection[T]) UpdateID(ctx context.Context, id ID, instance T) error {
if err := validateReceivedID(id); err != nil {
return err
}
Expand All @@ -159,14 +162,14 @@ func (c *collection[T]) UpdateID(ctx context.Context, id ID, instance T) error {
}

// Where returns all objects of a collection by filter
func (c *collection[T]) Where(ctx context.Context, filter any) ([]T, error) {
func (c Collection[T]) Where(ctx context.Context, filter any) ([]T, error) {
filter = validateReceivedFilter(filter)
emptyOrder := map[string]OrderBy{}
return where[T](ctx, c.mongoCollection, filter, emptyOrder)
}

// WhereWithOrder returns all objects of a collection by filter and order
func (c *collection[T]) WhereWithOrder(ctx context.Context, filter any, order map[string]OrderBy) ([]T, error) {
func (c Collection[T]) WhereWithOrder(ctx context.Context, filter any, order map[string]OrderBy) ([]T, error) {
filter = validateReceivedFilter(filter)
order, err := validateReceivedOrder(order)
if err != nil {
Expand All @@ -175,7 +178,7 @@ func (c *collection[T]) WhereWithOrder(ctx context.Context, filter any, order ma
return where[T](ctx, c.mongoCollection, filter, order)
}

func (c *collection[T]) CreateUniqueIndex(ctx context.Context, index Index) error {
func (c Collection[T]) CreateUniqueIndex(ctx context.Context, index Index) error {
if err := validateReceivedIndex(index); err != nil {
return err
}
Expand All @@ -184,22 +187,22 @@ func (c *collection[T]) CreateUniqueIndex(ctx context.Context, index Index) erro
}

// ListIndexes returns all indexes of a collection
func (c *collection[T]) ListIndexes(ctx context.Context) ([]Index, error) {
func (c Collection[T]) ListIndexes(ctx context.Context) ([]Index, error) {
return listIndexes(ctx, c.mongoCollection)
}

// DeleteIndex deletes an index of a collection
func (c *collection[T]) DeleteIndex(ctx context.Context, indexName string) error {
func (c Collection[T]) DeleteIndex(ctx context.Context, indexName string) error {
return deleteIndex(ctx, c.mongoCollection, indexName)
}

// Drop deletes a collection
func (c *collection[T]) Drop(ctx context.Context) error {
func (c Collection[T]) Drop(ctx context.Context) error {
return drop(ctx, c.mongoCollection)
}

// Name returns the name of a collection
func (c *collection[T]) Name() string {
func (c Collection[T]) Name() string {
return c.mongoCollection.Name()
}

Expand Down
38 changes: 19 additions & 19 deletions gomongo/collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ var _ = Describe("NewCollection", Ordered, func() {
mongodbContainerURI string
mongodbContainer *mongodb.MongoDBContainer

gomongoDatabase *Database
gomongoDatabase Database
)

BeforeAll(func() {
Expand All @@ -73,19 +73,11 @@ var _ = Describe("NewCollection", Ordered, func() {
})
})

Context("when database is nil", func() {
It("should return error", func() {
receivedCollection, receivedErr := NewCollection[DummyStruct](nil, collectionName)
Expect(receivedErr).To(MatchError(ErrConnectionNotInitialized))
Expect(receivedCollection).To(BeNil())
})
})

Context("when database is not initialized", func() {
It("should return error", func() {
receivedCollection, receivedErr := NewCollection[DummyStruct](&Database{}, collectionName)
receivedCollection, receivedErr := NewCollection[DummyStruct](Database{}, collectionName)
Expect(receivedErr).To(MatchError(ErrConnectionNotInitialized))
Expect(receivedCollection).To(BeNil())
Expect(receivedCollection).To(Equal(Collection[DummyStruct]{}))
})
})

Expand All @@ -102,19 +94,19 @@ var _ = Describe("NewCollection", Ordered, func() {
})
})

var _ = Describe("collection{}", Ordered, func() {
var _ = Describe("Collection{}", Ordered, func() {
var (
databaseName = "database_test"
collectionName = "collection_test"

mongodbContainerURI string
mongodbContainer *mongodb.MongoDBContainer

sut collection[DummyStruct]
err error
sut Collection[DummyStruct]
)

BeforeAll(func() {
var err error
mongodbContainer, mongodbContainerURI = runMongoContainer(context.Background())
sut, err = initializeCollection(context.Background(), mongodbContainerURI, databaseName, collectionName)
if err != nil {
Expand All @@ -140,6 +132,7 @@ var _ = Describe("collection{}", Ordered, func() {

BeforeAll(func() {
By("populating with Create")
var err error
dummiesCount := randomIntBetween(10, 20)
dummies, err = populateCollectionWithManyFakeDocuments(sut, dummiesCount)
if err != nil {
Expand Down Expand Up @@ -409,6 +402,7 @@ var _ = Describe("collection{}", Ordered, func() {

BeforeAll(func() {
By("populating with Create")
var err error
dummiesCount := randomIntBetween(10, 20)
dummies, err = populateCollectionWithManyFakeDocuments(sut, dummiesCount)
if err != nil {
Expand Down Expand Up @@ -502,6 +496,7 @@ var _ = Describe("collection{}", Ordered, func() {

BeforeAll(func() {
By("populating with Create")
var err error
dummiesCount := randomIntBetween(10, 20)
dummies, err = populateCollectionWithManyFakeDocuments(sut, dummiesCount)
if err != nil {
Expand Down Expand Up @@ -667,6 +662,7 @@ var _ = Describe("collection{}", Ordered, func() {

BeforeAll(func() {
By("populating with Create")
var err error
dummiesCount := randomIntBetween(10, 20)
dummies, err = populateCollectionWithManyFakeDocuments(sut, dummiesCount)
if err != nil {
Expand Down Expand Up @@ -848,6 +844,7 @@ var _ = Describe("collection{}", Ordered, func() {

BeforeAll(func() {
By("populating with Create")
var err error
dummiesCount := randomIntBetween(10, 20)
dummies, err = populateCollectionWithManyFakeDocuments(sut, dummiesCount)
if err != nil {
Expand Down Expand Up @@ -985,6 +982,7 @@ var _ = Describe("collection{}", Ordered, func() {

BeforeAll(func() {
By("populating with Create")
var err error
dummiesCount := randomIntBetween(10, 20)
dummies, err = populateCollectionWithManyFakeDocuments(sut, dummiesCount)
if err != nil {
Expand Down Expand Up @@ -1125,6 +1123,7 @@ var _ = Describe("collection{}", Ordered, func() {

BeforeAll(func() {
By("populating with Create")
var err error
dummiesCount = randomIntBetween(10, 20)
dummies, err = populateCollectionWithManyFakeDocuments(sut, dummiesCount)
if err != nil {
Expand Down Expand Up @@ -1295,6 +1294,7 @@ var _ = Describe("collection{}", Ordered, func() {

BeforeAll(func() {
By("populating with Create")
var err error
dummiesCount := randomIntBetween(10, 20)
dummies, err = populateCollectionWithManyFakeDocuments(sut, dummiesCount)
if err != nil {
Expand Down Expand Up @@ -1700,18 +1700,18 @@ var _ = Describe("collection{}", Ordered, func() {
})
})

func initializeCollection(ctx context.Context, mongoURI, databaseName, collectionName string) (collection[DummyStruct], error) {
func initializeCollection(ctx context.Context, mongoURI, databaseName, collectionName string) (Collection[DummyStruct], error) {
gomongoDatabase, err := NewDatabase(ctx, ConnectionSettings{
URI: mongoURI,
DatabaseName: databaseName,
ConnectionTimeout: time.Second,
})

if err != nil {
return collection[DummyStruct]{}, fmt.Errorf("Could not create database: %e", err)
return Collection[DummyStruct]{}, fmt.Errorf("Could not create database: %e", err)
}

sut := collection[DummyStruct]{
sut := Collection[DummyStruct]{
mongoCollection: gomongoDatabase.mongoDatabase.Collection(collectionName),
}

Expand All @@ -1730,7 +1730,7 @@ func randomIntBetween(min, max int) int {
return rand.Intn(max-min) + min
}

func populateCollectionWithManyFakeDocuments(collection collection[DummyStruct], n int) ([]DummyStruct, error) {
func populateCollectionWithManyFakeDocuments(collection Collection[DummyStruct], n int) ([]DummyStruct, error) {
dummies, err := generateDummyStructs(n)
if err != nil {
return nil, err
Expand All @@ -1754,7 +1754,7 @@ func generateDummyStructs(n int) ([]DummyStruct, error) {
return dummies, nil
}

func insertManyInCollection(collection collection[DummyStruct], dummies []DummyStruct) error {
func insertManyInCollection(collection Collection[DummyStruct], dummies []DummyStruct) error {
for i, dummy := range dummies {
var err error
dummies[i].ID, err = collection.Create(context.Background(), dummy)
Expand Down
14 changes: 7 additions & 7 deletions gomongo/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,21 @@ type Database struct {
mongoDatabase *mongo.Database
}

func NewDatabase(ctx context.Context, cs ConnectionSettings) (*Database, error) {
func NewDatabase(ctx context.Context, cs ConnectionSettings) (Database, error) {
if err := cs.validate(); err != nil {
return nil, err
return Database{}, err
}

mongoClient, err := mongoClient(ctx, &cs)
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrGomongoCanNotConnect, err)
return Database{}, fmt.Errorf("%w: %w", ErrGomongoCanNotConnect, err)
}

if err := pingMongoServer(&cs, mongoClient, ctx); err != nil {
return nil, fmt.Errorf("%w: %w", ErrGomongoCanNotConnect, err)
return Database{}, fmt.Errorf("%w: %w", ErrGomongoCanNotConnect, err)
}

return &Database{
return Database{
mongoClient.Database(cs.DatabaseName),
}, nil
}
Expand Down Expand Up @@ -62,8 +62,8 @@ func pingMongoServer(cs *ConnectionSettings, mongoClient *mongo.Client, ctx cont
return mongoClient.Ping(ctx, nil)
}

func validateDatabase(database *Database) error {
if database == nil || database.mongoDatabase == nil {
func validateDatabase(database Database) error {
if database.mongoDatabase == nil {
return ErrConnectionNotInitialized
}

Expand Down
2 changes: 1 addition & 1 deletion gomongo/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ var _ = Describe("NewDatabase", Ordered, func() {
It("returns error", func() {
receivedDatabase, receivedErr := NewDatabase(context.Background(), connectionSettings)
Expect(receivedErr).To(MatchError(ErrGomongoCanNotConnect))
Expect(receivedDatabase).To(BeNil())
Expect(receivedDatabase).To(Equal(Database{}))
})
})
})

0 comments on commit 8adc9dd

Please sign in to comment.