Skip to content

Commit

Permalink
Merge pull request #110 from nyaruka/less_sqlx
Browse files Browse the repository at this point in the history
Less sqlx
  • Loading branch information
rowanseymour authored Aug 28, 2023
2 parents 49b0e82 + 66c4991 commit d93a4ba
Show file tree
Hide file tree
Showing 34 changed files with 116 additions and 127 deletions.
2 changes: 1 addition & 1 deletion core/goflow/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestSimulatorAirtime(t *testing.T) {
func TestSimulatorTicket(t *testing.T) {
ctx, rt := testsuite.Runtime()

ticketer, err := models.LookupTicketerByUUID(ctx, rt.DB, testdata.Mailgun.UUID)
ticketer, err := models.LookupTicketerByUUID(ctx, rt.DB.DB, testdata.Mailgun.UUID)
require.NoError(t, err)

svc, err := goflow.Simulator(rt.Config).Services().Ticket(flows.NewTicketer(ticketer))
Expand Down
8 changes: 4 additions & 4 deletions core/models/assets.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func FlushCache() {
// org assets passed in to prevent refetching locations
func NewOrgAssets(ctx context.Context, rt *runtime.Runtime, orgID OrgID, prev *OrgAssets, refresh Refresh) (*OrgAssets, error) {
// assets are immutable in mailroom so safe to load from readonly database connection
db := rt.ReadonlyDB
db := rt.ReadonlyDB.DB

// build our new assets
oa := &OrgAssets{
Expand Down Expand Up @@ -520,7 +520,7 @@ func (a *OrgAssets) FlowByUUID(flowUUID assets.FlowUUID) (assets.Flow, error) {
return a.flowByUUID[flowUUID]
},
func(ctx context.Context) (*Flow, error) {
return LoadFlowByUUID(ctx, a.rt.ReadonlyDB, a.orgID, flowUUID)
return LoadFlowByUUID(ctx, a.rt.ReadonlyDB.DB, a.orgID, flowUUID)
},
)
}
Expand All @@ -537,7 +537,7 @@ func (a *OrgAssets) FlowByName(name string) (assets.Flow, error) {
return nil
},
func(ctx context.Context) (*Flow, error) {
return LoadFlowByName(ctx, a.rt.ReadonlyDB, a.orgID, name)
return LoadFlowByName(ctx, a.rt.ReadonlyDB.DB, a.orgID, name)
},
)
}
Expand All @@ -549,7 +549,7 @@ func (a *OrgAssets) FlowByID(flowID FlowID) (*Flow, error) {
return a.flowByID[flowID]
},
func(ctx context.Context) (*Flow, error) {
return LoadFlowByID(ctx, a.rt.ReadonlyDB, a.orgID, flowID)
return LoadFlowByID(ctx, a.rt.ReadonlyDB.DB, a.orgID, flowID)
},
)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions core/models/campaigns.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,10 @@ func (e *CampaignEvent) Campaign() *Campaign { return e.campaign }
func (e *CampaignEvent) StartMode() StartMode { return e.e.StartMode }

// loadCampaigns loads all the campaigns for the passed in org
func loadCampaigns(ctx context.Context, db sqlx.Queryer, orgID OrgID) ([]*Campaign, error) {
func loadCampaigns(ctx context.Context, db *sql.DB, orgID OrgID) ([]*Campaign, error) {
start := time.Now()

rows, err := db.Queryx(selectCampaignsSQL, orgID)
rows, err := db.QueryContext(ctx, selectCampaignsSQL, orgID)
if err != nil {
return nil, errors.Wrapf(err, "error querying campaigns for org: %d", orgID)
}
Expand Down
9 changes: 5 additions & 4 deletions core/models/channels.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package models

import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"math"
Expand Down Expand Up @@ -129,8 +130,8 @@ func (c *Channel) ChannelReference() *assets.ChannelReference {

// GetChannelsByID fetches channels by ID - NOTE these are "lite" channels and only include fields for sending, and
// that this function will return deleted channels.
func GetChannelsByID(ctx context.Context, db Queryer, ids []ChannelID) ([]*Channel, error) {
rows, err := db.QueryxContext(ctx, sqlSelectChannelsByID, pq.Array(ids))
func GetChannelsByID(ctx context.Context, db *sql.DB, ids []ChannelID) ([]*Channel, error) {
rows, err := db.QueryContext(ctx, sqlSelectChannelsByID, pq.Array(ids))
if err != nil {
return nil, errors.Wrapf(err, "error querying channels by id")
}
Expand Down Expand Up @@ -166,10 +167,10 @@ WHERE
) r;`

// loadChannels loads all the channels for the passed in org
func loadChannels(ctx context.Context, db Queryer, orgID OrgID) ([]assets.Channel, error) {
func loadChannels(ctx context.Context, db *sql.DB, orgID OrgID) ([]assets.Channel, error) {
start := time.Now()

rows, err := db.QueryxContext(ctx, sqlSelectChannels, orgID)
rows, err := db.QueryContext(ctx, sqlSelectChannels, orgID)
if err != nil {
return nil, errors.Wrapf(err, "error querying channels for org: %d", orgID)
}
Expand Down
6 changes: 3 additions & 3 deletions core/models/classifiers.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ package models

import (
"context"
"database/sql"
"database/sql/driver"
"time"

"github.com/jmoiron/sqlx"
"github.com/nyaruka/gocommon/dbutil"
"github.com/nyaruka/goflow/assets"
"github.com/nyaruka/goflow/flows"
Expand Down Expand Up @@ -127,10 +127,10 @@ func (c *Classifier) AsService(cfg *runtime.Config, classifier *flows.Classifier
}

// loadClassifiers loads all the classifiers for the passed in org
func loadClassifiers(ctx context.Context, db sqlx.Queryer, orgID OrgID) ([]assets.Classifier, error) {
func loadClassifiers(ctx context.Context, db *sql.DB, orgID OrgID) ([]assets.Classifier, error) {
start := time.Now()

rows, err := db.Queryx(sqlSelectClassifiers, orgID)
rows, err := db.QueryContext(ctx, sqlSelectClassifiers, orgID)
if err != nil {
return nil, errors.Wrapf(err, "error querying classifiers for org: %d", orgID)
}
Expand Down
6 changes: 3 additions & 3 deletions core/models/contacts.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ func LoadContacts(ctx context.Context, db Queryer, oa *OrgAssets, ids []ContactI

start := time.Now()

rows, err := db.QueryxContext(ctx, sqlSelectContact, pq.Array(ids), oa.OrgID())
rows, err := db.QueryContext(ctx, sqlSelectContact, pq.Array(ids), oa.OrgID())
if err != nil {
return nil, errors.Wrap(err, "error selecting contacts")
}
Expand Down Expand Up @@ -887,7 +887,7 @@ func insertContactAndURNs(ctx context.Context, db Queryer, orgID OrgID, userID U
// set that goflow and mailroom depend on.
func URNForURN(ctx context.Context, db Queryer, oa *OrgAssets, u urns.URN) (urns.URN, error) {
urn := &ContactURN{}
rows, err := db.QueryxContext(ctx,
rows, err := db.QueryContext(ctx,
`SELECT row_to_json(r) FROM (SELECT id, scheme, path, display, auth, channel_id, priority FROM contacts_contacturn WHERE identity = $1 AND org_id = $2) r;`,
u.Identity(), oa.OrgID(),
)
Expand Down Expand Up @@ -948,7 +948,7 @@ func GetOrCreateURN(ctx context.Context, db Queryer, oa *OrgAssets, contactID Co
// but occasionally we need to load URNs one by one and this accomplishes that
func URNForID(ctx context.Context, db Queryer, oa *OrgAssets, urnID URNID) (urns.URN, error) {
urn := &ContactURN{}
rows, err := db.QueryxContext(ctx,
rows, err := db.QueryContext(ctx,
`SELECT row_to_json(r) FROM (SELECT id, scheme, path, display, auth, channel_id, priority FROM contacts_contacturn WHERE id = $1) r;`,
urnID,
)
Expand Down
6 changes: 3 additions & 3 deletions core/models/fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package models

import (
"context"
"database/sql"
"time"

"github.com/jmoiron/sqlx"
"github.com/nyaruka/gocommon/dbutil"
"github.com/nyaruka/goflow/assets"
"github.com/pkg/errors"
Expand Down Expand Up @@ -45,10 +45,10 @@ func (f *Field) Type() assets.FieldType { return f.f.Type }
func (f *Field) System() bool { return f.f.System }

// loadFields loads the assets for the passed in db
func loadFields(ctx context.Context, db sqlx.Queryer, orgID OrgID) ([]assets.Field, []assets.Field, error) {
func loadFields(ctx context.Context, db *sql.DB, orgID OrgID) ([]assets.Field, []assets.Field, error) {
start := time.Now()

rows, err := db.Queryx(sqlSelectFields, orgID)
rows, err := db.QueryContext(ctx, sqlSelectFields, orgID)
if err != nil {
return nil, nil, errors.Wrapf(err, "error querying fields for org: %d", orgID)
}
Expand Down
11 changes: 6 additions & 5 deletions core/models/flows.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package models

import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -132,24 +133,24 @@ func FlowIDForUUID(ctx context.Context, tx *sqlx.Tx, oa *OrgAssets, flowUUID ass
return flowID, err
}

func LoadFlowByUUID(ctx context.Context, db Queryer, orgID OrgID, flowUUID assets.FlowUUID) (*Flow, error) {
func LoadFlowByUUID(ctx context.Context, db *sql.DB, orgID OrgID, flowUUID assets.FlowUUID) (*Flow, error) {
return loadFlow(ctx, db, sqlSelectFlowByUUID, orgID, flowUUID)
}

func LoadFlowByName(ctx context.Context, db Queryer, orgID OrgID, name string) (*Flow, error) {
func LoadFlowByName(ctx context.Context, db *sql.DB, orgID OrgID, name string) (*Flow, error) {
return loadFlow(ctx, db, sqlSelectFlowByName, orgID, name)
}

func LoadFlowByID(ctx context.Context, db Queryer, orgID OrgID, flowID FlowID) (*Flow, error) {
func LoadFlowByID(ctx context.Context, db *sql.DB, orgID OrgID, flowID FlowID) (*Flow, error) {
return loadFlow(ctx, db, sqlSelectFlowByID, orgID, flowID)
}

// loads the flow with the passed in UUID
func loadFlow(ctx context.Context, db Queryer, sql string, orgID OrgID, arg interface{}) (*Flow, error) {
func loadFlow(ctx context.Context, db *sql.DB, sql string, orgID OrgID, arg interface{}) (*Flow, error) {
start := time.Now()
flow := &Flow{}

rows, err := db.QueryxContext(ctx, sql, orgID, arg)
rows, err := db.QueryContext(ctx, sql, orgID, arg)
if err != nil {
return nil, errors.Wrapf(err, "error querying flow by: %v", arg)
}
Expand Down
8 changes: 4 additions & 4 deletions core/models/flows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,23 +115,23 @@ func TestLoadFlows(t *testing.T) {

for _, tc := range tcs {
// test loading by UUID
dbFlow, err := models.LoadFlowByUUID(ctx, rt.DB, tc.org.ID, tc.uuid)
dbFlow, err := models.LoadFlowByUUID(ctx, rt.DB.DB, tc.org.ID, tc.uuid)
assert.NoError(t, err)
assertFlow(&tc, dbFlow)

// test loading by name
dbFlow, err = models.LoadFlowByName(ctx, rt.DB, tc.org.ID, tc.name)
dbFlow, err = models.LoadFlowByName(ctx, rt.DB.DB, tc.org.ID, tc.name)
assert.NoError(t, err)
assertFlow(&tc, dbFlow)

// test loading by ID
dbFlow, err = models.LoadFlowByID(ctx, rt.DB, tc.org.ID, tc.id)
dbFlow, err = models.LoadFlowByID(ctx, rt.DB.DB, tc.org.ID, tc.id)
assert.NoError(t, err)
assertFlow(&tc, dbFlow)
}

// test loading flow with wrong org
dbFlow, err := models.LoadFlowByID(ctx, rt.DB, testdata.Org2.ID, testdata.Favorites.ID)
dbFlow, err := models.LoadFlowByID(ctx, rt.DB.DB, testdata.Org2.ID, testdata.Favorites.ID)
assert.NoError(t, err)
assert.Nil(t, dbFlow)
}
Expand Down
7 changes: 3 additions & 4 deletions core/models/globals.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@ package models

import (
"context"
"database/sql"
"encoding/json"
"time"

"github.com/nyaruka/gocommon/dbutil"
"github.com/nyaruka/goflow/assets"

"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
Expand All @@ -32,10 +31,10 @@ func (g *Global) UnmarshalJSON(data []byte) error { return json.Unmarshal(data,
func (g *Global) MarshalJSON() ([]byte, error) { return json.Marshal(g.g) }

// loads the globals for the passed in org
func loadGlobals(ctx context.Context, db sqlx.Queryer, orgID OrgID) ([]assets.Global, error) {
func loadGlobals(ctx context.Context, db *sql.DB, orgID OrgID) ([]assets.Global, error) {
start := time.Now()

rows, err := db.Queryx(selectGlobalsSQL, orgID)
rows, err := db.QueryContext(ctx, selectGlobalsSQL, orgID)
if err != nil {
return nil, errors.Wrapf(err, "error querying globals for org: %d", orgID)
}
Expand Down
5 changes: 3 additions & 2 deletions core/models/groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package models

import (
"context"
"database/sql"
"time"

"github.com/nyaruka/gocommon/dbutil"
Expand Down Expand Up @@ -65,10 +66,10 @@ func (g *Group) Status() GroupStatus { return g.g.Status }
func (g *Group) Type() GroupType { return g.g.Type }

// LoadGroups loads the groups for the passed in org
func LoadGroups(ctx context.Context, db Queryer, orgID OrgID) ([]assets.Group, error) {
func LoadGroups(ctx context.Context, db *sql.DB, orgID OrgID) ([]assets.Group, error) {
start := time.Now()

rows, err := db.QueryxContext(ctx, selectGroupsSQL, orgID)
rows, err := db.QueryContext(ctx, selectGroupsSQL, orgID)
if err != nil {
return nil, errors.Wrapf(err, "error querying groups for org: %d", orgID)
}
Expand Down
14 changes: 1 addition & 13 deletions core/models/groups_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package models_test

import (
"errors"
"testing"

"github.com/nyaruka/goflow/assets"
Expand All @@ -15,18 +14,7 @@ import (
func TestLoadGroups(t *testing.T) {
ctx, rt := testsuite.Runtime()

db := testsuite.NewMockDB(rt.DB, func(funcName string, call int) error {
// fail first query for groups
if funcName == "QueryxContext" && call == 0 {
return errors.New("boom")
}
return nil
})

_, err := models.LoadGroups(ctx, db, testdata.Org1.ID)
require.EqualError(t, err, "error querying groups for org: 1: boom")

groups, err := models.LoadGroups(ctx, db, 1)
groups, err := models.LoadGroups(ctx, rt.DB.DB, 1)
require.NoError(t, err)

tcs := []struct {
Expand Down
2 changes: 1 addition & 1 deletion core/models/http_logs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestHTTPLogger(t *testing.T) {
},
}))

mailgun, err := models.LookupTicketerByUUID(ctx, rt.DB, testdata.Mailgun.UUID)
mailgun, err := models.LookupTicketerByUUID(ctx, rt.DB.DB, testdata.Mailgun.UUID)
require.NoError(t, err)

logger := &models.HTTPLogger{}
Expand Down
5 changes: 3 additions & 2 deletions core/models/labels.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package models

import (
"context"
"database/sql"
"time"

"github.com/nyaruka/gocommon/dbutil"
Expand Down Expand Up @@ -33,10 +34,10 @@ func (l *Label) UUID() assets.LabelUUID { return l.l.UUID }
func (l *Label) Name() string { return l.l.Name }

// loads the labels for the passed in org
func loadLabels(ctx context.Context, db sqlx.Queryer, orgID OrgID) ([]assets.Label, error) {
func loadLabels(ctx context.Context, db *sql.DB, orgID OrgID) ([]assets.Label, error) {
start := time.Now()

rows, err := db.Queryx(sqlSelectLabels, orgID)
rows, err := db.QueryContext(ctx, sqlSelectLabels, orgID)
if err != nil {
return nil, errors.Wrapf(err, "error querying labels for org: %d", orgID)
}
Expand Down
6 changes: 3 additions & 3 deletions core/models/locations.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ package models

import (
"context"
"database/sql"
"encoding/json"
"time"

"github.com/nyaruka/goflow/assets"
"github.com/nyaruka/goflow/envs"

"github.com/jmoiron/sqlx"
"github.com/lib/pq"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -45,10 +45,10 @@ func (l *Location) Aliases() []string { return l.Aliases_ }
func (l *Location) Children() []*Location { return l.Children_ }

// loadLocations loads all the locations for this org returning the root node
func loadLocations(ctx context.Context, db sqlx.Queryer, oa *OrgAssets) ([]assets.LocationHierarchy, error) {
func loadLocations(ctx context.Context, db *sql.DB, oa *OrgAssets) ([]assets.LocationHierarchy, error) {
start := time.Now()

rows, err := db.Query(loadLocationsSQL, oa.orgID)
rows, err := db.QueryContext(ctx, loadLocationsSQL, oa.orgID)
if err != nil {
return nil, errors.Wrapf(err, "error querying locations for org: %d", oa.orgID)
}
Expand Down
6 changes: 3 additions & 3 deletions core/models/orgs.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package models

import (
"context"
"database/sql"
"encoding/json"
"fmt"
"io"
Expand All @@ -10,7 +11,6 @@ import (
"path/filepath"
"time"

"github.com/jmoiron/sqlx"
"github.com/nyaruka/gocommon/dbutil"
"github.com/nyaruka/gocommon/httpx"
"github.com/nyaruka/gocommon/jsonx"
Expand Down Expand Up @@ -181,11 +181,11 @@ func orgFromAssets(sa flows.SessionAssets) *Org {
}

// LoadOrg loads the org for the passed in id, returning any error encountered
func LoadOrg(ctx context.Context, cfg *runtime.Config, db sqlx.Queryer, orgID OrgID) (*Org, error) {
func LoadOrg(ctx context.Context, cfg *runtime.Config, db *sql.DB, orgID OrgID) (*Org, error) {
start := time.Now()

org := &Org{}
rows, err := db.Queryx(selectOrgByID, orgID)
rows, err := db.QueryContext(ctx, selectOrgByID, orgID)
if err != nil {
return nil, errors.Wrapf(err, "error loading org: %d", orgID)
}
Expand Down
Loading

0 comments on commit d93a4ba

Please sign in to comment.