Skip to content

Commit

Permalink
Add an option to distribute reads among all nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
o1egl committed May 26, 2020
1 parent 5bac59b commit ee6c7dc
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 29 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.idea
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
Expand Down
78 changes: 52 additions & 26 deletions mssqlx.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ func ping(w *wrapper) (err error) {

// DBs sqlx wrapper supports querying master-slave database connections for HA and scalability, auto-balancer integrated.
type DBs struct {
driverName string
driverName string
readFromAll bool

masters *balancer
slaves *balancer
Expand All @@ -65,6 +66,13 @@ func (dbs *DBs) getDBs(s []*wrapper) ([]*sqlx.DB, int) {
return r, n
}

func (dbs *DBs) getReadBalancer() *balancer {
if dbs.readFromAll {
return dbs.all
}
return dbs.slaves
}

// GetAllMasters get all master database connections, included failing one.
func (dbs *DBs) GetAllMasters() ([]*sqlx.DB, int) {
return dbs.getDBs(dbs._masters)
Expand Down Expand Up @@ -590,7 +598,7 @@ func _namedQuery(ctx context.Context, target *balancer, query string, arg interf
// NamedQuery do named query.
// Any named placeholder parameters are replaced with fields from arg.
func (dbs *DBs) NamedQuery(query string, arg interface{}) (*sqlx.Rows, error) {
return _namedQuery(context.Background(), dbs.slaves, query, arg)
return _namedQuery(context.Background(), dbs.getReadBalancer(), query, arg)
}

// NamedQueryOnMaster do named query on master.
Expand All @@ -602,7 +610,7 @@ func (dbs *DBs) NamedQueryOnMaster(query string, arg interface{}) (*sqlx.Rows, e
// NamedQueryContext do named query with context.
// Any named placeholder parameters are replaced with fields from arg.
func (dbs *DBs) NamedQueryContext(ctx context.Context, query string, arg interface{}) (*sqlx.Rows, error) {
return _namedQuery(ctx, dbs.slaves, query, arg)
return _namedQuery(ctx, dbs.getReadBalancer(), query, arg)
}

// NamedQueryContextOnMaster do named query with context on master.
Expand Down Expand Up @@ -699,7 +707,7 @@ func _query(ctx context.Context, target *balancer, query string, args ...interfa
// Query executes a query on slaves that returns rows, typically a SELECT.
// The args are for any placeholder parameters in the query.
func (dbs *DBs) Query(query string, args ...interface{}) (r *sql.Rows, err error) {
_, r, err = _query(context.Background(), dbs.slaves, query, args...)
_, r, err = _query(context.Background(), dbs.getReadBalancer(), query, args...)
return
}

Expand All @@ -713,7 +721,7 @@ func (dbs *DBs) QueryOnMaster(query string, args ...interface{}) (r *sql.Rows, e
// QueryContext executes a query on slaves that returns rows, typically a SELECT.
// The args are for any placeholder parameters in the query.
func (dbs *DBs) QueryContext(ctx context.Context, query string, args ...interface{}) (r *sql.Rows, err error) {
_, r, err = _query(ctx, dbs.slaves, query, args...)
_, r, err = _query(ctx, dbs.getReadBalancer(), query, args...)
return
}

Expand Down Expand Up @@ -758,7 +766,7 @@ func _queryx(ctx context.Context, target *balancer, query string, args ...interf
// Queryx executes a query on slaves that returns rows, typically a SELECT.
// The args are for any placeholder parameters in the query.
func (dbs *DBs) Queryx(query string, args ...interface{}) (r *sqlx.Rows, err error) {
_, r, err = _queryx(context.Background(), dbs.slaves, query, args...)
_, r, err = _queryx(context.Background(), dbs.getReadBalancer(), query, args...)
return
}

Expand All @@ -772,7 +780,7 @@ func (dbs *DBs) QueryxOnMaster(query string, args ...interface{}) (r *sqlx.Rows,
// QueryxContext executes a query on slaves that returns rows, typically a SELECT.
// The args are for any placeholder parameters in the query.
func (dbs *DBs) QueryxContext(ctx context.Context, query string, args ...interface{}) (r *sqlx.Rows, err error) {
_, r, err = _queryx(ctx, dbs.slaves, query, args...)
_, r, err = _queryx(ctx, dbs.getReadBalancer(), query, args...)
return
}

Expand Down Expand Up @@ -801,7 +809,7 @@ func _queryRow(ctx context.Context, target *balancer, query string, args ...inte
// QueryRow always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
func (dbs *DBs) QueryRow(query string, args ...interface{}) (r *sql.Row, err error) {
_, r, err = _queryRow(context.Background(), dbs.slaves, query, args...)
_, r, err = _queryRow(context.Background(), dbs.getReadBalancer(), query, args...)
return
}

Expand All @@ -817,7 +825,7 @@ func (dbs *DBs) QueryRowOnMaster(query string, args ...interface{}) (r *sql.Row,
// QueryRow always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
func (dbs *DBs) QueryRowContext(ctx context.Context, query string, args ...interface{}) (r *sql.Row, err error) {
_, r, err = _queryRow(ctx, dbs.slaves, query, args...)
_, r, err = _queryRow(ctx, dbs.getReadBalancer(), query, args...)
return
}

Expand Down Expand Up @@ -848,7 +856,7 @@ func _queryRowx(ctx context.Context, target *balancer, query string, args ...int
// QueryRow always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
func (dbs *DBs) QueryRowx(query string, args ...interface{}) (r *sqlx.Row, err error) {
_, r, err = _queryRowx(context.Background(), dbs.slaves, query, args...)
_, r, err = _queryRowx(context.Background(), dbs.getReadBalancer(), query, args...)
return
}

Expand All @@ -864,7 +872,7 @@ func (dbs *DBs) QueryRowxOnMaster(query string, args ...interface{}) (r *sqlx.Ro
// QueryRow always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
func (dbs *DBs) QueryRowxContext(ctx context.Context, query string, args ...interface{}) (r *sqlx.Row, err error) {
_, r, err = _queryRowx(ctx, dbs.slaves, query, args...)
_, r, err = _queryRowx(ctx, dbs.getReadBalancer(), query, args...)
return
}

Expand Down Expand Up @@ -904,7 +912,7 @@ func _select(ctx context.Context, target *balancer, dest interface{}, query stri
// Select do select on slaves.
// Any placeholder parameters are replaced with supplied args.
func (dbs *DBs) Select(dest interface{}, query string, args ...interface{}) (err error) {
_, err = _select(context.Background(), dbs.slaves, dest, query, args...)
_, err = _select(context.Background(), dbs.getReadBalancer(), dest, query, args...)
return
}

Expand All @@ -918,7 +926,7 @@ func (dbs *DBs) SelectOnMaster(dest interface{}, query string, args ...interface
// SelectContext do select on slaves with context.
// Any placeholder parameters are replaced with supplied args.
func (dbs *DBs) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
_, err = _select(ctx, dbs.slaves, dest, query, args...)
_, err = _select(ctx, dbs.getReadBalancer(), dest, query, args...)
return
}

Expand Down Expand Up @@ -974,7 +982,7 @@ func (dbs *DBs) GetOnMaster(dest interface{}, query string, args ...interface{})
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
func (dbs *DBs) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
_, err = _get(ctx, dbs.slaves, dest, query, args...)
_, err = _get(ctx, dbs.getReadBalancer(), dest, query, args...)
return
}

Expand Down Expand Up @@ -1437,12 +1445,32 @@ func (dbs *DBs) BeginTxx(ctx context.Context, opts *sql.TxOptions) (res *sqlx.Tx
}
}

type clusterArgs struct {
isWsrep bool
readFromAll bool
}

type Option func(*clusterArgs)

// IsWsrep indicates galera/wsrep cluster
func IsWsrep() Option {
return func(o *clusterArgs) {
o.isWsrep = true
}
}

// ReadFromAll sets flag to send read requests to masters and slaves
func ReadFromAll() Option {
return func(o *clusterArgs) {
o.readFromAll = true
}
}

// ConnectMasterSlaves to master-slave databases, healthchecks will ensure they are working
// driverName: mysql, postgres, etc.
// masterDSNs: data source names of Masters.
// slaveDSNs: data source names of Slaves.
// args: args[0] = true to indicates galera/wsrep cluster.
func ConnectMasterSlaves(driverName string, masterDSNs []string, slaveDSNs []string, args ...interface{}) (*DBs, []error) {
func ConnectMasterSlaves(driverName string, masterDSNs []string, slaveDSNs []string, options ...Option) (*DBs, []error) {
// Validate slave address
if slaveDSNs == nil {
slaveDSNs = []string{}
Expand All @@ -1452,12 +1480,9 @@ func ConnectMasterSlaves(driverName string, masterDSNs []string, slaveDSNs []str
masterDSNs = []string{}
}

isWsrep := false
if len(args) > 0 {
switch args[0].(type) {
case bool:
isWsrep = args[0].(bool)
}
args := &clusterArgs{}
for _, opt := range options {
opt(args)
}

nMaster := len(masterDSNs)
Expand All @@ -1466,15 +1491,16 @@ func ConnectMasterSlaves(driverName string, masterDSNs []string, slaveDSNs []str

errResult := make([]error, nAll)
dbs := &DBs{
driverName: driverName,
driverName: driverName,
readFromAll: args.readFromAll,

masters: newBalancer(nil, nMaster>>2, nMaster, isWsrep),
masters: newBalancer(nil, nMaster>>2, nMaster, args.isWsrep),
_masters: make([]*wrapper, nMaster),

slaves: newBalancer(nil, nSlave>>2, nSlave, isWsrep),
slaves: newBalancer(nil, nSlave>>2, nSlave, args.isWsrep),
_slaves: make([]*wrapper, nSlave),

all: newBalancer(nil, nAll>>2, nAll, isWsrep),
all: newBalancer(nil, nAll>>2, nAll, args.isWsrep),
_all: make([]*wrapper, nAll),
}

Expand Down
6 changes: 3 additions & 3 deletions mssqlx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ func TestConnectMasterSlave(t *testing.T) {
t.Fatal("DestroySlave fail")
}

db, _ = ConnectMasterSlaves(driver, masterDSNs, slaveDSNs, true)
db, _ = ConnectMasterSlaves(driver, masterDSNs, slaveDSNs, IsWsrep())
if _, c := db.GetAllMasters(); c != 3 {
t.Fatal("Initialize master slave fail")
}
Expand All @@ -454,12 +454,12 @@ func TestConnectMasterSlave(t *testing.T) {
t.Fatal("Destroy fail")
}

db, _ = ConnectMasterSlaves(driver, nil, slaveDSNs, true)
db, _ = ConnectMasterSlaves(driver, nil, slaveDSNs, IsWsrep())
if _, c := db.GetAllMasters(); c != 0 {
t.Fatal("Initialize master slave fail")
}

db, _ = ConnectMasterSlaves(driver, nil, nil, true)
db, _ = ConnectMasterSlaves(driver, nil, nil, IsWsrep())
if _, c := db.GetAllSlaves(); c != 0 {
t.Fatal("Initialize master slave fail")
}
Expand Down

0 comments on commit ee6c7dc

Please sign in to comment.