From bdbc328552f24b121df487a7f6d9ed0e8f5ad96d Mon Sep 17 00:00:00 2001 From: Piotr Fus Date: Mon, 24 Jul 2023 14:17:32 +0200 Subject: [PATCH] SNOW-857631 Handle multistatement query type --- arrow_test.go | 21 +++++++++++++++++++++ connection.go | 3 ++- connection_util.go | 4 ++-- multistatement_test.go | 6 ++---- 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/arrow_test.go b/arrow_test.go index 1234a6765..cca07603d 100644 --- a/arrow_test.go +++ b/arrow_test.go @@ -13,6 +13,27 @@ import ( "time" ) +//A test just to show Snowflake version +func TestCheckVersion(t *testing.T) { + conn := openConn(t) + defer conn.Close() + + rows, err := conn.QueryContext(context.Background(), "SELECT current_version()") + if err != nil { + t.Error(err) + } + defer rows.Close() + + if !rows.Next() { + t.Fatalf("failed to find any row") + } + var s string + if err = rows.Scan(&s); err != nil { + t.Fatal(err) + } + println(s) +} + func TestArrowBigInt(t *testing.T) { conn := openConn(t) defer conn.Close() diff --git a/connection.go b/connection.go index 7321838c2..4638175e9 100644 --- a/connection.go +++ b/connection.go @@ -37,9 +37,10 @@ const ( ) const ( - statementTypeIDMulti = int64(0x1000) + statementTypeIDSelect = int64(0x1000) statementTypeIDDml = int64(0x3000) statementTypeIDMultiTableInsert = statementTypeIDDml + int64(0x500) + statementTypeIDMultistatement = int64(0xA000) ) const ( diff --git a/connection_util.go b/connection_util.go index c3e4cb60a..737a025dd 100644 --- a/connection_util.go +++ b/connection_util.go @@ -213,8 +213,8 @@ func updateRows(data execResponseData) (int64, error) { // Note that the statement type code is also equivalent to type INSERT, so an // additional check of the name is required func isMultiStmt(data *execResponseData) bool { - return data.StatementTypeID == statementTypeIDMulti && - data.RowType[0].Name == "multiple statement execution" + var isMultistatementByReturningSelect = data.StatementTypeID == statementTypeIDSelect && data.RowType[0].Name == "multiple statement execution" + return isMultistatementByReturningSelect || data.StatementTypeID == statementTypeIDMultistatement } func getResumeQueryID(ctx context.Context) (string, error) { diff --git a/multistatement_test.go b/multistatement_test.go index 8086e81b5..ce0d713bf 100644 --- a/multistatement_test.go +++ b/multistatement_test.go @@ -23,10 +23,7 @@ func TestMultiStatementExecuteNoResultSet(t *testing.T) { "commit;" runDBTest(t, func(dbt *DBTest) { - dbt.mustExec("drop table if exists test_multi_statement_txn") - dbt.mustExec(`create or replace table test_multi_statement_txn( - c1 number, c2 string) as select 10, 'z'`) - defer dbt.mustExec("drop table if exists test_multi_statement_txn") + dbt.mustExec(`create or replace table test_multi_statement_txn(c1 number, c2 string) as select 10, 'z'`) res := dbt.mustExecContext(ctx, multiStmtQuery) count, err := res.RowsAffected() @@ -48,6 +45,7 @@ func TestMultiStatementQueryResultSet(t *testing.T) { var v1, v2, v3 int64 var v4 string + runDBTest(t, func(dbt *DBTest) { rows := dbt.mustQueryContext(ctx, multiStmtQuery) defer rows.Close()