Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for the driver.Connector interface. #41

Merged
merged 5 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 45 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,33 +24,56 @@ Given, you have a mysql database called **txdb_test** and a table **users** with
column.

``` go
package main
package main

import (
"database/sql"
"log"

"github.com/DATA-DOG/go-txdb"
_ "github.com/go-sql-driver/mysql"
)

func init() {
// we register an sql driver named "txdb"
txdb.Register("txdb", "mysql", "root@/txdb_test")
}

func main() {
// dsn serves as an unique identifier for connection pool
db, err := sql.Open("txdb", "identifier")
if err != nil {
log.Fatal(err)
}
defer db.Close()

import (
"database/sql"
"log"
if _, err := db.Exec(`INSERT INTO users(username) VALUES("gopher")`); err != nil {
log.Fatal(err)
}
}
```

"github.com/DATA-DOG/go-txdb"
_ "github.com/go-sql-driver/mysql"
)
You can also use [`sql.OpenDB`](https://golang.org/pkg/database/sql/#OpenDB) (added in Go 1.10) rather than registering a txdb driver instance, if you prefer:

func init() {
// we register an sql driver named "txdb"
txdb.Register("txdb", "mysql", "root@/txdb_test")
}
``` go
package main

import (
"database/sql"
"log"

"github.com/DATA-DOG/go-txdb"
_ "github.com/go-sql-driver/mysql"
)

func main() {
db := sql.OpenDB(txdb.New("mysql", "root@/txdb_test"))
defer db.Close()

func main() {
// dsn serves as an unique identifier for connection pool
db, err := sql.Open("txdb", "identifier")
if err != nil {
log.Fatal(err)
}
defer db.Close()

if _, err := db.Exec(`INSERT INTO users(username) VALUES("gopher")`); err != nil {
log.Fatal(err)
}
if _, err := db.Exec(`INSERT INTO users(username) VALUES("gopher")`); err != nil {
log.Fatal(err)
}
}
```

Every time you will run this application, it will remain in the same state as before.
Expand Down
30 changes: 28 additions & 2 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Every time you will run this application, it will remain in the same state as be
package txdb

import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
Expand All @@ -62,6 +63,18 @@ import (
"sync"
)

// New returns a [database/sql/driver.Connector], which can be passed to
// [database/sql.OpenDB]. This can be used in place of [Register].
// It takes the same arguments as [Register], with the omission of name.
func New(drv, dsn string, options ...func(*conn) error) driver.Connector {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if New is the best function name for this. I'm open to alternative suggestions.

return &TxDriver{
dsn: dsn,
drv: drv,
conns: make(map[string]*conn),
options: options,
}
}

// Register a txdb sql driver under the given sql driver name
// which can be used to open a single transaction based database
// connection.
Expand Down Expand Up @@ -94,8 +107,6 @@ func Register(name, drv, dsn string, options ...func(*conn) error) {
})
}

// txDriver is an sql driver which runs on single transaction
// when the Close is called, transaction is rolled back
type conn struct {
sync.Mutex
tx *sql.Tx
Expand All @@ -109,6 +120,8 @@ type conn struct {
ctx interface{ Done() <-chan struct{} }
}

// TxDriver is an sql driver which runs on single transaction
// when the Close is called, transaction is rolled back
type TxDriver struct {
sync.Mutex
db *sql.DB
Expand All @@ -120,6 +133,19 @@ type TxDriver struct {
dsn string
}

// Connect satisfies the [database/sql/driver.Connector] interface.
func (d *TxDriver) Connect(context.Context) (driver.Conn, error) {
// The DSN passed here doesn't matter, since it's only used to disambiguate
// connections, but that disambiguation happens in the call to New() when
// used through the driver.Connector interface.
return d.Open("connector")
}

// Driver satisfies the [database/sql/driver.Connector] interface.
func (d *TxDriver) Driver() driver.Driver {
return d
}

func (d *TxDriver) DB() *sql.DB {
return d.db
}
Expand Down
12 changes: 12 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ func drivers() []string {
return all
}

func TestShouldWorkWithOpenDB(t *testing.T) {
t.Parallel()
for _, d := range txDrivers {
db := sql.OpenDB(txdb.New(d.driver, d.dsn))
defer db.Close()
_, err := db.Exec("SELECT 1")
if err != nil {
t.Fatal(err)
}
}
}

func TestShouldRunWithNestedTransaction(t *testing.T) {
t.Parallel()
for _, driver := range drivers() {
Expand Down