diff --git a/pkg/cluster/audit/audit.go b/pkg/cluster/audit/audit.go index 16f1329821..004eafea4f 100644 --- a/pkg/cluster/audit/audit.go +++ b/pkg/cluster/audit/audit.go @@ -56,11 +56,10 @@ func ShowAuditList(dir string) error { if fi.IsDir() { continue } - ts, err := base52.Decode(fi.Name()) + t, err := decodeAuditID(fi.Name()) if err != nil { continue } - t := time.Unix(ts/1e9, 0) cmd, err := firstLine(fi.Name()) if err != nil { continue @@ -93,7 +92,7 @@ func ShowAuditLog(dir string, auditID string) error { return errors.Errorf("cannot find the audit log '%s'", auditID) } - ts, err := base52.Decode(auditID) + t, err := decodeAuditID(auditID) if err != nil { return errors.Annotatef(err, "unrecognized audit id '%s'", auditID) } @@ -103,10 +102,23 @@ func ShowAuditLog(dir string, auditID string) error { return errors.Trace(err) } - t := time.Unix(ts/1e9, 0) hint := fmt.Sprintf("- OPERATION TIME: %s -", t.Format("2006-01-02T15:04:05")) line := strings.Repeat("-", len(hint)) _, _ = os.Stdout.WriteString(color.MagentaString("%s\n%s\n%s\n", line, hint, line)) _, _ = os.Stdout.Write(content) return nil } + +//decodeAuditID decodes the auditID to unix timestamp +func decodeAuditID(auditID string) (time.Time, error) { + ts, err := base52.Decode(auditID) + if err != nil { + return time.Time{}, err + } + // compatible with old second based ts + if ts>>32 > 0 { + ts = ts / 1e9 + } + t := time.Unix(ts, 0) + return t, nil +} diff --git a/pkg/cluster/audit/audit_test.go b/pkg/cluster/audit/audit_test.go index 254a4460ec..1c06c24fb4 100644 --- a/pkg/cluster/audit/audit_test.go +++ b/pkg/cluster/audit/audit_test.go @@ -14,15 +14,23 @@ package audit import ( + "fmt" + "io/ioutil" "os" "path" "path/filepath" "runtime" + "strings" + "testing" + "time" . "github.com/pingcap/check" + "github.com/pingcap/tiup/pkg/base52" "golang.org/x/sync/errgroup" ) +func Test(t *testing.T) { TestingT(t) } + var _ = Suite(&testAuditSuite{}) type testAuditSuite struct{} @@ -36,17 +44,29 @@ func auditDir() string { return path.Join(currentDir(), "testdata", "audit") } -func (s *testAuditSuite) SetUpSuite(c *C) { +func resetDir() { _ = os.RemoveAll(auditDir()) _ = os.MkdirAll(auditDir(), 0777) } +func readFakeStdout(f *os.File) string { + _, _ = f.Seek(0, 0) + read, _ := ioutil.ReadAll(f) + return string(read) +} + +func (s *testAuditSuite) SetUpSuite(c *C) { + resetDir() +} + func (s *testAuditSuite) TearDownSuite(c *C) { - _ = os.RemoveAll(auditDir()) + _ = os.RemoveAll(auditDir()) // path.Join(currentDir(), "testdata")) } func (s *testAuditSuite) TestOutputAuditLog(c *C) { dir := auditDir() + resetDir() + var g errgroup.Group for i := 0; i < 20; i++ { g.Go(func() error { return OutputAuditLog(dir, []byte("audit log")) }) @@ -56,10 +76,74 @@ func (s *testAuditSuite) TestOutputAuditLog(c *C) { var paths []string err = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { - // simply filter the not relate files. - paths = append(paths, path) + if !info.IsDir() { + paths = append(paths, path) + } return nil }) c.Assert(err, IsNil) c.Assert(len(paths), Equals, 20) } + +func (s *testAuditSuite) TestShowAuditLog(c *C) { + dir := auditDir() + resetDir() + + originStdout := os.Stdout + defer func() { + os.Stdout = originStdout + }() + + fakeStdout := path.Join(currentDir(), "fake-stdout") + defer os.Remove(fakeStdout) + + openStdout := func() *os.File { + _ = os.Remove(fakeStdout) + f, err := os.OpenFile(fakeStdout, os.O_CREATE|os.O_RDWR, 0644) + c.Assert(err, IsNil) + os.Stdout = f + return f + } + + second := int64(1604413577) + nanoSecond := int64(1604413624836105381) + + fname := filepath.Join(dir, base52.Encode(second)) + c.Assert(ioutil.WriteFile(fname, []byte("test with second"), 0644), IsNil) + fname = filepath.Join(dir, base52.Encode(nanoSecond)) + c.Assert(ioutil.WriteFile(fname, []byte("test with nanosecond"), 0644), IsNil) + + f := openStdout() + c.Assert(ShowAuditList(dir), IsNil) + // tabby table size is based on column width, while time.RFC3339 maybe print out timezone like +08:00 or Z(UTC) + // skip the first two lines + list := strings.Join(strings.Split(readFakeStdout(f), "\n")[2:], "\n") + c.Assert(list, Equals, fmt.Sprintf(`ftmpqzww84Q %s test with nanosecond +4F7ZTL %s test with second +`, + time.Unix(nanoSecond/1e9, 0).Format(time.RFC3339), + time.Unix(second, 0).Format(time.RFC3339), + )) + f.Close() + + f = openStdout() + c.Assert(ShowAuditLog(dir, "4F7ZTL"), IsNil) + c.Assert(readFakeStdout(f), Equals, fmt.Sprintf(`--------------------------------------- +- OPERATION TIME: %s - +--------------------------------------- +test with second`, + time.Unix(second, 0).Format("2006-01-02T15:04:05"), + )) + + f.Close() + + f = openStdout() + c.Assert(ShowAuditLog(dir, "ftmpqzww84Q"), IsNil) + c.Assert(readFakeStdout(f), Equals, fmt.Sprintf(`--------------------------------------- +- OPERATION TIME: %s - +--------------------------------------- +test with nanosecond`, + time.Unix(nanoSecond/1e9, 0).Format("2006-01-02T15:04:05"), + )) + f.Close() +}