From 5d8b8f62c7262f215def5f9b722701af1ff569f1 Mon Sep 17 00:00:00 2001 From: Rowan Seymour Date: Tue, 29 Sep 2020 15:47:43 -0500 Subject: [PATCH] Add mockable DB to enable testing database errors --- models/assets.go | 2 +- models/groups.go | 6 ++-- models/groups_test.go | 70 +++++++++++++++++++++++++------------------ testsuite/db.go | 60 +++++++++++++++++++++++++++++++++++++ 4 files changed, 105 insertions(+), 33 deletions(-) create mode 100644 testsuite/db.go diff --git a/models/assets.go b/models/assets.go index 86c55e4c4..a686211ad 100644 --- a/models/assets.go +++ b/models/assets.go @@ -154,7 +154,7 @@ func NewOrgAssets(ctx context.Context, db *sqlx.DB, orgID OrgID, prev *OrgAssets } if prev == nil || refresh&RefreshGroups > 0 { - oa.groups, err = loadGroups(ctx, db, orgID) + oa.groups, err = LoadGroups(ctx, db, orgID) if err != nil { return nil, errors.Wrapf(err, "error loading group assets for org %d", orgID) } diff --git a/models/groups.go b/models/groups.go index c41c9da94..0190d3366 100644 --- a/models/groups.go +++ b/models/groups.go @@ -49,11 +49,11 @@ func (g *Group) Name() string { return g.g.Name } // Query returns the query string (if any) for this group func (g *Group) Query() string { return g.g.Query } -// loads the groups for the passed in org -func loadGroups(ctx context.Context, db sqlx.Queryer, orgID OrgID) ([]assets.Group, error) { +// LoadGroups loads the groups for the passed in org +func LoadGroups(ctx context.Context, db Queryer, orgID OrgID) ([]assets.Group, error) { start := time.Now() - rows, err := db.Queryx(selectGroupsSQL, orgID) + rows, err := db.QueryxContext(ctx, selectGroupsSQL, orgID) if err != nil { return nil, errors.Wrapf(err, "error querying groups for org: %d", orgID) } diff --git a/models/groups_test.go b/models/groups_test.go index d78f6a96b..f5fdc1d9d 100644 --- a/models/groups_test.go +++ b/models/groups_test.go @@ -1,38 +1,50 @@ -package models +package models_test import ( + "errors" "fmt" "testing" "github.com/nyaruka/gocommon/uuids" "github.com/nyaruka/goflow/assets" + "github.com/nyaruka/mailroom/models" "github.com/nyaruka/mailroom/testsuite" "github.com/lib/pq" "github.com/olivere/elastic" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func TestGroups(t *testing.T) { +func TestLoadGroups(t *testing.T) { ctx := testsuite.CTX() - db := testsuite.DB() + db := testsuite.NewMockDB(testsuite.DB(), func(funcName string, call int) error { + // fail first query for groups + if funcName == "QueryxContext" && call == 0 { + return errors.New("boom") + } + return nil + }) - groups, err := loadGroups(ctx, db, 1) - assert.NoError(t, err) + groups, err := models.LoadGroups(ctx, db, 1) + require.EqualError(t, err, "error querying groups for org: 1: boom") + + groups, err = models.LoadGroups(ctx, db, 1) + require.NoError(t, err) tcs := []struct { - ID GroupID + ID models.GroupID UUID assets.GroupUUID Name string Query string }{ - {DoctorsGroupID, DoctorsGroupUUID, "Doctors", ""}, - {TestersGroupID, TestersGroupUUID, "Testers", ""}, + {models.DoctorsGroupID, models.DoctorsGroupUUID, "Doctors", ""}, + {models.TestersGroupID, models.TestersGroupUUID, "Testers", ""}, } assert.Equal(t, 2, len(groups)) for i, tc := range tcs { - group := groups[i].(*Group) + group := groups[i].(*models.Group) assert.Equal(t, tc.UUID, group.UUID()) assert.Equal(t, tc.ID, group.ID()) assert.Equal(t, tc.Name, group.Name()) @@ -45,27 +57,27 @@ func TestDynamicGroups(t *testing.T) { db := testsuite.DB() // insert an event on our campaign - var eventID CampaignEventID + var eventID models.CampaignEventID testsuite.DB().Get(&eventID, `INSERT INTO campaigns_campaignevent(is_active, created_on, modified_on, uuid, "offset", unit, event_type, delivery_hour, campaign_id, created_by_id, modified_by_id, flow_id, relative_to_id, start_mode) VALUES(TRUE, NOW(), NOW(), $1, 1000, 'W', 'F', -1, $2, 1, 1, $3, $4, 'I') RETURNING id`, - uuids.New(), DoctorRemindersCampaignID, FavoritesFlowID, JoinedFieldID) + uuids.New(), models.DoctorRemindersCampaignID, models.FavoritesFlowID, models.JoinedFieldID) // clear Cathy's value testsuite.DB().MustExec( `update contacts_contact set fields = fields - $2 - WHERE id = $1`, CathyID, JoinedFieldUUID) + WHERE id = $1`, models.CathyID, models.JoinedFieldUUID) // and populate Bob's testsuite.DB().MustExec( fmt.Sprintf(`update contacts_contact set fields = fields || '{"%s": { "text": "2029-09-15T12:00:00+00:00", "datetime": "2029-09-15T12:00:00+00:00" }}'::jsonb - WHERE id = $1`, JoinedFieldUUID), BobID) + WHERE id = $1`, models.JoinedFieldUUID), models.BobID) // clear our org cache so we reload org campaigns and events - FlushCache() - org, err := GetOrgAssets(ctx, db, Org1) + models.FlushCache() + org, err := models.GetOrgAssets(ctx, db, models.Org1) assert.NoError(t, err) esServer := testsuite.NewMockElasticServer() @@ -106,52 +118,52 @@ func TestDynamicGroups(t *testing.T) { } }` - cathyHit := fmt.Sprintf(contactHit, CathyID) - bobHit := fmt.Sprintf(contactHit, BobID) + cathyHit := fmt.Sprintf(contactHit, models.CathyID) + bobHit := fmt.Sprintf(contactHit, models.BobID) tcs := []struct { Query string ESResponse string - ContactIDs []ContactID - EventContactIDs []ContactID + ContactIDs []models.ContactID + EventContactIDs []models.ContactID }{ { "cathy", cathyHit, - []ContactID{CathyID}, - []ContactID{}, + []models.ContactID{models.CathyID}, + []models.ContactID{}, }, { "bob", bobHit, - []ContactID{BobID}, - []ContactID{BobID}, + []models.ContactID{models.BobID}, + []models.ContactID{models.BobID}, }, { "unchanged", bobHit, - []ContactID{BobID}, - []ContactID{BobID}, + []models.ContactID{models.BobID}, + []models.ContactID{models.BobID}, }, } for _, tc := range tcs { - err := UpdateGroupStatus(ctx, db, DoctorsGroupID, GroupStatusInitializing) + err := models.UpdateGroupStatus(ctx, db, models.DoctorsGroupID, models.GroupStatusInitializing) assert.NoError(t, err) esServer.NextResponse = tc.ESResponse - count, err := PopulateDynamicGroup(ctx, db, es, org, DoctorsGroupID, tc.Query) + count, err := models.PopulateDynamicGroup(ctx, db, es, org, models.DoctorsGroupID, tc.Query) assert.NoError(t, err, "error populating dynamic group for: %s", tc.Query) assert.Equal(t, count, len(tc.ContactIDs)) // assert the current group membership - contactIDs, err := ContactIDsForGroupIDs(ctx, db, []GroupID{DoctorsGroupID}) + contactIDs, err := models.ContactIDsForGroupIDs(ctx, db, []models.GroupID{models.DoctorsGroupID}) assert.Equal(t, tc.ContactIDs, contactIDs) testsuite.AssertQueryCount(t, db, `SELECT count(*) from contacts_contactgroup WHERE id = $1 AND status = 'R'`, - []interface{}{DoctorsGroupID}, 1, "wrong number of contacts in group for query: %s", tc.Query) + []interface{}{models.DoctorsGroupID}, 1, "wrong number of contacts in group for query: %s", tc.Query) testsuite.AssertQueryCount(t, db, `SELECT count(*) from campaigns_eventfire WHERE event_id = $1`, diff --git a/testsuite/db.go b/testsuite/db.go new file mode 100644 index 000000000..351c90380 --- /dev/null +++ b/testsuite/db.go @@ -0,0 +1,60 @@ +package testsuite + +import ( + "context" + "database/sql" + + "github.com/jmoiron/sqlx" +) + +type MockDB struct { + real *sqlx.DB + callCounts map[string]int + shouldErr func(funcName string, call int) error +} + +func NewMockDB(db *sqlx.DB, shouldErr func(funcName string, call int) error) *MockDB { + return &MockDB{ + real: db, + callCounts: make(map[string]int), + shouldErr: shouldErr, + } +} + +func (d *MockDB) check(funcName string) error { + call := d.callCounts[funcName] + d.callCounts[funcName]++ + return d.shouldErr(funcName, call) +} + +func (d *MockDB) Rebind(query string) string { + return d.Rebind(query) +} + +func (d *MockDB) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) { + if err := d.check("QueryxContext"); err != nil { + return nil, err + } + return d.real.QueryxContext(ctx, query, args...) +} + +func (d *MockDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + if err := d.check("ExecContext"); err != nil { + return nil, err + } + return d.real.ExecContext(ctx, query, args...) +} + +func (d *MockDB) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) { + if err := d.check("NamedExecContext"); err != nil { + return nil, err + } + return d.real.NamedExecContext(ctx, query, arg) +} + +func (d *MockDB) GetContext(ctx context.Context, value interface{}, query string, args ...interface{}) error { + if err := d.check("GetContext"); err != nil { + return err + } + return d.real.GetContext(ctx, value, query, args...) +}