diff --git a/main.go b/main.go index 30a2288..aa1cddc 100644 --- a/main.go +++ b/main.go @@ -40,9 +40,7 @@ func (s *SQLVet) reportError(format string, a ...interface{}) { // Vet performs static analysis func (s *SQLVet) Vet() { queries, err := vet.CheckDir( - vet.VetContext{ - Schema: s.Schema, - }, + vet.NewContext(s.Schema.Tables), s.ProjectRoot, s.Cfg.BuildFlags, s.Cfg.SqlFuncMatchers, diff --git a/pkg/schema/postgres.go b/pkg/schema/postgres.go index 7d62331..e6cbee1 100644 --- a/pkg/schema/postgres.go +++ b/pkg/schema/postgres.go @@ -91,51 +91,50 @@ func parsePostgresSchema(schemaInput string) (map[string]Table, error) { continue } - if resTarget.Name != "" { - table.Columns[resTarget.Name] = Column{ - Name: resTarget.Name, - } - continue + if col, ok := GetResTargetColumn(resTarget); ok { + table.Columns[col.Name] = col } + } - if resTarget.Val == nil { - continue - } + tables[tableName] = table + } + } - colRef := resTarget.Val.GetColumnRef() - if colRef == nil { - // parse only column references when no alias is provided - continue - } + return tables, nil +} - var colField *pg_query.Node - if len(colRef.Fields) > 0 { - colField = colRef.Fields[len(colRef.Fields)-1] - } +func GetResTargetColumn(resTarget *pg_query.ResTarget) (col Column, ok bool) { + if resTarget.Name != "" { + return Column{Name: resTarget.Name}, true + } - if colField == nil { - continue - } + if resTarget.Val == nil { + return + } - if colField.GetAStar() != nil { - // SELECT * - force parsing explicit columns for simplicity - continue - } + colRef := resTarget.Val.GetColumnRef() + if colRef == nil { + // parse only column references when no alias is provided + return + } - if colField.GetString_() == nil { - continue - } + var colField *pg_query.Node + if len(colRef.Fields) > 0 { + colField = colRef.Fields[len(colRef.Fields)-1] + } - colName := colField.GetString_().GetStr() - table.Columns[colName] = Column{ - Name: colName, - Type: "", // type not set, never used for validation - } - } + if colField == nil { + return + } - tables[tableName] = table - } + if colField.GetAStar() != nil { + // SELECT * - force parsing explicit columns for simplicity + return } - return tables, nil + if colField.GetString_() == nil { + return + } + + return Column{Name: colField.GetString_().GetStr()}, true } diff --git a/pkg/vet/gosource.go b/pkg/vet/gosource.go index 62cb1c2..d6ea941 100644 --- a/pkg/vet/gosource.go +++ b/pkg/vet/gosource.go @@ -142,7 +142,7 @@ func handleQuery(ctx VetContext, qs *QuerySite) { } var queryParams []QueryParam - queryParams, qs.Err = ValidateSqlQuery(ctx, qs.Query) + queryParams, qs.Err = ValidateSqlQuery(NewContext(ctx.Schema.Tables), qs.Query) if qs.Err != nil { return diff --git a/pkg/vet/vet.go b/pkg/vet/vet.go index 4b6125f..0d3fc29 100644 --- a/pkg/vet/vet.go +++ b/pkg/vet/vet.go @@ -11,8 +11,21 @@ import ( "github.com/houqp/sqlvet/pkg/schema" ) +type Schema struct { + Tables map[string]schema.Table +} + +func NewContext(tables map[string]schema.Table) VetContext { + return VetContext{ + Schema: Schema{Tables: tables}, + InnerSchema: Schema{Tables: map[string]schema.Table{}}, + } +} + type VetContext struct { - Schema *schema.Db + Schema Schema + InnerSchema Schema + UsedTables []TableUsed } type TableUsed struct { @@ -31,9 +44,32 @@ type QueryParam struct { // TODO: also store related column type info for analysis } +type PostponedNodes struct { + RangeSubselectNodes []*pg_query.RangeSubselect +} + +func (p *PostponedNodes) Parse(ctx VetContext, parseRe *ParseResult) (err error) { + for _, r := range p.RangeSubselectNodes { + if err = parseRangeSubselect(ctx, r, parseRe); err != nil { + return err + } + } + return nil +} + +func (p *PostponedNodes) Append(other *PostponedNodes) { + if other == nil { + return + } + p.RangeSubselectNodes = append(p.RangeSubselectNodes, other.RangeSubselectNodes...) +} + type ParseResult struct { Columns []ColumnUsed + Tables []TableUsed Params []QueryParam + + PostponedNodes *PostponedNodes } // insert query param based on parameter number and avoid deduplications @@ -130,6 +166,7 @@ func getUsedTablesFromJoinArg(arg *pg_query.Node) []TableUsed { } // extract used tables from FROM clause and JOIN clauses +// TODO ? maybe this should be moved to parseExpression() and be collected in ParseResult func getUsedTablesFromSelectStmt(fromClauseList []*pg_query.Node) []TableUsed { usedTables := []TableUsed{} @@ -153,67 +190,12 @@ func getUsedTablesFromSelectStmt(fromClauseList []*pg_query.Node) []TableUsed { return usedTables } -func getUsedColumnsFromJoinQuals(quals *pg_query.Node) []ColumnUsed { - usedCols := []ColumnUsed{} - - if quals.GetAExpr() != nil { - joinCond := quals.GetAExpr() - if lColRef := joinCond.GetLexpr().GetColumnRef(); lColRef != nil { - cu := columnRefToColumnUsed(lColRef) - if cu != nil { - usedCols = append(usedCols, *cu) - } - } - if rColRef := joinCond.GetRexpr().GetColumnRef(); rColRef != nil { - cu := columnRefToColumnUsed(rColRef) - if cu != nil { - usedCols = append(usedCols, *cu) - } - } - } - - return usedCols -} - -// todo this rewrite seems especially dubious -func getUsedColumnsFromJoinExpr(expr *pg_query.Node) []ColumnUsed { - usedCols := []ColumnUsed{} - if expr.GetJoinExpr() == nil { - return usedCols - } - joinExpr := expr.GetJoinExpr() - if larg := joinExpr.Larg; larg != nil { - usedCols = append(usedCols, getUsedColumnsFromJoinExpr(larg)...) - } - if rarg := joinExpr.Rarg; rarg != nil { - usedCols = append(usedCols, getUsedColumnsFromJoinExpr(rarg)...) - } - usedCols = append(usedCols, getUsedColumnsFromJoinQuals(joinExpr.Quals)...) - - return usedCols -} - -func getUsedColumnsFromJoinClauses(fromClauseList []*pg_query.Node) []ColumnUsed { - usedCols := []ColumnUsed{} - - if len(fromClauseList) <= 0 { - // skip because no table is referenced in the query, which means there - // is no Join clause - return usedCols - } - - for _, fromItem := range fromClauseList { - switch { - case fromItem.GetRangeVar() != nil: - // SELECT without JOIN - continue - case fromItem.GetJoinExpr() != nil: - // SELECT with one or more JOINs - usedCols = append(usedCols, getUsedColumnsFromJoinExpr(fromItem)...) - } +func parseFromClause(ctx VetContext, clause *pg_query.Node, parseRe *ParseResult) error { + err := parseExpression(ctx, clause, parseRe) + if err != nil { + err = fmt.Errorf("invalid FROM clause: %w", err) } - - return usedCols + return err } func getUsedColumnsFromReturningList(returningList []*pg_query.Node) []ColumnUsed { @@ -242,7 +224,7 @@ func getUsedColumnsFromReturningList(returningList []*pg_query.Node) []ColumnUse } func validateTable(ctx VetContext, tname string, notReadOnly bool) error { - if ctx.Schema == nil { + if ctx.Schema.Tables == nil { return nil } t, ok := ctx.Schema.Tables[tname] @@ -256,16 +238,19 @@ func validateTable(ctx VetContext, tname string, notReadOnly bool) error { } func validateTableColumns(ctx VetContext, tables []TableUsed, cols []ColumnUsed) error { - if ctx.Schema == nil { + if ctx.Schema.Tables == nil || ctx.InnerSchema.Tables == nil { return nil } var ok bool usedTables := map[string]schema.Table{} for _, tu := range tables { - usedTables[tu.Name], ok = ctx.Schema.Tables[tu.Name] + usedTables[tu.Name], ok = ctx.InnerSchema.Tables[tu.Name] if !ok { - return fmt.Errorf("invalid table name: %s", tu.Name) + usedTables[tu.Name], ok = ctx.Schema.Tables[tu.Name] + if !ok { + return fmt.Errorf("invalid table name: %s", tu.Name) + } } if tu.Alias != "" { usedTables[tu.Alias] = usedTables[tu.Name] @@ -410,18 +395,7 @@ func parseExpression(ctx VetContext, clause *pg_query.Node, parseRe *ParseResult } case clause.GetSubLink() != nil: // WHERE id IN (SELECT id FROM foo) - subselect := clause.GetSubLink().GetSubselect() - if subselect.GetSelectStmt() == nil { - return fmt.Errorf( - "unsupported subquery type: %v", subselect) - } - queryParams, err := validateSelectStmt(ctx, subselect.GetSelectStmt()) - if err != nil { - return err - } - if len(queryParams) > 0 { - AddQueryParams(&parseRe.Params, queryParams) - } + return parseSublink(ctx, clause.GetSubLink(), parseRe) case clause.GetCoalesceExpr() != nil: // TODO should this be [0]? return parseExpression(ctx, clause.GetCoalesceExpr().GetArgs()[0], parseRe) @@ -429,16 +403,97 @@ func parseExpression(ctx VetContext, clause *pg_query.Node, parseRe *ParseResult return parseWindowDef(ctx, clause.GetWindowDef(), parseRe) case clause.GetSortBy() != nil: return parseExpression(ctx, clause.GetSortBy().Node, parseRe) + case clause.GetJoinExpr() != nil: + return parseJoinExpr(ctx, clause.GetJoinExpr(), parseRe) + case clause.GetRangeVar() != nil: + parseRe.Tables = append(parseRe.Tables, rangeVarToTableUsed(clause.GetRangeVar())) + return nil + case clause.GetRangeSubselect() != nil: + // LEFT JOIN LATERAL (SELECT id FROM foo) AS bar ON true + if parseRe.PostponedNodes == nil { + parseRe.PostponedNodes = &PostponedNodes{} + } + parseRe.PostponedNodes.RangeSubselectNodes = append(parseRe.PostponedNodes.RangeSubselectNodes, clause.GetRangeSubselect()) + return nil default: return fmt.Errorf( - "unsupported expression, found node of type: %v", - reflect.TypeOf(clause), + "unsupported expression, found node of type: %v (%s)", + reflect.TypeOf(clause.Node), clause.String(), ) } return nil } +func parseSublink(ctx VetContext, clause *pg_query.SubLink, parseRe *ParseResult) error { + subSelect := clause.GetSubselect() + if subSelect.GetSelectStmt() == nil { + return fmt.Errorf( + "unsupported sublink subselect type: %v", subSelect) + } + queryParams, _, err := validateSelectStmt(ctx, subSelect.GetSelectStmt()) + if err != nil { + return err + } + if len(queryParams) > 0 { + AddQueryParams(&parseRe.Params, queryParams) + } + + return nil +} + +func parseRangeSubselect(ctx VetContext, clause *pg_query.RangeSubselect, parseRe *ParseResult) error { + subQuery := clause.GetSubquery() + if subQuery.GetSelectStmt() == nil { + return fmt.Errorf("unsupported range subselect subquery type: %v", clause) + } + + queryParams, targetCols, err := validateSelectStmt(ctx, subQuery.GetSelectStmt()) + if err != nil { + return err + } + + if len(queryParams) > 0 { + AddQueryParams(&parseRe.Params, queryParams) + } + + if clause.Alias == nil { + return nil + } + + t := schema.Table{ + Name: clause.Alias.Aliasname, + ReadOnly: true, + Columns: map[string]schema.Column{}, + } + + for _, col := range targetCols { + t.Columns[col.Name] = col + } + + ctx.InnerSchema.Tables[t.Name] = t + parseRe.Tables = append(parseRe.Tables, TableUsed{Name: t.Name}) + + return nil +} + +func parseJoinExpr(ctx VetContext, clause *pg_query.JoinExpr, parseRe *ParseResult) error { + err := parseExpression(ctx, clause.Larg, parseRe) + if err != nil { + return err + } + err = parseExpression(ctx, clause.Rarg, parseRe) + if err != nil { + return err + } + err = parseExpression(ctx, clause.Quals, parseRe) + if err != nil { + return err + } + + return nil +} + // find used column names from where clause func parseWhereClause(ctx VetContext, clause *pg_query.Node, parseRe *ParseResult) error { err := parseExpression(ctx, clause, parseRe) @@ -476,21 +531,53 @@ func getUsedColumnsFromSortClause(sortList []*pg_query.Node) []ColumnUsed { return usedCols } -func validateSelectStmt(ctx VetContext, stmt *pg_query.SelectStmt) ([]QueryParam, error) { - usedTables := getUsedTablesFromSelectStmt(stmt.FromClause) - +func validateSelectStmt(ctx VetContext, stmt *pg_query.SelectStmt) (queryParams []QueryParam, targetCols []schema.Column, err error) { usedCols := []ColumnUsed{} - queryParams := []QueryParam{} + + postponed := PostponedNodes{} + for _, fromClause := range stmt.FromClause { + re := &ParseResult{} + err := parseFromClause(ctx, fromClause, re) + if err != nil { + return nil, nil, err + } + if len(re.Columns) > 0 { + usedCols = append(usedCols, re.Columns...) + } + if len(re.Tables) > 0 { + ctx.UsedTables = append(ctx.UsedTables, re.Tables...) + } + if len(re.Params) > 0 { + AddQueryParams(&queryParams, re.Params) + } + + postponed.Append(re.PostponedNodes) + } + + re := &ParseResult{} + if err := postponed.Parse(ctx, re); err != nil { + return nil, nil, err + } + if len(re.Columns) > 0 { + usedCols = append(usedCols, re.Columns...) + } + if len(re.Tables) > 0 { + ctx.UsedTables = append(ctx.UsedTables, re.Tables...) + } + if len(re.Params) > 0 { + AddQueryParams(&queryParams, re.Params) + } for _, item := range stmt.TargetList { - if item.GetResTarget() == nil { + resTarget := item.GetResTarget() + if resTarget == nil { continue } re := &ParseResult{} - err := parseExpression(ctx, item.GetResTarget().Val, re) + err := parseExpression(ctx, resTarget.Val, re) if err != nil { - return nil, err + return nil, nil, err } if len(re.Columns) > 0 { usedCols = append(usedCols, re.Columns...) @@ -498,15 +585,17 @@ func validateSelectStmt(ctx VetContext, stmt *pg_query.SelectStmt) ([]QueryParam if len(re.Params) > 0 { AddQueryParams(&queryParams, re.Params) } - } - usedCols = append(usedCols, getUsedColumnsFromJoinClauses(stmt.FromClause)...) + if col, ok := schema.GetResTargetColumn(resTarget); ok { + targetCols = append(targetCols, col) + } + } if stmt.WhereClause != nil { re := &ParseResult{} err := parseWhereClause(ctx, stmt.WhereClause, re) if err != nil { - return nil, err + return nil, nil, err } if len(re.Columns) > 0 { usedCols = append(usedCols, re.Columns...) @@ -524,7 +613,7 @@ func validateSelectStmt(ctx VetContext, stmt *pg_query.SelectStmt) ([]QueryParam re := &ParseResult{} err := parseExpression(ctx, stmt.HavingClause, re) if err != nil { - return nil, err + return nil, nil, err } if len(re.Columns) > 0 { usedCols = append(usedCols, re.Columns...) @@ -539,7 +628,7 @@ func validateSelectStmt(ctx VetContext, stmt *pg_query.SelectStmt) ([]QueryParam // TODO: should this be [0]? err := parseExpression(ctx, stmt.WindowClause[0], re) if err != nil { - return nil, err + return nil, nil, err } usedCols = append(usedCols, re.Columns...) AddQueryParams(&queryParams, re.Params) @@ -549,7 +638,7 @@ func validateSelectStmt(ctx VetContext, stmt *pg_query.SelectStmt) ([]QueryParam usedCols = append(usedCols, getUsedColumnsFromSortClause(stmt.SortClause)...) } - return queryParams, validateTableColumns(ctx, usedTables, usedCols) + return queryParams, targetCols, validateTableColumns(ctx, ctx.UsedTables, usedCols) } func validateUpdateStmt(ctx VetContext, stmt *pg_query.UpdateStmt) ([]QueryParam, error) { @@ -658,8 +747,19 @@ func validateInsertStmt(ctx VetContext, stmt *pg_query.InsertStmt) ([]QueryParam */ usedTables = append(usedTables, getUsedTablesFromSelectStmt(selectStmt.FromClause)...) - usedCols = append( - usedCols, getUsedColumnsFromJoinClauses(selectStmt.FromClause)...) + for _, fromClause := range selectStmt.FromClause { + re := &ParseResult{} + err := parseFromClause(ctx, fromClause, re) + if err != nil { + return nil, err + } + if len(re.Columns) > 0 { + usedCols = append(usedCols, re.Columns...) + } + if len(re.Params) > 0 { + AddQueryParams(&queryParams, re.Params) + } + } if selectStmt.WhereClause != nil { re := &ParseResult{} @@ -692,7 +792,7 @@ func validateInsertStmt(ctx VetContext, stmt *pg_query.InsertStmt) ([]QueryParam return nil, fmt.Errorf( "unsupported subquery type in value list: %s", reflect.TypeOf(tv)) } - qparams, err := validateSelectStmt(ctx, tv.GetSelectStmt()) + qparams, _, err := validateSelectStmt(ctx, tv.GetSelectStmt()) if err != nil { return nil, fmt.Errorf("invalid SELECT query in value list: %w", err) } @@ -769,7 +869,8 @@ func ValidateSqlQuery(ctx VetContext, queryStr string) ([]QueryParam, error) { var raw *pg_query.RawStmt = tree.Stmts[0] switch { case raw.Stmt.GetSelectStmt() != nil: - return validateSelectStmt(ctx, raw.Stmt.GetSelectStmt()) + qparams, _, err := validateSelectStmt(ctx, raw.Stmt.GetSelectStmt()) + return qparams, err case raw.Stmt.GetUpdateStmt() != nil: return validateUpdateStmt(ctx, raw.Stmt.GetUpdateStmt()) case raw.Stmt.GetInsertStmt() != nil: diff --git a/pkg/vet/vet_test.go b/pkg/vet/vet_test.go index 7ccf978..05902a1 100644 --- a/pkg/vet/vet_test.go +++ b/pkg/vet/vet_test.go @@ -48,13 +48,19 @@ var mockDbSchema = &schema.Db{ "created_at": { Name: "created_at", }, + "baz_count": { + Name: "baz_count", + Type: "int", + }, }, ReadOnly: true, }, }, } -var mockCtx = vet.VetContext{Schema: mockDbSchema} +func mockCtx() vet.VetContext { + return vet.NewContext(mockDbSchema.Tables) +} func TestInsert(t *testing.T) { testCases := []struct { @@ -108,7 +114,7 @@ func TestInsert(t *testing.T) { for _, tcase := range testCases { t.Run(tcase.Name, func(t *testing.T) { - _, err := vet.ValidateSqlQuery(mockCtx, tcase.Query) + _, err := vet.ValidateSqlQuery(mockCtx(), tcase.Query) if err != nil { vet.DebugQuery(tcase.Query) } @@ -239,7 +245,7 @@ func TestInvalidInsert(t *testing.T) { for _, tcase := range testCases { t.Run(tcase.Name, func(t *testing.T) { - _, err := vet.ValidateSqlQuery(mockCtx, tcase.Query) + _, err := vet.ValidateSqlQuery(mockCtx(), tcase.Query) if err == nil { vet.DebugQuery(tcase.Query) } @@ -349,7 +355,7 @@ func TestInvalidSelect(t *testing.T) { for _, tcase := range testCases { t.Run(tcase.Name, func(t *testing.T) { - qparams, err := vet.ValidateSqlQuery(mockCtx, tcase.Query) + qparams, err := vet.ValidateSqlQuery(mockCtx(), tcase.Query) if err == nil { vet.DebugQuery(tcase.Query) } @@ -404,19 +410,48 @@ func TestSelect(t *testing.T) { `SELECT id FROM foo WHERE value IS NULL`, }, { - "select with multiple joins", - `SELECT id + "select with join", + `SELECT id, coalesce(count,0) + FROM foo + LEFT JOIN bar b ON b.id = foo.id + WHERE value IS NULL`, + }, + { + "select with multiple joins with sub select", + `SELECT id, coalesce(bzz.created_at,0), coalesce(bzzz.created_at,0) FROM foo LEFT JOIN bar b ON b.id = foo.id LEFT JOIN foo f ON f.id = foo.id LEFT JOIN baz bz ON bz.id = foo.id + LEFT JOIN LATERAL (SELECT created_at from baz) bzz ON true + LEFT JOIN LATERAL (SELECT created_at from baz) AS bzzz ON true + WHERE value IS NULL`, + }, + { + "select with single left join", + `SELECT id, f.id, coalesce(bzz.created_at,0) + FROM foo as f + LEFT JOIN LATERAL ( + SELECT *, created_at, b.created_at, coalesce(baz_count,0), coalesce(baz_count,0) AS b_created_at + FROM baz b + ) bzz ON true + WHERE value IS NULL`, + }, + { + "select with single left join and linked where", + `SELECT id, f.id, coalesce(bzz.created_at,0) + FROM foo as f + LEFT JOIN LATERAL ( + SELECT *, created_at, b.created_at, coalesce(baz_count,0), coalesce(baz_count,0) as b_created_at + FROM baz b + WHERE f.id = b.id) bzz ON true WHERE value IS NULL`, }, } for _, tcase := range testCases { t.Run(tcase.Name, func(t *testing.T) { - qparams, err := vet.ValidateSqlQuery(mockCtx, tcase.Query) + qparams, err := vet.ValidateSqlQuery(mockCtx(), tcase.Query) if err != nil { vet.DebugQuery(tcase.Query) } @@ -463,7 +498,7 @@ func TestUpdate(t *testing.T) { for _, tcase := range testCases { t.Run(tcase.Name, func(t *testing.T) { - qparams, err := vet.ValidateSqlQuery(mockCtx, tcase.Query) + qparams, err := vet.ValidateSqlQuery(mockCtx(), tcase.Query) if err != nil { vet.DebugQuery(tcase.Query) } @@ -523,7 +558,7 @@ func TestInvalidUpdate(t *testing.T) { for _, tcase := range testCases { t.Run(tcase.Name, func(t *testing.T) { - qparams, err := vet.ValidateSqlQuery(mockCtx, tcase.Query) + qparams, err := vet.ValidateSqlQuery(mockCtx(), tcase.Query) assert.Equal(t, tcase.Err, err) assert.Equal(t, 0, len(qparams)) }) @@ -555,7 +590,7 @@ func TestDelete(t *testing.T) { for _, tcase := range testCases { t.Run(tcase.Name, func(t *testing.T) { - qparams, err := vet.ValidateSqlQuery(mockCtx, tcase.Query) + qparams, err := vet.ValidateSqlQuery(mockCtx(), tcase.Query) if err != nil { vet.DebugQuery(tcase.Query) } @@ -629,7 +664,7 @@ func TestInvalidDelete(t *testing.T) { for _, tcase := range testCases { t.Run(tcase.Name, func(t *testing.T) { - qparams, err := vet.ValidateSqlQuery(mockCtx, tcase.Query) + qparams, err := vet.ValidateSqlQuery(mockCtx(), tcase.Query) assert.Equal(t, tcase.Err, err) assert.Equal(t, 0, len(qparams)) }) @@ -676,7 +711,7 @@ func TestQueryParams(t *testing.T) { for _, tcase := range testCases { t.Run(tcase.Name, func(t *testing.T) { - qparams, err := vet.ValidateSqlQuery(mockCtx, tcase.Query) + qparams, err := vet.ValidateSqlQuery(mockCtx(), tcase.Query) if err != nil { vet.DebugQuery(tcase.Query) }