diff --git a/connection.go b/connection.go index 4ea8dbd77..7321838c2 100644 --- a/connection.go +++ b/connection.go @@ -85,6 +85,10 @@ func (sc *snowflakeConn) exec( var err error counter := atomic.AddUint64(&sc.SequenceCounter, 1) // query sequence counter + queryContext, err := buildQueryContext(sc.queryContextCache) + if err != nil { + logger.Errorf("error while building query context: %v", err) + } req := execRequest{ SQLText: query, AsyncExec: noResult, @@ -92,6 +96,7 @@ func (sc *snowflakeConn) exec( IsInternal: isInternal, DescribeOnly: describeOnly, SequenceID: counter, + QueryContext: queryContext, } if key := ctx.Value(multiStatementCount); key != nil { req.Parameters[string(multiStatementCount)] = key @@ -173,6 +178,27 @@ func extractQueryContext(data *execResponse) (queryContext, error) { return queryContext, err } +func buildQueryContext(qcc *queryContextCache) (requestQueryContext, error) { + rqc := requestQueryContext{} + if qcc == nil || len(qcc.entries) == 0 { + logger.Debugf("empty qcc") + return rqc, nil + } + for _, qce := range qcc.entries { + contextData := contextData{} + if qce.Context == "" { + contextData.Base64Data = qce.Context + } + rqc.Entries = append(rqc.Entries, requestQueryContextEntry{ + ID: qce.ID, + Priority: qce.Priority, + Timestamp: qce.Timestamp, + Context: contextData, + }) + } + return rqc, nil +} + func (sc *snowflakeConn) Begin() (driver.Tx, error) { return sc.BeginTx(sc.ctx, driver.TxOptions{}) } diff --git a/driver_test.go b/driver_test.go index fadbe3a55..1d6662ebe 100644 --- a/driver_test.go +++ b/driver_test.go @@ -387,6 +387,7 @@ func runSnowflakeConnTest(t *testing.T, test func(sct *SCTest)) { } sct := &SCTest{t, sc} + test(sct) } diff --git a/htap.go b/htap.go index d8730be36..93806801f 100644 --- a/htap.go +++ b/htap.go @@ -16,10 +16,10 @@ type queryContext struct { } type queryContextEntry struct { - ID int `json:"id"` - Timestamp int64 `json:"timestamp"` - Priority int `json:"priority"` - Context any `json:"context,omitempty"` + ID int `json:"id"` + Timestamp int64 `json:"timestamp"` + Priority int `json:"priority"` + Context string `json:"context,omitempty"` } type queryContextCache struct { diff --git a/htap_test.go b/htap_test.go index 8aae85810..a724424f7 100644 --- a/htap_test.go +++ b/htap_test.go @@ -1,106 +1,13 @@ package gosnowflake import ( - "encoding/json" + "database/sql/driver" + "fmt" "reflect" - "strings" "testing" + "time" ) -func TestMarshallAndDecodeOpaqueContext(t *testing.T) { - testcases := []struct { - json string - qc queryContextEntry - }{ - { - json: `{ - "id": 1, - "timestamp": 2, - "priority": 3 - }`, - qc: queryContextEntry{1, 2, 3, nil}, - }, - { - json: `{ - "id": 1, - "timestamp": 2, - "priority": 3, - "context": "abc" - }`, - qc: queryContextEntry{1, 2, 3, "abc"}, - }, - { - json: `{ - "id": 1, - "timestamp": 2, - "priority": 3, - "context": { - "val": "abc" - } - }`, - qc: queryContextEntry{1, 2, 3, map[string]interface{}{"val": "abc"}}, - }, - { - json: `{ - "id": 1, - "timestamp": 2, - "priority": 3, - "context": [ - "abc" - ] - }`, - qc: queryContextEntry{1, 2, 3, []any{"abc"}}, - }, - { - json: `{ - "id": 1, - "timestamp": 2, - "priority": 3, - "context": [ - { - "val": "abc" - } - ] - }`, - qc: queryContextEntry{1, 2, 3, []any{map[string]interface{}{"val": "abc"}}}, - }, - } - - for _, tc := range testcases { - t.Run(trimWhitespaces(tc.json), func(t *testing.T) { - var qc queryContextEntry - - err := json.NewDecoder(strings.NewReader(tc.json)).Decode(&qc) - if err != nil { - t.Fatalf("failed to decode json. %v", err) - } - - if !reflect.DeepEqual(tc.qc, qc) { - t.Errorf("failed to decode json. expected: %v, got: %v", tc.qc, qc) - } - - bytes, err := json.Marshal(qc) - if err != nil { - t.Fatalf("failed to encode json. %v", err) - } - - resultJSON := string(bytes) - if resultJSON != trimWhitespaces(tc.json) { - t.Errorf("failed to encode json. epxected: %v, got: %v", trimWhitespaces(tc.json), resultJSON) - } - }) - } -} - -func trimWhitespaces(s string) string { - return strings.ReplaceAll( - strings.ReplaceAll( - strings.ReplaceAll(s, "\t", ""), - " ", ""), - "\n", "", - ) -} - func TestSortingByPriority(t *testing.T) { qcc := (&queryContextCache{}).init() sc := htapTestSnowflakeConn() @@ -302,9 +209,9 @@ func containsNewEntries(entriesAfter []queryContextEntry, entriesBefore []queryC } func TestPruneBySessionValue(t *testing.T) { - qce1 := queryContextEntry{1, 1, 1, nil} - qce2 := queryContextEntry{2, 2, 2, nil} - qce3 := queryContextEntry{3, 3, 3, nil} + qce1 := queryContextEntry{1, 1, 1, ""} + qce2 := queryContextEntry{2, 2, 2, ""} + qce3 := queryContextEntry{3, 3, 3, ""} testcases := []struct { size string @@ -352,12 +259,12 @@ func TestPruneBySessionValue(t *testing.T) { } func TestPruneByDefaultValue(t *testing.T) { - qce1 := queryContextEntry{1, 1, 1, nil} - qce2 := queryContextEntry{2, 2, 2, nil} - qce3 := queryContextEntry{3, 3, 3, nil} - qce4 := queryContextEntry{4, 4, 4, nil} - qce5 := queryContextEntry{5, 5, 5, nil} - qce6 := queryContextEntry{6, 6, 6, nil} + qce1 := queryContextEntry{1, 1, 1, ""} + qce2 := queryContextEntry{2, 2, 2, ""} + qce3 := queryContextEntry{3, 3, 3, ""} + qce4 := queryContextEntry{4, 4, 4, ""} + qce5 := queryContextEntry{5, 5, 5, ""} + qce6 := queryContextEntry{6, 6, 6, ""} sc := &snowflakeConn{ cfg: &Config{ @@ -383,7 +290,7 @@ func TestPruneByDefaultValue(t *testing.T) { } func TestNoQcesClearsCache(t *testing.T) { - qce1 := queryContextEntry{1, 1, 1, nil} + qce1 := queryContextEntry{1, 1, 1, ""} sc := &snowflakeConn{ cfg: &Config{ @@ -426,3 +333,79 @@ func TestQueryContextCacheDisabled(t *testing.T) { } }) } + +func TestHybridTablesE2E(t *testing.T) { + if runningOnGithubAction() && !runningOnAWS() { + t.Skip("HTAP is enabled only on AWS") + } + runID := time.Now().UnixMilli() + testDb1 := fmt.Sprintf("hybrid_db_test_%v", runID) + testDb2 := fmt.Sprintf("hybrid_db_test_%v_2", runID) + runSnowflakeConnTest(t, func(sct *SCTest) { + dbQuery := sct.mustQuery("SELECT CURRENT_DATABASE()", nil) + defer dbQuery.Close() + currentDb := make([]driver.Value, 1) + dbQuery.Next(currentDb) + defer func() { + sct.mustExec(fmt.Sprintf("USE DATABASE %v", currentDb[0]), nil) + sct.mustExec(fmt.Sprintf("DROP DATABASE IF EXISTS %v", testDb1), nil) + sct.mustExec(fmt.Sprintf("DROP DATABASE IF EXISTS %v", testDb2), nil) + }() + + t.Run("Run tests on first database", func(t *testing.T) { + sct.mustExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %v", testDb1), nil) + sct.mustExec("CREATE HYBRID TABLE test_hybrid_table (id INT PRIMARY KEY, text VARCHAR)", nil) + + sct.mustExec("INSERT INTO test_hybrid_table VALUES (1, 'a')", nil) + rows := sct.mustQuery("SELECT * FROM test_hybrid_table", nil) + defer rows.Close() + row := make([]driver.Value, 2) + rows.Next(row) + if row[0] != "1" || row[1] != "a" { + t.Errorf("expected 1, got %v and expected a, got %v", row[0], row[1]) + } + + sct.mustExec("INSERT INTO test_hybrid_table VALUES (2, 'b')", nil) + rows2 := sct.mustQuery("SELECT * FROM test_hybrid_table", nil) + defer rows2.Close() + rows2.Next(row) + if row[0] != "1" || row[1] != "a" { + t.Errorf("expected 1, got %v and expected a, got %v", row[0], row[1]) + } + rows2.Next(row) + if row[0] != "2" || row[1] != "b" { + t.Errorf("expected 2, got %v and expected b, got %v", row[0], row[1]) + } + if len(sct.sc.queryContextCache.entries) != 2 { + t.Errorf("expected two entries in query context cache, got: %v", sct.sc.queryContextCache.entries) + } + }) + t.Run("Run tests on second database", func(t *testing.T) { + sct.mustExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %v", testDb2), nil) + sct.mustExec("CREATE HYBRID TABLE test_hybrid_table_2 (id INT PRIMARY KEY, text VARCHAR)", nil) + sct.mustExec("INSERT INTO test_hybrid_table_2 VALUES (3, 'c')", nil) + + rows := sct.mustQuery("SELECT * FROM test_hybrid_table_2", nil) + defer rows.Close() + row := make([]driver.Value, 2) + rows.Next(row) + if row[0] != "3" || row[1] != "c" { + t.Errorf("expected 3, got %v and expected c, got %v", row[0], row[1]) + } + if len(sct.sc.queryContextCache.entries) != 3 { + t.Errorf("expected three entries in query context cache, got: %v", sct.sc.queryContextCache.entries) + } + }) + t.Run("Run tests on first database again", func(t *testing.T) { + sct.mustExec(fmt.Sprintf("USE DATABASE %v", testDb1), nil) + + sct.mustExec("INSERT INTO test_hybrid_table VALUES (4, 'd')", nil) + + rows := sct.mustQuery("SELECT * FROM test_hybrid_table", nil) + defer rows.Close() + if len(sct.sc.queryContextCache.entries) != 3 { + t.Errorf("expected three entries in query context cache, got: %v", sct.sc.queryContextCache.entries) + } + }) + }) +} diff --git a/query.go b/query.go index ffd3983cd..300233d0e 100644 --- a/query.go +++ b/query.go @@ -28,6 +28,22 @@ type execRequest struct { Parameters map[string]interface{} `json:"parameters,omitempty"` Bindings map[string]execBindParameter `json:"bindings,omitempty"` BindStage string `json:"bindStage,omitempty"` + QueryContext requestQueryContext `json:"queryContextDTO,omitempty"` +} + +type requestQueryContext struct { + Entries []requestQueryContextEntry `json:"entries,omitempty"` +} + +type requestQueryContextEntry struct { + Context contextData `json:"context,omitempty"` + ID int `json:"id"` + Priority int `json:"priority"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +type contextData struct { + Base64Data string `json:"base64Data,omitempty"` } type execResponseRowType struct {