Skip to content

Commit

Permalink
cmd/tool/slowest: AST rewrite for slowest tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dnephin committed May 15, 2020
1 parent 2cc4019 commit ddc609f
Show file tree
Hide file tree
Showing 10 changed files with 260 additions and 75 deletions.
13 changes: 13 additions & 0 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package cmd

// Next splits args into the next positional argument and any remaining args.
func Next(args []string) (string, []string) {
switch len(args) {
case 0:
return "", nil
case 1:
return args[0], nil
default:
return args[0], args[1:]
}
}
29 changes: 21 additions & 8 deletions cmd/tool/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,32 @@ package tool

import (
"fmt"
"os"

"gotest.tools/gotestsum/cmd"
"gotest.tools/gotestsum/cmd/tool/slowest"
)

// Run one of the tool commands.
func Run(name string, args []string) error {
if len(args) == 0 {
// TOOD: print help
return fmt.Errorf("invalid command: %v", name)
}
switch args[0] {
next, rest := cmd.Next(args)
switch next {
case "":
fmt.Println(usage(name))
return nil
case "slowest":
return slowest.Run(name+" "+args[0], args[1:])
return slowest.Run(name+" "+next, rest)
default:
fmt.Fprintln(os.Stderr, usage(name))
return fmt.Errorf("invalid command: %v %v", name, next)
}
// TOOD: print help
return fmt.Errorf("invalid command: %v", name)
}

func usage(name string) string {
return fmt.Sprintf(`Usage: %s COMMAND [flags]
Commands: slowest
Use '%s COMMAND --help' for command specific help.
`, name, name)
}
141 changes: 141 additions & 0 deletions cmd/tool/slowest/ast.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package slowest

import (
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"os"
"strings"

"golang.org/x/tools/go/packages"
"gotest.tools/gotestsum/log"
"gotest.tools/gotestsum/testjson"
)

func writeTestSkip(tcs []testjson.TestCase, skipStmt ast.Stmt) error {
fset := token.NewFileSet()
cfg := packages.Config{
Mode: modeAll(),
Tests: true,
Fset: fset,
// FIXME: BuildFlags: strings.Split(os.Getenv("GOFLAGS"), " "),
}
pkgNames, index := testNamesByPkgName(tcs)
pkgs, err := packages.Load(&cfg, pkgNames...)
if err != nil {
return fmt.Errorf("failed to load packages: %w", err)
}

for _, pkg := range pkgs {
if len(pkg.Errors) > 0 {
return errPkgLoad(pkg)
}
tcs, ok := index[pkg.PkgPath]
if !ok {
log.Debugf("skipping %v, no slow tests", pkg.PkgPath)
continue
}

log.Debugf("rewriting %v for %d test cases", pkg.PkgPath, len(tcs))
for _, file := range pkg.Syntax {
path := fset.File(file.Pos()).Name()
log.Debugf("looking for test cases in: %v", path)
if !rewriteAST(file, tcs, skipStmt) {
continue
}
if err := writeFile(path, file, fset); err != nil {
return fmt.Errorf("failed to write ast to file %v: %w", path, err)
}
}
}
return errTestCasesNotFound(index)
}

// TODO: sometimes this writes the new AST with strange indentation. It appears
// to be non-deterministic. Given the same input, it only happens sometimes.
func writeFile(path string, file *ast.File, fset *token.FileSet) error {
fh, err := os.Create(path)
if err != nil {
return err
}
defer fh.Close()
return format.Node(fh, fset, file)
}

func parseSkipStatement(text string) (ast.Stmt, error) {
// Add some required boilerplate around the statement to make it a valid file
text = "package stub\nfunc Stub() {\n" + text + "\n}\n"
file, err := parser.ParseFile(token.NewFileSet(), "fragment", text, 0)
if err != nil {
return nil, err
}
stmt := file.Decls[0].(*ast.FuncDecl).Body.List[0]
return stmt, nil
}

func rewriteAST(file *ast.File, testNames set, skipStmt ast.Stmt) bool {
var modified bool
for _, decl := range file.Decls {
fd, ok := decl.(*ast.FuncDecl)
if !ok {
continue
}
name := fd.Name.Name // TODO: can this be nil?
if _, ok := testNames[name]; !ok {
continue
}

fd.Body.List = append([]ast.Stmt{skipStmt}, fd.Body.List...)
modified = true
delete(testNames, name)
}
return modified
}

type set map[string]struct{}

// FIXME: this should drop subtests from the index, so that errTestCasesNotFound
// does not report an error when we can't find the test function.
func testNamesByPkgName(tcs []testjson.TestCase) ([]string, map[string]set) {
pkgs := make([]string, 0, len(tcs))
index := make(map[string]set)
for _, tc := range tcs {
if len(index[tc.Package]) == 0 {
pkgs = append(pkgs, tc.Package)
index[tc.Package] = make(map[string]struct{})
}
index[tc.Package][tc.Test] = struct{}{}
}
return pkgs, index
}

func errPkgLoad(pkg *packages.Package) error {
buf := new(strings.Builder)
for _, err := range pkg.Errors {
buf.WriteString("\n" + err.Error())
}
return fmt.Errorf("failed to load package %v %v", pkg.PkgPath, buf.String())
}

func errTestCasesNotFound(index map[string]set) error {
var missed []string
for pkg, tcs := range index {
for tc := range tcs {
missed = append(missed, fmt.Sprintf("%v.%v", pkg, tc))
}
}
if len(missed) == 0 {
return nil
}
return fmt.Errorf("failed to find source for test cases: %v", strings.Join(missed, ","))
}

func modeAll() packages.LoadMode {
mode := packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles
mode = mode | packages.NeedImports | packages.NeedDeps
mode = mode | packages.NeedTypes | packages.NeedTypesSizes
mode = mode | packages.NeedSyntax | packages.NeedTypesInfo
return mode
}
67 changes: 44 additions & 23 deletions cmd/tool/slowest/slowest.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,18 @@ import (
"time"

"github.com/spf13/pflag"
"gotest.tools/gotestsum/log"
"gotest.tools/gotestsum/testjson"
)

// Run the command
func Run(name string, args []string) error {
flags, opts := setupFlags(name)
if err := flags.Parse(args); err != nil {
switch err := flags.Parse(args); {
case err == pflag.ErrHelp:
return nil
case err != nil:
flags.Usage()
return err
}
return run(opts)
Expand All @@ -27,31 +33,58 @@ func setupFlags(name string) (*pflag.FlagSet, *options) {
fmt.Fprintf(os.Stderr, `Usage:
%s [flags]
By default this command will print the list of tests slower than threshold to stdout.
If --skip-stmt is set, instead of printing the list of stdout, the AST for the
Go source code in the working directory tree will be modified. The --skip-stmt
will be added to Go test files as the first statement in all the test functions
which are slower than threshold.
Example - use testing.Short():
skip_stmt='if testing.Short() { t.Skip("too slow for short run") }'
go test -json -short ./... | %s --skip-stmt "$skip_stmt"
Flags:
`, name)
`, name, name)
flags.PrintDefaults()
}
flags.DurationVar(&opts.threshold, "threshold", 100*time.Millisecond,
"tests faster than this threshold will be omitted from the output")
flags.StringVar(&opts.skipStatement, "skip-stmt", "",
"add this go statement to slow tests, instead of printing the list of slow tests")
flags.BoolVar(&opts.debug, "debug", false,
"enable debug logging.")
return flags, opts
}

type options struct {
threshold time.Duration
threshold time.Duration
skipStatement string
debug bool
}

func run(opts *options) error {
if opts.debug {
log.SetLevel(log.DebugLevel)
}
exec, err := testjson.ScanTestOutput(testjson.ScanConfig{
Stdout: os.Stdin,
Stderr: bytes.NewReader(nil),
Handler: eventHandler{},
Stdout: os.Stdin,
Stderr: bytes.NewReader(nil),
})
if err != nil {
return err
return fmt.Errorf("failed to scan testjson: %w", err)
}

tcs := slowTestCases(exec, opts.threshold)
if opts.skipStatement != "" {
skipStmt, err := parseSkipStatement(opts.skipStatement)
if err != nil {
return fmt.Errorf("failed to parse skip expr: %w", err)
}
return writeTestSkip(tcs, skipStmt)
}
for _, tc := range slowTestCases(exec, opts.threshold) {
// TODO: allow elapsed time unit to be configurable
fmt.Printf("%s %s %d\n", tc.Package, tc.Test, tc.Elapsed.Milliseconds())
for _, tc := range tcs {
fmt.Printf("%s %s %v\n", tc.Package, tc.Test, tc.Elapsed)
}

return nil
Expand All @@ -60,7 +93,7 @@ func run(opts *options) error {
// slowTestCases returns a slice of all tests with an elapsed time greater than
// threshold. The slice is sorted by Elapsed time in descending order (slowest
// test first).
// TODO: may be shared with testjson Summary
// FIXME: use medium elapsed time when there are multiple instances of the same test
func slowTestCases(exec *testjson.Execution, threshold time.Duration) []testjson.TestCase {
if threshold == 0 {
return nil
Expand All @@ -70,7 +103,6 @@ func slowTestCases(exec *testjson.Execution, threshold time.Duration) []testjson
for _, pkg := range pkgs {
tests = append(tests, exec.Package(pkg).TestCases()...)
}
// TODO: use median test runtime
sort.Slice(tests, func(i, j int) bool {
return tests[i].Elapsed > tests[j].Elapsed
})
Expand All @@ -79,14 +111,3 @@ func slowTestCases(exec *testjson.Execution, threshold time.Duration) []testjson
})
return tests[:end]
}

type eventHandler struct{}

func (h eventHandler) Err(text string) error {
_, err := fmt.Fprintln(os.Stdout, text)
return err
}

func (h eventHandler) Event(_ testjson.TestEvent, _ *testjson.Execution) error {
return nil
}
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ require (
github.com/spf13/pflag v1.0.3
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2
golang.org/x/sync v0.0.0-20190423024810-112230192c58
golang.org/x/sys v0.0.0-20190422165155-953cdadca894 // indirect
golang.org/x/tools v0.0.0-20190624222133-a101b041ded4
gotest.tools/v3 v3.0.2
)

Expand Down
3 changes: 3 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEha
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/pU5OE2C0WrNTOYK1Uuc=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/tools v0.0.0-20190624222133-a101b041ded4 h1:1mMox4TgefDwqluYCv677yNXwlfTkija4owZve/jr78=
golang.org/x/tools v0.0.0-20190624222133-a101b041ded4/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
gotest.tools/v3 v3.0.2 h1:kG1BFyqVHuQoVQiR1bWGnfz/fmHvvuiSPIV7rvl360E=
gotest.tools/v3 v3.0.2/go.mod h1:3SzNCllyD9/Y+b5r9JIKQ474KzkZyqLqEfYqMsX94Bk=
15 changes: 2 additions & 13 deletions internal/junitxml/report_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@ func TestWrite(t *testing.T) {

func createExecution(t *testing.T) *testjson.Execution {
exec, err := testjson.ScanTestOutput(testjson.ScanConfig{
Stdout: readTestData(t, "out"),
Stderr: readTestData(t, "err"),
Handler: &noopHandler{},
Stdout: readTestData(t, "out"),
Stderr: readTestData(t, "err"),
})
assert.NilError(t, err)
return exec
Expand All @@ -40,16 +39,6 @@ func readTestData(t *testing.T, stream string) io.Reader {
return bytes.NewReader(raw)
}

type noopHandler struct{}

func (s *noopHandler) Event(testjson.TestEvent, *testjson.Execution) error {
return nil
}

func (s *noopHandler) Err(string) error {
return nil
}

func TestGoVersion(t *testing.T) {
t.Run("unknown", func(t *testing.T) {
defer env.Patch(t, "PATH", "/bogus")()
Expand Down
Loading

0 comments on commit ddc609f

Please sign in to comment.