Skip to content

Commit

Permalink
Add pg_backend_pid function support
Browse files Browse the repository at this point in the history
  • Loading branch information
arjunlol committed Dec 17, 2024
1 parent 527bf5e commit 3987690
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 195 deletions.
96 changes: 50 additions & 46 deletions src/query_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ func TestHandleQuery(t *testing.T) {
"description": {"pg_encoding_to_char"},
"values": {"UTF8"},
},
"SELECT pg_backend_pid()": {
"description": {"pg_backend_pid"},
"values": {"0"},
},
// PG system tables
"SELECT oid, typname AS typename FROM pg_type WHERE typname='geometry' OR typname='geography'": {
"description": {"oid", "typename"},
Expand Down Expand Up @@ -670,70 +674,70 @@ func TestHandleExecuteQuery(t *testing.T) {
}

func TestHandleMultipleQueries(t *testing.T) {
t.Run("Handles multiple SET statements", func(t *testing.T) {
query := `SET client_encoding TO 'UTF8';
t.Run("Handles multiple SET statements", func(t *testing.T) {
query := `SET client_encoding TO 'UTF8';
SET client_min_messages TO 'warning';
SET standard_conforming_strings = on;`
queryHandler := initQueryHandler()
queryHandler := initQueryHandler()

messages, err := queryHandler.HandleQuery(query)
messages, err := queryHandler.HandleQuery(query)

testNoError(t, err)
testMessageTypes(t, messages, []pgproto3.Message{
&pgproto3.RowDescription{},
&pgproto3.CommandComplete{},
})
})
testNoError(t, err)
testMessageTypes(t, messages, []pgproto3.Message{
&pgproto3.RowDescription{},
&pgproto3.CommandComplete{},
})
})

t.Run("Handles mixed SET and SELECT statements", func(t *testing.T) {
query := `SET client_encoding TO 'UTF8';
t.Run("Handles mixed SET and SELECT statements", func(t *testing.T) {
query := `SET client_encoding TO 'UTF8';
SELECT passwd FROM pg_shadow WHERE usename='bemidb';`
queryHandler := initQueryHandler()
queryHandler := initQueryHandler()

messages, err := queryHandler.HandleQuery(query)
messages, err := queryHandler.HandleQuery(query)

testNoError(t, err)
testMessageTypes(t, messages, []pgproto3.Message{
&pgproto3.RowDescription{},
&pgproto3.DataRow{},
&pgproto3.CommandComplete{},
})
testDataRowValues(t, messages[1], []string{"bemidb-encrypted"})
})
testNoError(t, err)
testMessageTypes(t, messages, []pgproto3.Message{
&pgproto3.RowDescription{},
&pgproto3.DataRow{},
&pgproto3.CommandComplete{},
})
testDataRowValues(t, messages[1], []string{"bemidb-encrypted"})
})

t.Run("Handles multiple SELECT statements", func(t *testing.T) {
query := `SELECT passwd FROM pg_shadow WHERE usename='bemidb';
t.Run("Handles multiple SELECT statements", func(t *testing.T) {
query := `SELECT passwd FROM pg_shadow WHERE usename='bemidb';
SELECT passwd FROM pg_shadow WHERE usename='bemidb';`
queryHandler := initQueryHandler()
queryHandler := initQueryHandler()

messages, err := queryHandler.HandleQuery(query)
messages, err := queryHandler.HandleQuery(query)

testNoError(t, err)
testMessageTypes(t, messages, []pgproto3.Message{
&pgproto3.RowDescription{},
&pgproto3.DataRow{},
&pgproto3.CommandComplete{},
})
testDataRowValues(t, messages[1], []string{"bemidb-encrypted"})
})
testNoError(t, err)
testMessageTypes(t, messages, []pgproto3.Message{
&pgproto3.RowDescription{},
&pgproto3.DataRow{},
&pgproto3.CommandComplete{},
})
testDataRowValues(t, messages[1], []string{"bemidb-encrypted"})
})

t.Run("Handles error in any of multiple statements", func(t *testing.T) {
query := `SET client_encoding TO 'UTF8';
t.Run("Handles error in any of multiple statements", func(t *testing.T) {
query := `SET client_encoding TO 'UTF8';
SELECT * FROM non_existent_table;
SET standard_conforming_strings = on;`
queryHandler := initQueryHandler()
queryHandler := initQueryHandler()

_, err := queryHandler.HandleQuery(query)
_, err := queryHandler.HandleQuery(query)

if (err == nil) {
t.Error("Expected an error for non-existent table, got nil")
return
}
if err == nil {
t.Error("Expected an error for non-existent table, got nil")
return
}

if !strings.Contains(err.Error(), "non_existent_table") {
t.Errorf("Expected error message to contain 'non_existent_table', got: %s", err.Error())
}
})
if !strings.Contains(err.Error(), "non_existent_table") {
t.Errorf("Expected error message to contain 'non_existent_table', got: %s", err.Error())
}
})
}

func initQueryHandler() *QueryHandler {
Expand Down
220 changes: 110 additions & 110 deletions src/query_parser_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,114 +84,114 @@ func (parser *QueryParserTable) MakePgRolesNode(user string, alias string) *pgQu

// pg_catalog.pg_extension -> VALUES(values...) t(columns...)
func (parser *QueryParserTable) MakePgExtensionNode(alias string) *pgQuery.Node {
columns := PG_EXTENSION_VALUE_BY_COLUMN.Keys()
staticRowValues := PG_EXTENSION_VALUE_BY_COLUMN.Values()
rowsValues := [][]string{staticRowValues}
return parser.utils.MakeSubselectWithRowsNode(columns, rowsValues, alias)
columns := PG_EXTENSION_VALUE_BY_COLUMN.Keys()
staticRowValues := PG_EXTENSION_VALUE_BY_COLUMN.Values()
rowsValues := [][]string{staticRowValues}
return parser.utils.MakeSubselectWithRowsNode(columns, rowsValues, alias)
}

// pg_catalog.pg_database -> VALUES(values...) t(columns...)
func (parser *QueryParserTable) MakePgDatabaseNode(alias string) *pgQuery.Node {
targetList := []*pgQuery.Node{
pgQuery.MakeResTargetNodeWithNameAndVal(
"oid",
pgQuery.MakeColumnRefNode([]*pgQuery.Node{pgQuery.MakeStrNode("oid")}, 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"datname",
pgQuery.MakeColumnRefNode([]*pgQuery.Node{pgQuery.MakeStrNode("datname")}, 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"datdba",
pgQuery.MakeAConstStrNode("", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"encoding",
pgQuery.MakeAConstStrNode("6", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"datlocprovider",
pgQuery.MakeAConstStrNode("c", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"datistemplate",
pgQuery.MakeAConstStrNode("f", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"datallowconn",
pgQuery.MakeAConstStrNode("t", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"datconnlimit",
pgQuery.MakeAConstStrNode("-1", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"datfrozenxid",
pgQuery.MakeAConstStrNode("722", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"datminmxid",
pgQuery.MakeAConstStrNode("1", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"dattablespace",
pgQuery.MakeAConstStrNode("1663", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"datcollate",
pgQuery.MakeAConstStrNode("en_US.UTF-8", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"datctype",
pgQuery.MakeAConstStrNode("en_US.UTF-8", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"daticulocale",
pgQuery.MakeAConstStrNode("", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"daticurules",
pgQuery.MakeAConstStrNode("", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"datcollversion",
pgQuery.MakeAConstStrNode("", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"datacl",
pgQuery.MakeAConstStrNode("", 0),
0,
),
}
targetList := []*pgQuery.Node{
pgQuery.MakeResTargetNodeWithNameAndVal(
"oid",
pgQuery.MakeColumnRefNode([]*pgQuery.Node{pgQuery.MakeStrNode("oid")}, 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"datname",
pgQuery.MakeColumnRefNode([]*pgQuery.Node{pgQuery.MakeStrNode("datname")}, 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"datdba",
pgQuery.MakeAConstStrNode("", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"encoding",
pgQuery.MakeAConstStrNode("6", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"datlocprovider",
pgQuery.MakeAConstStrNode("c", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"datistemplate",
pgQuery.MakeAConstStrNode("f", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"datallowconn",
pgQuery.MakeAConstStrNode("t", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"datconnlimit",
pgQuery.MakeAConstStrNode("-1", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"datfrozenxid",
pgQuery.MakeAConstStrNode("722", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"datminmxid",
pgQuery.MakeAConstStrNode("1", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"dattablespace",
pgQuery.MakeAConstStrNode("1663", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"datcollate",
pgQuery.MakeAConstStrNode("en_US.UTF-8", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"datctype",
pgQuery.MakeAConstStrNode("en_US.UTF-8", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"daticulocale",
pgQuery.MakeAConstStrNode("", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"daticurules",
pgQuery.MakeAConstStrNode("", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"datcollversion",
pgQuery.MakeAConstStrNode("", 0),
0,
),
pgQuery.MakeResTargetNodeWithNameAndVal(
"datacl",
pgQuery.MakeAConstStrNode("", 0),
0,
),
}

fromClause := &pgQuery.Node{
Node: &pgQuery.Node_RangeVar{
RangeVar: &pgQuery.RangeVar{
Schemaname: "pg_catalog",
Relname: "pg_database",
Inh: true,
Relpersistence: "p",
},
},
}
fromClause := &pgQuery.Node{
Node: &pgQuery.Node_RangeVar{
RangeVar: &pgQuery.RangeVar{
Schemaname: "pg_catalog",
Relname: "pg_database",
Inh: true,
Relpersistence: "p",
},
},
}

return parser.utils.MakeSubselectFromNode(targetList, fromClause, alias)
return parser.utils.MakeSubselectFromNode(targetList, fromClause, alias)
}

// System pg_* tables
Expand Down Expand Up @@ -618,14 +618,14 @@ var PG_ROLES_VALUE_BY_COLUMN = NewOrderedMap([][]string{
})

var PG_EXTENSION_VALUE_BY_COLUMN = NewOrderedMap([][]string{
{"oid", "13823"},
{"extname", "plpgsql"},
{"extowner", "10"},
{"extnamespace", "11"},
{"extrelocatable", "false"},
{"extversion", "1.0"},
{"extconfig", "NULL"},
{"extcondition", "NULL"},
{"oid", "13823"},
{"extname", "plpgsql"},
{"extowner", "10"},
{"extnamespace", "11"},
{"extrelocatable", "false"},
{"extversion", "1.0"},
{"extconfig", "NULL"},
{"extcondition", "NULL"},
})

type DuckDBKeyword struct {
Expand Down
Loading

0 comments on commit 3987690

Please sign in to comment.