Skip to content

Commit

Permalink
Merge pull request #48 from jsteenb2/fix/context_pollution
Browse files Browse the repository at this point in the history
fix: update context handling in *Context db methods to stop context pollution
  • Loading branch information
l3pp4rd authored Feb 23, 2023
2 parents efef918 + 6c8986c commit 164f0b8
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 11 deletions.
46 changes: 43 additions & 3 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ type conn struct {
drv *txDriver
saves uint
savePoint SavePoint

cancel func()
ctx interface{ Done() <-chan struct{} }
}

type txDriver struct {
Expand Down Expand Up @@ -135,7 +138,13 @@ func (d *txDriver) Open(dsn string) (driver.Conn, error) {
}
c, ok := d.conns[dsn]
if !ok {
c = &conn{dsn: dsn, drv: d, savePoint: &defaultSavePoint{}}
c = &conn{
dsn: dsn,
drv: d,
savePoint: &defaultSavePoint{},
cancel: func() {},
ctx: stubCtx{},
}
for _, opt := range d.options {
if e := opt(c); e != nil {
return c, e
Expand Down Expand Up @@ -181,6 +190,7 @@ func (c *conn) Close() (err error) {
if c.opened == 0 {
if c.tx != nil {
c.tx.Rollback()
c.cancel()
c.tx = nil
}
c.drv.deleteConn(c.dsn)
Expand Down Expand Up @@ -305,29 +315,53 @@ func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
}

type stmt struct {
st *sql.Stmt
mu sync.Mutex
st *sql.Stmt
done chan bool
}

func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
return s.st.Exec(mapArgs(args)...)
dr, err := s.st.Exec(mapArgs(args)...)
if err != nil {
s.closeDone(true)
}
return dr, err
}

func (s *stmt) NumInput() int {
return -1
}

func (s *stmt) Close() error {
s.closeDone(false)
return s.st.Close()
}

func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
rows, err := s.st.Query(mapArgs(args)...)
if err != nil {
s.closeDone(true)
return nil, err
}
return buildRows(rows)
}

func (s *stmt) closeDone(withErr bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.done == nil {
return
}

select {
case s.done <- withErr:
default:
}

close(s.done)
s.done = nil
}

type rows struct {
rows [][]driver.Value
pos int
Expand Down Expand Up @@ -414,3 +448,9 @@ func (rs *rowSets) Close() error {
func (rs *rowSets) Next(dest []driver.Value) error {
return rs.sets[rs.pos].Next(dest)
}

type stubCtx struct{}

func (s stubCtx) Done() <-chan struct{} {
return nil
}
53 changes: 45 additions & 8 deletions db_go18.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,26 @@ func (rs *rowSets) NextResultSet() error {
return nil
}

func (c *conn) beginTxOnce(ctx context.Context) (*sql.Tx, error) {
func (c *conn) beginTxOnce(ctx context.Context, done <-chan struct{}) (*sql.Tx, error) {
if c.tx == nil {
tx, err := c.drv.db.BeginTx(ctx, &sql.TxOptions{})
rootCtx, cancel := context.WithCancel(context.Background())
tx, err := c.drv.db.BeginTx(rootCtx, &sql.TxOptions{})
if err != nil {
cancel()
return nil, err
}
c.tx = tx
c.tx, c.ctx, c.cancel = tx, rootCtx, cancel
}
go func() {
select {
case <-ctx.Done():
// operation was interrupted by context cancel, so we cancel parent as well
c.cancel()
case <-done:
// operation was successfully finished, so we don't close ctx on tx
case <-c.ctx.Done():
}
}()
return c.tx, nil
}

Expand All @@ -57,7 +69,10 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
c.Lock()
defer c.Unlock()

tx, err := c.beginTxOnce(ctx)
done := make(chan struct{})
defer close(done)

tx, err := c.beginTxOnce(ctx, done)
if err != nil {
return nil, err
}
Expand All @@ -76,7 +91,10 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
c.Lock()
defer c.Unlock()

tx, err := c.beginTxOnce(ctx)
done := make(chan struct{})
defer close(done)

tx, err := c.beginTxOnce(ctx, done)
if err != nil {
return nil, err
}
Expand All @@ -94,7 +112,10 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
c.Lock()
defer c.Unlock()

tx, err := c.beginTxOnce(ctx)
done := make(chan struct{})
defer close(done)

tx, err := c.beginTxOnce(ctx, done)
if err != nil {
return nil, err
}
Expand All @@ -103,7 +124,18 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
if err != nil {
return nil, err
}
return &stmt{st: st}, nil

stmtFailedStr := make(chan bool)
go func() {
select {
case <-c.ctx.Done():
case erred := <-stmtFailedStr:
if erred {
c.cancel()
}
}
}()
return &stmt{st: st, done: stmtFailedStr}, nil
}

// Implement the "Pinger" interface
Expand All @@ -113,13 +145,18 @@ func (c *conn) Ping(ctx context.Context) error {

// Implement the "StmtExecContext" interface
func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
return s.st.ExecContext(ctx, mapNamedArgs(args)...)
dr, err := s.st.ExecContext(ctx, mapNamedArgs(args)...)
if err != nil {
s.closeDone(true)
}
return dr, err
}

// Implement the "StmtQueryContext" interface
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
rows, err := s.st.QueryContext(ctx, mapNamedArgs(args)...)
if err != nil {
s.closeDone(true)
return nil, err
}
return buildRows(rows)
Expand Down
94 changes: 94 additions & 0 deletions db_go18_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package txdb
import (
"context"
"database/sql"
"sort"
"strings"
"testing"
)

Expand Down Expand Up @@ -64,3 +66,95 @@ func TestShouldBeAbleToPingWithContext(t *testing.T) {
}
}
}

func TestShouldHandleStmtsWithoutContextPollution(t *testing.T) {
t.Parallel()
for _, driver := range drivers() {
t.Run(driver, func(t *testing.T) {
db, err := sql.Open(driver, "contextpollution")
if err != nil {
t.Fatalf(driver+": failed to open a connection, have you run 'make test'? err: %s", err)
}
defer db.Close()

insertSQL := "INSERT INTO users (username, email) VALUES(?, ?)"
if strings.Index(driver, "psql_") == 0 {
insertSQL = "INSERT INTO users (username, email) VALUES($1, $2)"
}

ctx1, cancel1 := context.WithCancel(context.Background())
defer cancel1()

_, err = db.ExecContext(ctx1, insertSQL, "first", "first@foo.com")
if err != nil {
t.Fatalf("unexpected error inserting user 1: %s", err)
}
cancel1()

ctx2, cancel2 := context.WithCancel(context.Background())
defer cancel2()

_, err = db.ExecContext(ctx2, insertSQL, "second", "second@foo.com")
if err != nil {
t.Fatalf("unexpected error inserting user 2: %s", err)
}
cancel2()

const selectQuery = `
select username
from users
where username = 'first' OR username = 'second'`

rows, err := db.QueryContext(context.Background(), selectQuery)
if err != nil {
t.Fatalf("unexpected error querying users: %s", err)
}
defer rows.Close()

assertRows := func(t *testing.T, rows *sql.Rows) {
t.Helper()

var users []string
for rows.Next() {
var user string
err := rows.Scan(&user)
if err != nil {
t.Errorf("unexpected scan failure: %s", err)
continue
}
users = append(users, user)
}
sort.Strings(users)

wanted := []string{"first", "second"}

if len(users) != 2 {
t.Fatalf("invalid users received; want=%v\tgot=%v", wanted, users)
}
for i, want := range wanted {
if got := users[i]; want != got {
t.Errorf("invalid user; want=%s\tgot=%s", want, got)
}
}
}

assertRows(t, rows)

ctx3, cancel3 := context.WithCancel(context.Background())
defer cancel3()

stmt, err := db.PrepareContext(ctx3, selectQuery)
if err != nil {
t.Fatalf("unexpected error preparing stmt: %s", err)
}

rows, err = stmt.QueryContext(context.TODO())
if err != nil {
t.Fatalf("unexpected error in stmt querying users: %s", err)
}
defer rows.Close()

assertRows(t, rows)
})
}
}

0 comments on commit 164f0b8

Please sign in to comment.