diff --git a/vulncheck/witness.go b/vulncheck/witness.go index 49893e345..57c74ca58 100644 --- a/vulncheck/witness.go +++ b/vulncheck/witness.go @@ -1,7 +1,14 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package vulncheck import ( "container/list" + "fmt" + "sort" + "strings" "sync" ) @@ -10,18 +17,18 @@ import ( // known vulnerabilities. type ImportChain []*PkgNode -// ImportChains performs a BFS search of res.RequireGraph for imports of vulnerable -// packages. Search is performed for each vulnerable package in res.Vulns. The search -// starts at a vulnerable package and goes up until reaching an entry package in -// res.ImportGraph.Entries, hence producing an import chain. During the search, a -// package is visited only once to avoid analyzing every possible import chain. -// Hence, not all possible vulnerable import chains are reported. +// ImportChains lists import chains for each vulnerability in res. The +// reported chains are ordered by how seemingly easy is to understand +// them. Shorter import chains appear earlier in the returned slices. // -// Note that the resulting map produces an import chain for each Vuln. Thus, a Vuln -// with the same PkgPath will have the same list of identified import chains. +// ImportChains does not list all import chains for a vulnerability. +// It performs a BFS search of res.RequireGraph starting at a vulnerable +// package and going up until reaching an entry package in res.ImportGraph.Entries. +// During this search, a package is visited only once to avoid analyzing +// every possible import chain. // -// The reported import chains are ordered by how seemingly easy is to understand -// them. Shorter import chains appear earlier in the returned slices. +// Note that the resulting map produces an import chain for each Vuln. Vulns +// with the same PkgPath will have the same list of identified import chains. func ImportChains(res *Result) map[*Vuln][]ImportChain { // Group vulns per package. vPerPkg := make(map[int][]*Vuln) @@ -122,3 +129,175 @@ type StackEntry struct { // nil when the frame represents an entry point of the stack. Call *CallSite } + +// CallStacks lists call stacks for each vulnerability in res. The listed call +// stacks are ordered by how seemingly easy is to understand them. In general, +// shorter call stacks with less dynamic call sites appear earlier in the returned +// call stack slices. +// +// CallStacks does not report every possible call stack for a vulnerable symbol. +// It performs a BFS search of res.CallGraph starting at the symbol and going up +// until reaching an entry function or method in res.CallGraph.Entries. During +// this search, each function is visited at most once to avoid potential +// exponential explosion, thus skipping some call stacks. +func CallStacks(res *Result) map[*Vuln][]CallStack { + var ( + wg sync.WaitGroup + mu sync.Mutex + ) + stacksPerVuln := make(map[*Vuln][]CallStack) + for _, vuln := range res.Vulns { + vuln := vuln + wg.Add(1) + go func() { + cs := callStacks(vuln.CallSink, res) + // sort call stacks by the estimated value to the user + sort.SliceStable(cs, func(i int, j int) bool { return stackLess(cs[i], cs[j]) }) + mu.Lock() + stacksPerVuln[vuln] = cs + mu.Unlock() + wg.Done() + }() + } + + wg.Wait() + return stacksPerVuln +} + +// callStacks finds representative call stacks +// for vulnerable symbol identified with vulnSinkID. +func callStacks(vulnSinkID int, res *Result) []CallStack { + if vulnSinkID == 0 { + return nil + } + + entries := make(map[int]bool) + for _, e := range res.Calls.Entries { + entries[e] = true + } + + var stacks []CallStack + seen := make(map[int]bool) + + queue := list.New() + queue.PushBack(&callChain{f: res.Calls.Functions[vulnSinkID]}) + + for queue.Len() > 0 { + front := queue.Front() + c := front.Value.(*callChain) + queue.Remove(front) + + f := c.f + if seen[f.ID] { + continue + } + seen[f.ID] = true + + for _, cs := range f.CallSites { + callee := res.Calls.Functions[cs.Parent] + nStack := &callChain{f: callee, call: cs, child: c} + if entries[callee.ID] { + stacks = append(stacks, nStack.CallStack()) + } + queue.PushBack(nStack) + } + } + return stacks +} + +// callChain models a chain of function calls. +type callChain struct { + call *CallSite // nil for entry points + f *FuncNode + child *callChain +} + +// CallStack converts callChain to CallStack type. +func (c *callChain) CallStack() CallStack { + if c == nil { + return nil + } + return append(CallStack{StackEntry{Function: c.f, Call: c.call}}, c.child.CallStack()...) +} + +// weight computes an approximate measure of how easy is to understand the call +// stack when presented to the client as a witness. The smaller the value, the more +// understandable the stack is. Currently defined as the number of unresolved +// call sites in the stack. +func weight(stack CallStack) int { + w := 0 + for _, e := range stack { + if e.Call != nil && !e.Call.Resolved { + w += 1 + } + } + return w +} + +func isStdPackage(pkg string) bool { + if pkg == "" { + return false + } + // std packages do not have a "." in their path. For instance, see + // Contains in pkgsite/+/refs/heads/master/internal/stdlbib/stdlib.go. + if i := strings.IndexByte(pkg, '/'); i != -1 { + pkg = pkg[:i] + } + return !strings.Contains(pkg, ".") +} + +// confidence computes an approximate measure of whether the stack +// is realizeable in practice. Currently, it equals the number of call +// sites in stack that go through standard libraries. Such call stacks +// have been experimentally shown to often result in false positives. +func confidence(stack CallStack) int { + c := 0 + for _, e := range stack { + if isStdPackage(e.Function.PkgPath) { + c += 1 + } + } + return c +} + +// stackLess compares two call stacks in terms of their estimated +// value to the user. Shorter stacks generally come earlier in the ordering. +// +// Two stacks are lexicographically ordered by: +// 1) their estimated level of confidence in being a real call stack, +// 2) their length, and 3) the number of dynamic call sites in the stack. +func stackLess(s1, s2 CallStack) bool { + if c1, c2 := confidence(s1), confidence(s2); c1 != c2 { + return c1 < c2 + } + + if len(s1) != len(s2) { + return len(s1) < len(s2) + } + + if w1, w2 := weight(s1), weight(s2); w1 != w2 { + return w1 < w2 + } + // At this point we just need to make sure the ordering is deterministic. + // TODO(zpavlinovic): is there a more meaningful additional ordering? + return stackStrLess(s1, s2) +} + +// stackStrLess compares string representation of stacks. +func stackStrLess(s1, s2 CallStack) bool { + // Creates a unique string representation of a call stack + // for comparison purposes only. + stackStr := func(stack CallStack) string { + var stackStr []string + for _, cs := range stack { + s := cs.Function.String() + if cs.Call != nil && cs.Call.Pos != nil { + p := cs.Call.Pos + s = fmt.Sprintf("%s[%s:%d:%d:%d]", s, p.Filename, p.Line, p.Column, p.Offset) + } + stackStr = append(stackStr, s) + } + return strings.Join(stackStr, "->") + } + return strings.Compare(stackStr(s1), stackStr(s2)) <= 0 +} diff --git a/vulncheck/witness_test.go b/vulncheck/witness_test.go index eef6cf205..348b86d42 100644 --- a/vulncheck/witness_test.go +++ b/vulncheck/witness_test.go @@ -1,3 +1,7 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package vulncheck import ( @@ -24,6 +28,24 @@ func chainsToString(chains map[*Vuln][]ImportChain) map[string][]string { return m } +// stacksToString converts map *Vuln:stacks to Vuln.Symbol:["f1->...->fN", ...] +// string representation. +func stacksToString(stacks map[*Vuln][]CallStack) map[string][]string { + m := make(map[string][]string) + for v, sts := range stacks { + var stsStr []string + for _, st := range sts { + var stStr []string + for _, call := range st { + stStr = append(stStr, call.Function.Name) + } + stsStr = append(stsStr, strings.Join(stStr, "->")) + } + m[v.Symbol] = stsStr + } + return m +} + func TestImportChains(t *testing.T) { // Package import structure for the test program // entry1 entry2 @@ -61,3 +83,38 @@ func TestImportChains(t *testing.T) { t.Errorf("want %v; got %v", want, got) } } + +func TestCallStacks(t *testing.T) { + // Call graph structure for the test program + // entry1 entry2 + // | | + // interm1(std) | + // | \ / + // | interm2(interface) + // | / | + // vuln1 vuln2 + e1 := &FuncNode{ID: 1, Name: "entry1"} + e2 := &FuncNode{ID: 2, Name: "entry2"} + i1 := &FuncNode{ID: 3, Name: "interm1", PkgPath: "net/http", CallSites: []*CallSite{&CallSite{Parent: 1, Resolved: true}}} + i2 := &FuncNode{ID: 4, Name: "interm2", CallSites: []*CallSite{&CallSite{Parent: 2, Resolved: true}, &CallSite{Parent: 3, Resolved: true}}} + v1 := &FuncNode{ID: 5, Name: "vuln1", CallSites: []*CallSite{&CallSite{Parent: 3, Resolved: true}, &CallSite{Parent: 4, Resolved: false}}} + v2 := &FuncNode{ID: 6, Name: "vuln2", CallSites: []*CallSite{&CallSite{Parent: 4, Resolved: false}}} + + cg := &CallGraph{ + Functions: map[int]*FuncNode{1: e1, 2: e2, 3: i1, 4: i2, 5: v1, 6: v2}, + Entries: []int{1, 2}, + } + vuln1 := &Vuln{CallSink: 5, Symbol: "vuln1"} + vuln2 := &Vuln{CallSink: 6, Symbol: "vuln2"} + res := &Result{Calls: cg, Vulns: []*Vuln{vuln1, vuln2}} + + want := map[string][]string{ + "vuln1": {"entry2->interm2->vuln1", "entry1->interm1->vuln1"}, + "vuln2": {"entry2->interm2->vuln2", "entry1->interm1->interm2->vuln2"}, + } + + stacks := CallStacks(res) + if got := stacksToString(stacks); !reflect.DeepEqual(want, got) { + t.Errorf("want %v; got %v", want, got) + } +}