diff --git a/db.go b/db.go index 60f280d..6550b0a 100644 --- a/db.go +++ b/db.go @@ -67,11 +67,14 @@ import ( // [database/sql.OpenDB]. This can be used in place of [Register]. // It takes the same arguments as [Register], with the omission of name. func New(drv, dsn string, options ...func(*conn) error) driver.Connector { - return &TxDriver{ - dsn: dsn, - drv: drv, - conns: make(map[string]*conn), - options: options, + return &txConnector{ + driver: &TxDriver{ + dsn: dsn, + drv: drv, + conns: make(map[string]*conn), + options: options, + }, + name: "connector", } } @@ -133,23 +136,43 @@ type TxDriver struct { dsn string } +var ( + _ driver.Driver = (*TxDriver)(nil) + _ driver.DriverContext = (*TxDriver)(nil) +) + +type txConnector struct { + driver *TxDriver + name string +} + +var _ driver.Connector = (*txConnector)(nil) + // Connect satisfies the [database/sql/driver.Connector] interface. -func (d *TxDriver) Connect(context.Context) (driver.Conn, error) { +func (c *txConnector) Connect(context.Context) (driver.Conn, error) { // The DSN passed here doesn't matter, since it's only used to disambiguate // connections, but that disambiguation happens in the call to New() when // used through the driver.Connector interface. - return d.Open("connector") + return c.driver.Open(c.name) } // Driver satisfies the [database/sql/driver.Connector] interface. -func (d *TxDriver) Driver() driver.Driver { - return d +func (c *txConnector) Driver() driver.Driver { + return c.driver } func (d *TxDriver) DB() *sql.DB { return d.db } +// OpenConnector satisfies the [database/sql/driver.DriverContext] interface. +func (d *TxDriver) OpenConnector(name string) (driver.Connector, error) { + return &txConnector{ + driver: d, + name: name, + }, nil +} + func (d *TxDriver) Open(dsn string) (driver.Conn, error) { d.Lock() defer d.Unlock()