Skip to content

Commit

Permalink
Add mockable DB to enable testing database errors
Browse files Browse the repository at this point in the history
  • Loading branch information
rowanseymour committed Sep 29, 2020
1 parent fc94a2b commit 5d8b8f6
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 33 deletions.
2 changes: 1 addition & 1 deletion models/assets.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func NewOrgAssets(ctx context.Context, db *sqlx.DB, orgID OrgID, prev *OrgAssets
}

if prev == nil || refresh&RefreshGroups > 0 {
oa.groups, err = loadGroups(ctx, db, orgID)
oa.groups, err = LoadGroups(ctx, db, orgID)
if err != nil {
return nil, errors.Wrapf(err, "error loading group assets for org %d", orgID)
}
Expand Down
6 changes: 3 additions & 3 deletions models/groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ func (g *Group) Name() string { return g.g.Name }
// Query returns the query string (if any) for this group
func (g *Group) Query() string { return g.g.Query }

// loads the groups for the passed in org
func loadGroups(ctx context.Context, db sqlx.Queryer, orgID OrgID) ([]assets.Group, error) {
// LoadGroups loads the groups for the passed in org
func LoadGroups(ctx context.Context, db Queryer, orgID OrgID) ([]assets.Group, error) {
start := time.Now()

rows, err := db.Queryx(selectGroupsSQL, orgID)
rows, err := db.QueryxContext(ctx, selectGroupsSQL, orgID)
if err != nil {
return nil, errors.Wrapf(err, "error querying groups for org: %d", orgID)
}
Expand Down
70 changes: 41 additions & 29 deletions models/groups_test.go
Original file line number Diff line number Diff line change
@@ -1,38 +1,50 @@
package models
package models_test

import (
"errors"
"fmt"
"testing"

"github.com/nyaruka/gocommon/uuids"
"github.com/nyaruka/goflow/assets"
"github.com/nyaruka/mailroom/models"
"github.com/nyaruka/mailroom/testsuite"

"github.com/lib/pq"
"github.com/olivere/elastic"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestGroups(t *testing.T) {
func TestLoadGroups(t *testing.T) {
ctx := testsuite.CTX()
db := testsuite.DB()
db := testsuite.NewMockDB(testsuite.DB(), func(funcName string, call int) error {
// fail first query for groups
if funcName == "QueryxContext" && call == 0 {
return errors.New("boom")
}
return nil
})

groups, err := loadGroups(ctx, db, 1)
assert.NoError(t, err)
groups, err := models.LoadGroups(ctx, db, 1)
require.EqualError(t, err, "error querying groups for org: 1: boom")

groups, err = models.LoadGroups(ctx, db, 1)
require.NoError(t, err)

tcs := []struct {
ID GroupID
ID models.GroupID
UUID assets.GroupUUID
Name string
Query string
}{
{DoctorsGroupID, DoctorsGroupUUID, "Doctors", ""},
{TestersGroupID, TestersGroupUUID, "Testers", ""},
{models.DoctorsGroupID, models.DoctorsGroupUUID, "Doctors", ""},
{models.TestersGroupID, models.TestersGroupUUID, "Testers", ""},
}

assert.Equal(t, 2, len(groups))
for i, tc := range tcs {
group := groups[i].(*Group)
group := groups[i].(*models.Group)
assert.Equal(t, tc.UUID, group.UUID())
assert.Equal(t, tc.ID, group.ID())
assert.Equal(t, tc.Name, group.Name())
Expand All @@ -45,27 +57,27 @@ func TestDynamicGroups(t *testing.T) {
db := testsuite.DB()

// insert an event on our campaign
var eventID CampaignEventID
var eventID models.CampaignEventID
testsuite.DB().Get(&eventID,
`INSERT INTO campaigns_campaignevent(is_active, created_on, modified_on, uuid, "offset", unit, event_type, delivery_hour,
campaign_id, created_by_id, modified_by_id, flow_id, relative_to_id, start_mode)
VALUES(TRUE, NOW(), NOW(), $1, 1000, 'W', 'F', -1, $2, 1, 1, $3, $4, 'I') RETURNING id`,
uuids.New(), DoctorRemindersCampaignID, FavoritesFlowID, JoinedFieldID)
uuids.New(), models.DoctorRemindersCampaignID, models.FavoritesFlowID, models.JoinedFieldID)

// clear Cathy's value
testsuite.DB().MustExec(
`update contacts_contact set fields = fields - $2
WHERE id = $1`, CathyID, JoinedFieldUUID)
WHERE id = $1`, models.CathyID, models.JoinedFieldUUID)

// and populate Bob's
testsuite.DB().MustExec(
fmt.Sprintf(`update contacts_contact set fields = fields ||
'{"%s": { "text": "2029-09-15T12:00:00+00:00", "datetime": "2029-09-15T12:00:00+00:00" }}'::jsonb
WHERE id = $1`, JoinedFieldUUID), BobID)
WHERE id = $1`, models.JoinedFieldUUID), models.BobID)

// clear our org cache so we reload org campaigns and events
FlushCache()
org, err := GetOrgAssets(ctx, db, Org1)
models.FlushCache()
org, err := models.GetOrgAssets(ctx, db, models.Org1)
assert.NoError(t, err)

esServer := testsuite.NewMockElasticServer()
Expand Down Expand Up @@ -106,52 +118,52 @@ func TestDynamicGroups(t *testing.T) {
}
}`

cathyHit := fmt.Sprintf(contactHit, CathyID)
bobHit := fmt.Sprintf(contactHit, BobID)
cathyHit := fmt.Sprintf(contactHit, models.CathyID)
bobHit := fmt.Sprintf(contactHit, models.BobID)

tcs := []struct {
Query string
ESResponse string
ContactIDs []ContactID
EventContactIDs []ContactID
ContactIDs []models.ContactID
EventContactIDs []models.ContactID
}{
{
"cathy",
cathyHit,
[]ContactID{CathyID},
[]ContactID{},
[]models.ContactID{models.CathyID},
[]models.ContactID{},
},
{
"bob",
bobHit,
[]ContactID{BobID},
[]ContactID{BobID},
[]models.ContactID{models.BobID},
[]models.ContactID{models.BobID},
},
{
"unchanged",
bobHit,
[]ContactID{BobID},
[]ContactID{BobID},
[]models.ContactID{models.BobID},
[]models.ContactID{models.BobID},
},
}

for _, tc := range tcs {
err := UpdateGroupStatus(ctx, db, DoctorsGroupID, GroupStatusInitializing)
err := models.UpdateGroupStatus(ctx, db, models.DoctorsGroupID, models.GroupStatusInitializing)
assert.NoError(t, err)

esServer.NextResponse = tc.ESResponse
count, err := PopulateDynamicGroup(ctx, db, es, org, DoctorsGroupID, tc.Query)
count, err := models.PopulateDynamicGroup(ctx, db, es, org, models.DoctorsGroupID, tc.Query)
assert.NoError(t, err, "error populating dynamic group for: %s", tc.Query)

assert.Equal(t, count, len(tc.ContactIDs))

// assert the current group membership
contactIDs, err := ContactIDsForGroupIDs(ctx, db, []GroupID{DoctorsGroupID})
contactIDs, err := models.ContactIDsForGroupIDs(ctx, db, []models.GroupID{models.DoctorsGroupID})
assert.Equal(t, tc.ContactIDs, contactIDs)

testsuite.AssertQueryCount(t, db,
`SELECT count(*) from contacts_contactgroup WHERE id = $1 AND status = 'R'`,
[]interface{}{DoctorsGroupID}, 1, "wrong number of contacts in group for query: %s", tc.Query)
[]interface{}{models.DoctorsGroupID}, 1, "wrong number of contacts in group for query: %s", tc.Query)

testsuite.AssertQueryCount(t, db,
`SELECT count(*) from campaigns_eventfire WHERE event_id = $1`,
Expand Down
60 changes: 60 additions & 0 deletions testsuite/db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package testsuite

import (
"context"
"database/sql"

"github.com/jmoiron/sqlx"
)

type MockDB struct {
real *sqlx.DB
callCounts map[string]int
shouldErr func(funcName string, call int) error
}

func NewMockDB(db *sqlx.DB, shouldErr func(funcName string, call int) error) *MockDB {
return &MockDB{
real: db,
callCounts: make(map[string]int),
shouldErr: shouldErr,
}
}

func (d *MockDB) check(funcName string) error {
call := d.callCounts[funcName]
d.callCounts[funcName]++
return d.shouldErr(funcName, call)
}

func (d *MockDB) Rebind(query string) string {
return d.Rebind(query)
}

func (d *MockDB) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) {
if err := d.check("QueryxContext"); err != nil {
return nil, err
}
return d.real.QueryxContext(ctx, query, args...)
}

func (d *MockDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
if err := d.check("ExecContext"); err != nil {
return nil, err
}
return d.real.ExecContext(ctx, query, args...)
}

func (d *MockDB) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) {
if err := d.check("NamedExecContext"); err != nil {
return nil, err
}
return d.real.NamedExecContext(ctx, query, arg)
}

func (d *MockDB) GetContext(ctx context.Context, value interface{}, query string, args ...interface{}) error {
if err := d.check("GetContext"); err != nil {
return err
}
return d.real.GetContext(ctx, value, query, args...)
}

0 comments on commit 5d8b8f6

Please sign in to comment.