Skip to content

Commit

Permalink
Introduce interface for Collection to make it easier to test
Browse files Browse the repository at this point in the history
  • Loading branch information
tpoeppke committed Nov 25, 2024
1 parent 9505952 commit f6c8a42
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 22 deletions.
7 changes: 4 additions & 3 deletions cmd/example.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@ package main

import (
"fmt"

"github.com/jensteichert/colt"
"go.mongodb.org/mongo-driver/bson"
)

type Database struct {
Todos *colt.Collection[*Todo]
Todos colt.Collection[*Todo]
}
type Todo struct {
colt.DocWithTimestamps `bson:",inline"`
Title string `bson:"title" json:"title"`
Title string `bson:"title" json:"title"`
}

func(t *Todo) BeforeInsert() error {
func (t *Todo) BeforeInsert() error {
t.DocWithTimestamps.BeforeInsert()
fmt.Println("BeforeInsert executed")
return nil
Expand Down
37 changes: 26 additions & 11 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,26 @@ import (
"go.mongodb.org/mongo-driver/mongo/options"
)

type Collection[T Document] struct {

type Collection[T Document] interface {
CountDocuments(filter interface{}) (int64, error)
CreateIndex(keys primitive.D) error
DeleteById(id string) error
Find(filter interface{}, opts ...*options.FindOptions) ([]T, error)
FindById(id interface{}) (T, error)
FindOne(filter interface{}) (T, error)
Insert(model T) (T, error)
NewId() primitive.ObjectID
UpdateById(id string, model T) error
UpdateMany(filter interface{}, doc primitive.M) error
UpdateOne(filter interface{}, model T) error
}

type CollectionImpl[T Document] struct {
collection *mongo.Collection
}

func (repo *Collection[T]) Insert(model T) (T, error) {
func (repo *CollectionImpl[T]) Insert(model T) (T, error) {
if model.GetID() == "" {
model.SetID(model.NewID())
}
Expand All @@ -28,11 +43,11 @@ func (repo *Collection[T]) Insert(model T) (T, error) {
return model, err
}

func (repo *Collection[T]) UpdateById(id string, model T) error {
func (repo *CollectionImpl[T]) UpdateById(id string, model T) error {
return repo.UpdateOne(bson.M{"_id": id}, model)
}

func (repo *Collection[T]) UpdateOne(filter interface{}, model T) error {
func (repo *CollectionImpl[T]) UpdateOne(filter interface{}, model T) error {
if hook, ok := any(model).(BeforeUpdateHook); ok {
if err := hook.BeforeUpdate(); err != nil {
return err
Expand All @@ -43,13 +58,13 @@ func (repo *Collection[T]) UpdateOne(filter interface{}, model T) error {
return err
}

func (repo *Collection[T]) UpdateMany(filter interface{}, doc bson.M) error {
func (repo *CollectionImpl[T]) UpdateMany(filter interface{}, doc bson.M) error {
_, err := repo.collection.UpdateMany(DefaultContext(), filter, doc)
return err
}


func (repo *Collection[T]) DeleteById(id string) error {
func (repo *CollectionImpl[T]) DeleteById(id string) error {
res, err := repo.collection.DeleteOne(DefaultContext(), bson.M{"_id": id})

if err != nil {
Expand All @@ -63,18 +78,18 @@ func (repo *Collection[T]) DeleteById(id string) error {
return nil
}

func (repo *Collection[T]) FindById(id interface{}) (T, error) {
func (repo *CollectionImpl[T]) FindById(id interface{}) (T, error) {
return repo.FindOne(bson.M{"_id": id})
}

func (repo *Collection[T]) FindOne(filter interface{}) (T, error) {
func (repo *CollectionImpl[T]) FindOne(filter interface{}) (T, error) {
var target T
err := repo.collection.FindOne(DefaultContext(), filter).Decode(&target)

return target, err
}

func (repo *Collection[T]) Find(filter interface{}, opts ...*options.FindOptions) ([]T, error) {
func (repo *CollectionImpl[T]) Find(filter interface{}, opts ...*options.FindOptions) ([]T, error) {
csr, err := repo.collection.Find(DefaultContext(), filter, opts...)

var result = []T{}
Expand All @@ -85,11 +100,11 @@ func (repo *Collection[T]) Find(filter interface{}, opts ...*options.FindOptions
return result, nil
}

func (repo *Collection[T]) CountDocuments(filter interface{}) (int64, error) {
func (repo *CollectionImpl[T]) CountDocuments(filter interface{}) (int64, error) {
count, err := repo.collection.CountDocuments(DefaultContext(), filter)
return count, err
}

func (repo *Collection[T]) NewId() primitive.ObjectID {
func (repo *CollectionImpl[T]) NewId() primitive.ObjectID {
return primitive.NewObjectID()
}
4 changes: 2 additions & 2 deletions database.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,6 @@ func DefaultContext() context.Context {
return ctx
}

func GetCollection[T Document](db *Database, collectionName string) *Collection[T] {
return &Collection[T]{db.db.Collection(collectionName)}
func GetCollection[T Document](db *Database, collectionName string) Collection[T] {
return &CollectionImpl[T]{db.db.Collection(collectionName)}
}
4 changes: 2 additions & 2 deletions indexes.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ import (
"go.mongodb.org/mongo-driver/mongo"
)

func (repo *Collection[T]) CreateIndex(keys bson.D) error {
func (repo *CollectionImpl[T]) CreateIndex(keys bson.D) error {
mod := mongo.IndexModel{
Keys: keys,
Keys: keys,
Options: nil,
}
_, err := repo.collection.Indexes().CreateOne(DefaultContext(), mod)
Expand Down
9 changes: 5 additions & 4 deletions indexes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@ package colt

import (
"fmt"
"github.com/stretchr/testify/assert"
"go.mongodb.org/mongo-driver/bson"
"math/rand"
"testing"
"time"

"github.com/stretchr/testify/assert"
"go.mongodb.org/mongo-driver/bson"
)

func TestCollection_CreateIndex(t *testing.T) {
rand.Seed(time.Now().UnixNano())
mockDb.Connect("mongodb://localhost:27017/colt?readPreference=primary&directConnection=true&ssl=false", "colt")

collection := GetCollection[*testdoc](&mockDb, "testdocs")
collection := &CollectionImpl[*testdoc]{collection: mockDb.db.Collection("testdocs")}

var indxs = []interface{}{}
indexCursor, _ := collection.collection.Indexes().List(DefaultContext())
Expand All @@ -38,7 +39,7 @@ func TestCollection_CreateMultiKeyIndex(t *testing.T) {
rand.Seed(time.Now().UnixNano())
mockDb.Connect("mongodb://localhost:27017/colt?readPreference=primary&directConnection=true&ssl=false", "colt")

collection := GetCollection[*testdoc](&mockDb, "testdocs")
collection := &CollectionImpl[*testdoc]{collection: mockDb.db.Collection("testdocs")}

var indxs = []interface{}{}
indexCursor, _ := collection.collection.Indexes().List(DefaultContext())
Expand Down

0 comments on commit f6c8a42

Please sign in to comment.