diff --git a/src/query_handler_test.go b/src/query_handler_test.go index 1513cbd..cc67c03 100644 --- a/src/query_handler_test.go +++ b/src/query_handler_test.go @@ -47,6 +47,10 @@ func TestHandleQuery(t *testing.T) { "description": {"pg_encoding_to_char"}, "values": {"UTF8"}, }, + "SELECT pg_backend_pid()": { + "description": {"pg_backend_pid"}, + "values": {"0"}, + }, // PG system tables "SELECT oid, typname AS typename FROM pg_type WHERE typname='geometry' OR typname='geography'": { "description": {"oid", "typename"}, @@ -670,70 +674,70 @@ func TestHandleExecuteQuery(t *testing.T) { } func TestHandleMultipleQueries(t *testing.T) { - t.Run("Handles multiple SET statements", func(t *testing.T) { - query := `SET client_encoding TO 'UTF8'; + t.Run("Handles multiple SET statements", func(t *testing.T) { + query := `SET client_encoding TO 'UTF8'; SET client_min_messages TO 'warning'; SET standard_conforming_strings = on;` - queryHandler := initQueryHandler() + queryHandler := initQueryHandler() - messages, err := queryHandler.HandleQuery(query) + messages, err := queryHandler.HandleQuery(query) - testNoError(t, err) - testMessageTypes(t, messages, []pgproto3.Message{ - &pgproto3.RowDescription{}, - &pgproto3.CommandComplete{}, - }) - }) + testNoError(t, err) + testMessageTypes(t, messages, []pgproto3.Message{ + &pgproto3.RowDescription{}, + &pgproto3.CommandComplete{}, + }) + }) - t.Run("Handles mixed SET and SELECT statements", func(t *testing.T) { - query := `SET client_encoding TO 'UTF8'; + t.Run("Handles mixed SET and SELECT statements", func(t *testing.T) { + query := `SET client_encoding TO 'UTF8'; SELECT passwd FROM pg_shadow WHERE usename='bemidb';` - queryHandler := initQueryHandler() + queryHandler := initQueryHandler() - messages, err := queryHandler.HandleQuery(query) + messages, err := queryHandler.HandleQuery(query) - testNoError(t, err) - testMessageTypes(t, messages, []pgproto3.Message{ - &pgproto3.RowDescription{}, - &pgproto3.DataRow{}, - &pgproto3.CommandComplete{}, - }) - testDataRowValues(t, messages[1], []string{"bemidb-encrypted"}) - }) + testNoError(t, err) + testMessageTypes(t, messages, []pgproto3.Message{ + &pgproto3.RowDescription{}, + &pgproto3.DataRow{}, + &pgproto3.CommandComplete{}, + }) + testDataRowValues(t, messages[1], []string{"bemidb-encrypted"}) + }) - t.Run("Handles multiple SELECT statements", func(t *testing.T) { - query := `SELECT passwd FROM pg_shadow WHERE usename='bemidb'; + t.Run("Handles multiple SELECT statements", func(t *testing.T) { + query := `SELECT passwd FROM pg_shadow WHERE usename='bemidb'; SELECT passwd FROM pg_shadow WHERE usename='bemidb';` - queryHandler := initQueryHandler() + queryHandler := initQueryHandler() - messages, err := queryHandler.HandleQuery(query) + messages, err := queryHandler.HandleQuery(query) - testNoError(t, err) - testMessageTypes(t, messages, []pgproto3.Message{ - &pgproto3.RowDescription{}, - &pgproto3.DataRow{}, - &pgproto3.CommandComplete{}, - }) - testDataRowValues(t, messages[1], []string{"bemidb-encrypted"}) - }) + testNoError(t, err) + testMessageTypes(t, messages, []pgproto3.Message{ + &pgproto3.RowDescription{}, + &pgproto3.DataRow{}, + &pgproto3.CommandComplete{}, + }) + testDataRowValues(t, messages[1], []string{"bemidb-encrypted"}) + }) - t.Run("Handles error in any of multiple statements", func(t *testing.T) { - query := `SET client_encoding TO 'UTF8'; + t.Run("Handles error in any of multiple statements", func(t *testing.T) { + query := `SET client_encoding TO 'UTF8'; SELECT * FROM non_existent_table; SET standard_conforming_strings = on;` - queryHandler := initQueryHandler() + queryHandler := initQueryHandler() - _, err := queryHandler.HandleQuery(query) + _, err := queryHandler.HandleQuery(query) - if (err == nil) { - t.Error("Expected an error for non-existent table, got nil") - return - } + if err == nil { + t.Error("Expected an error for non-existent table, got nil") + return + } - if !strings.Contains(err.Error(), "non_existent_table") { - t.Errorf("Expected error message to contain 'non_existent_table', got: %s", err.Error()) - } - }) + if !strings.Contains(err.Error(), "non_existent_table") { + t.Errorf("Expected error message to contain 'non_existent_table', got: %s", err.Error()) + } + }) } func initQueryHandler() *QueryHandler { diff --git a/src/query_parser_table.go b/src/query_parser_table.go index 054ef99..05a816b 100644 --- a/src/query_parser_table.go +++ b/src/query_parser_table.go @@ -84,114 +84,114 @@ func (parser *QueryParserTable) MakePgRolesNode(user string, alias string) *pgQu // pg_catalog.pg_extension -> VALUES(values...) t(columns...) func (parser *QueryParserTable) MakePgExtensionNode(alias string) *pgQuery.Node { - columns := PG_EXTENSION_VALUE_BY_COLUMN.Keys() - staticRowValues := PG_EXTENSION_VALUE_BY_COLUMN.Values() - rowsValues := [][]string{staticRowValues} - return parser.utils.MakeSubselectWithRowsNode(columns, rowsValues, alias) + columns := PG_EXTENSION_VALUE_BY_COLUMN.Keys() + staticRowValues := PG_EXTENSION_VALUE_BY_COLUMN.Values() + rowsValues := [][]string{staticRowValues} + return parser.utils.MakeSubselectWithRowsNode(columns, rowsValues, alias) } // pg_catalog.pg_database -> VALUES(values...) t(columns...) func (parser *QueryParserTable) MakePgDatabaseNode(alias string) *pgQuery.Node { - targetList := []*pgQuery.Node{ - pgQuery.MakeResTargetNodeWithNameAndVal( - "oid", - pgQuery.MakeColumnRefNode([]*pgQuery.Node{pgQuery.MakeStrNode("oid")}, 0), - 0, - ), - pgQuery.MakeResTargetNodeWithNameAndVal( - "datname", - pgQuery.MakeColumnRefNode([]*pgQuery.Node{pgQuery.MakeStrNode("datname")}, 0), - 0, - ), - pgQuery.MakeResTargetNodeWithNameAndVal( - "datdba", - pgQuery.MakeAConstStrNode("", 0), - 0, - ), - pgQuery.MakeResTargetNodeWithNameAndVal( - "encoding", - pgQuery.MakeAConstStrNode("6", 0), - 0, - ), - pgQuery.MakeResTargetNodeWithNameAndVal( - "datlocprovider", - pgQuery.MakeAConstStrNode("c", 0), - 0, - ), - pgQuery.MakeResTargetNodeWithNameAndVal( - "datistemplate", - pgQuery.MakeAConstStrNode("f", 0), - 0, - ), - pgQuery.MakeResTargetNodeWithNameAndVal( - "datallowconn", - pgQuery.MakeAConstStrNode("t", 0), - 0, - ), - pgQuery.MakeResTargetNodeWithNameAndVal( - "datconnlimit", - pgQuery.MakeAConstStrNode("-1", 0), - 0, - ), - pgQuery.MakeResTargetNodeWithNameAndVal( - "datfrozenxid", - pgQuery.MakeAConstStrNode("722", 0), - 0, - ), - pgQuery.MakeResTargetNodeWithNameAndVal( - "datminmxid", - pgQuery.MakeAConstStrNode("1", 0), - 0, - ), - pgQuery.MakeResTargetNodeWithNameAndVal( - "dattablespace", - pgQuery.MakeAConstStrNode("1663", 0), - 0, - ), - pgQuery.MakeResTargetNodeWithNameAndVal( - "datcollate", - pgQuery.MakeAConstStrNode("en_US.UTF-8", 0), - 0, - ), - pgQuery.MakeResTargetNodeWithNameAndVal( - "datctype", - pgQuery.MakeAConstStrNode("en_US.UTF-8", 0), - 0, - ), - pgQuery.MakeResTargetNodeWithNameAndVal( - "daticulocale", - pgQuery.MakeAConstStrNode("", 0), - 0, - ), - pgQuery.MakeResTargetNodeWithNameAndVal( - "daticurules", - pgQuery.MakeAConstStrNode("", 0), - 0, - ), - pgQuery.MakeResTargetNodeWithNameAndVal( - "datcollversion", - pgQuery.MakeAConstStrNode("", 0), - 0, - ), - pgQuery.MakeResTargetNodeWithNameAndVal( - "datacl", - pgQuery.MakeAConstStrNode("", 0), - 0, - ), - } + targetList := []*pgQuery.Node{ + pgQuery.MakeResTargetNodeWithNameAndVal( + "oid", + pgQuery.MakeColumnRefNode([]*pgQuery.Node{pgQuery.MakeStrNode("oid")}, 0), + 0, + ), + pgQuery.MakeResTargetNodeWithNameAndVal( + "datname", + pgQuery.MakeColumnRefNode([]*pgQuery.Node{pgQuery.MakeStrNode("datname")}, 0), + 0, + ), + pgQuery.MakeResTargetNodeWithNameAndVal( + "datdba", + pgQuery.MakeAConstStrNode("", 0), + 0, + ), + pgQuery.MakeResTargetNodeWithNameAndVal( + "encoding", + pgQuery.MakeAConstStrNode("6", 0), + 0, + ), + pgQuery.MakeResTargetNodeWithNameAndVal( + "datlocprovider", + pgQuery.MakeAConstStrNode("c", 0), + 0, + ), + pgQuery.MakeResTargetNodeWithNameAndVal( + "datistemplate", + pgQuery.MakeAConstStrNode("f", 0), + 0, + ), + pgQuery.MakeResTargetNodeWithNameAndVal( + "datallowconn", + pgQuery.MakeAConstStrNode("t", 0), + 0, + ), + pgQuery.MakeResTargetNodeWithNameAndVal( + "datconnlimit", + pgQuery.MakeAConstStrNode("-1", 0), + 0, + ), + pgQuery.MakeResTargetNodeWithNameAndVal( + "datfrozenxid", + pgQuery.MakeAConstStrNode("722", 0), + 0, + ), + pgQuery.MakeResTargetNodeWithNameAndVal( + "datminmxid", + pgQuery.MakeAConstStrNode("1", 0), + 0, + ), + pgQuery.MakeResTargetNodeWithNameAndVal( + "dattablespace", + pgQuery.MakeAConstStrNode("1663", 0), + 0, + ), + pgQuery.MakeResTargetNodeWithNameAndVal( + "datcollate", + pgQuery.MakeAConstStrNode("en_US.UTF-8", 0), + 0, + ), + pgQuery.MakeResTargetNodeWithNameAndVal( + "datctype", + pgQuery.MakeAConstStrNode("en_US.UTF-8", 0), + 0, + ), + pgQuery.MakeResTargetNodeWithNameAndVal( + "daticulocale", + pgQuery.MakeAConstStrNode("", 0), + 0, + ), + pgQuery.MakeResTargetNodeWithNameAndVal( + "daticurules", + pgQuery.MakeAConstStrNode("", 0), + 0, + ), + pgQuery.MakeResTargetNodeWithNameAndVal( + "datcollversion", + pgQuery.MakeAConstStrNode("", 0), + 0, + ), + pgQuery.MakeResTargetNodeWithNameAndVal( + "datacl", + pgQuery.MakeAConstStrNode("", 0), + 0, + ), + } - fromClause := &pgQuery.Node{ - Node: &pgQuery.Node_RangeVar{ - RangeVar: &pgQuery.RangeVar{ - Schemaname: "pg_catalog", - Relname: "pg_database", - Inh: true, - Relpersistence: "p", - }, - }, - } + fromClause := &pgQuery.Node{ + Node: &pgQuery.Node_RangeVar{ + RangeVar: &pgQuery.RangeVar{ + Schemaname: "pg_catalog", + Relname: "pg_database", + Inh: true, + Relpersistence: "p", + }, + }, + } - return parser.utils.MakeSubselectFromNode(targetList, fromClause, alias) + return parser.utils.MakeSubselectFromNode(targetList, fromClause, alias) } // System pg_* tables @@ -618,14 +618,14 @@ var PG_ROLES_VALUE_BY_COLUMN = NewOrderedMap([][]string{ }) var PG_EXTENSION_VALUE_BY_COLUMN = NewOrderedMap([][]string{ - {"oid", "13823"}, - {"extname", "plpgsql"}, - {"extowner", "10"}, - {"extnamespace", "11"}, - {"extrelocatable", "false"}, - {"extversion", "1.0"}, - {"extconfig", "NULL"}, - {"extcondition", "NULL"}, + {"oid", "13823"}, + {"extname", "plpgsql"}, + {"extowner", "10"}, + {"extnamespace", "11"}, + {"extrelocatable", "false"}, + {"extversion", "1.0"}, + {"extconfig", "NULL"}, + {"extcondition", "NULL"}, }) type DuckDBKeyword struct { diff --git a/src/select_remapper.go b/src/select_remapper.go index b9184ef..c63efdb 100644 --- a/src/select_remapper.go +++ b/src/select_remapper.go @@ -130,45 +130,45 @@ func (selectRemapper *SelectRemapper) hasCaseExpressions(selectStatement *pgQuer } func (selectRemapper *SelectRemapper) remapCaseExpressions(selectStatement *pgQuery.SelectStmt, indentLevel int) *pgQuery.SelectStmt { - for _, target := range selectStatement.TargetList { - if caseExpr := target.GetResTarget().Val.GetCaseExpr(); caseExpr != nil { - for _, when := range caseExpr.Args { - if whenClause := when.GetCaseWhen(); whenClause != nil { - if whenClause.Expr != nil { - if aExpr := whenClause.Expr.GetAExpr(); aExpr != nil { - if subLink := aExpr.Lexpr.GetSubLink(); subLink != nil { - selectRemapper.traceTreeTraversal("CASE WHEN left", indentLevel+1) - subSelect := subLink.Subselect.GetSelectStmt() - subSelect = selectRemapper.remapSelectStatement(subSelect, indentLevel+1) - } - if subLink := aExpr.Rexpr.GetSubLink(); subLink != nil { - selectRemapper.traceTreeTraversal("CASE WHEN right", indentLevel+1) - subSelect := subLink.Subselect.GetSelectStmt() - subSelect = selectRemapper.remapSelectStatement(subSelect, indentLevel+1) - } - } - } - - if whenClause.Result != nil { - if subLink := whenClause.Result.GetSubLink(); subLink != nil { - selectRemapper.traceTreeTraversal("CASE THEN", indentLevel+1) - subSelect := subLink.Subselect.GetSelectStmt() - subSelect = selectRemapper.remapSelectStatement(subSelect, indentLevel+1) - } - } - } - } - - if caseExpr.Defresult != nil { - if subLink := caseExpr.Defresult.GetSubLink(); subLink != nil { - selectRemapper.traceTreeTraversal("CASE ELSE", indentLevel+1) - subSelect := subLink.Subselect.GetSelectStmt() - subSelect = selectRemapper.remapSelectStatement(subSelect, indentLevel+1) - } - } - } - } - return selectStatement + for _, target := range selectStatement.TargetList { + if caseExpr := target.GetResTarget().Val.GetCaseExpr(); caseExpr != nil { + for _, when := range caseExpr.Args { + if whenClause := when.GetCaseWhen(); whenClause != nil { + if whenClause.Expr != nil { + if aExpr := whenClause.Expr.GetAExpr(); aExpr != nil { + if subLink := aExpr.Lexpr.GetSubLink(); subLink != nil { + selectRemapper.traceTreeTraversal("CASE WHEN left", indentLevel+1) + subSelect := subLink.Subselect.GetSelectStmt() + subSelect = selectRemapper.remapSelectStatement(subSelect, indentLevel+1) + } + if subLink := aExpr.Rexpr.GetSubLink(); subLink != nil { + selectRemapper.traceTreeTraversal("CASE WHEN right", indentLevel+1) + subSelect := subLink.Subselect.GetSelectStmt() + subSelect = selectRemapper.remapSelectStatement(subSelect, indentLevel+1) + } + } + } + + if whenClause.Result != nil { + if subLink := whenClause.Result.GetSubLink(); subLink != nil { + selectRemapper.traceTreeTraversal("CASE THEN", indentLevel+1) + subSelect := subLink.Subselect.GetSelectStmt() + subSelect = selectRemapper.remapSelectStatement(subSelect, indentLevel+1) + } + } + } + } + + if caseExpr.Defresult != nil { + if subLink := caseExpr.Defresult.GetSubLink(); subLink != nil { + selectRemapper.traceTreeTraversal("CASE ELSE", indentLevel+1) + subSelect := subLink.Subselect.GetSelectStmt() + subSelect = selectRemapper.remapSelectStatement(subSelect, indentLevel+1) + } + } + } + } + return selectStatement } // FROM PG_FUNCTION() diff --git a/src/select_remapper_select.go b/src/select_remapper_select.go index cf7ea3e..87e0a5b 100644 --- a/src/select_remapper_select.go +++ b/src/select_remapper_select.go @@ -14,6 +14,7 @@ var REMAPPED_CONSTANT_BY_PG_FUNCTION_NAME = map[string]string{ "pg_get_partkeydef": "", "pg_tablespace_location": "", "pg_encoding_to_char": "UTF8", + "pg_backend_pid": "0", } type SelectRemapperSelect struct {