Skip to content

Commit

Permalink
Ensure consistent CASE expression typing
Browse files Browse the repository at this point in the history
  • Loading branch information
arjunlol committed Dec 20, 2024
1 parent f1074a2 commit 5b1ce01
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 34 deletions.
2 changes: 1 addition & 1 deletion src/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"time"
)

const VERSION = "0.28.0"
const VERSION = "0.28.1"

func main() {
config := LoadConfig()
Expand Down
4 changes: 4 additions & 0 deletions src/query_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,10 @@ func TestHandleQuery(t *testing.T) {
"description": {"id", "name", "is_superuser", "can_create_role"},
"values": {},
},
"SELECT roles.oid AS id, roles.rolname AS name, roles.rolsuper AS is_superuser, CASE WHEN roles.rolsuper THEN true ELSE roles.rolcreaterole END AS can_create_role FROM pg_catalog.pg_roles roles WHERE rolname = current_user": {
"description": {"id", "name", "is_superuser", "can_create_role"},
"values": {},
},
// WHERE pg functions
"SELECT gss_authenticated, encrypted FROM (SELECT false, false, false, false, false WHERE false) t(pid, gss_authenticated, principal, encrypted, credentials_delegated) WHERE pid = pg_backend_pid()": {
"description": {"gss_authenticated", "encrypted"},
Expand Down
100 changes: 100 additions & 0 deletions src/query_parser_type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package main

import (
"strings"

pgQuery "github.com/pganalyze/pg_query_go/v5"
)

type QueryParserType struct {
config *Config
utils *QueryParserUtils
}

func NewQueryParserType(config *Config) *QueryParserType {
return &QueryParserType{
config: config,
utils: NewQueryParserUtils(config),
}
}

func (parser *QueryParserType) MakeTypeCastNode(arg *pgQuery.Node, typeName string) *pgQuery.Node {
return &pgQuery.Node{
Node: &pgQuery.Node_TypeCast{
TypeCast: &pgQuery.TypeCast{
Arg: arg,
TypeName: &pgQuery.TypeName{
Names: []*pgQuery.Node{
pgQuery.MakeStrNode(typeName),
},
Location: 0,
},
},
},
}
}

func (parser *QueryParserType) inferNodeType(node *pgQuery.Node) string {
if typeCast := node.GetTypeCast(); typeCast != nil {
return typeCast.TypeName.Names[0].GetString_().Sval
}

if aConst := node.GetAConst(); aConst != nil {
switch {
case aConst.GetBoolval() != nil:
return "boolean"
case aConst.GetIval() != nil:
return "int8"
case aConst.GetSval() != nil:
return "text"
}
}
return ""
}

func (parser *QueryParserType) MakeCaseTypeCastNode(arg *pgQuery.Node, typeName string) *pgQuery.Node {
if existingType := parser.inferNodeType(arg); existingType == typeName {
return arg
}
return parser.MakeTypeCastNode(arg, typeName)
}

func (parser *QueryParserType) RemapTypeCast(node *pgQuery.Node) *pgQuery.Node {
if node.GetTypeCast() != nil {
typeCast := node.GetTypeCast()
if len(typeCast.TypeName.Names) > 0 {
typeName := typeCast.TypeName.Names[0].GetString_().Sval
if typeName == "regclass" {
return typeCast.Arg
}

if typeName == "text" {
return parser.MakeListValueFromArray(typeCast.Arg)
}
}
}
return node
}

func (parser *QueryParserType) MakeListValueFromArray(node *pgQuery.Node) *pgQuery.Node {
arrayStr := node.GetAConst().GetSval().Sval
arrayStr = strings.Trim(arrayStr, "{}")
elements := strings.Split(arrayStr, ",")

funcCall := &pgQuery.FuncCall{
Funcname: []*pgQuery.Node{
pgQuery.MakeStrNode("list_value"),
},
}

for _, elem := range elements {
funcCall.Args = append(funcCall.Args,
pgQuery.MakeAConstStrNode(elem, 0))
}

return &pgQuery.Node{
Node: &pgQuery.Node_FuncCall{
FuncCall: funcCall,
},
}
}
58 changes: 25 additions & 33 deletions src/select_remapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ var KNOWN_SET_STATEMENTS = NewSet([]string{

type SelectRemapper struct {
parserTable *QueryParserTable
parserType *QueryParserType
remapperTable *SelectRemapperTable
remapperWhere *SelectRemapperWhere
remapperSelect *SelectRemapperSelect
Expand All @@ -29,6 +30,7 @@ type SelectRemapper struct {
func NewSelectRemapper(config *Config, icebergReader *IcebergReader, duckdb *Duckdb) *SelectRemapper {
return &SelectRemapper{
parserTable: NewQueryParserTable(config),
parserType: NewQueryParserType(config),
remapperTable: NewSelectRemapperTable(config, icebergReader, duckdb),
remapperWhere: NewSelectRemapperWhere(config),
remapperSelect: NewSelectRemapperSelect(config),
Expand Down Expand Up @@ -127,6 +129,9 @@ func (selectRemapper *SelectRemapper) hasCaseExpressions(selectStatement *pgQuer
func (selectRemapper *SelectRemapper) remapCaseExpressions(selectStatement *pgQuery.SelectStmt, indentLevel int) *pgQuery.SelectStmt {
for _, target := range selectStatement.TargetList {
if caseExpr := target.GetResTarget().Val.GetCaseExpr(); caseExpr != nil {

selectRemapper.ensureConsistentCaseTypes(caseExpr)

for _, when := range caseExpr.Args {
if whenClause := when.GetCaseWhen(); whenClause != nil {
if whenClause.Expr != nil {
Expand Down Expand Up @@ -166,6 +171,25 @@ func (selectRemapper *SelectRemapper) remapCaseExpressions(selectStatement *pgQu
return selectStatement
}

func (selectRemapper *SelectRemapper) ensureConsistentCaseTypes(caseExpr *pgQuery.CaseExpr) {
if len(caseExpr.Args) > 0 {
if when := caseExpr.Args[0].GetCaseWhen(); when != nil && when.Result != nil {
if typeName := selectRemapper.parserType.inferNodeType(when.Result); typeName != "" {
// WHEN
for i := 1; i < len(caseExpr.Args); i++ {
if whenClause := caseExpr.Args[i].GetCaseWhen(); whenClause != nil && whenClause.Result != nil {
whenClause.Result = selectRemapper.parserType.MakeCaseTypeCastNode(whenClause.Result, typeName)
}
}
// ELSE
if caseExpr.Defresult != nil {
caseExpr.Defresult = selectRemapper.parserType.MakeCaseTypeCastNode(caseExpr.Defresult, typeName)
}
}
}
}
}

// FROM PG_FUNCTION()
func (selectRemapper *SelectRemapper) remapTableFunction(fromNode *pgQuery.Node, indentLevel int) *pgQuery.Node {
selectRemapper.traceTreeTraversal("FROM function()", indentLevel)
Expand Down Expand Up @@ -358,39 +382,7 @@ func (selectRemapper *SelectRemapper) remapSelect(selectStatement *pgQuery.Selec
}

func (selectRemapper *SelectRemapper) remapTypecast(node *pgQuery.Node) *pgQuery.Node {
if node.GetTypeCast() != nil {
typeCast := node.GetTypeCast()
if len(typeCast.TypeName.Names) > 0 {
typeName := typeCast.TypeName.Names[0].GetString_().Sval
if typeName == "regclass" {
return typeCast.Arg
}

if typeName == "text" {
arrayStr := typeCast.Arg.GetAConst().GetSval().Sval
arrayStr = strings.Trim(arrayStr, "{}")
elements := strings.Split(arrayStr, ",")

funcCall := &pgQuery.FuncCall{
Funcname: []*pgQuery.Node{
pgQuery.MakeStrNode("list_value"),
},
}

for _, elem := range elements {
funcCall.Args = append(funcCall.Args,
pgQuery.MakeAConstStrNode(elem, 0))
}

return &pgQuery.Node{
Node: &pgQuery.Node_FuncCall{
FuncCall: funcCall,
},
}
}
}
}
return node
return selectRemapper.parserType.RemapTypeCast(node)
}

func (selectRemapper *SelectRemapper) traceTreeTraversal(label string, indentLevel int) {
Expand Down
2 changes: 2 additions & 0 deletions src/select_remapper_where.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ func (remapper *SelectRemapperWhere) RemapWhereClauseForTable(qSchemaTable Query
if remapper.parserTable.IsTableFromPgCatalog(qSchemaTable) {
switch qSchemaTable.Table {
case PG_TABLE_PG_NAMESPACE:
// FROM pg_catalog.pg_namespace -> FROM pg_catalog.pg_namespace WHERE nspname != 'main'
withoutMainSchemaWhereCondition := remapper.parserWhere.MakeExpressionNode("nspname", "!=", "main")
return remapper.parserWhere.AppendWhereCondition(selectStatement, withoutMainSchemaWhereCondition)
case PG_TABLE_PG_STATIO_USER_TABLES:
// FROM pg_catalog.pg_statio_user_tables -> FROM pg_catalog.pg_statio_user_tables WHERE false
falseWhereCondition := remapper.parserWhere.MakeFalseConditionNode()
return remapper.parserWhere.OverrideWhereCondition(selectStatement, falseWhereCondition)
}
Expand Down

0 comments on commit 5b1ce01

Please sign in to comment.