Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/on connect #571

Merged
merged 2 commits into from
May 26, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ type ClusterOptions struct {

// Following options are copied from Options struct.

OnConnect func(*Conn) error

MaxRetries int
Password string

Expand Down Expand Up @@ -65,6 +67,8 @@ func (opt *ClusterOptions) clientOptions() *Options {
const disableIdleCheck = -1

return &Options{
OnConnect: opt.OnConnect,

MaxRetries: opt.MaxRetries,
Password: opt.Password,
ReadOnly: opt.ReadOnly,
Expand All @@ -77,7 +81,6 @@ func (opt *ClusterOptions) clientOptions() *Options {
PoolTimeout: opt.PoolTimeout,
IdleTimeout: opt.IdleTimeout,

// IdleCheckFrequency is not copied to disable reaper
IdleCheckFrequency: disableIdleCheck,
}
}
Expand Down Expand Up @@ -349,7 +352,7 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient {
opt: opt,
nodes: newClusterNodes(opt),
}
c.cmdable.process = c.Process
c.setProcessor(c.Process)

// Add initial nodes.
for _, addr := range opt.Addrs {
Expand Down Expand Up @@ -678,8 +681,7 @@ func (c *ClusterClient) Pipeline() Pipeliner {
pipe := Pipeline{
exec: c.pipelineExec,
}
pipe.cmdable.process = pipe.Process
pipe.statefulCmdable.process = pipe.Process
pipe.setProcessor(pipe.Process)
return &pipe
}

Expand Down Expand Up @@ -801,8 +803,7 @@ func (c *ClusterClient) TxPipeline() Pipeliner {
pipe := Pipeline{
exec: c.txPipelineExec,
}
pipe.cmdable.process = pipe.Process
pipe.statefulCmdable.process = pipe.Process
pipe.setProcessor(pipe.Process)
return &pipe
}

Expand Down
16 changes: 13 additions & 3 deletions commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ type Cmdable interface {
Pipeline() Pipeliner
Pipelined(fn func(Pipeliner) error) ([]Cmder, error)

ClientGetName() *StringCmd
Echo(message interface{}) *StringCmd
Ping() *StatusCmd
Quit() *StatusCmd
Expand Down Expand Up @@ -238,10 +239,10 @@ type Cmdable interface {
}

type StatefulCmdable interface {
Cmdable
Auth(password string) *StatusCmd
Select(index int) *StatusCmd
ClientSetName(name string) *BoolCmd
ClientGetName() *StringCmd
ReadOnly() *StatusCmd
ReadWrite() *StatusCmd
}
Expand All @@ -255,10 +256,20 @@ type cmdable struct {
process func(cmd Cmder) error
}

func (c *cmdable) setProcessor(fn func(Cmder) error) {
c.process = fn
}

type statefulCmdable struct {
cmdable
process func(cmd Cmder) error
}

func (c *statefulCmdable) setProcessor(fn func(Cmder) error) {
c.process = fn
c.cmdable.setProcessor(fn)
}

//------------------------------------------------------------------------------

func (c *statefulCmdable) Auth(password string) *StatusCmd {
Expand All @@ -280,7 +291,6 @@ func (c *cmdable) Ping() *StatusCmd {
}

func (c *cmdable) Wait(numSlaves int, timeout time.Duration) *IntCmd {

cmd := NewIntCmd("wait", numSlaves, int(timeout/time.Millisecond))
c.process(cmd)
return cmd
Expand Down Expand Up @@ -1639,7 +1649,7 @@ func (c *statefulCmdable) ClientSetName(name string) *BoolCmd {
}

// ClientGetName returns the name of the connection.
func (c *statefulCmdable) ClientGetName() *StringCmd {
func (c *cmdable) ClientGetName() *StringCmd {
cmd := NewStringCmd("client", "getname")
c.process(cmd)
return cmd
Expand Down
3 changes: 3 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ type Options struct {
// Network and Addr options.
Dialer func() (net.Conn, error)

// Hook that is called when new connection is established.
OnConnect func(*Conn) error

// Optional password. Must match the password specified in the
// requirepass server configuration option.
Password string
Expand Down
2 changes: 0 additions & 2 deletions pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
type pipelineExecer func([]Cmder) error

type Pipeliner interface {
Cmdable
StatefulCmdable
Process(cmd Cmder) error
Close() error
Expand All @@ -26,7 +25,6 @@ var _ Pipeliner = (*Pipeline)(nil)
// http://redis.io/topics/pipelining. It's safe for concurrent use
// by multiple goroutines.
type Pipeline struct {
cmdable
statefulCmdable

exec pipelineExecer
Expand Down
82 changes: 65 additions & 17 deletions redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@ func (c *baseClient) String() string {
return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
}

// Options returns read-only Options that were used to create the client.
func (c *baseClient) Options() *Options {
return c.opt
}

func (c *baseClient) conn() (*pool.Conn, bool, error) {
cn, isNew, err := c.connPool.Get()
if err != nil {
Expand Down Expand Up @@ -55,13 +50,23 @@ func (c *baseClient) putConn(cn *pool.Conn, err error) bool {
func (c *baseClient) initConn(cn *pool.Conn) error {
cn.Inited = true

if c.opt.Password == "" && c.opt.DB == 0 && !c.opt.ReadOnly {
if c.opt.Password == "" &&
c.opt.DB == 0 &&
!c.opt.ReadOnly &&
c.opt.OnConnect == nil {
return nil
}

// Temp client for Auth and Select.
client := newClient(c.opt, pool.NewSingleConnPool(cn))
_, err := client.Pipelined(func(pipe Pipeliner) error {
// Temp client to initialize connection.
conn := &Conn{
baseClient: baseClient{
opt: c.opt,
connPool: pool.NewSingleConnPool(cn),
},
}
conn.setProcessor(conn.Process)

_, err := conn.Pipelined(func(pipe Pipeliner) error {
if c.opt.Password != "" {
pipe.Auth(c.opt.Password)
}
Expand All @@ -76,7 +81,14 @@ func (c *baseClient) initConn(cn *pool.Conn) error {

return nil
})
return err
if err != nil {
return err
}

if c.opt.OnConnect != nil {
return c.opt.OnConnect(conn)
}
return nil
}

func (c *baseClient) Process(cmd Cmder) error {
Expand Down Expand Up @@ -182,7 +194,7 @@ func (c *baseClient) pipelineExecer(p pipelineProcessor) pipelineExecer {
}
}

func (c *baseClient) pipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (retry bool, firstErr error) {
func (c *baseClient) pipelineProcessCmds(cn *pool.Conn, cmds []Cmder) (bool, error) {
cn.SetWriteTimeout(c.opt.WriteTimeout)
if err := writeCmd(cn, cmds...); err != nil {
setCmdsErr(cmds, err)
Expand Down Expand Up @@ -294,7 +306,7 @@ func newClient(opt *Options, pool pool.Pooler) *Client {
connPool: pool,
},
}
client.cmdable.process = client.Process
client.setProcessor(client.Process)
return &client
}

Expand All @@ -307,10 +319,15 @@ func NewClient(opt *Options) *Client {
func (c *Client) copy() *Client {
c2 := new(Client)
*c2 = *c
c2.cmdable.process = c2.Process
c2.setProcessor(c2.Process)
return c2
}

// Options returns read-only Options that were used to create the client.
func (c *Client) Options() *Options {
return c.opt
}

// PoolStats returns connection pool stats.
func (c *Client) PoolStats() *PoolStats {
s := c.connPool.Stats()
Expand All @@ -332,8 +349,7 @@ func (c *Client) Pipeline() Pipeliner {
pipe := Pipeline{
exec: c.pipelineExecer(c.pipelineProcessCmds),
}
pipe.cmdable.process = pipe.Process
pipe.statefulCmdable.process = pipe.Process
pipe.setProcessor(pipe.Process)
return &pipe
}

Expand All @@ -346,8 +362,7 @@ func (c *Client) TxPipeline() Pipeliner {
pipe := Pipeline{
exec: c.pipelineExecer(c.txPipelineProcessCmds),
}
pipe.cmdable.process = pipe.Process
pipe.statefulCmdable.process = pipe.Process
pipe.setProcessor(pipe.Process)
return &pipe
}

Expand Down Expand Up @@ -377,3 +392,36 @@ func (c *Client) PSubscribe(channels ...string) *PubSub {
}
return pubsub
}

//------------------------------------------------------------------------------

// Conn is like Client, but its pool contains single connection.
type Conn struct {
baseClient
statefulCmdable
}

func (c *Conn) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.Pipeline().pipelined(fn)
}

func (c *Conn) Pipeline() Pipeliner {
pipe := Pipeline{
exec: c.pipelineExecer(c.pipelineProcessCmds),
}
pipe.setProcessor(pipe.Process)
return &pipe
}

func (c *Conn) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) {
return c.TxPipeline().pipelined(fn)
}

// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC.
func (c *Conn) TxPipeline() Pipeliner {
pipe := Pipeline{
exec: c.pipelineExecer(c.txPipelineProcessCmds),
}
pipe.setProcessor(pipe.Process)
return &pipe
}
23 changes: 23 additions & 0 deletions redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,26 @@ var _ = Describe("Client timeout", func() {
testTimeout()
})
})

var _ = Describe("Client OnConnect", func() {
var client *redis.Client

BeforeEach(func() {
opt := redisOptions()
opt.OnConnect = func(cn *redis.Conn) error {
return cn.ClientSetName("on_connect").Err()
}

client = redis.NewClient(opt)
})

AfterEach(func() {
Expect(client.Close()).NotTo(HaveOccurred())
})

It("calls OnConnect", func() {
name, err := client.ClientGetName().Result()
Expect(err).NotTo(HaveOccurred())
Expect(name).To(Equal("on_connect"))
})
})
9 changes: 6 additions & 3 deletions ring.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ type RingOptions struct {

// Following options are copied from Options struct.

OnConnect func(*Conn) error

DB int
Password string

Expand All @@ -52,6 +54,8 @@ func (opt *RingOptions) init() {

func (opt *RingOptions) clientOptions() *Options {
return &Options{
OnConnect: opt.OnConnect,

DB: opt.DB,
Password: opt.Password,

Expand Down Expand Up @@ -148,7 +152,7 @@ func NewRing(opt *RingOptions) *Ring {

cmdsInfoOnce: new(sync.Once),
}
ring.cmdable.process = ring.Process
ring.setProcessor(ring.Process)
for name, addr := range opt.Addrs {
clopt := opt.clientOptions()
clopt.Addr = addr
Expand Down Expand Up @@ -385,8 +389,7 @@ func (c *Ring) Pipeline() Pipeliner {
pipe := Pipeline{
exec: c.pipelineExec,
}
pipe.cmdable.process = pipe.Process
pipe.statefulCmdable.process = pipe.Process
pipe.setProcessor(pipe.Process)
return &pipe
}

Expand Down
6 changes: 5 additions & 1 deletion sentinel.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ type FailoverOptions struct {

// Following options are copied from Options struct.

OnConnect func(*Conn) error

Password string
DB int

Expand All @@ -42,6 +44,8 @@ func (opt *FailoverOptions) options() *Options {
return &Options{
Addr: "FailoverClient",

OnConnect: opt.OnConnect,

DB: opt.DB,
Password: opt.Password,

Expand Down Expand Up @@ -82,7 +86,7 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client {
},
},
}
client.cmdable.process = client.Process
client.setProcessor(client.Process)

return &client
}
Expand Down
Loading