From d39a66cce4e1c411c03e08f78cbbf58a10396be3 Mon Sep 17 00:00:00 2001 From: Piotr Fus Date: Wed, 19 Jul 2023 08:47:12 +0200 Subject: [PATCH] SNOW-859548 Reuse connection in tests --- arrow_test.go | 177 ++++++++++++++++++++++-------------- async_test.go | 21 +++-- bindings_test.go | 71 ++++++++------- connection_test.go | 6 +- driver_test.go | 91 +++++++++--------- file_transfer_agent_test.go | 4 +- multistatement_test.go | 10 +- priv_key_test.go | 8 +- put_get_test.go | 45 +++++---- put_get_user_stage_test.go | 2 + put_get_with_aws_test.go | 2 + rows_test.go | 4 +- statement_test.go | 49 +++++----- transaction_test.go | 32 +++---- 14 files changed, 302 insertions(+), 220 deletions(-) diff --git a/arrow_test.go b/arrow_test.go index 6b4706f16..7ff56fbca 100644 --- a/arrow_test.go +++ b/arrow_test.go @@ -14,8 +14,9 @@ import ( ) func TestArrowBigInt(t *testing.T) { - db := openDB(t) - dbt := &DBTest{t, db} + conn := openConn(t) + defer conn.Close() + dbt := &DBTest{t, conn} testcases := []struct { num string @@ -24,7 +25,7 @@ func TestArrowBigInt(t *testing.T) { }{ {"10000000000000000000000000000000000000", 38, 0}, {"-10000000000000000000000000000000000000", 38, 0}, - {"12345678901234567890123456789012345678", 38, 0}, + {"12345678901234567890123456789012345678", 38, 0}, // #pragma: allowlist secret {"-12345678901234567890123456789012345678", 38, 0}, {"99999999999999999999999999999999999999", 38, 0}, {"-99999999999999999999999999999999999999", 38, 0}, @@ -53,8 +54,9 @@ func TestArrowBigInt(t *testing.T) { } func TestArrowBigFloat(t *testing.T) { - db := openDB(t) - dbt := &DBTest{t, db} + conn := openConn(t) + defer conn.Close() + dbt := &DBTest{t, conn} testcases := []struct { num string @@ -95,7 +97,12 @@ func TestArrowBigFloat(t *testing.T) { func TestArrowIntPrecision(t *testing.T) { db := openDB(t) - dbt := &DBTest{t, db} + defer db.Close() + + _, err := db.Exec(forceJSON) + if err != nil { + t.Fatalf("failed to set JSON as result type: %v", err) + } intTestcases := []struct { num string @@ -104,7 +111,7 @@ func TestArrowIntPrecision(t *testing.T) { }{ {"10000000000000000000000000000000000000", 38, 0}, {"-10000000000000000000000000000000000000", 38, 0}, - {"12345678901234567890123456789012345678", 38, 0}, + {"12345678901234567890123456789012345678", 38, 0}, // pragma: allowlist secret {"-12345678901234567890123456789012345678", 38, 0}, {"99999999999999999999999999999999999999", 38, 0}, {"-99999999999999999999999999999999999999", 38, 0}, @@ -112,68 +119,77 @@ func TestArrowIntPrecision(t *testing.T) { t.Run("arrow_disabled_scan_int64", func(t *testing.T) { for _, tc := range intTestcases { - dbt.mustExec(forceJSON) - rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) - if !rows.Next() { - dbt.Error("failed to query") + rows, err := db.Query(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) + if err != nil { + t.Fatalf("failed to query: %v", err) } defer rows.Close() + if !rows.Next() { + t.Error("failed to query") + } var v int64 if err := rows.Scan(&v); err == nil { - dbt.Error("should fail to scan") + t.Error("should fail to scan") } } }) t.Run("arrow_disabled_scan_string", func(t *testing.T) { for _, tc := range intTestcases { - dbt.mustExec(forceJSON) - rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) + rows, err := db.Query(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) + if err != nil { + t.Fatalf("failed to query: %v", err) + } + defer rows.Close() if !rows.Next() { - dbt.Error("failed to query") + t.Error("failed to query") } defer rows.Close() var v int64 if err := rows.Scan(&v); err == nil { - dbt.Error("should fail to scan") + t.Error("should fail to scan") } } }) t.Run("arrow_enabled_scan_big_int", func(t *testing.T) { for _, tc := range intTestcases { - rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) - if !rows.Next() { - dbt.Error("failed to query") + rows, err := db.Query(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) + if err != nil { + t.Fatalf("failed to query: %v", err) } defer rows.Close() + if !rows.Next() { + t.Error("failed to query") + } var v string if err := rows.Scan(&v); err != nil { - dbt.Errorf("failed to scan. %#v", err) + t.Errorf("failed to scan. %#v", err) } if !strings.EqualFold(v, tc.num) { - dbt.Errorf("int value mismatch: expected %v, got %v", tc.num, v) + t.Errorf("int value mismatch: expected %v, got %v", tc.num, v) } } }) t.Run("arrow_high_precision_enabled_scan_big_int", func(t *testing.T) { for _, tc := range intTestcases { - rows := dbt.mustQueryContext( - WithHigherPrecision(context.Background()), - fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) - if !rows.Next() { - dbt.Error("failed to query") + rows, err := db.QueryContext(WithHigherPrecision(context.Background()), fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) + if err != nil { + t.Fatalf("failed to query: %v", err) } defer rows.Close() + if !rows.Next() { + t.Error("failed to query") + } var v *big.Int if err := rows.Scan(&v); err != nil { - dbt.Errorf("failed to scan. %#v", err) + t.Errorf("failed to scan. %#v", err) } b, ok := new(big.Int).SetString(tc.num, 10) if !ok { - dbt.Errorf("failed to convert %v big.Int.", tc.num) + t.Errorf("failed to convert %v big.Int.", tc.num) } if v.Cmp(b) != 0 { - dbt.Errorf("big.Int value mismatch: expected %v, got %v", b, v) + t.Errorf("big.Int value mismatch: expected %v, got %v", b, v) } } }) @@ -184,7 +200,12 @@ func TestArrowIntPrecision(t *testing.T) { // to check the value as precision could be lost. func TestArrowFloatPrecision(t *testing.T) { db := openDB(t) - dbt := &DBTest{t, db} + defer db.Close() + + _, err := db.Exec(forceJSON) + if err != nil { + t.Fatalf("failed to set JSON as result type: %v", err) + } fltTestcases := []struct { num string @@ -202,109 +223,125 @@ func TestArrowFloatPrecision(t *testing.T) { t.Run("arrow_disabled_scan_float64", func(t *testing.T) { for _, tc := range fltTestcases { - dbt.mustExec(forceJSON) - rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) - if !rows.Next() { - dbt.Error("failed to query") + rows, err := db.Query(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) + if err != nil { + t.Fatalf("failed to query: %v", err) } defer rows.Close() + if !rows.Next() { + t.Error("failed to query") + } var v float64 if err := rows.Scan(&v); err != nil { - dbt.Errorf("failed to scan. %#v", err) + t.Errorf("failed to scan. %#v", err) } } }) t.Run("arrow_disabled_scan_float32", func(t *testing.T) { for _, tc := range fltTestcases { - dbt.mustExec(forceJSON) - rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) - if !rows.Next() { - dbt.Error("failed to query") + rows, err := db.Query(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) + if err != nil { + t.Fatalf("failed to query: %v", err) } defer rows.Close() + if !rows.Next() { + t.Error("failed to query") + } var v float32 if err := rows.Scan(&v); err != nil { - dbt.Errorf("failed to scan. %#v", err) + t.Errorf("failed to scan. %#v", err) } } }) t.Run("arrow_disabled_scan_string", func(t *testing.T) { for _, tc := range fltTestcases { - dbt.mustExec(forceJSON) - rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) - if !rows.Next() { - dbt.Error("failed to query") + rows, err := db.Query(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) + if err != nil { + t.Fatalf("failed to query: %v", err) } defer rows.Close() + if !rows.Next() { + t.Error("failed to query") + } var v string if err := rows.Scan(&v); err != nil { - dbt.Errorf("failed to scan. %#v", err) + t.Errorf("failed to scan. %#v", err) } if !strings.EqualFold(v, tc.num) { - dbt.Errorf("int value mismatch: expected %v, got %v", tc.num, v) + t.Errorf("int value mismatch: expected %v, got %v", tc.num, v) } } }) t.Run("arrow_enabled_scan_float64", func(t *testing.T) { for _, tc := range fltTestcases { - rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) - if !rows.Next() { - dbt.Error("failed to query") + rows, err := db.Query(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) + if err != nil { + t.Fatalf("failed to query: %v", err) } defer rows.Close() + if !rows.Next() { + t.Error("failed to query") + } var v float64 if err := rows.Scan(&v); err != nil { - dbt.Errorf("failed to scan. %#v", err) + t.Errorf("failed to scan. %#v", err) } } }) t.Run("arrow_enabled_scan_float32", func(t *testing.T) { for _, tc := range fltTestcases { - rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) - if !rows.Next() { - dbt.Error("failed to query") + rows, err := db.Query(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) + if err != nil { + t.Fatalf("failed to query: %v", err) } defer rows.Close() + if !rows.Next() { + t.Error("failed to query") + } var v float32 if err := rows.Scan(&v); err != nil { - dbt.Errorf("failed to scan. %#v", err) + t.Errorf("failed to scan. %#v", err) } } }) t.Run("arrow_enabled_scan_string", func(t *testing.T) { for _, tc := range fltTestcases { - rows := dbt.mustQuery(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) - if !rows.Next() { - dbt.Error("failed to query") + rows, err := db.Query(fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) + if err != nil { + t.Fatalf("failed to query: %v", err) } defer rows.Close() + if !rows.Next() { + t.Error("failed to query") + } var v string if err := rows.Scan(&v); err != nil { - dbt.Errorf("failed to scan. %#v", err) + t.Errorf("failed to scan. %#v", err) } } }) t.Run("arrow_high_precision_enabled_scan_big_float", func(t *testing.T) { for _, tc := range fltTestcases { - rows := dbt.mustQueryContext( - WithHigherPrecision(context.Background()), - fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) - if !rows.Next() { - dbt.Error("failed to query") + rows, err := db.QueryContext(WithHigherPrecision(context.Background()), fmt.Sprintf(selectNumberSQL, tc.num, tc.prec, tc.sc)) + if err != nil { + t.Fatalf("failed to query: %v", err) } defer rows.Close() + if !rows.Next() { + t.Error("failed to query") + } var v *big.Float if err := rows.Scan(&v); err != nil { - dbt.Errorf("failed to scan. %#v", err) + t.Errorf("failed to scan. %#v", err) } prec := v.Prec() b, ok := new(big.Float).SetPrec(prec).SetString(tc.num) if !ok { - dbt.Errorf("failed to convert %v to big.Float.", tc.num) + t.Errorf("failed to convert %v to big.Float.", tc.num) } if v.Cmp(b) != 0 { - dbt.Errorf("big.Float value mismatch: expected %v, got %v", b, v) + t.Errorf("big.Float value mismatch: expected %v, got %v", b, v) } } }) @@ -317,6 +354,7 @@ func TestArrowTimePrecision(t *testing.T) { dbt.mustExec("INSERT INTO t VALUES ('23:59:59.99999', '23:59:59.999999', '23:59:59.9999999', '23:59:59.99999999');") rows := dbt.mustQuery("select * from t") + defer rows.Close() var c5, c6, c7, c8 time.Time for rows.Next() { if err := rows.Scan(&c5, &c6, &c7, &c8); err != nil { @@ -361,10 +399,11 @@ func TestArrowTimePrecision(t *testing.T) { '9999-12-31T23:59:59.99999999' );`) - rows = dbt.mustQuery("select * from t_ntz") + rows2 := dbt.mustQuery("select * from t_ntz") + defer rows2.Close() var c1, c2, c3, c4 time.Time - for rows.Next() { - if err := rows.Scan(&c1, &c2, &c3, &c4, &c5, &c6, &c7, &c8); err != nil { + for rows2.Next() { + if err := rows2.Scan(&c1, &c2, &c3, &c4, &c5, &c6, &c7, &c8); err != nil { t.Errorf("values were not scanned: %v", err) } } diff --git a/async_test.go b/async_test.go index d69619c8a..e0df8acba 100644 --- a/async_test.go +++ b/async_test.go @@ -4,6 +4,7 @@ package gosnowflake import ( "context" + "database/sql" "fmt" "testing" ) @@ -105,10 +106,18 @@ func TestMultipleAsyncQueries(t *testing.T) { ch1 := make(chan string) ch2 := make(chan string) + db := openDB(t) + runTests(t, dsn, func(dbt *DBTest) { - rows1 := dbt.mustQueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s1, 30)) + rows1, err := db.QueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s1, 30)) + if err != nil { + t.Fatalf("can't read rows1: %v", err) + } defer rows1.Close() - rows2 := dbt.mustQueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s2, 10)) + rows2, err := db.QueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s2, 10)) + if err != nil { + t.Fatalf("can't read rows2: %v", err) + } defer rows2.Close() go retrieveRows(rows1, ch1) @@ -124,7 +133,7 @@ func TestMultipleAsyncQueries(t *testing.T) { }) } -func retrieveRows(rows *RowsExtended, ch chan string) { +func retrieveRows(rows *sql.Rows, ch chan string) { var s string for rows.Next() { if err := rows.Scan(&s); err != nil { @@ -138,13 +147,13 @@ func retrieveRows(rows *RowsExtended, ch chan string) { } func TestLongRunningAsyncQuery(t *testing.T) { - db := openDB(t) - defer db.Close() + conn := openConn(t) + defer conn.Close() ctx, _ := WithMultiStatement(context.Background(), 0) query := "CALL SYSTEM$WAIT(50, 'SECONDS');use snowflake_sample_data" - rows, err := db.QueryContext(WithAsyncMode(ctx), query) + rows, err := conn.QueryContext(WithAsyncMode(ctx), query) if err != nil { t.Fatalf("failed to run a query. %v, err: %v", query, err) } diff --git a/bindings_test.go b/bindings_test.go index 79d4be32e..1fb41af5d 100644 --- a/bindings_test.go +++ b/bindings_test.go @@ -68,7 +68,7 @@ func TestBindingUint64(t *testing.T) { expected := uint64(18446744073709551615) for _, v := range types { dbt.mustExec(fmt.Sprintf("CREATE TABLE test (id int, value %v)", v)) - if _, err := dbt.db.Exec("INSERT INTO test VALUES (1, ?)", expected); err == nil { + if _, err := dbt.exec("INSERT INTO test VALUES (1, ?)", expected); err == nil { dbt.Fatal("should fail as uint64 values with high bit set are not supported.") } else { logger.Infof("expected err: %v", err) @@ -84,7 +84,7 @@ func TestBindingDateTimeTimestamp(t *testing.T) { expected := time.Now() dbt.mustExec( "CREATE OR REPLACE TABLE tztest (id int, ntz timestamp_ntz, ltz timestamp_ltz, dt date, tm time)") - stmt, err := dbt.db.Prepare("INSERT INTO tztest(id,ntz,ltz,dt,tm) VALUES(1,?,?,?,?)") + stmt, err := dbt.prepare("INSERT INTO tztest(id,ntz,ltz,dt,tm) VALUES(1,?,?,?,?)") if err != nil { dbt.Fatal(err.Error()) } @@ -175,7 +175,7 @@ func TestBindingTimestampTZ(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { expected := time.Now() dbt.mustExec("CREATE OR REPLACE TABLE tztest (id int, tz timestamp_tz)") - stmt, err := dbt.db.Prepare("INSERT INTO tztest(id,tz) VALUES(1, ?)") + stmt, err := dbt.prepare("INSERT INTO tztest(id,tz) VALUES(1, ?)") if err != nil { dbt.Fatal(err.Error()) } @@ -214,7 +214,7 @@ func TestBindingTimePtrInStruct(t *testing.T) { runInsertQuery := false for i := 0; i < 2; i++ { if !runInsertQuery { - _, err := dbt.db.Exec("INSERT INTO timeStructTest(id,tz) VALUES(?, ?)", testStruct.id, testStruct.timeVal) + _, err := dbt.exec("INSERT INTO timeStructTest(id,tz) VALUES(?, ?)", testStruct.id, testStruct.timeVal) if err != nil { dbt.Fatal(err.Error()) } @@ -223,7 +223,7 @@ func TestBindingTimePtrInStruct(t *testing.T) { // Update row with a new time value expectedTime = time.Now().Add(1) testStruct.timeVal = &expectedTime - _, err := dbt.db.Exec("UPDATE timeStructTest SET tz = ? where id = ?", testStruct.timeVal, testStruct.id) + _, err := dbt.exec("UPDATE timeStructTest SET tz = ? where id = ?", testStruct.timeVal, testStruct.id) if err != nil { dbt.Fatal(err.Error()) } @@ -261,7 +261,7 @@ func TestBindingTimeInStruct(t *testing.T) { runInsertQuery := false for i := 0; i < 2; i++ { if !runInsertQuery { - _, err := dbt.db.Exec("INSERT INTO timeStructTest(id,tz) VALUES(?, ?)", testStruct.id, testStruct.timeVal) + _, err := dbt.exec("INSERT INTO timeStructTest(id,tz) VALUES(?, ?)", testStruct.id, testStruct.timeVal) if err != nil { dbt.Fatal(err.Error()) } @@ -270,7 +270,7 @@ func TestBindingTimeInStruct(t *testing.T) { // Update row with a new time value expectedTime = time.Now().Add(1) testStruct.timeVal = expectedTime - _, err := dbt.db.Exec("UPDATE timeStructTest SET tz = ? where id = ?", testStruct.timeVal, testStruct.id) + _, err := dbt.exec("UPDATE timeStructTest SET tz = ? where id = ?", testStruct.timeVal, testStruct.id) if err != nil { dbt.Fatal(err.Error()) } @@ -630,7 +630,7 @@ func testBindingArray(t *testing.T, bulk bool) { dbt.mustExec(createTableSQL) defer dbt.mustExec(deleteTableSQL) if bulk { - if _, err := dbt.db.Exec("ALTER SESSION SET CLIENT_STAGE_ARRAY_BINDING_THRESHOLD = 1"); err != nil { + if _, err := dbt.exec("ALTER SESSION SET CLIENT_STAGE_ARRAY_BINDING_THRESHOLD = 1"); err != nil { t.Error(err) } } @@ -705,6 +705,7 @@ func TestBulkArrayBinding(t *testing.T) { } dbt.mustExec(fmt.Sprintf("insert into %v values (?, ?)", dbname), Array(&intArr), Array(&strArr)) rows := dbt.mustQuery("select * from " + dbname) + defer rows.Close() cnt := 0 var i int var s string @@ -747,6 +748,7 @@ func TestBulkArrayMultiPartBinding(t *testing.T) { fmt.Sprintf("INSERT INTO %s VALUES (?)", tempTableName), Array(&randomStrings)) rows := dbt.mustQuery("select count(*) from " + tempTableName) + defer rows.Close() if rows.Next() { var count int if err := rows.Scan(&count); err != nil { @@ -756,6 +758,7 @@ func TestBulkArrayMultiPartBinding(t *testing.T) { } rows := dbt.mustQuery("select count(*) from " + tempTableName) + defer rows.Close() if rows.Next() { var count int if err := rows.Scan(&count); err != nil { @@ -778,12 +781,13 @@ func TestBulkArrayMultiPartBindingInt(t *testing.T) { for i := startNum; i < endNum; i++ { intArr[i-startNum] = i } - _, err := dbt.db.Exec("insert into binding_test values (?)", Array(&intArr)) + _, err := dbt.exec("insert into binding_test values (?)", Array(&intArr)) if err != nil { t.Errorf("Should have succeeded to insert. err: %v", err) } rows := dbt.mustQuery("select * from binding_test order by c1") + defer rows.Close() cnt := startNum var i int for rows.Next() { @@ -825,12 +829,13 @@ func TestBulkArrayMultiPartBindingWithNull(t *testing.T) { stringArr[2] = nil stringArr[3] = nil - _, err := dbt.db.Exec("insert into binding_test values (?, ?)", Array(&intArr), Array(&stringArr)) + _, err := dbt.exec("insert into binding_test values (?, ?)", Array(&intArr), Array(&stringArr)) if err != nil { t.Errorf("Should have succeeded to insert. err: %v", err) } rows := dbt.mustQuery("select * from binding_test order by c1,c2") + defer rows.Close() cnt := startNum var i sql.NullInt32 var s sql.NullString @@ -903,28 +908,27 @@ func TestFunctionParameters(t *testing.T) { LANGUAGE SQL AS 'select param1';`, tc.paramType, tc.paramType) dbt.mustExec(query) - if rows, err := dbt.db.Query("select * from table(NULLPARAMFUNCTION(?))", tc.input); err != nil { + rows, err := dbt.query("select * from table(NULLPARAMFUNCTION(?))", tc.input) + if err != nil { t.Fatal(err) - } else { - if rows.Err() != nil { - t.Fatal(err) - } else { - if !rows.Next() { - t.Fatal() - } else { - var r1 any - err = rows.Scan(&r1) - if err != nil { - t.Fatal(err) - } - if tc.nullResult && r1 != nil { - t.Fatalf("the result for %v is of type %v but should be null", tc.paramType, reflect.TypeOf(r1)) - } - if !tc.nullResult && r1 == nil { - t.Fatalf("the result for %v should not be null", tc.paramType) - } - } - } + } + defer rows.Close() + if rows.Err() != nil { + t.Fatal(err) + } + if !rows.Next() { + t.Fatal("no rows fetched") + } + var r1 any + err = rows.Scan(&r1) + if err != nil { + t.Fatal(err) + } + if tc.nullResult && r1 != nil { + t.Fatalf("the result for %v is of type %v but should be null", tc.paramType, reflect.TypeOf(r1)) + } + if !tc.nullResult && r1 == nil { + t.Fatalf("the result for %v should not be null", tc.paramType) } }) } @@ -987,7 +991,7 @@ func TestVariousBindingModes(t *testing.T) { t.Run(tc.testDesc+" "+bindingMode.param, func(t *testing.T) { query := fmt.Sprintf(`CREATE OR REPLACE TABLE BINDING_MODES(param1 %v)`, tc.paramType) dbt.mustExec(query) - if _, err := dbt.db.Exec(fmt.Sprintf("INSERT INTO BINDING_MODES VALUES (%v)", bindingMode.param), bindingMode.transform(tc.input)); err != nil { + if _, err := dbt.exec(fmt.Sprintf("INSERT INTO BINDING_MODES VALUES (%v)", bindingMode.param), bindingMode.transform(tc.input)); err != nil { t.Fatal(err) } if tc.isNil { @@ -995,10 +999,11 @@ func TestVariousBindingModes(t *testing.T) { } else { query = fmt.Sprintf("SELECT * FROM BINDING_MODES WHERE param1 = %v", bindingMode.param) } - rows, err := dbt.db.Query(query, bindingMode.transform(tc.input)) + rows, err := dbt.query(query, bindingMode.transform(tc.input)) if err != nil { t.Fatal(err) } + defer rows.Close() if !rows.Next() { t.Fatal("Expected to return a row") } diff --git a/connection_test.go b/connection_test.go index 83196df78..f32141032 100644 --- a/connection_test.go +++ b/connection_test.go @@ -32,13 +32,13 @@ func TestInvalidConnection(t *testing.T) { if err := db.Close(); err != nil { t.Error("should not cause error in the second call of Close") } - if _, err := db.Exec("CREATE TABLE OR REPLACE test0(c1 int)"); err == nil { + if _, err := db.ExecContext(context.Background(), "CREATE TABLE OR REPLACE test0(c1 int)"); err == nil { t.Error("should fail to run Exec") } - if _, err := db.Query("SELECT CURRENT_TIMESTAMP()"); err == nil { + if _, err := db.QueryContext(context.Background(), "SELECT CURRENT_TIMESTAMP()"); err == nil { t.Error("should fail to run Query") } - if _, err := db.Begin(); err == nil { + if _, err := db.BeginTx(context.Background(), nil); err == nil { t.Error("should fail to run Begin") } } diff --git a/driver_test.go b/driver_test.go index 869e87f73..3b1306147 100644 --- a/driver_test.go +++ b/driver_test.go @@ -163,7 +163,7 @@ func TestMain(m *testing.M) { type DBTest struct { *testing.T - db *sql.DB + conn *sql.Conn } func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *RowsExtended) { @@ -187,7 +187,7 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *RowsExten close(c) }() - rs, err := dbt.db.QueryContext(ctx, query, args...) + rs, err := dbt.conn.QueryContext(ctx, query, args...) if err != nil { dbt.fail("query", query, err) } @@ -218,7 +218,7 @@ func (dbt *DBTest) mustQueryContext(ctx context.Context, query string, args ...i close(c) }() - rs, err := dbt.db.QueryContext(ctx, query, args...) + rs, err := dbt.conn.QueryContext(ctx, query, args...) if err != nil { dbt.fail("query", query, err) } @@ -228,8 +228,13 @@ func (dbt *DBTest) mustQueryContext(ctx context.Context, query string, args ...i } } +func (dbt *DBTest) query(query string, args ...any) (*sql.Rows, error) { + return dbt.conn.QueryContext(context.Background(), query, args...) +} + func (dbt *DBTest) mustQueryAssertCount(query string, expected int, args ...interface{}) { rows := dbt.mustQuery(query, args...) + defer rows.Close() cnt := 0 for rows.Next() { cnt++ @@ -239,6 +244,10 @@ func (dbt *DBTest) mustQueryAssertCount(query string, expected int, args ...inte } } +func (dbt *DBTest) prepare(query string) (*sql.Stmt, error) { + return dbt.conn.PrepareContext(context.Background(), query) +} + func (dbt *DBTest) fail(method, query string, err error) { if len(query) > 300 { query = "[query too large to print]" @@ -247,21 +256,21 @@ func (dbt *DBTest) fail(method, query string, err error) { } func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) { - res, err := dbt.db.Exec(query, args...) - if err != nil { - dbt.fail("exec", query, err) - } - return res + return dbt.mustExecContext(context.Background(), query, args...) } func (dbt *DBTest) mustExecContext(ctx context.Context, query string, args ...interface{}) (res sql.Result) { - res, err := dbt.db.ExecContext(ctx, query, args...) + res, err := dbt.conn.ExecContext(ctx, query, args...) if err != nil { dbt.fail("exec context", query, err) } return res } +func (dbt *DBTest) exec(query string, args ...any) (sql.Result, error) { + return dbt.conn.ExecContext(context.Background(), query, args...) +} + func (dbt *DBTest) mustDecimalSize(ct *sql.ColumnType) (pr int64, sc int64) { var ok bool pr, sc, ok = ct.DecimalSize() @@ -304,7 +313,7 @@ func (dbt *DBTest) mustNullable(ct *sql.ColumnType) (canNull bool) { } func (dbt *DBTest) mustPrepare(query string) (stmt *sql.Stmt) { - stmt, err := dbt.db.Prepare(query) + stmt, err := dbt.conn.PrepareContext(context.Background(), query) if err != nil { dbt.fail("prepare", query, err) } @@ -312,20 +321,18 @@ func (dbt *DBTest) mustPrepare(query string) (stmt *sql.Stmt) { } func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { - db, err := sql.Open("snowflake", dsn) - if err != nil { - t.Fatalf("error connecting: %s", err.Error()) - } - defer db.Close() + ctx := context.Background() + conn := openConn(t) + dbt := &DBTest{t, conn} + defer conn.Close() - if _, err = db.Exec("DROP TABLE IF EXISTS test"); err != nil { + if _, err := conn.ExecContext(ctx, "DROP TABLE IF EXISTS test"); err != nil { t.Fatalf("failed to drop table: %v", err) } - dbt := &DBTest{t, db} for _, test := range tests { test(dbt) - dbt.db.Exec("DROP TABLE IF EXISTS test") + dbt.conn.ExecContext(ctx, "DROP TABLE IF EXISTS test") } } @@ -415,7 +422,7 @@ func TestCommentOnlyQuery(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { query := "--" // just a comment, no query - rows, err := dbt.db.Query(query) + rows, err := dbt.query(query) if err == nil { rows.Close() dbt.fail("query", query, err) @@ -432,12 +439,12 @@ func TestEmptyQuery(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { query := "select 1 from dual where 1=0" // just a comment, no query - rows := dbt.db.QueryRow(query) - var v1 interface{} + rows := dbt.conn.QueryRowContext(context.Background(), query) + var v1 any if err := rows.Scan(&v1); err != sql.ErrNoRows { dbt.Errorf("should fail. err: %v", err) } - rows = dbt.db.QueryRowContext(context.Background(), query) + rows = dbt.conn.QueryRowContext(context.Background(), query) if err := rows.Scan(&v1); err != sql.ErrNoRows { dbt.Errorf("should fail. err: %v", err) } @@ -448,7 +455,7 @@ func TestEmptyQueryWithRequestID(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { query := "select 1" ctx := WithRequestID(context.Background(), NewUUID()) - rows := dbt.db.QueryRowContext(ctx, query) + rows := dbt.conn.QueryRowContext(ctx, query) var v1 interface{} if err := rows.Scan(&v1); err != nil { dbt.Errorf("should not have failed with valid request id. err: %v", err) @@ -695,7 +702,7 @@ func testString(t *testing.T, json bool) { gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet.` dbt.mustExec("INSERT INTO test VALUES (?, ?)", id, in) - if err := dbt.db.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&out); err != nil { + if err := dbt.conn.QueryRowContext(context.Background(), "SELECT value FROM test WHERE id = ?", id).Scan(&out); err != nil { dbt.Fatalf("Error on BLOB-Query: %s", err.Error()) } else if out != in { dbt.Errorf("BLOB: %s != %s", in, out) @@ -994,13 +1001,13 @@ func testNULL(t *testing.T, json bool) { if json { dbt.mustExec(forceJSON) } - nullStmt, err := dbt.db.Prepare("SELECT NULL") + nullStmt, err := dbt.conn.PrepareContext(context.Background(), "SELECT NULL") if err != nil { dbt.Fatal(err) } defer nullStmt.Close() - nonNullStmt, err := dbt.db.Prepare("SELECT 1") + nonNullStmt, err := dbt.conn.PrepareContext(context.Background(), "SELECT 1") if err != nil { dbt.Fatal(err) } @@ -1101,7 +1108,7 @@ func testNULL(t *testing.T, json bool) { // Insert nil b = nil success := false - if err = dbt.db.QueryRow("SELECT ? IS NULL", b).Scan(&success); err != nil { + if err = dbt.conn.QueryRowContext(context.Background(), "SELECT ? IS NULL", b).Scan(&success); err != nil { dbt.Fatal(err) } if !success { @@ -1110,7 +1117,7 @@ func testNULL(t *testing.T, json bool) { } // Check input==output with input==nil b = nil - if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil { + if err = dbt.conn.QueryRowContext(context.Background(), "SELECT ?", b).Scan(&b); err != nil { dbt.Fatal(err) } if b != nil { @@ -1118,7 +1125,7 @@ func testNULL(t *testing.T, json bool) { } // Check input==output with input!=nil b = []byte("") - if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil { + if err = dbt.conn.QueryRowContext(context.Background(), "SELECT ?", b).Scan(&b); err != nil { dbt.Fatal(err) } if b == nil { @@ -1258,7 +1265,7 @@ func TestDML(t *testing.T) { } func insertData(dbt *DBTest, commit bool) error { - tx, err := dbt.db.Begin() + tx, err := dbt.conn.BeginTx(context.Background(), nil) if err != nil { dbt.Fatalf("failed to begin transaction: %v", err) } @@ -1314,7 +1321,7 @@ func queryTestTx(tx *sql.Tx) (*map[int]string, error) { func queryTest(dbt *DBTest) (*map[int]string, error) { var c1 int var c2 string - rows, err := dbt.db.Query("SELECT c1, c2 FROM test") + rows, err := dbt.query("SELECT c1, c2 FROM test") if err != nil { return nil, err } @@ -1334,7 +1341,7 @@ func TestCancelQuery(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - _, err := dbt.db.QueryContext(ctx, "SELECT DISTINCT 1 FROM TABLE(GENERATOR(TIMELIMIT=> 100))") + _, err := dbt.conn.QueryContext(ctx, "SELECT DISTINCT 1 FROM TABLE(GENERATOR(TIMELIMIT=> 100))") if err == nil { dbt.Fatal("No timeout error returned") } @@ -1345,8 +1352,8 @@ func TestCancelQuery(t *testing.T) { } func TestPing(t *testing.T) { - db := openDB(t) - if err := db.Ping(); err != nil { + db := openConn(t) + if err := db.PingContext(context.Background()); err != nil { t.Fatalf("failed to ping. err: %v", err) } if err := db.PingContext(context.Background()); err != nil { @@ -1355,7 +1362,7 @@ func TestPing(t *testing.T) { if err := db.Close(); err != nil { t.Fatalf("failed to close db. err: %v", err) } - if err := db.Ping(); err == nil { + if err := db.PingContext(context.Background()); err == nil { t.Fatal("should have failed to ping") } if err := db.PingContext(context.Background()); err == nil { @@ -1387,10 +1394,10 @@ $$ func TestTimezoneSessionParameter(t *testing.T) { createDSN(PSTLocation) - db := openDB(t) - defer db.Close() + conn := openConn(t) + defer conn.Close() - rows, err := db.Query("SHOW PARAMETERS LIKE 'TIMEZONE'") + rows, err := conn.QueryContext(context.Background(), "SHOW PARAMETERS LIKE 'TIMEZONE'") if err != nil { t.Errorf("failed to run show parameters. err: %v", err) } @@ -1416,7 +1423,7 @@ func TestLargeSetResultCancel(t *testing.T) { go func() { // attempt to run a 100 seconds query, but it should be canceled in 1 second timelimit := 100 - rows, err := dbt.db.QueryContext( + rows, err := dbt.conn.QueryContext( ctx, fmt.Sprintf("SELECT COUNT(*) FROM TABLE(GENERATOR(timelimit=>%v))", timelimit)) if err != nil { @@ -1652,6 +1659,7 @@ func TestClientSessionKeepAliveParameter(t *testing.T) { createDSNWithClientSessionKeepAlive() runTests(t, dsn, func(dbt *DBTest) { rows := dbt.mustQuery("SHOW PARAMETERS LIKE 'CLIENT_SESSION_KEEP_ALIVE'") + defer rows.Close() if !rows.Next() { t.Fatal("failed to get timezone.") } @@ -1664,8 +1672,8 @@ func TestClientSessionKeepAliveParameter(t *testing.T) { t.Fatalf("failed to get an expected client_session_keep_alive. got: %v", p.Value) } - rows = dbt.mustQuery("select count(*) from table(generator(timelimit=>30))") - defer rows.Close() + rows2 := dbt.mustQuery("select count(*) from table(generator(timelimit=>30))") + defer rows2.Close() }) } @@ -1673,6 +1681,7 @@ func TestTimePrecision(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { dbt.mustExec("create or replace table z3 (t1 time(5))") rows := dbt.mustQuery("select * from z3") + defer rows.Close() cols, err := rows.ColumnTypes() if err != nil { t.Error(err) diff --git a/file_transfer_agent_test.go b/file_transfer_agent_test.go index e589d5ecd..294407539 100644 --- a/file_transfer_agent_test.go +++ b/file_transfer_agent_test.go @@ -70,7 +70,7 @@ func TestUnitDownloadWithInvalidLocalPath(t *testing.T) { f.Close() runTests(t, dsn, func(dbt *DBTest) { - if _, err = dbt.db.Exec("use role sysadmin"); err != nil { + if _, err = dbt.exec("use role sysadmin"); err != nil { t.Skip("snowflake admin account not accessible") } dbt.mustExec("rm @~/test_get") @@ -79,7 +79,7 @@ func TestUnitDownloadWithInvalidLocalPath(t *testing.T) { dbt.mustExec(sqlText) sqlText = fmt.Sprintf("get @~/test_get/data.txt file://%v\\get", tmpDir) - if _, err = dbt.db.Query(sqlText); err == nil { + if _, err = dbt.query(sqlText); err == nil { t.Fatalf("should return local path not directory error.") } dbt.mustExec("rm @~/test_get") diff --git a/multistatement_test.go b/multistatement_test.go index bd645cdee..9bc7b53a8 100644 --- a/multistatement_test.go +++ b/multistatement_test.go @@ -50,6 +50,7 @@ func TestMultiStatementQueryResultSet(t *testing.T) { var v4 string runTests(t, dsn, func(dbt *DBTest) { rows := dbt.mustQueryContext(ctx, multiStmtQuery) + defer rows.Close() // first statement if rows.Next() { @@ -149,7 +150,8 @@ func TestMultiStatementQueryNoResultSet(t *testing.T) { c1 number, c2 string) as select 10, 'z'`) defer dbt.mustExec("drop table if exists tfmuest_multi_statement_txn") - dbt.mustQueryContext(ctx, multiStmtQuery) + rows := dbt.mustQueryContext(ctx, multiStmtQuery) + defer rows.Close() }) } @@ -330,8 +332,8 @@ func TestMultiStatementCountZero(t *testing.T) { } func TestMultiStatementCountMismatch(t *testing.T) { - db := openDB(t) - defer db.Close() + conn := openConn(t) + defer conn.Close() multiStmtQuery := "select 123;\n" + "select 456;\n" + @@ -339,7 +341,7 @@ func TestMultiStatementCountMismatch(t *testing.T) { "select '000';" ctx, _ := WithMultiStatement(context.Background(), 3) - if _, err := db.QueryContext(ctx, multiStmtQuery); err == nil { + if _, err := conn.QueryContext(ctx, multiStmtQuery); err == nil { t.Fatal("should have failed to query multiple statements") } } diff --git a/priv_key_test.go b/priv_key_test.go index c3f2646f6..adba84c91 100644 --- a/priv_key_test.go +++ b/priv_key_test.go @@ -73,20 +73,20 @@ func appendPrivateKeyString(dsn *string, key *rsa.PrivateKey) string { func TestJWTAuthentication(t *testing.T) { // For private key generated on the fly, we want to load the public key to the server first if !customPrivateKey { - db := openDB(t) + conn := openConn(t) + defer conn.Close() // Load server's public key to database pubKeyByte, err := x509.MarshalPKIXPublicKey(testPrivKey.Public()) if err != nil { t.Fatalf("error marshaling public key: %s", err.Error()) } - if _, err = db.Exec("USE ROLE ACCOUNTADMIN"); err != nil { + if _, err = conn.ExecContext(context.Background(), "USE ROLE ACCOUNTADMIN"); err != nil { t.Fatalf("error changin role: %s", err.Error()) } encodedKey := base64.StdEncoding.EncodeToString(pubKeyByte) - if _, err = db.Exec(fmt.Sprintf("ALTER USER %v set rsa_public_key='%v'", username, encodedKey)); err != nil { + if _, err = conn.ExecContext(context.Background(), fmt.Sprintf("ALTER USER %v set rsa_public_key='%v'", username, encodedKey)); err != nil { t.Fatalf("error setting server's public key: %s", err.Error()) } - db.Close() } // Test that a valid private key can pass diff --git a/put_get_test.go b/put_get_test.go index 20d18a448..d1666dc9a 100644 --- a/put_get_test.go +++ b/put_get_test.go @@ -164,7 +164,7 @@ func createTestData(dbt *DBTest) (*tcPutGetData, error) { bucket, } - if _, err = dbt.db.Exec("use role sysadmin"); err != nil { + if _, err = dbt.exec("use role sysadmin"); err != nil { return nil, err } dbt.mustExec(fmt.Sprintf( @@ -211,6 +211,7 @@ func TestPutLocalFile(t *testing.T) { var s0, s1, s2, s3, s4, s5, s6, s7, s8, s9 string rows := dbt.mustQuery("copy into gotest_putget_t1") + defer rows.Close() for rows.Next() { rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7, &s8, &s9) if s1 != "LOADED" { @@ -218,19 +219,20 @@ func TestPutLocalFile(t *testing.T) { } } - rows = dbt.mustQuery("select count(*) from gotest_putget_t1") + rows2 := dbt.mustQuery("select count(*) from gotest_putget_t1") + defer rows2.Close() var i int - if rows.Next() { - rows.Scan(&i) + if rows2.Next() { + rows2.Scan(&i) if i != 75 { t.Fatalf("expected 75 rows, got %v", i) } } - rows = dbt.mustQuery(`select STATUS from information_schema - .load_history where table_name='gotest_putget_t1'`) - if rows.Next() { - rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7, &s8, &s9) + rows3 := dbt.mustQuery(`select STATUS from information_schema .load_history where table_name='gotest_putget_t1'`) + rows3.Close() + if rows3.Next() { + rows3.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7, &s8, &s9) if s1 != "LOADED" { t.Fatal("not loaded") } @@ -257,7 +259,7 @@ func TestPutWithAutoCompressFalse(t *testing.T) { defer f.Close() runTests(t, dsn, func(dbt *DBTest) { - if _, err = dbt.db.Exec("use role sysadmin"); err != nil { + if _, err = dbt.exec("use role sysadmin"); err != nil { t.Skip("snowflake admin account not accessible") } dbt.mustExec("rm @~/test_put_uncompress_file") @@ -266,6 +268,7 @@ func TestPutWithAutoCompressFalse(t *testing.T) { dbt.mustExec(sqlText) defer dbt.mustExec("rm @~/test_put_uncompress_file") rows := dbt.mustQuery("ls @~/test_put_uncompress_file") + defer rows.Close() var file, s1, s2, s3 string if rows.Next() { if err := rows.Scan(&file, &s1, &s2, &s3); err != nil { @@ -296,7 +299,7 @@ func TestPutOverwrite(t *testing.T) { f.Close() runTests(t, dsn, func(dbt *DBTest) { - if _, err = dbt.db.Exec("use role sysadmin"); err != nil { + if _, err = dbt.exec("use role sysadmin"); err != nil { t.Skip("snowflake admin account not accessible") } dbt.mustExec("rm @~/test_put_overwrite") @@ -306,6 +309,7 @@ func TestPutOverwrite(t *testing.T) { WithFileStream(context.Background(), f), fmt.Sprintf("put 'file://%v' @~/test_put_overwrite", strings.ReplaceAll(testData, "\\", "\\\\"))) + defer rows.Close() f.Close() defer dbt.mustExec("rm @~/test_put_overwrite") var s0, s1, s2, s3, s4, s5, s6, s7 string @@ -327,6 +331,7 @@ func TestPutOverwrite(t *testing.T) { WithFileStream(ctx, f), fmt.Sprintf("put 'file://%v' @~/test_put_overwrite", strings.ReplaceAll(testData, "\\", "\\\\"))) + defer rows.Close() f.Close() if rows.Next() { if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil { @@ -353,6 +358,7 @@ func TestPutOverwrite(t *testing.T) { } rows = dbt.mustQuery("ls @~/test_put_overwrite") + defer rows.Close() if rows.Next() { if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil { t.Fatal(err) @@ -417,6 +423,7 @@ func testPutGet(t *testing.T, isStream bool) { sql, strings.ReplaceAll(fname, "\\", "\\\\"), tableName) rows = dbt.mustQuery(sqlText) } + defer rows.Close() var s0, s1, s2, s3, s4, s5, s6, s7 string if rows.Next() { @@ -439,10 +446,10 @@ func testPutGet(t *testing.T, isStream bool) { sql = fmt.Sprintf("get @%%%v 'file://%v'", tableName, tmpDir) sqlText = strings.ReplaceAll(sql, "\\", "\\\\") - rows = dbt.mustQuery(sqlText) - defer rows.Close() - for rows.Next() { - if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil { + rows2 := dbt.mustQuery(sqlText) + defer rows2.Close() + for rows2.Next() { + if err = rows2.Scan(&s0, &s1, &s2, &s3); err != nil { t.Error(err) } if !strings.HasPrefix(s0, "data_") { @@ -535,6 +542,7 @@ func TestPutGetGcsDownscopedCredential(t *testing.T) { sqlText = fmt.Sprintf( sql, strings.ReplaceAll(fname, "\\", "\\\\"), tableName) rows = dbt.mustQuery(sqlText) + defer rows.Close() var s0, s1, s2, s3, s4, s5, s6, s7 string if rows.Next() { @@ -557,10 +565,10 @@ func TestPutGetGcsDownscopedCredential(t *testing.T) { sql = fmt.Sprintf("get @%%%v 'file://%v'", tableName, tmpDir) sqlText = strings.ReplaceAll(sql, "\\", "\\\\") - rows = dbt.mustQuery(sqlText) - defer rows.Close() - for rows.Next() { - if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil { + rows2 := dbt.mustQuery(sqlText) + defer rows2.Close() + for rows2.Next() { + if err = rows2.Scan(&s0, &s1, &s2, &s3); err != nil { t.Error(err) } if !strings.HasPrefix(s0, "data_") { @@ -624,6 +632,7 @@ func TestPutLargeFile(t *testing.T) { dbt.mustExec(sqlText) defer dbt.mustExec("rm @~/test_put_largefile") rows := dbt.mustQuery("ls @~/test_put_largefile") + defer rows.Close() var file, s1, s2, s3 string if rows.Next() { if err := rows.Scan(&file, &s1, &s2, &s3); err != nil { diff --git a/put_get_user_stage_test.go b/put_get_user_stage_test.go index ef499f68d..aeef80c19 100644 --- a/put_get_user_stage_test.go +++ b/put_get_user_stage_test.go @@ -87,6 +87,7 @@ func putGetUserStage(t *testing.T, tmpDir string, numberOfFiles int, numberOfLin dbt.mustExec(fmt.Sprintf("copy into %v from @%v", dbname, stageName)) rows := dbt.mustQuery("select count(*) from " + dbname) + defer rows.Close() var cnt string if rows.Next() { rows.Scan(&cnt) @@ -134,6 +135,7 @@ func TestPutLoadFromUserStage(t *testing.T) { rows := dbt.mustQuery(fmt.Sprintf(`copy into gotest_putget_t2 from @%v file_format = (field_delimiter = '|' error_on_column_count_mismatch =false) purge=true`, data.stage)) + defer rows.Close() var s0, s1, s2, s3, s4, s5 string var s6, s7, s8, s9 interface{} orders100 := fmt.Sprintf("s3://%v/%v/orders_100.csv.gz", diff --git a/put_get_with_aws_test.go b/put_get_with_aws_test.go index 8c2e91bf1..1a43b2c67 100644 --- a/put_get_with_aws_test.go +++ b/put_get_with_aws_test.go @@ -70,6 +70,7 @@ func TestLoadS3(t *testing.T) { AWS_SECRET_KEY='%v') file_format=(skip_header=1 null_if=('') field_optionally_enclosed_by='\"')`, data.awsAccessKeyID, data.awsSecretAccessKey)) + defer rows.Close() var s0, s1, s2, s3, s4, s5, s6, s7, s8, s9 string cnt := 0 for rows.Next() { @@ -303,6 +304,7 @@ func TestPutGetAWSStage(t *testing.T) { sql := "put 'file://%v' @~/%v auto_compress=false" sqlText := fmt.Sprintf(sql, strings.ReplaceAll(fname, "\\", "\\\\"), stageName) rows := dbt.mustQuery(sqlText) + defer rows.Close() var s0, s1, s2, s3, s4, s5, s6, s7 string if rows.Next() { diff --git a/rows_test.go b/rows_test.go index 3520be40c..dabca4cf4 100644 --- a/rows_test.go +++ b/rows_test.go @@ -57,7 +57,7 @@ var ( // Special cases where rows are already closed func TestRowsClose(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { - rows, err := dbt.db.Query("SELECT 1") + rows, err := dbt.query("SELECT 1") if err != nil { dbt.Fatal(err) } @@ -77,7 +77,7 @@ func TestRowsClose(t *testing.T) { func TestResultNoRows(t *testing.T) { // DDL runTests(t, dsn, func(dbt *DBTest) { - row, err := dbt.db.Exec("CREATE OR REPLACE TABLE test(c1 int)") + row, err := dbt.exec("CREATE OR REPLACE TABLE test(c1 int)") if err != nil { t.Fatalf("failed to execute DDL. err: %v", err) } diff --git a/statement_test.go b/statement_test.go index 6ba879410..51e9f236b 100644 --- a/statement_test.go +++ b/statement_test.go @@ -18,22 +18,32 @@ func openDB(t *testing.T) *sql.DB { var err error if db, err = sql.Open("snowflake", dsn); err != nil { - t.Fatalf("failed to open db. %v, err: %v", dsn, err) + t.Fatalf("failed to open db. %v", err) } + return db } -func TestGetQueryID(t *testing.T) { - db := openDB(t) - defer db.Close() +func openConn(t *testing.T) *sql.Conn { + var db *sql.DB + var conn *sql.Conn + var err error - ctx := context.TODO() - conn, err := db.Conn(ctx) - if err != nil { - t.Error(err) + if db, err = sql.Open("snowflake", dsn); err != nil { + t.Fatalf("failed to open db. %v, err: %v", dsn, err) + } + if conn, err = db.Conn(context.Background()); err != nil { + t.Fatalf("failed to open connection: %v", err) } + return conn +} - if err = conn.Raw(func(x interface{}) error { +func TestGetQueryID(t *testing.T) { + ctx := context.Background() + conn := openConn(t) + defer conn.Close() + + if err := conn.Raw(func(x interface{}) error { rows, err := x.(driver.QueryerContext).QueryContext(ctx, "select 1", nil) if err != nil { return err @@ -150,6 +160,7 @@ func TestWithDescribeOnly(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { ctx := WithDescribeOnly(context.Background()) rows := dbt.mustQueryContext(ctx, selectVariousTypes) + defer rows.Close() cols, err := rows.Columns() if err != nil { t.Error(err) @@ -178,7 +189,7 @@ func TestCallStatement(t *testing.T) { expected := "1 \"[2,3]\" [2,3]" var out string - dbt.db.Exec("ALTER SESSION SET USE_STATEMENT_TYPE_CALL_FOR_STORED_PROC_CALLS = true") + dbt.exec("ALTER SESSION SET USE_STATEMENT_TYPE_CALL_FOR_STORED_PROC_CALLS = true") dbt.mustExec("create or replace procedure " + "TEST_SP_CALL_STMT_ENABLED(in1 float, in2 variant) " + @@ -188,7 +199,7 @@ func TestCallStatement(t *testing.T) { "return res.getColumnValueAsString(1) + ' ' + res.getColumnValueAsString(2) + ' ' + IN2; " + "$$;") - stmt, err := dbt.db.Prepare("call TEST_SP_CALL_STMT_ENABLED(?, to_variant(?))") + stmt, err := dbt.conn.PrepareContext(context.Background(), "call TEST_SP_CALL_STMT_ENABLED(?, to_variant(?))") if err != nil { dbt.Errorf("failed to prepare query: %v", err) } @@ -207,19 +218,15 @@ func TestCallStatement(t *testing.T) { } func TestStmtExec(t *testing.T) { - db := openDB(t) - defer db.Close() + ctx := context.Background() + conn := openConn(t) + defer conn.Close() - if _, err := db.Exec(`create or replace table test_table(col1 int, col2 int)`); err != nil { + if _, err := conn.ExecContext(ctx, `create or replace table test_table(col1 int, col2 int)`); err != nil { t.Fatalf("failed to create table: %v", err) } - ctx := context.Background() - conn, err := db.Conn(ctx) - if err != nil { - t.Error(err) - } - if err = conn.Raw(func(x interface{}) error { + if err := conn.Raw(func(x interface{}) error { stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "insert into test_table values (1, 2)") if err != nil { t.Error(err) @@ -237,7 +244,7 @@ func TestStmtExec(t *testing.T) { t.Fatalf("failed to drop table: %v", err) } - if _, err := db.Exec("drop table if exists test_table"); err != nil { + if _, err := conn.ExecContext(ctx, "drop table if exists test_table"); err != nil { t.Fatalf("failed to drop table: %v", err) } } diff --git a/transaction_test.go b/transaction_test.go index e32aa7b8f..5027e47f0 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -15,23 +15,23 @@ func TestTransactionOptions(t *testing.T) { var tx *sql.Tx var err error - db := openDB(t) - defer db.Close() + conn := openConn(t) + defer conn.Close() - tx, err = db.BeginTx(context.Background(), &sql.TxOptions{}) + tx, err = conn.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { t.Fatal("failed to start transaction.") } if err = tx.Rollback(); err != nil { t.Fatal("failed to rollback") } - if _, err = db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}); err == nil { + if _, err = conn.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}); err == nil { t.Fatal("should have failed.") } if driverErr, ok := err.(*SnowflakeError); !ok || driverErr.Number != ErrNoReadOnlyTransaction { t.Fatalf("should have returned Snowflake Error: %v", errMsgNoReadOnlyTransaction) } - if _, err = db.BeginTx(context.Background(), &sql.TxOptions{Isolation: 100}); err == nil { + if _, err = conn.BeginTx(context.Background(), &sql.TxOptions{Isolation: 100}); err == nil { t.Fatal("should have failed.") } if driverErr, ok := err.(*SnowflakeError); !ok || driverErr.Number != ErrNoDefaultTransactionIsolationLevel { @@ -44,25 +44,25 @@ func TestTransactionContext(t *testing.T) { var tx *sql.Tx var err error - db := openDB(t) - defer db.Close() + conn := openConn(t) + defer conn.Close() - ctx2 := context.Background() + ctx := context.Background() pingWithRetry := withRetry(PingFunc, 5, 3*time.Second) - err = pingWithRetry(context.Background(), db) + err = pingWithRetry(context.Background(), conn) if err != nil { t.Fatal(err) } - tx, err = db.BeginTx(ctx2, nil) + tx, err = conn.BeginTx(ctx, nil) if err != nil { t.Fatal(err) } defer tx.Rollback() - _, err = tx.ExecContext(ctx2, "SELECT SYSTEM$WAIT(10, 'SECONDS')") + _, err = tx.ExecContext(ctx, "SELECT SYSTEM$WAIT(10, 'SECONDS')") if err != nil { t.Fatal(err) } @@ -71,17 +71,15 @@ func TestTransactionContext(t *testing.T) { if err != nil { t.Fatal(err) } - - defer db.Close() } -func PingFunc(ctx context.Context, db *sql.DB) error { - return db.PingContext(ctx) +func PingFunc(ctx context.Context, conn *sql.Conn) error { + return conn.PingContext(ctx) } // Helper function for SNOW-823072 repro -func withRetry(fn func(context.Context, *sql.DB) error, numAttempts int, timeout time.Duration) func(context.Context, *sql.DB) error { - return func(ctx context.Context, db *sql.DB) error { +func withRetry(fn func(context.Context, *sql.Conn) error, numAttempts int, timeout time.Duration) func(context.Context, *sql.Conn) error { + return func(ctx context.Context, db *sql.Conn) error { for currAttempt := 1; currAttempt <= numAttempts; currAttempt++ { ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel()