diff --git a/db_go18.go b/db_go18.go index 269b784..72cb390 100644 --- a/db_go18.go +++ b/db_go18.go @@ -53,16 +53,17 @@ func (c *conn) beginTxOnce(ctx context.Context, done <-chan struct{}) (*sql.Tx, } go func() { select { - case <-done: - // operation was successfully finished, so we don't close ctx on tx - case <-c.ctx.Done(): - default: + case <-ctx.Done(): select { - case <-ctx.Done(): + case <-done: + // the operation successfully finished at the "same time" as context cancellation, so we won't close ctx on tx + default: // operation was interrupted by context cancel, so we cancel parent as well c.cancel() - default: } + case <-done: + // operation was successfully finished, so we don't close ctx on tx + case <-c.ctx.Done(): } }() return c.tx, nil