diff --git a/go/tools/generate_test_main.go b/go/tools/generate_test_main.go index a2632e7bc3..dc82aee7ce 100644 --- a/go/tools/generate_test_main.go +++ b/go/tools/generate_test_main.go @@ -25,18 +25,20 @@ import ( "log" "os" "runtime" + "strconv" "strings" "text/template" ) // Cases holds template data. type Cases struct { - Package string - RunDir string - TestNames []string - BenchmarkNames []string - HasTestMain bool - Version string + Package string + RunDir string + TestNames []string + BenchmarkNames []string + HasTestMain bool + Version17 bool + Version18OrNewer bool } var codeTpl = ` @@ -44,11 +46,11 @@ package main import ( "flag" "os" -{{if eq .Version "go1.7.5"}} +{{if .Version17}} "regexp" {{end}} "testing" -{{if eq .Version "go1.8"}} +{{if .Version18OrNewer}} "testing/internal/testdeps" {{end}} @@ -79,14 +81,14 @@ func main() { } } -{{if eq .Version "go1.8"}} +{{if .Version18OrNewer}} m := testing.MainStart(testdeps.TestDeps{}, tests, benchmarks, nil) {{if not .HasTestMain}} os.Exit(m.Run()) {{else}} undertest.TestMain(m) {{end}} -{{else}} +{{else if .Version17}} {{if not .HasTestMain}} testing.Main(regexp.MatchString, tests, benchmarks, nil) {{else}} @@ -185,9 +187,46 @@ func main() { } } - cases.Version = runtime.Version() + goVersion := parseVersion(runtime.Version()) + if goVersion.Less(version{1, 7}) { + log.Fatalf("go version %s not supported", runtime.Version()) + } else if goVersion.Less(version{1, 8}) { + cases.Version17 = true + } else { + cases.Version18OrNewer = true + } + tpl := template.Must(template.New("source").Parse(codeTpl)) if err := tpl.Execute(outFile, &cases); err != nil { log.Fatalf("template.Execute(%v): %v", cases, err) } } + +type version []int + +func parseVersion(s string) version { + strParts := strings.Split(s[len("go"):], ".") + intParts := make([]int, len(strParts)) + for i, s := range strParts { + v, err := strconv.Atoi(s) + if err != nil { + panic(err) + } + intParts[i] = v + } + return intParts +} + +func (x version) Less(y version) bool { + n := len(x) + if len(y) < n { + n = len(y) + } + for i := 0; i < n; i++ { + cmp := x[i] - y[i] + if cmp != 0 { + return cmp < 0 + } + } + return len(x) < len(y) +}