Skip to content

Commit

Permalink
fix: correctly identify infixed concats as potential SQL injections (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
audunmo authored Jul 25, 2023
1 parent 2292ed5 commit bf7feda
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 15 deletions.
54 changes: 50 additions & 4 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,46 @@ func GetChar(n ast.Node) (byte, error) {
return 0, fmt.Errorf("Unexpected AST node type: %T", n)
}

// GetStringRecursive will recursively walk down a tree of *ast.BinaryExpr. It will then concat the results, and return.
// Unlike the other getters, it does _not_ raise an error for unknown ast.Node types. At the base, the recursion will hit a non-BinaryExpr type,
// either BasicLit or other, so it's not an error case. It will only error if `strconv.Unquote` errors. This matters, because there's
// currently functionality that relies on error values being returned by GetString if and when it hits a non-basiclit string node type,
// hence for cases where recursion is needed, we use this separate function, so that we can still be backwards compatbile.
//
// This was added to handle a SQL injection concatenation case where the injected value is infixed between two strings, not at the start or end. See example below
//
// Do note that this will omit non-string values. So for example, if you were to use this node:
// ```go
// q := "SELECT * FROM foo WHERE name = '" + os.Args[0] + "' AND 1=1" // will result in "SELECT * FROM foo WHERE ” AND 1=1"

func GetStringRecursive(n ast.Node) (string, error) {
if node, ok := n.(*ast.BasicLit); ok && node.Kind == token.STRING {
return strconv.Unquote(node.Value)
}

if expr, ok := n.(*ast.BinaryExpr); ok {
x, err := GetStringRecursive(expr.X)
if err != nil {
return "", err
}

y, err := GetStringRecursive(expr.Y)
if err != nil {
return "", err
}

return x + y, nil
}

return "", nil
}

// GetString will read and return a string value from an ast.BasicLit
func GetString(n ast.Node) (string, error) {
if node, ok := n.(*ast.BasicLit); ok && node.Kind == token.STRING {
return strconv.Unquote(node.Value)
}

return "", fmt.Errorf("Unexpected AST node type: %T", n)
}

Expand Down Expand Up @@ -201,22 +236,21 @@ func GetCallStringArgsValues(n ast.Node, _ *Context) []string {
return values
}

// GetIdentStringValues return the string values of an Ident if they can be resolved
func GetIdentStringValues(ident *ast.Ident) []string {
func getIdentStringValues(ident *ast.Ident, stringFinder func(ast.Node) (string, error)) []string {
values := []string{}
obj := ident.Obj
if obj != nil {
switch decl := obj.Decl.(type) {
case *ast.ValueSpec:
for _, v := range decl.Values {
value, err := GetString(v)
value, err := stringFinder(v)
if err == nil {
values = append(values, value)
}
}
case *ast.AssignStmt:
for _, v := range decl.Rhs {
value, err := GetString(v)
value, err := stringFinder(v)
if err == nil {
values = append(values, value)
}
Expand All @@ -226,6 +260,18 @@ func GetIdentStringValues(ident *ast.Ident) []string {
return values
}

// getIdentStringRecursive returns the string of values of an Ident if they can be resolved
// The difference between this and GetIdentStringValues is that it will attempt to resolve the strings recursively,
// if it is passed a *ast.BinaryExpr. See GetStringRecursive for details
func GetIdentStringValuesRecursive(ident *ast.Ident) []string {
return getIdentStringValues(ident, GetStringRecursive)
}

// GetIdentStringValues return the string values of an Ident if they can be resolved
func GetIdentStringValues(ident *ast.Ident) []string {
return getIdentStringValues(ident, GetString)
}

// GetBinaryExprOperands returns all operands of a binary expression by traversing
// the expression tree
func GetBinaryExprOperands(be *ast.BinaryExpr) []ast.Node {
Expand Down
51 changes: 50 additions & 1 deletion rules/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,32 @@ func (s *sqlStrConcat) ID() string {
return s.MetaData.ID
}

// findInjectionInBranch walks diwb a set if expressions, and will create new issues if it finds SQL injections
// This method assumes you've already verified that the branch contains SQL syntax
func (s *sqlStrConcat) findInjectionInBranch(ctx *gosec.Context, branch []ast.Expr) *ast.BinaryExpr {
for _, node := range branch {
be, ok := node.(*ast.BinaryExpr)
if !ok {
continue
}

operands := gosec.GetBinaryExprOperands(be)

for _, op := range operands {
if _, ok := op.(*ast.BasicLit); ok {
continue
}

if ident, ok := op.(*ast.Ident); ok && s.checkObject(ident, ctx) {
continue
}

return be
}
}
return nil
}

// see if we can figure out what it is
func (s *sqlStrConcat) checkObject(n *ast.Ident, c *gosec.Context) bool {
if n.Obj != nil {
Expand Down Expand Up @@ -140,6 +166,28 @@ func (s *sqlStrConcat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*issu
}
}

// Handle the case where an injection occurs as an infixed string concatenation, ie "SELECT * FROM foo WHERE name = '" + os.Args[0] + "' AND 1=1"
if id, ok := query.(*ast.Ident); ok {
var match bool
for _, str := range gosec.GetIdentStringValuesRecursive(id) {
if s.MatchPatterns(str) {
match = true
break
}
}

if !match {
return nil, nil
}

switch decl := id.Obj.Decl.(type) {
case *ast.AssignStmt:
if injection := s.findInjectionInBranch(ctx, decl.Rhs); injection != nil {
return ctx.NewIssue(injection, s.ID(), s.What, s.Severity, s.Confidence), nil
}
}
}

return nil, nil
}

Expand All @@ -157,6 +205,7 @@ func (s *sqlStrConcat) Match(n ast.Node, ctx *gosec.Context) (*issue.Issue, erro
return s.checkQuery(sqlQueryCall, ctx)
}
}

return nil, nil
}

Expand All @@ -165,7 +214,7 @@ func NewSQLStrConcat(id string, _ gosec.Config) (gosec.Rule, []ast.Node) {
rule := &sqlStrConcat{
sqlStatement: sqlStatement{
patterns: []*regexp.Regexp{
regexp.MustCompile(`(?i)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE) `),
regexp.MustCompile("(?i)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE)( |\n|\r|\t)"),
},
MetaData: issue.MetaData{
ID: id,
Expand Down
52 changes: 42 additions & 10 deletions testutils/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -1712,6 +1712,28 @@ func main() {
// SampleCodeG202 - SQL query string building via string concatenation
SampleCodeG202 = []CodeSample{
{[]string{`
// infixed concatenation
package main
import (
"database/sql"
"os"
)
func main(){
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
panic(err)
}
q := "INSERT INTO foo (name) VALUES ('" + os.Args[0] + "')"
rows, err := db.Query(q)
if err != nil {
panic(err)
}
defer rows.Close()
}`}, 1, gosec.NewConfig()},
{[]string{`
package main
import (
Expand All @@ -1729,7 +1751,8 @@ func main(){
panic(err)
}
defer rows.Close()
}`}, 1, gosec.NewConfig()}, {[]string{`
}`}, 1, gosec.NewConfig()},
{[]string{`
// case insensitive match
package main
Expand All @@ -1748,7 +1771,8 @@ func main(){
panic(err)
}
defer rows.Close()
}`}, 1, gosec.NewConfig()}, {[]string{`
}`}, 1, gosec.NewConfig()},
{[]string{`
// context match
package main
Expand All @@ -1768,7 +1792,8 @@ func main(){
panic(err)
}
defer rows.Close()
}`}, 1, gosec.NewConfig()}, {[]string{`
}`}, 1, gosec.NewConfig()},
{[]string{`
// DB transaction check
package main
Expand Down Expand Up @@ -1796,7 +1821,8 @@ func main(){
if err := tx.Commit(); err != nil {
panic(err)
}
}`}, 1, gosec.NewConfig()}, {[]string{`
}`}, 1, gosec.NewConfig()},
{[]string{`
// multiple string concatenation
package main
Expand All @@ -1815,7 +1841,8 @@ func main(){
panic(err)
}
defer rows.Close()
}`}, 1, gosec.NewConfig()}, {[]string{`
}`}, 1, gosec.NewConfig()},
{[]string{`
// false positive
package main
Expand All @@ -1834,7 +1861,8 @@ func main(){
panic(err)
}
defer rows.Close()
}`}, 0, gosec.NewConfig()}, {[]string{`
}`}, 0, gosec.NewConfig()},
{[]string{`
package main
import (
Expand All @@ -1856,7 +1884,8 @@ func main(){
}
defer rows.Close()
}
`}, 0, gosec.NewConfig()}, {[]string{`
`}, 0, gosec.NewConfig()},
{[]string{`
package main
const gender = "M"
Expand All @@ -1882,7 +1911,8 @@ func main(){
}
defer rows.Close()
}
`}, 0, gosec.NewConfig()}, {[]string{`
`}, 0, gosec.NewConfig()},
{[]string{`
// ExecContext match
package main
Expand All @@ -1903,7 +1933,8 @@ func main() {
panic(err)
}
fmt.Println(result)
}`}, 1, gosec.NewConfig()}, {[]string{`
}`}, 1, gosec.NewConfig()},
{[]string{`
// Exec match
package main
Expand All @@ -1923,7 +1954,8 @@ func main() {
panic(err)
}
fmt.Println(result)
}`}, 1, gosec.NewConfig()}, {[]string{`
}`}, 1, gosec.NewConfig()},
{[]string{`
package main
import (
Expand Down

0 comments on commit bf7feda

Please sign in to comment.