Skip to content

Commit

Permalink
Make generate_test_main.go work with future versions
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Conrod committed Apr 4, 2017
1 parent cfdcbdc commit 30cdcb7
Showing 1 changed file with 50 additions and 11 deletions.
61 changes: 50 additions & 11 deletions go/tools/generate_test_main.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,32 @@ import (
"log"
"os"
"runtime"
"strconv"
"strings"
"text/template"
)

// Cases holds template data.
type Cases struct {
Package string
RunDir string
TestNames []string
BenchmarkNames []string
HasTestMain bool
Version string
Package string
RunDir string
TestNames []string
BenchmarkNames []string
HasTestMain bool
Version17 bool
Version18OrNewer bool
}

var codeTpl = `
package main
import (
"flag"
"os"
{{if eq .Version "go1.7.5"}}
{{if .Version17}}
"regexp"
{{end}}
"testing"
{{if eq .Version "go1.8"}}
{{if .Version18OrNewer}}
"testing/internal/testdeps"
{{end}}
Expand Down Expand Up @@ -79,14 +81,14 @@ func main() {
}
}
{{if eq .Version "go1.8"}}
{{if .Version18OrNewer}}
m := testing.MainStart(testdeps.TestDeps{}, tests, benchmarks, nil)
{{if not .HasTestMain}}
os.Exit(m.Run())
{{else}}
undertest.TestMain(m)
{{end}}
{{else}}
{{else if .Version17}}
{{if not .HasTestMain}}
testing.Main(regexp.MatchString, tests, benchmarks, nil)
{{else}}
Expand Down Expand Up @@ -185,9 +187,46 @@ func main() {
}
}

cases.Version = runtime.Version()
goVersion := parseVersion(runtime.Version())
if goVersion.Less(version{1, 7}) {
log.Fatalf("go version %s not supported", runtime.Version())
} else if goVersion.Less(version{1, 8}) {
cases.Version17 = true
} else {
cases.Version18OrNewer = true
}

tpl := template.Must(template.New("source").Parse(codeTpl))
if err := tpl.Execute(outFile, &cases); err != nil {
log.Fatalf("template.Execute(%v): %v", cases, err)
}
}

type version []int

func parseVersion(s string) version {
strParts := strings.Split(s[len("go"):], ".")
intParts := make([]int, len(strParts))
for i, s := range strParts {
v, err := strconv.Atoi(s)
if err != nil {
panic(err)
}
intParts[i] = v
}
return intParts
}

func (x version) Less(y version) bool {
n := len(x)
if len(y) < n {
n = len(y)
}
for i := 0; i < n; i++ {
cmp := x[i] - y[i]
if cmp != 0 {
return cmp < 0
}
}
return len(x) < len(y)
}

0 comments on commit 30cdcb7

Please sign in to comment.