From 5a026ff40196f6ae57b12136f80b88bd6062950a Mon Sep 17 00:00:00 2001 From: Bobby Powers Date: Fri, 6 Sep 2019 18:29:30 -0700 Subject: [PATCH 1/4] remove support for Go < 1.5 --- package15.go | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 package15.go diff --git a/package15.go b/package15.go deleted file mode 100644 index 4c05efc..0000000 --- a/package15.go +++ /dev/null @@ -1,7 +0,0 @@ -// +build !go1.6 - -package main - -import "os" - -var useVendor = os.Getenv("GO15VENDOREXPERIMENT") == "1" From 666aaddf5d5485090328252c876b6b1945581f51 Mon Sep 17 00:00:00 2001 From: Bobby Powers Date: Fri, 6 Sep 2019 20:43:35 -0700 Subject: [PATCH 2/4] remove vendoring stuff --- package16.go | 7 ------- safesql.go | 50 +++----------------------------------------------- 2 files changed, 3 insertions(+), 54 deletions(-) delete mode 100644 package16.go diff --git a/package16.go b/package16.go deleted file mode 100644 index 409faaf..0000000 --- a/package16.go +++ /dev/null @@ -1,7 +0,0 @@ -// +build go1.6 - -package main - -import "os" - -var useVendor = os.Getenv("GO15VENDOREXPERIMENT") == "0" || os.Getenv("GO15VENDOREXPERIMENT") == "" diff --git a/safesql.go b/safesql.go index adf8bb8..6b74468 100644 --- a/safesql.go +++ b/safesql.go @@ -10,9 +10,6 @@ import ( "go/types" "os" - "path/filepath" - "strings" - "golang.org/x/tools/go/callgraph" "golang.org/x/tools/go/loader" "golang.org/x/tools/go/pointer" @@ -58,7 +55,9 @@ func main() { } c := loader.Config{ - FindPackage: FindPackage, + FindPackage: func(ctx *build.Context, path, dir string, mode build.ImportMode) (*build.Package, error) { + return ctx.Import(path, dir, mode) + }, } for _, pkg := range pkgs { c.Import(pkg) @@ -284,46 +283,3 @@ func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstru return bad } - -// Deal with GO15VENDOREXPERIMENT -func FindPackage(ctxt *build.Context, path, dir string, mode build.ImportMode) (*build.Package, error) { - if !useVendor { - return ctxt.Import(path, dir, mode) - } - - // First, walk up the filesystem from dir looking for vendor directories - var vendorDir string - for tmp := dir; vendorDir == "" && tmp != "/"; tmp = filepath.Dir(tmp) { - dname := filepath.Join(tmp, "vendor", filepath.FromSlash(path)) - fd, err := os.Open(dname) - if err != nil { - continue - } - // Directories are only valid if they contain at least one file - // with suffix ".go" (this also ensures that the file descriptor - // we have is in fact a directory) - names, err := fd.Readdirnames(-1) - if err != nil { - continue - } - for _, name := range names { - if strings.HasSuffix(name, ".go") { - vendorDir = filepath.ToSlash(dname) - break - } - } - } - - if vendorDir != "" { - pkg, err := ctxt.ImportDir(vendorDir, mode) - if err != nil { - return nil, err - } - // Go tries to derive a valid import path for the package, but - // it's wrong (it includes "/vendor/"). Overwrite it here. - pkg.ImportPath = path - return pkg, nil - } - - return ctxt.Import(path, dir, mode) -} From 1d492a3c94bdb3e3dcb30dcb5c399ecfe74756e8 Mon Sep 17 00:00:00 2001 From: Bobby Powers Date: Fri, 6 Sep 2019 18:32:17 -0700 Subject: [PATCH 3/4] build: add .gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e215659 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/safesql +/.idea +*~ From 1f8ea970d6469d2c5fda6842b738ab8078532bfe Mon Sep 17 00:00:00 2001 From: Bobby Powers Date: Fri, 6 Sep 2019 20:32:42 -0700 Subject: [PATCH 4/4] use the modern analyzer framework for safesql --- cmd/safesql/main.go | 13 ++ go.mod | 5 + go.sum | 8 + safesql.go | 343 +++++++++++++++++++++++------------- safesql_test.go | 17 ++ testdata/src/a_pass/main.go | 30 ++++ 6 files changed, 289 insertions(+), 127 deletions(-) create mode 100644 cmd/safesql/main.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 safesql_test.go create mode 100644 testdata/src/a_pass/main.go diff --git a/cmd/safesql/main.go b/cmd/safesql/main.go new file mode 100644 index 0000000..0d487b7 --- /dev/null +++ b/cmd/safesql/main.go @@ -0,0 +1,13 @@ +// Command safesql is a tool for performing static analysis on programs to +// ensure that SQL injection attacks are not possible. It does this by ensuring +// package database/sql is only used with compile-time constant queries. +package main + +import ( + "github.com/bpowers/safesql" + "golang.org/x/tools/go/analysis/singlechecker" +) + +func main() { + singlechecker.Main(safesql.Analyzer) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..0791b5e --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module github.com/bpowers/safesql + +go 1.13 + +require golang.org/x/tools v0.0.0-20190909030654-5b82db07426d diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..30fb0d8 --- /dev/null +++ b/go.sum @@ -0,0 +1,8 @@ +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20190909030654-5b82db07426d h1:PhtdWYteEBebOX7KXm4qkIAVSUTHQ883/2hRB92r9lk= +golang.org/x/tools v0.0.0-20190909030654-5b82db07426d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/safesql.go b/safesql.go index 6b74468..ed5ce2d 100644 --- a/safesql.go +++ b/safesql.go @@ -1,147 +1,245 @@ // Command safesql is a tool for performing static analysis on programs to // ensure that SQL injection attacks are not possible. It does this by ensuring // package database/sql is only used with compile-time constant queries. -package main +package safesql import ( - "flag" + "errors" "fmt" - "go/build" + "go/ast" + "go/token" "go/types" + "log" "os" + "strings" + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/buildssa" + "golang.org/x/tools/go/analysis/passes/inspect" + "golang.org/x/tools/go/ast/inspector" "golang.org/x/tools/go/callgraph" - "golang.org/x/tools/go/loader" "golang.org/x/tools/go/pointer" "golang.org/x/tools/go/ssa" - "golang.org/x/tools/go/ssa/ssautil" + "golang.org/x/tools/go/types/typeutil" ) -type sqlPackage struct { - packageName string - paramNames []string - enable bool -} +const Doc = `ensure SQL injection attacks are not possible -var sqlPackages = []sqlPackage{ - { - packageName: "database/sql", - paramNames: []string{"query"}, - }, - { - packageName: "github.com/jinzhu/gorm", - paramNames: []string{"sql", "query"}, - }, - { - packageName: "github.com/jmoiron/sqlx", - paramNames: []string{"query"}, +The safesql analysis reports calls to DB functions are only made with constant strings.` + +var Analyzer = &analysis.Analyzer{ + Name: "safesql", + Doc: Doc, + Run: run, + Requires: []*analysis.Analyzer{ + buildssa.Analyzer, + inspect.Analyzer, }, + FactTypes: []analysis.Fact{new(unsafeCallFact)}, } -func main() { - var verbose, quiet bool - flag.BoolVar(&verbose, "v", false, "Verbose mode") - flag.BoolVar(&quiet, "q", false, "Only print on failure") - flag.Usage = func() { - fmt.Fprintf(os.Stderr, "Usage: %s [-q] [-v] package1 [package2 ...]\n", os.Args[0]) - flag.PrintDefaults() - } - - flag.Parse() - pkgs := flag.Args() - if len(pkgs) == 0 { - flag.Usage() - os.Exit(2) - } +// unsafeCallFact represents a call to a SQL execution function that isn't a +// provably constant string. +type unsafeCallFact struct { + Pos token.Pos +} - c := loader.Config{ - FindPackage: func(ctx *build.Context, path, dir string, mode build.ImportMode) (*build.Package, error) { - return ctx.Import(path, dir, mode) - }, - } - for _, pkg := range pkgs { - c.Import(pkg) - } - p, err := c.Load() +func (*unsafeCallFact) String() string { return "found" } +func (*unsafeCallFact) AFact() {} - if err != nil { - fmt.Printf("error loading packages %v: %v\n", pkgs, err) - os.Exit(2) - } +// run performs the safesql analysis on a single package; it may be called +// multiple times during a single execution of the binary, once per dependency. +func run(pass *analysis.Pass) (interface{}, error) { - imports := getImports(p) - existOne := false - for i := range sqlPackages { - if _, exist := imports[sqlPackages[i].packageName]; exist { - if verbose { - fmt.Printf("Enabling support for %s\n", sqlPackages[i].packageName) - } - sqlPackages[i].enable = true - existOne = true + // package database/sql has a couple helper functions which are thin + // wrappers around other sensitive functions. Instead of handling the + // general case by tracing down callsites of wrapper functions + // recursively, let's just allowlist these DB packages, since it + // happens to be good enough for our use case. + for _, sql := range sqlPackages { + if strings.HasPrefix(pass.Pkg.Path(), sql.packageName) { + return nil, nil } } - if !existOne { - fmt.Printf("No packages in %v include a supported database driver", pkgs) - os.Exit(2) - } - s := ssautil.CreateProgram(p, 0) - s.Build() + log.Printf("-- %s --\n", pass.Pkg.Path()) - qms := make([]*QueryMethod, 0) + // TODO: we should only need one of these + var err error + err = CheckSafeSqlSsa(pass) + err = CheckSafeSqlAst(pass) - for i := range sqlPackages { - if sqlPackages[i].enable { - qms = append(qms, FindQueryMethods(sqlPackages[i], p.Package(sqlPackages[i].packageName).Pkg, s)...) + return nil, err +} + +// This more closely matches the original safesql implementation, but doesn't +// actually work. See the big comment in the middle for details +func CheckSafeSqlSsa(pass *analysis.Pass) error { + // we listed this as a dependency above; it is guaranteed to have run + ssaPass := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA) + + prog := ssaPass.Pkg.Prog + prog.Build() + + qms := make([]*QueryMethod, 0) + for _, sql := range sqlPackages { + var pkg *ssa.Package + for _, usedPkg := range prog.AllPackages() { + if usedPkg.Pkg.Path() == sql.packageName { + pkg = usedPkg + break + } } + // the SQL package we were worried about isn't used in this module! + if pkg == nil { + continue + } + qms = append(qms, FindQueryMethods(sql, pkg.Pkg, prog)...) } - if verbose { - fmt.Println("database driver functions that accept queries:") - for _, m := range qms { - fmt.Printf("- %s (param %d)\n", m.Func, m.Param) + if pass.Pkg.Path() == "a_pass" { + for _, fn := range ssaPass.SrcFuncs { + log.Printf("srcfunc: %s", fn.Name()) } - fmt.Println() } - mains := FindMains(p, s) - if len(mains) == 0 { - fmt.Println("Did not find any commands (i.e., main functions).") - os.Exit(2) + // the pointer.Analyze function below only works on packages with that + // _literally_ have main functions. + if ssaPass.Pkg.Func("main") == nil { + return nil } - res, err := pointer.Analyze(&pointer.Config{ - Mains: mains, + res, err2 := pointer.Analyze(&pointer.Config{ + Mains: []*ssa.Package{ssaPass.Pkg}, BuildCallGraph: true, + // Log: os.Stdout, }) - if err != nil { - fmt.Printf("error performing pointer analysis: %v\n", err) + if err2 != nil { + fmt.Printf("error performing pointer analysis: %v\n", err2) os.Exit(2) } + // XXX: at this point, the callgraph doesn't contain edges from our SQL + // callsites to e.g. DB.Exec. I think there are two explanations: 1) it + // is a Go modules thing. 2) it is something that broke when moving away from + // the deprecated loader package. I am pretty sure it is the second -- I + // rebuilt my local go as go1.11.13, ran `export GO111MODULE=off` in a terminal + // and ran the test, and still see the same behavior below. + + // for example, when running the test, we see: + // + // fn main -- []*callgraph.Edge{(*callgraph.Edge)(0xc00dbb4d20)} + // n5:a_pass.main --> n6:a_pass.runDbQuery + // fn runDbQuery -- []*callgraph.Edge{} + // + // main is shown to have a single edge, to runDbQuery, and runDbQuery has + // no edges. This is wrong on both accounts - main also has a call to log.Printf, + // and runDbQuery has a call to DB.Exec. + bad := FindNonConstCalls(res.CallGraph, qms) + log.Printf("!! found %v non-const calls", bad) - if len(bad) == 0 { - if !quiet { - fmt.Println(`You're safe from SQL injection! Yay \o/`) - } - return + for _, ci := range bad { + pos := prog.Fset.Position(ci.Pos()) + fmt.Printf("- %s\n", pos) } - if verbose { - fmt.Printf("Found %d potentially unsafe SQL statements:\n", len(bad)) + var err error + if len(bad) > 0 { + err = fmt.Errorf("found %d safesql errors", len(bad)) } - for _, ci := range bad { - pos := p.Fset.Position(ci.Pos()) - fmt.Printf("- %s\n", pos) + return err +} + +// This was my first approach at a Go 1.13+ version of safesql; the problem +// here is that the AST is very high level; if you have a package-level const +// string, the functions like db.Exec will receive an identifier, not a string +// literal. I guess we could look up the identifier, and see if it resolves +// immediately to a string literal? That might be an easy way to match the +// current behavior, but IDK if it will be easy to extend to more things that +// act as false positives today. +func CheckSafeSqlAst(pass *analysis.Pass) error { + inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) + nodeFilter := []ast.Node{ + &ast.CallExpr{}, } - if verbose { - fmt.Println("Please ensure that all SQL queries you use are compile-time constants.") - fmt.Println("You should always use parameterized queries or prepared statements") - fmt.Println("instead of building queries from strings.") + + nErrors := 0 + inspect.Preorder(nodeFilter, func(n ast.Node) { + call := n.(*ast.CallExpr) + fn, ok := typeutil.Callee(pass.TypesInfo, call).(*types.Func) + if !ok { + // log.Printf("call Fun not a Func? %#v\n", call.Fun) + return + } + + for _, sql := range sqlPackages { + if fn.Pkg() != nil && fn.Pkg().Path() != sql.packageName { + continue + } + + sig := fn.Type().(*types.Signature) + params := sig.Params() + for i := 0; i < params.Len(); i++ { + v := params.At(i) + if _, ok := sql.paramNames[v.Name()]; !ok { + continue + } + arg := call.Args[i] + lit, ok := arg.(*ast.BasicLit) + if !ok { + nErrors++ + // this will trigger even for _identifiers_ that point to static strings + pass.Reportf(arg.Pos(), "SQL query with non-static argument: %s", arg) + continue + } + if lit.Kind != token.STRING { + nErrors++ + pass.Reportf(arg.Pos(), "SQL query with non-string literal: %s", arg) + log.Printf("bad bad") + continue + } + log.Printf("all good") + } + } + }) + + var err error + if nErrors != 0 { + err = errors.New("potentially unsafe SQL queries found") } - os.Exit(1) + + return err +} + +type sqlPackage struct { + packageName string + paramNames map[string]struct{} + enable bool + pkg *ssa.Package +} + +var sqlPackages = []sqlPackage{ + { + packageName: "database/sql", + paramNames: map[string]struct{}{ + "query": {}, + }, + }, + { + packageName: "github.com/jinzhu/gorm", + paramNames: map[string]struct{}{ + "sql": {}, + "query": {}, + }, + }, + { + packageName: "github.com/jmoiron/sqlx", + paramNames: map[string]struct{}{ + "query": {}, + }, + }, } // QueryMethod represents a method on a type which has a string parameter named @@ -174,9 +272,10 @@ func FindQueryMethods(sqlPackages sqlPackage, sql *types.Package, ssa *ssa.Progr } s := m.Type().(*types.Signature) if num, ok := FuncHasQuery(sqlPackages, s); ok { + fn := ssa.FuncValue(m) methods = append(methods, &QueryMethod{ Func: m, - SSA: ssa.FuncValue(m), + SSA: fn, ArgCount: s.Params().Len(), Param: num, }) @@ -192,39 +291,13 @@ func FuncHasQuery(sqlPackages sqlPackage, s *types.Signature) (offset int, ok bo params := s.Params() for i := 0; i < params.Len(); i++ { v := params.At(i) - for _, paramName := range sqlPackages.paramNames { - if v.Name() == paramName { - return i, true - } + if _, ok := sqlPackages.paramNames[v.Name()]; ok { + return i, true } } return 0, false } -// FindMains returns the set of all packages loaded into the given -// loader.Program which contain main functions -func FindMains(p *loader.Program, s *ssa.Program) []*ssa.Package { - ips := p.InitialPackages() - mains := make([]*ssa.Package, 0, len(ips)) - for _, info := range ips { - ssaPkg := s.Package(info.Pkg) - if ssaPkg.Func("main") != nil { - mains = append(mains, ssaPkg) - } - } - return mains -} - -func getImports(p *loader.Program) map[string]interface{} { - pkgs := make(map[string]interface{}) - for _, pkg := range p.AllPackages { - if pkg.Importable { - pkgs[pkg.Pkg.Path()] = nil - } - } - return pkgs -} - // FindNonConstCalls returns the set of callsites of the given set of methods // for which the "query" parameter is not a compile-time constant. func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstruction { @@ -240,10 +313,25 @@ func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstru okFuncs[m.SSA] = struct{}{} } + for fn, node := range cg.Nodes { + if fn.Name() == "main" || fn.Name() == "runDbQuery" { + fmt.Printf("fn %s -- %#v\n", fn.Name(), node.Out) + for _, out := range node.Out { + fmt.Printf(" %s\n", out) + } + } + } + bad := make([]ssa.CallInstruction, 0) for _, m := range qms { - node := cg.CreateNode(m.SSA) + node := cg.Nodes[m.SSA] + if node == nil { + continue + } + + fmt.Printf("func %s contains callees %#v\n", m.Func, node.In) for _, edge := range node.In { + fmt.Printf("found an edge\n") if _, ok := okFuncs[edge.Site.Parent()]; ok { continue } @@ -268,6 +356,7 @@ func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstru panic("arg count mismatch") } v := args[m.Param] + fmt.Printf("found the call!!\n") if _, ok := v.(*ssa.Const); !ok { if inter, ok := v.(*ssa.MakeInterface); ok && types.IsInterface(v.(*ssa.MakeInterface).Type()) { diff --git a/safesql_test.go b/safesql_test.go new file mode 100644 index 0000000..864e041 --- /dev/null +++ b/safesql_test.go @@ -0,0 +1,17 @@ +package safesql_test + +import ( + "testing" + + "github.com/bpowers/safesql" + "golang.org/x/tools/go/analysis/analysistest" +) + +func init() { + safesql.Analyzer.Flags.Set("name", "safesql") +} + +func TestFromFileSystem(t *testing.T) { + testdata := analysistest.TestData() + analysistest.Run(t, testdata, safesql.Analyzer, "a_pass") +} diff --git a/testdata/src/a_pass/main.go b/testdata/src/a_pass/main.go new file mode 100644 index 0000000..7690d62 --- /dev/null +++ b/testdata/src/a_pass/main.go @@ -0,0 +1,30 @@ +package main + +import ( + "context" + "database/sql" + "log" +) + +var ( + ctx context.Context + db *sql.DB +) + +func runDbQuery(db *sql.DB) { + sqlStmt := ` + create table foo (id integer not null primary key, name text); + delete from foo; + ` + + if _, err := db.Exec(sqlStmt); err != nil { + log.Printf("%q: %s\n", err, sqlStmt) + return + } +} + +func main() { + runDbQuery(db) + + log.Printf("holy moley\n") +}