From 6c8986c05423985e6d31a215fc2b4c7350b05025 Mon Sep 17 00:00:00 2001 From: Johnny Steenbergen Date: Tue, 14 Feb 2023 14:15:53 -0600 Subject: [PATCH] fix: update context handling in *Context db methods to stop context pollution --- db.go | 46 ++++++++++++++++++++++-- db_go18.go | 53 +++++++++++++++++++++++----- db_go18_test.go | 94 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 182 insertions(+), 11 deletions(-) diff --git a/db.go b/db.go index 12358c2..38b9703 100644 --- a/db.go +++ b/db.go @@ -103,6 +103,9 @@ type conn struct { drv *txDriver saves uint savePoint SavePoint + + cancel func() + ctx interface{ Done() <-chan struct{} } } type txDriver struct { @@ -135,7 +138,13 @@ func (d *txDriver) Open(dsn string) (driver.Conn, error) { } c, ok := d.conns[dsn] if !ok { - c = &conn{dsn: dsn, drv: d, savePoint: &defaultSavePoint{}} + c = &conn{ + dsn: dsn, + drv: d, + savePoint: &defaultSavePoint{}, + cancel: func() {}, + ctx: stubCtx{}, + } for _, opt := range d.options { if e := opt(c); e != nil { return c, e @@ -181,6 +190,7 @@ func (c *conn) Close() (err error) { if c.opened == 0 { if c.tx != nil { c.tx.Rollback() + c.cancel() c.tx = nil } c.drv.deleteConn(c.dsn) @@ -305,11 +315,17 @@ func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) { } type stmt struct { - st *sql.Stmt + mu sync.Mutex + st *sql.Stmt + done chan bool } func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { - return s.st.Exec(mapArgs(args)...) + dr, err := s.st.Exec(mapArgs(args)...) + if err != nil { + s.closeDone(true) + } + return dr, err } func (s *stmt) NumInput() int { @@ -317,17 +333,35 @@ func (s *stmt) NumInput() int { } func (s *stmt) Close() error { + s.closeDone(false) return s.st.Close() } func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { rows, err := s.st.Query(mapArgs(args)...) if err != nil { + s.closeDone(true) return nil, err } return buildRows(rows) } +func (s *stmt) closeDone(withErr bool) { + s.mu.Lock() + defer s.mu.Unlock() + if s.done == nil { + return + } + + select { + case s.done <- withErr: + default: + } + + close(s.done) + s.done = nil +} + type rows struct { rows [][]driver.Value pos int @@ -414,3 +448,9 @@ func (rs *rowSets) Close() error { func (rs *rowSets) Next(dest []driver.Value) error { return rs.sets[rs.pos].Next(dest) } + +type stubCtx struct{} + +func (s stubCtx) Done() <-chan struct{} { + return nil +} diff --git a/db_go18.go b/db_go18.go index 2301d7c..1a53a89 100644 --- a/db_go18.go +++ b/db_go18.go @@ -41,14 +41,26 @@ func (rs *rowSets) NextResultSet() error { return nil } -func (c *conn) beginTxOnce(ctx context.Context) (*sql.Tx, error) { +func (c *conn) beginTxOnce(ctx context.Context, done <-chan struct{}) (*sql.Tx, error) { if c.tx == nil { - tx, err := c.drv.db.BeginTx(ctx, &sql.TxOptions{}) + rootCtx, cancel := context.WithCancel(context.Background()) + tx, err := c.drv.db.BeginTx(rootCtx, &sql.TxOptions{}) if err != nil { + cancel() return nil, err } - c.tx = tx + c.tx, c.ctx, c.cancel = tx, rootCtx, cancel } + go func() { + select { + case <-ctx.Done(): + // operation was interrupted by context cancel, so we cancel parent as well + c.cancel() + case <-done: + // operation was successfully finished, so we don't close ctx on tx + case <-c.ctx.Done(): + } + }() return c.tx, nil } @@ -57,7 +69,10 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam c.Lock() defer c.Unlock() - tx, err := c.beginTxOnce(ctx) + done := make(chan struct{}) + defer close(done) + + tx, err := c.beginTxOnce(ctx, done) if err != nil { return nil, err } @@ -76,7 +91,10 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name c.Lock() defer c.Unlock() - tx, err := c.beginTxOnce(ctx) + done := make(chan struct{}) + defer close(done) + + tx, err := c.beginTxOnce(ctx, done) if err != nil { return nil, err } @@ -94,7 +112,10 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e c.Lock() defer c.Unlock() - tx, err := c.beginTxOnce(ctx) + done := make(chan struct{}) + defer close(done) + + tx, err := c.beginTxOnce(ctx, done) if err != nil { return nil, err } @@ -103,7 +124,18 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e if err != nil { return nil, err } - return &stmt{st: st}, nil + + stmtFailedStr := make(chan bool) + go func() { + select { + case <-c.ctx.Done(): + case erred := <-stmtFailedStr: + if erred { + c.cancel() + } + } + }() + return &stmt{st: st, done: stmtFailedStr}, nil } // Implement the "Pinger" interface @@ -113,13 +145,18 @@ func (c *conn) Ping(ctx context.Context) error { // Implement the "StmtExecContext" interface func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { - return s.st.ExecContext(ctx, mapNamedArgs(args)...) + dr, err := s.st.ExecContext(ctx, mapNamedArgs(args)...) + if err != nil { + s.closeDone(true) + } + return dr, err } // Implement the "StmtQueryContext" interface func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { rows, err := s.st.QueryContext(ctx, mapNamedArgs(args)...) if err != nil { + s.closeDone(true) return nil, err } return buildRows(rows) diff --git a/db_go18_test.go b/db_go18_test.go index a24cfaa..eafc308 100644 --- a/db_go18_test.go +++ b/db_go18_test.go @@ -5,6 +5,8 @@ package txdb import ( "context" "database/sql" + "sort" + "strings" "testing" ) @@ -64,3 +66,95 @@ func TestShouldBeAbleToPingWithContext(t *testing.T) { } } } + +func TestShouldHandleStmtsWithoutContextPollution(t *testing.T) { + t.Parallel() + for _, driver := range drivers() { + t.Run(driver, func(t *testing.T) { + db, err := sql.Open(driver, "contextpollution") + if err != nil { + t.Fatalf(driver+": failed to open a connection, have you run 'make test'? err: %s", err) + } + defer db.Close() + + insertSQL := "INSERT INTO users (username, email) VALUES(?, ?)" + if strings.Index(driver, "psql_") == 0 { + insertSQL = "INSERT INTO users (username, email) VALUES($1, $2)" + } + + ctx1, cancel1 := context.WithCancel(context.Background()) + defer cancel1() + + _, err = db.ExecContext(ctx1, insertSQL, "first", "first@foo.com") + if err != nil { + t.Fatalf("unexpected error inserting user 1: %s", err) + } + cancel1() + + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + + _, err = db.ExecContext(ctx2, insertSQL, "second", "second@foo.com") + if err != nil { + t.Fatalf("unexpected error inserting user 2: %s", err) + } + cancel2() + + const selectQuery = ` +select username +from users +where username = 'first' OR username = 'second'` + + rows, err := db.QueryContext(context.Background(), selectQuery) + if err != nil { + t.Fatalf("unexpected error querying users: %s", err) + } + defer rows.Close() + + assertRows := func(t *testing.T, rows *sql.Rows) { + t.Helper() + + var users []string + for rows.Next() { + var user string + err := rows.Scan(&user) + if err != nil { + t.Errorf("unexpected scan failure: %s", err) + continue + } + users = append(users, user) + } + sort.Strings(users) + + wanted := []string{"first", "second"} + + if len(users) != 2 { + t.Fatalf("invalid users received; want=%v\tgot=%v", wanted, users) + } + for i, want := range wanted { + if got := users[i]; want != got { + t.Errorf("invalid user; want=%s\tgot=%s", want, got) + } + } + } + + assertRows(t, rows) + + ctx3, cancel3 := context.WithCancel(context.Background()) + defer cancel3() + + stmt, err := db.PrepareContext(ctx3, selectQuery) + if err != nil { + t.Fatalf("unexpected error preparing stmt: %s", err) + } + + rows, err = stmt.QueryContext(context.TODO()) + if err != nil { + t.Fatalf("unexpected error in stmt querying users: %s", err) + } + defer rows.Close() + + assertRows(t, rows) + }) + } +}