Skip to content

Commit

Permalink
Add support for functions in where expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
arjunlol committed Dec 17, 2024
1 parent 3987690 commit e39d20e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/query_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,11 @@ func TestHandleQuery(t *testing.T) {
"description": {"type"},
"values": {""},
},
// WHERE pg functions
"SELECT gss_authenticated, encrypted FROM (SELECT false, false, false, false, false WHERE false) t(pid, gss_authenticated, principal, encrypted, credentials_delegated) WHERE pid = pg_backend_pid()": {
"description": {"gss_authenticated", "encrypted"},
"values": {},
},
}

for query, responses := range responsesByQuery {
Expand Down
28 changes: 28 additions & 0 deletions src/select_remapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ func (selectRemapper *SelectRemapper) remapSelectStatement(selectStatement *pgQu
return selectStatement
}

// WHERE
if selectStatement.WhereClause != nil {
selectStatement = selectRemapper.remapWhereExpression(selectStatement, selectStatement.WhereClause, indentLevel)
}

// FROM
if len(selectStatement.FromClause) > 0 {
// SELECT
Expand Down Expand Up @@ -304,6 +309,29 @@ func (selectRemapper *SelectRemapper) remapTypeCastsInNode(node *pgQuery.Node) *
return node
}

func (selectRemapper *SelectRemapper) remapWhereExpression(selectStatement *pgQuery.SelectStmt, node *pgQuery.Node, indentLevel int) *pgQuery.SelectStmt {
selectRemapper.traceTreeTraversal("WHERE expression", indentLevel)

if aExpr := node.GetAExpr(); aExpr != nil {
if aExpr.Lexpr != nil {
selectRemapper.traceTreeTraversal("WHERE expression left", indentLevel+1)
selectStatement = selectRemapper.remapWhereExpression(selectStatement, aExpr.Lexpr, indentLevel+1)
}
if aExpr.Rexpr != nil {
selectRemapper.traceTreeTraversal("WHERE expression right", indentLevel+1)
selectStatement = selectRemapper.remapWhereExpression(selectStatement, aExpr.Rexpr, indentLevel+1)
}
}

if funcCall := node.GetFuncCall(); funcCall != nil {
if constantNode := selectRemapper.remapperSelect.remappedToConstant(funcCall); constantNode != nil {
node.Node = constantNode.Node
}
}

return selectStatement
}

func (selectRemapper *SelectRemapper) remapJoinExpressions(selectStatement *pgQuery.SelectStmt, node *pgQuery.Node, indentLevel int) *pgQuery.Node {
selectRemapper.traceTreeTraversal("JOIN left", indentLevel)
leftJoinNode := node.GetJoinExpr().Larg
Expand Down

0 comments on commit e39d20e

Please sign in to comment.