Skip to content
This repository has been archived by the owner on Sep 21, 2021. It is now read-only.

Commit

Permalink
Add support for sqlx
Browse files Browse the repository at this point in the history
  • Loading branch information
richo-stripe authored and richo committed Jun 8, 2016
1 parent 452e37e commit a0f208b
Showing 1 changed file with 66 additions and 3 deletions.
69 changes: 66 additions & 3 deletions safesql.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ import (
"golang.org/x/tools/go/ssa/ssautil"
)

type Config struct {
Sqlx bool
DatabaseSql bool
}

const SQLX string = "github.com/jmoiron/sqlx"

func main() {
var verbose, quiet bool
flag.BoolVar(&verbose, "v", false, "Verbose mode")
Expand All @@ -28,6 +35,11 @@ func main() {
flag.PrintDefaults()
}

config := Config{
Sqlx: false,
DatabaseSql: false,
}

flag.Parse()
pkgs := flag.Args()
if len(pkgs) == 0 {
Expand All @@ -38,21 +50,53 @@ func main() {
c := loader.Config{
FindPackage: FindPackage,
}
c.Import("database/sql")
for _, pkg := range pkgs {
c.Import(pkg)
}
p, err := c.Load()

imports := GetImports(p)

if _, exist := imports["database/sql"]; exist {
if verbose {
fmt.Println("Enabling support for database/sql")
}
config.DatabaseSql = true
}

if _, exist := imports["github.com/jmoiron/sqlx"]; exist {
if verbose {
fmt.Println("Enabling support for sqlx")
}
config.Sqlx = true
}

if !(config.Sqlx || config.DatabaseSql) {
fmt.Printf("No packages in %v include a supported database driver", pkgs)
os.Exit(2)
}

if err != nil {
fmt.Printf("error loading packages %v: %v\n", pkgs, err)
os.Exit(2)
}

GetImports(p)

s := ssautil.CreateProgram(p, 0)
s.Build()

qms := FindQueryMethods(p.Package("database/sql").Pkg, s)
qms := make([]*QueryMethod, 0)

if config.DatabaseSql {
qms = append(qms, FindQueryMethods(p.Package("database/sql").Pkg, s)...)
}
if config.Sqlx {
qms = append(qms, FindQueryMethods(p.Package(SQLX).Pkg, s)...)
}

if verbose {
fmt.Println("database/sql functions that accept queries:")
fmt.Println("database driver functions that accept queries:")
for _, m := range qms {
fmt.Printf("- %s (param %d)\n", m.Func, m.Param)
}
Expand All @@ -75,6 +119,7 @@ func main() {
}

bad := FindNonConstCalls(res.CallGraph, qms)

if len(bad) == 0 {
if !quiet {
fmt.Println(`You're safe from SQL injection! Yay \o/`)
Expand Down Expand Up @@ -164,6 +209,17 @@ func FindMains(p *loader.Program, s *ssa.Program) []*ssa.Package {
return mains
}

func GetImports(p *loader.Program) map[string]interface{} {
packages := make(map[string]interface{})
for _, info := range p.AllPackages {
// Invert the map so we can do lookups more easily
if info.Importable {
packages[info.Pkg.Path()] = nil
}
}
return packages
}

// FindNonConstCalls returns the set of callsites of the given set of methods
// for which the "query" parameter is not a compile-time constant.
func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstruction {
Expand Down Expand Up @@ -196,6 +252,13 @@ func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstru
}
v := args[m.Param]
if _, ok := v.(*ssa.Const); !ok {
// This is super lurky, but sqlx wants to hand query objects about under
// the hood. We could do clever taint analysis, but it's easier
// to just bless the innards of sqlx internally, and rely on it
// to do Reasonable Things under the hood.
if edge.Caller.Func.Pkg.Pkg.Path() == SQLX {
continue
}
bad = append(bad, edge.Site)
}
}
Expand Down

0 comments on commit a0f208b

Please sign in to comment.