Skip to content

Commit

Permalink
implement RegisterDialContext.
Browse files Browse the repository at this point in the history
  • Loading branch information
shogo82148 committed Feb 16, 2019
1 parent af9889e commit 89cc76d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
3 changes: 1 addition & 2 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,11 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
mc.parseTime = mc.cfg.ParseTime

// Connect to Server
// TODO: needs RegisterDialContext
dialsLock.RLock()
dial, ok := dials[mc.cfg.Net]
dialsLock.RUnlock()
if ok {
mc.netConn, err = dial(mc.cfg.Addr)
mc.netConn, err = dial(ctx, mc.cfg.Addr)
} else {
nd := net.Dialer{Timeout: mc.cfg.Timeout}
mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
Expand Down
19 changes: 17 additions & 2 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"database/sql"
"database/sql/driver"
"net"
pkgnet "net"
"sync"
)

Expand All @@ -32,19 +33,33 @@ type MySQLDriver struct{}
// Custom dial functions must be registered with RegisterDial
type DialFunc func(addr string) (net.Conn, error)

// DialContextFunc is a function which can be used to establish the network connection using the provided context.
// Custom dial functions must be registered with RegisterDialContext
type DialContextFunc func(ctx context.Context, addr string) (net.Conn, error)

var (
dialsLock sync.RWMutex
dials map[string]DialFunc
dials map[string]DialContextFunc
)

// RegisterDial registers a custom dial function. It can then be used by the
// network address mynet(addr), where mynet is the registered new network.
// addr is passed as a parameter to the dial function.
func RegisterDial(net string, dial DialFunc) {
dialContext := DialContextFunc(func(ctx context.Context, addr string) (pkgnet.Conn, error) {
return dial(addr)
})
RegisterDialContext(net, dialContext)
}

// RegisterDialContext registers a custom dial function. It can then be used by the
// network address mynet(addr), where mynet is the registered new network.
// addr is passed as a parameter to the dial function.
func RegisterDialContext(net string, dial DialContextFunc) {
dialsLock.Lock()
defer dialsLock.Unlock()
if dials == nil {
dials = make(map[string]DialFunc)
dials = make(map[string]DialContextFunc)
}
dials[net] = dial
}
Expand Down

0 comments on commit 89cc76d

Please sign in to comment.