Skip to content

Commit

Permalink
go/vt/vitessdriver: implement driver.{Connector,DriverContext} (#13704)
Browse files Browse the repository at this point in the history
Signed-off-by: Matt Layher <mdlayher@planetscale.com>
  • Loading branch information
mdlayher authored Aug 7, 2023
1 parent 8a78949 commit 54f0b33
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 46 deletions.
16 changes: 10 additions & 6 deletions go/vt/vitessdriver/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,16 @@ func (cv *converter) bindVarsFromNamedValues(args []driver.NamedValue) (map[stri
return bindVars, nil
}

func newConverter(cfg *Configuration) (c *converter, err error) {
c = &converter{
location: time.UTC,
func newConverter(cfg *Configuration) (*converter, error) {
c := &converter{location: time.UTC}
if cfg.DefaultLocation == "" {
return c, nil
}
if cfg.DefaultLocation != "" {
c.location, err = time.LoadLocation(cfg.DefaultLocation)

loc, err := time.LoadLocation(cfg.DefaultLocation)
if err != nil {
return nil, err
}
return
c.location = loc
return c, nil
}
111 changes: 85 additions & 26 deletions go/vt/vitessdriver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,30 @@ var (

// Type-check interfaces.
var (
_ driver.QueryerContext = &conn{}
_ driver.ExecerContext = &conn{}
_ driver.StmtQueryContext = &stmt{}
_ driver.StmtExecContext = &stmt{}
_ interface {
driver.Connector
} = &connector{}

_ interface {
driver.Driver
driver.DriverContext
} = drv{}

_ interface {
driver.Conn
driver.ConnBeginTx
driver.ConnPrepareContext
driver.ExecerContext
driver.Pinger
driver.QueryerContext
driver.Tx
} = &conn{}

_ interface {
driver.Stmt
driver.StmtExecContext
driver.StmtQueryContext
} = &stmt{}
)

func init() {
Expand Down Expand Up @@ -94,8 +114,7 @@ func OpenWithConfiguration(c Configuration) (*sql.DB, error) {
return sql.Open(c.DriverName, json)
}

type drv struct {
}
type drv struct{}

// Open implements the database/sql/driver.Driver interface.
//
Expand All @@ -112,25 +131,65 @@ type drv struct {
//
// For a description of the available fields, see the Configuration struct.
func (d drv) Open(name string) (driver.Conn, error) {
c := &conn{}
err := json.Unmarshal([]byte(name), c)
conn, err := d.OpenConnector(name)
if err != nil {
return nil, err
}

c.setDefaults()
return conn.Connect(context.Background())
}

if c.convert, err = newConverter(&c.Configuration); err != nil {
// OpenConnector implements the database/sql/driver.DriverContext interface.
//
// See the documentation of Open for details on the format of name.
func (d drv) OpenConnector(name string) (driver.Connector, error) {
var cfg Configuration
if err := json.Unmarshal([]byte(name), &cfg); err != nil {
return nil, err
}

if err = c.dial(); err != nil {
cfg.setDefaults()
return d.newConnector(cfg)
}

// A connector holds immutable state for the creation of additional conns via
// the Connect method.
type connector struct {
drv drv
cfg Configuration
convert *converter
}

func (d drv) newConnector(cfg Configuration) (driver.Connector, error) {
convert, err := newConverter(&cfg)
if err != nil {
return nil, err
}

return c, nil
return &connector{
drv: d,
cfg: cfg,
convert: convert,
}, nil
}

// Connect implements the database/sql/driver.Connector interface.
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
conn := &conn{
cfg: c.cfg,
convert: c.convert,
}

if err := conn.dial(ctx); err != nil {
return nil, err
}

return conn, nil
}

// Driver implements the database/sql/driver.Connector interface.
func (c *connector) Driver() driver.Driver { return c.drv }

// Configuration holds all Vitess driver settings.
//
// Fields with documented default values do not have to be set explicitly.
Expand Down Expand Up @@ -202,32 +261,32 @@ func (c *Configuration) setDefaults() {
}

type conn struct {
Configuration
cfg Configuration
convert *converter
conn *vtgateconn.VTGateConn
session *vtgateconn.VTGateSession
}

func (c *conn) dial() error {
func (c *conn) dial(ctx context.Context) error {
var err error
c.conn, err = vtgateconn.DialProtocol(context.Background(), c.Protocol, c.Address)
c.conn, err = vtgateconn.DialProtocol(ctx, c.cfg.Protocol, c.cfg.Address)
if err != nil {
return err
}
if c.Configuration.SessionToken != "" {
sessionFromToken, err := sessionTokenToSession(c.Configuration.SessionToken)
if c.cfg.SessionToken != "" {
sessionFromToken, err := sessionTokenToSession(c.cfg.SessionToken)
if err != nil {
return err
}
c.session = c.conn.SessionFromPb(sessionFromToken)
} else {
c.session = c.conn.Session(c.Target, nil)
c.session = c.conn.Session(c.cfg.Target, nil)
}
return nil
}

func (c *conn) Ping(ctx context.Context) error {
if c.Streaming {
if c.cfg.Streaming {
return errors.New("Ping not allowed for streaming connections")
}

Expand Down Expand Up @@ -378,7 +437,7 @@ func sessionTokenToSession(sessionToken string) (*vtgatepb.Session, error) {

func (c *conn) Begin() (driver.Tx, error) {
// if we're loading from an existing session, we need to avoid starting a new transaction
if c.Configuration.SessionToken != "" {
if c.cfg.SessionToken != "" {
return c, nil
}

Expand All @@ -401,7 +460,7 @@ func (c *conn) Commit() error {
// if we're loading from an existing session, disallow committing/rolling back the transaction
// this isn't a technical limitation, but is enforced to prevent misuse, so that only
// the original creator of the transaction can commit/rollback
if c.Configuration.SessionToken != "" {
if c.cfg.SessionToken != "" {
return errors.New("calling Commit from a distributed tx is not allowed")
}

Expand All @@ -413,7 +472,7 @@ func (c *conn) Rollback() error {
// if we're loading from an existing session, disallow committing/rolling back the transaction
// this isn't a technical limitation, but is enforced to prevent misuse, so that only
// the original creator of the transaction can commit/rollback
if c.Configuration.SessionToken != "" {
if c.cfg.SessionToken != "" {
return errors.New("calling Rollback from a distributed tx is not allowed")
}

Expand All @@ -424,7 +483,7 @@ func (c *conn) Rollback() error {
func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) {
ctx := context.TODO()

if c.Streaming {
if c.cfg.Streaming {
return nil, errors.New("Exec not allowed for streaming connections")
}
bindVars, err := c.convert.buildBindVars(args)
Expand All @@ -440,7 +499,7 @@ func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) {
}

func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
if c.Streaming {
if c.cfg.Streaming {
return nil, errors.New("Exec not allowed for streaming connections")
}

Expand All @@ -462,7 +521,7 @@ func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
return nil, err
}

if c.Streaming {
if c.cfg.Streaming {
stream, err := c.session.StreamExecute(ctx, query, bindVars)
if err != nil {
return nil, err
Expand All @@ -488,7 +547,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
return nil, err
}

if c.Streaming {
if c.cfg.Streaming {
stream, err := c.session.StreamExecute(ctx, query, bv)
if err != nil {
return nil, err
Expand Down
Loading

0 comments on commit 54f0b33

Please sign in to comment.