Skip to content

Commit

Permalink
Allow executing SET commands
Browse files Browse the repository at this point in the history
  • Loading branch information
exAspArk committed Nov 19, 2024
1 parent bc4f518 commit e3cdfcc
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 7 deletions.
27 changes: 27 additions & 0 deletions src/custom_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,33 @@ func (orderedMap *OrderedMap) Values() []string {

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

type Set struct {
valueByItem map[string]bool
}

func NewSet(items []string) *Set {
set := &Set{
valueByItem: make(map[string]bool),
}

for _, item := range items {
set.Add(item)
}

return set
}

func (set *Set) Add(item string) {
set.valueByItem[item] = true
}

func (set Set) Contains(item string) bool {
_, ok := set.valueByItem[item]
return ok
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

type SchemaTable struct {
Schema string
Table string
Expand Down
12 changes: 5 additions & 7 deletions src/parser.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
package main

import (
"slices"

pgQuery "github.com/pganalyze/pg_query_go/v5"
)

var PG_SYSTEM_TABLES = []string{
var PG_SYSTEM_TABLES = NewSet([]string{
"pg_aggregate",
"pg_am",
"pg_amop",
Expand Down Expand Up @@ -71,9 +69,9 @@ var PG_SYSTEM_TABLES = []string{
"pg_ts_template",
"pg_type",
"pg_user_mapping",
}
})

var PG_SYSTEM_VIEWS = []string{
var PG_SYSTEM_VIEWS = NewSet([]string{
"pg_stat_activity",
"pg_stat_replication",
"pg_stat_wal_receiver",
Expand Down Expand Up @@ -117,7 +115,7 @@ var PG_SYSTEM_VIEWS = []string{
"pg_statio_all_sequences",
"pg_statio_sys_sequences",
"pg_statio_user_sequences",
}
})

var PG_INFORMATION_SCHEMA_TABLES_VALUE_BY_COLUMN = NewOrderedMap([][]string{
{"table_catalog", "bemidb"},
Expand Down Expand Up @@ -149,7 +147,7 @@ var PG_STATIO_USER_TABLES_VALUE_BY_COLUMN = NewOrderedMap([][]string{
})

func IsSystemTable(table string) bool {
return slices.Contains(PG_SYSTEM_TABLES, table) || slices.Contains(PG_SYSTEM_VIEWS, table)
return PG_SYSTEM_TABLES.Contains(table) || PG_SYSTEM_VIEWS.Contains(table)
}

func RawSelectColumns(selectStatement *pgQuery.SelectStmt) []string {
Expand Down
5 changes: 5 additions & 0 deletions src/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,11 @@ func (proxy *Proxy) remapQuery(query string) (string, error) {
return pgQuery.Deparse(queryTree)
}

if statementNode != nil && statementNode.GetVariableSetStmt() != nil {
queryTree = proxy.selectRemapper.RemapQueryTreeWithSet(queryTree)
return pgQuery.Deparse(queryTree)
}

LogDebug(proxy.config, queryTree)
return "", errors.New("Unsupported query type")
}
Expand Down
5 changes: 5 additions & 0 deletions src/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ func TestHandleQuery(t *testing.T) {
"description": {"table_catalog", "table_schema", "table"},
"values": {"bemidb", "public", "test_table"},
},
// SET
"SET client_encoding TO 'UTF8'": {
"description": {"Success"},
"values": {},
},
// Iceberg data
"SELECT COUNT(*) AS count FROM public.test_table": {
"description": {"count"},
Expand Down
23 changes: 23 additions & 0 deletions src/select_remapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ var REMAPPED_CONSTANT_BY_PG_FUNCTION_NAME = map[string]string{
"pg_indexes_size": "0",
}

var KNOWN_SET_STATEMENTS = NewSet([]string{
"client_encoding", // SET client_encoding TO 'UTF8'
"client_min_messages", // SET client_min_messages TO 'warning'
"standard_conforming_strings", // SET standard_conforming_strings = on
"intervalstyle", // SET intervalstyle = iso_8601
})

type SelectRemapper struct {
icebergReader *IcebergReader
config *Config
Expand All @@ -39,6 +46,22 @@ func (selectRemapper *SelectRemapper) RemapQueryTreeWithSelect(queryTree *pgQuer
return queryTree
}

// No-op
func (selectRemapper *SelectRemapper) RemapQueryTreeWithSet(queryTree *pgQuery.ParseResult) *pgQuery.ParseResult {
setStatement := queryTree.Stmts[0].Stmt.GetVariableSetStmt()

if !KNOWN_SET_STATEMENTS.Contains(setStatement.Name) {
LogError(selectRemapper.config, "Unsupported SET ", setStatement.Name, ":", setStatement)
}

queryTree.Stmts[0].Stmt.GetVariableSetStmt().Name = "schema"
queryTree.Stmts[0].Stmt.GetVariableSetStmt().Args = []*pgQuery.Node{
pgQuery.MakeAConstStrNode("main", 0),
}

return queryTree
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

func (selectRemapper *SelectRemapper) remapSelectStatement(selectStatement *pgQuery.SelectStmt, indentLevel int) *pgQuery.SelectStmt {
Expand Down

0 comments on commit e3cdfcc

Please sign in to comment.