diff --git a/README.md b/README.md index 4702a97..50e31ea 100644 --- a/README.md +++ b/README.md @@ -24,33 +24,56 @@ Given, you have a mysql database called **txdb_test** and a table **users** with column. ``` go - package main +package main + +import ( + "database/sql" + "log" + + "github.com/DATA-DOG/go-txdb" + _ "github.com/go-sql-driver/mysql" +) + +func init() { + // we register an sql driver named "txdb" + txdb.Register("txdb", "mysql", "root@/txdb_test") +} + +func main() { + // dsn serves as an unique identifier for connection pool + db, err := sql.Open("txdb", "identifier") + if err != nil { + log.Fatal(err) + } + defer db.Close() - import ( - "database/sql" - "log" + if _, err := db.Exec(`INSERT INTO users(username) VALUES("gopher")`); err != nil { + log.Fatal(err) + } +} +``` - "github.com/DATA-DOG/go-txdb" - _ "github.com/go-sql-driver/mysql" - ) +You can also use [`sql.OpenDB`](https://golang.org/pkg/database/sql/#OpenDB) (added in Go 1.10) rather than registering a txdb driver instance, if you prefer: - func init() { - // we register an sql driver named "txdb" - txdb.Register("txdb", "mysql", "root@/txdb_test") - } +``` go +package main + +import ( + "database/sql" + "log" + + "github.com/DATA-DOG/go-txdb" + _ "github.com/go-sql-driver/mysql" +) + +func main() { + db := sql.OpenDB(txdb.New("mysql", "root@/txdb_test")) + defer db.Close() - func main() { - // dsn serves as an unique identifier for connection pool - db, err := sql.Open("txdb", "identifier") - if err != nil { - log.Fatal(err) - } - defer db.Close() - - if _, err := db.Exec(`INSERT INTO users(username) VALUES("gopher")`); err != nil { - log.Fatal(err) - } + if _, err := db.Exec(`INSERT INTO users(username) VALUES("gopher")`); err != nil { + log.Fatal(err) } +} ``` Every time you will run this application, it will remain in the same state as before. diff --git a/db.go b/db.go index 4e8f1c0..60f280d 100644 --- a/db.go +++ b/db.go @@ -54,6 +54,7 @@ Every time you will run this application, it will remain in the same state as be package txdb import ( + "context" "database/sql" "database/sql/driver" "fmt" @@ -62,6 +63,18 @@ import ( "sync" ) +// New returns a [database/sql/driver.Connector], which can be passed to +// [database/sql.OpenDB]. This can be used in place of [Register]. +// It takes the same arguments as [Register], with the omission of name. +func New(drv, dsn string, options ...func(*conn) error) driver.Connector { + return &TxDriver{ + dsn: dsn, + drv: drv, + conns: make(map[string]*conn), + options: options, + } +} + // Register a txdb sql driver under the given sql driver name // which can be used to open a single transaction based database // connection. @@ -94,8 +107,6 @@ func Register(name, drv, dsn string, options ...func(*conn) error) { }) } -// txDriver is an sql driver which runs on single transaction -// when the Close is called, transaction is rolled back type conn struct { sync.Mutex tx *sql.Tx @@ -109,6 +120,8 @@ type conn struct { ctx interface{ Done() <-chan struct{} } } +// TxDriver is an sql driver which runs on single transaction +// when the Close is called, transaction is rolled back type TxDriver struct { sync.Mutex db *sql.DB @@ -120,6 +133,19 @@ type TxDriver struct { dsn string } +// Connect satisfies the [database/sql/driver.Connector] interface. +func (d *TxDriver) Connect(context.Context) (driver.Conn, error) { + // The DSN passed here doesn't matter, since it's only used to disambiguate + // connections, but that disambiguation happens in the call to New() when + // used through the driver.Connector interface. + return d.Open("connector") +} + +// Driver satisfies the [database/sql/driver.Connector] interface. +func (d *TxDriver) Driver() driver.Driver { + return d +} + func (d *TxDriver) DB() *sql.DB { return d.db } diff --git a/db_test.go b/db_test.go index 18c75dc..85a6e75 100644 --- a/db_test.go +++ b/db_test.go @@ -44,6 +44,18 @@ func drivers() []string { return all } +func TestShouldWorkWithOpenDB(t *testing.T) { + t.Parallel() + for _, d := range txDrivers { + db := sql.OpenDB(txdb.New(d.driver, d.dsn)) + defer db.Close() + _, err := db.Exec("SELECT 1") + if err != nil { + t.Fatal(err) + } + } +} + func TestShouldRunWithNestedTransaction(t *testing.T) { t.Parallel() for _, driver := range drivers() {