Skip to content

Commit

Permalink
SNOW-895537: Send query context with request (#904)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus committed Sep 12, 2023
1 parent 213a5ac commit 1043498
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 110 deletions.
26 changes: 26 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,18 @@ func (sc *snowflakeConn) exec(
var err error
counter := atomic.AddUint64(&sc.SequenceCounter, 1) // query sequence counter

queryContext, err := buildQueryContext(sc.queryContextCache)
if err != nil {
logger.Errorf("error while building query context: %v", err)
}
req := execRequest{
SQLText: query,
AsyncExec: noResult,
Parameters: map[string]interface{}{},
IsInternal: isInternal,
DescribeOnly: describeOnly,
SequenceID: counter,
QueryContext: queryContext,
}
if key := ctx.Value(multiStatementCount); key != nil {
req.Parameters[string(multiStatementCount)] = key
Expand Down Expand Up @@ -173,6 +178,27 @@ func extractQueryContext(data *execResponse) (queryContext, error) {
return queryContext, err
}

func buildQueryContext(qcc *queryContextCache) (requestQueryContext, error) {
rqc := requestQueryContext{}
if qcc == nil || len(qcc.entries) == 0 {
logger.Debugf("empty qcc")
return rqc, nil
}
for _, qce := range qcc.entries {
contextData := contextData{}
if qce.Context == "" {
contextData.Base64Data = qce.Context
}
rqc.Entries = append(rqc.Entries, requestQueryContextEntry{
ID: qce.ID,
Priority: qce.Priority,
Timestamp: qce.Timestamp,
Context: contextData,
})
}
return rqc, nil
}

func (sc *snowflakeConn) Begin() (driver.Tx, error) {
return sc.BeginTx(sc.ctx, driver.TxOptions{})
}
Expand Down
1 change: 1 addition & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ func runSnowflakeConnTest(t *testing.T, test func(sct *SCTest)) {
}

sct := &SCTest{t, sc}

test(sct)
}

Expand Down
8 changes: 4 additions & 4 deletions htap.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ type queryContext struct {
}

type queryContextEntry struct {
ID int `json:"id"`
Timestamp int64 `json:"timestamp"`
Priority int `json:"priority"`
Context any `json:"context,omitempty"`
ID int `json:"id"`
Timestamp int64 `json:"timestamp"`
Priority int `json:"priority"`
Context string `json:"context,omitempty"`
}

type queryContextCache struct {
Expand Down
195 changes: 89 additions & 106 deletions htap_test.go
Original file line number Diff line number Diff line change
@@ -1,106 +1,13 @@
package gosnowflake

import (
"encoding/json"
"database/sql/driver"
"fmt"
"reflect"
"strings"
"testing"
"time"
)

func TestMarshallAndDecodeOpaqueContext(t *testing.T) {
testcases := []struct {
json string
qc queryContextEntry
}{
{
json: `{
"id": 1,
"timestamp": 2,
"priority": 3
}`,
qc: queryContextEntry{1, 2, 3, nil},
},
{
json: `{
"id": 1,
"timestamp": 2,
"priority": 3,
"context": "abc"
}`,
qc: queryContextEntry{1, 2, 3, "abc"},
},
{
json: `{
"id": 1,
"timestamp": 2,
"priority": 3,
"context": {
"val": "abc"
}
}`,
qc: queryContextEntry{1, 2, 3, map[string]interface{}{"val": "abc"}},
},
{
json: `{
"id": 1,
"timestamp": 2,
"priority": 3,
"context": [
"abc"
]
}`,
qc: queryContextEntry{1, 2, 3, []any{"abc"}},
},
{
json: `{
"id": 1,
"timestamp": 2,
"priority": 3,
"context": [
{
"val": "abc"
}
]
}`,
qc: queryContextEntry{1, 2, 3, []any{map[string]interface{}{"val": "abc"}}},
},
}

for _, tc := range testcases {
t.Run(trimWhitespaces(tc.json), func(t *testing.T) {
var qc queryContextEntry

err := json.NewDecoder(strings.NewReader(tc.json)).Decode(&qc)
if err != nil {
t.Fatalf("failed to decode json. %v", err)
}

if !reflect.DeepEqual(tc.qc, qc) {
t.Errorf("failed to decode json. expected: %v, got: %v", tc.qc, qc)
}

bytes, err := json.Marshal(qc)
if err != nil {
t.Fatalf("failed to encode json. %v", err)
}

resultJSON := string(bytes)
if resultJSON != trimWhitespaces(tc.json) {
t.Errorf("failed to encode json. epxected: %v, got: %v", trimWhitespaces(tc.json), resultJSON)
}
})
}
}

func trimWhitespaces(s string) string {
return strings.ReplaceAll(
strings.ReplaceAll(
strings.ReplaceAll(s, "\t", ""),
" ", ""),
"\n", "",
)
}

func TestSortingByPriority(t *testing.T) {
qcc := (&queryContextCache{}).init()
sc := htapTestSnowflakeConn()
Expand Down Expand Up @@ -302,9 +209,9 @@ func containsNewEntries(entriesAfter []queryContextEntry, entriesBefore []queryC
}

func TestPruneBySessionValue(t *testing.T) {
qce1 := queryContextEntry{1, 1, 1, nil}
qce2 := queryContextEntry{2, 2, 2, nil}
qce3 := queryContextEntry{3, 3, 3, nil}
qce1 := queryContextEntry{1, 1, 1, ""}
qce2 := queryContextEntry{2, 2, 2, ""}
qce3 := queryContextEntry{3, 3, 3, ""}

testcases := []struct {
size string
Expand Down Expand Up @@ -352,12 +259,12 @@ func TestPruneBySessionValue(t *testing.T) {
}

func TestPruneByDefaultValue(t *testing.T) {
qce1 := queryContextEntry{1, 1, 1, nil}
qce2 := queryContextEntry{2, 2, 2, nil}
qce3 := queryContextEntry{3, 3, 3, nil}
qce4 := queryContextEntry{4, 4, 4, nil}
qce5 := queryContextEntry{5, 5, 5, nil}
qce6 := queryContextEntry{6, 6, 6, nil}
qce1 := queryContextEntry{1, 1, 1, ""}
qce2 := queryContextEntry{2, 2, 2, ""}
qce3 := queryContextEntry{3, 3, 3, ""}
qce4 := queryContextEntry{4, 4, 4, ""}
qce5 := queryContextEntry{5, 5, 5, ""}
qce6 := queryContextEntry{6, 6, 6, ""}

sc := &snowflakeConn{
cfg: &Config{
Expand All @@ -383,7 +290,7 @@ func TestPruneByDefaultValue(t *testing.T) {
}

func TestNoQcesClearsCache(t *testing.T) {
qce1 := queryContextEntry{1, 1, 1, nil}
qce1 := queryContextEntry{1, 1, 1, ""}

sc := &snowflakeConn{
cfg: &Config{
Expand Down Expand Up @@ -426,3 +333,79 @@ func TestQueryContextCacheDisabled(t *testing.T) {
}
})
}

func TestHybridTablesE2E(t *testing.T) {
if runningOnGithubAction() && !runningOnAWS() {
t.Skip("HTAP is enabled only on AWS")
}
runID := time.Now().UnixMilli()
testDb1 := fmt.Sprintf("hybrid_db_test_%v", runID)
testDb2 := fmt.Sprintf("hybrid_db_test_%v_2", runID)
runSnowflakeConnTest(t, func(sct *SCTest) {
dbQuery := sct.mustQuery("SELECT CURRENT_DATABASE()", nil)
defer dbQuery.Close()
currentDb := make([]driver.Value, 1)
dbQuery.Next(currentDb)
defer func() {
sct.mustExec(fmt.Sprintf("USE DATABASE %v", currentDb[0]), nil)
sct.mustExec(fmt.Sprintf("DROP DATABASE IF EXISTS %v", testDb1), nil)
sct.mustExec(fmt.Sprintf("DROP DATABASE IF EXISTS %v", testDb2), nil)
}()

t.Run("Run tests on first database", func(t *testing.T) {
sct.mustExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %v", testDb1), nil)
sct.mustExec("CREATE HYBRID TABLE test_hybrid_table (id INT PRIMARY KEY, text VARCHAR)", nil)

sct.mustExec("INSERT INTO test_hybrid_table VALUES (1, 'a')", nil)
rows := sct.mustQuery("SELECT * FROM test_hybrid_table", nil)
defer rows.Close()
row := make([]driver.Value, 2)
rows.Next(row)
if row[0] != "1" || row[1] != "a" {
t.Errorf("expected 1, got %v and expected a, got %v", row[0], row[1])
}

sct.mustExec("INSERT INTO test_hybrid_table VALUES (2, 'b')", nil)
rows2 := sct.mustQuery("SELECT * FROM test_hybrid_table", nil)
defer rows2.Close()
rows2.Next(row)
if row[0] != "1" || row[1] != "a" {
t.Errorf("expected 1, got %v and expected a, got %v", row[0], row[1])
}
rows2.Next(row)
if row[0] != "2" || row[1] != "b" {
t.Errorf("expected 2, got %v and expected b, got %v", row[0], row[1])
}
if len(sct.sc.queryContextCache.entries) != 2 {
t.Errorf("expected two entries in query context cache, got: %v", sct.sc.queryContextCache.entries)
}
})
t.Run("Run tests on second database", func(t *testing.T) {
sct.mustExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %v", testDb2), nil)
sct.mustExec("CREATE HYBRID TABLE test_hybrid_table_2 (id INT PRIMARY KEY, text VARCHAR)", nil)
sct.mustExec("INSERT INTO test_hybrid_table_2 VALUES (3, 'c')", nil)

rows := sct.mustQuery("SELECT * FROM test_hybrid_table_2", nil)
defer rows.Close()
row := make([]driver.Value, 2)
rows.Next(row)
if row[0] != "3" || row[1] != "c" {
t.Errorf("expected 3, got %v and expected c, got %v", row[0], row[1])
}
if len(sct.sc.queryContextCache.entries) != 3 {
t.Errorf("expected three entries in query context cache, got: %v", sct.sc.queryContextCache.entries)
}
})
t.Run("Run tests on first database again", func(t *testing.T) {
sct.mustExec(fmt.Sprintf("USE DATABASE %v", testDb1), nil)

sct.mustExec("INSERT INTO test_hybrid_table VALUES (4, 'd')", nil)

rows := sct.mustQuery("SELECT * FROM test_hybrid_table", nil)
defer rows.Close()
if len(sct.sc.queryContextCache.entries) != 3 {
t.Errorf("expected three entries in query context cache, got: %v", sct.sc.queryContextCache.entries)
}
})
})
}
16 changes: 16 additions & 0 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,22 @@ type execRequest struct {
Parameters map[string]interface{} `json:"parameters,omitempty"`
Bindings map[string]execBindParameter `json:"bindings,omitempty"`
BindStage string `json:"bindStage,omitempty"`
QueryContext requestQueryContext `json:"queryContextDTO,omitempty"`
}

type requestQueryContext struct {
Entries []requestQueryContextEntry `json:"entries,omitempty"`
}

type requestQueryContextEntry struct {
Context contextData `json:"context,omitempty"`
ID int `json:"id"`
Priority int `json:"priority"`
Timestamp int64 `json:"timestamp,omitempty"`
}

type contextData struct {
Base64Data string `json:"base64Data,omitempty"`
}

type execResponseRowType struct {
Expand Down

0 comments on commit 1043498

Please sign in to comment.