diff --git a/db_go18.go b/db_go18.go index 91bf4db..72cb390 100644 --- a/db_go18.go +++ b/db_go18.go @@ -54,8 +54,13 @@ func (c *conn) beginTxOnce(ctx context.Context, done <-chan struct{}) (*sql.Tx, go func() { select { case <-ctx.Done(): - // operation was interrupted by context cancel, so we cancel parent as well - c.cancel() + select { + case <-done: + // the operation successfully finished at the "same time" as context cancellation, so we won't close ctx on tx + default: + // operation was interrupted by context cancel, so we cancel parent as well + c.cancel() + } case <-done: // operation was successfully finished, so we don't close ctx on tx case <-c.ctx.Done(): diff --git a/db_test.go b/db_test.go index c511425..eb3fbdc 100644 --- a/db_test.go +++ b/db_test.go @@ -808,3 +808,38 @@ func TestIssue49(t *testing.T) { } }) } + +func TestShouldRunWithHeavyWork(t *testing.T) { + t.Parallel() + + testFn := func(t *testing.T, db *sql.DB) { + t.Helper() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + row, err := db.QueryContext(ctx, "SELECT 1 from HeavyWork") + if err != nil { + t.Fatalf("failed to query users: %s", err) + } + if err := row.Close(); err != nil { + t.Fatalf("failed to close rows: %s", err) + } + } + + txDrivers.Run(t, func(t *testing.T, driver *testDriver) { + db, err := sql.Open(driver.name, "HeavyWork") + if err != nil { + t.Fatalf("failed to open a connection: %s", err) + } + defer db.Close() + + _, err = db.Exec("CREATE TABLE IF NOT EXISTS HeavyWork (id INT, name VARCHAR(255))") + if err != nil { + t.Fatalf("failed to create table: %s", err) + } + + for i := 0; i < 10000; i++ { + testFn(t, db) + } + }) +}