Skip to content

Commit

Permalink
Merge pull request #41 from flimzy/connector
Browse files Browse the repository at this point in the history
Add support for the driver.Connector interface.
  • Loading branch information
Yiling-J authored Nov 23, 2023
2 parents 8f8ce42 + 744fbbc commit c347f3a
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 24 deletions.
67 changes: 45 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,33 +24,56 @@ Given, you have a mysql database called **txdb_test** and a table **users** with
column.

``` go
package main
package main

import (
"database/sql"
"log"

"github.com/DATA-DOG/go-txdb"
_ "github.com/go-sql-driver/mysql"
)

func init() {
// we register an sql driver named "txdb"
txdb.Register("txdb", "mysql", "root@/txdb_test")
}

func main() {
// dsn serves as an unique identifier for connection pool
db, err := sql.Open("txdb", "identifier")
if err != nil {
log.Fatal(err)
}
defer db.Close()

import (
"database/sql"
"log"
if _, err := db.Exec(`INSERT INTO users(username) VALUES("gopher")`); err != nil {
log.Fatal(err)
}
}
```

"github.com/DATA-DOG/go-txdb"
_ "github.com/go-sql-driver/mysql"
)
You can also use [`sql.OpenDB`](https://golang.org/pkg/database/sql/#OpenDB) (added in Go 1.10) rather than registering a txdb driver instance, if you prefer:

func init() {
// we register an sql driver named "txdb"
txdb.Register("txdb", "mysql", "root@/txdb_test")
}
``` go
package main

import (
"database/sql"
"log"

"github.com/DATA-DOG/go-txdb"
_ "github.com/go-sql-driver/mysql"
)

func main() {
db := sql.OpenDB(txdb.New("mysql", "root@/txdb_test"))
defer db.Close()

func main() {
// dsn serves as an unique identifier for connection pool
db, err := sql.Open("txdb", "identifier")
if err != nil {
log.Fatal(err)
}
defer db.Close()

if _, err := db.Exec(`INSERT INTO users(username) VALUES("gopher")`); err != nil {
log.Fatal(err)
}
if _, err := db.Exec(`INSERT INTO users(username) VALUES("gopher")`); err != nil {
log.Fatal(err)
}
}
```

Every time you will run this application, it will remain in the same state as before.
Expand Down
30 changes: 28 additions & 2 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Every time you will run this application, it will remain in the same state as be
package txdb

import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
Expand All @@ -62,6 +63,18 @@ import (
"sync"
)

// New returns a [database/sql/driver.Connector], which can be passed to
// [database/sql.OpenDB]. This can be used in place of [Register].
// It takes the same arguments as [Register], with the omission of name.
func New(drv, dsn string, options ...func(*conn) error) driver.Connector {
return &TxDriver{
dsn: dsn,
drv: drv,
conns: make(map[string]*conn),
options: options,
}
}

// Register a txdb sql driver under the given sql driver name
// which can be used to open a single transaction based database
// connection.
Expand Down Expand Up @@ -94,8 +107,6 @@ func Register(name, drv, dsn string, options ...func(*conn) error) {
})
}

// txDriver is an sql driver which runs on single transaction
// when the Close is called, transaction is rolled back
type conn struct {
sync.Mutex
tx *sql.Tx
Expand All @@ -109,6 +120,8 @@ type conn struct {
ctx interface{ Done() <-chan struct{} }
}

// TxDriver is an sql driver which runs on single transaction
// when the Close is called, transaction is rolled back
type TxDriver struct {
sync.Mutex
db *sql.DB
Expand All @@ -120,6 +133,19 @@ type TxDriver struct {
dsn string
}

// Connect satisfies the [database/sql/driver.Connector] interface.
func (d *TxDriver) Connect(context.Context) (driver.Conn, error) {
// The DSN passed here doesn't matter, since it's only used to disambiguate
// connections, but that disambiguation happens in the call to New() when
// used through the driver.Connector interface.
return d.Open("connector")
}

// Driver satisfies the [database/sql/driver.Connector] interface.
func (d *TxDriver) Driver() driver.Driver {
return d
}

func (d *TxDriver) DB() *sql.DB {
return d.db
}
Expand Down
12 changes: 12 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ func drivers() []string {
return all
}

func TestShouldWorkWithOpenDB(t *testing.T) {
t.Parallel()
for _, d := range txDrivers {
db := sql.OpenDB(txdb.New(d.driver, d.dsn))
defer db.Close()
_, err := db.Exec("SELECT 1")
if err != nil {
t.Fatal(err)
}
}
}

func TestShouldRunWithNestedTransaction(t *testing.T) {
t.Parallel()
for _, driver := range drivers() {
Expand Down

0 comments on commit c347f3a

Please sign in to comment.