From f0eb74b15c4c66a20bc9e942c28073a6746e7d67 Mon Sep 17 00:00:00 2001 From: Hu# Date: Wed, 10 Apr 2024 18:25:52 +0800 Subject: [PATCH] *: make unit test great (#7952) ref tikv/pd#7969 Signed-off-by: husharp Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- Makefile | 18 +- client/resource_group/controller/util_test.go | 2 - pkg/audit/audit_test.go | 3 - pkg/autoscaling/calculation_test.go | 5 - pkg/autoscaling/prometheus_test.go | 6 - pkg/balancer/balancer_test.go | 3 - pkg/cache/cache_test.go | 6 - pkg/codec/codec_test.go | 2 - pkg/core/rangetree/range_tree_test.go | 2 - pkg/core/storelimit/limit_test.go | 2 - pkg/encryption/config_test.go | 4 - pkg/encryption/crypter_test.go | 5 - pkg/encryption/master_key_test.go | 8 - pkg/encryption/region_crypter_test.go | 8 - pkg/errs/errs_test.go | 2 - pkg/mock/mockhbstream/mockhbstream_test.go | 1 - pkg/movingaverage/avg_over_time_test.go | 5 - pkg/movingaverage/max_filter_test.go | 1 - pkg/movingaverage/moving_average_test.go | 2 - pkg/movingaverage/weight_allocator_test.go | 1 - pkg/ratelimit/concurrency_limiter_test.go | 1 - pkg/ratelimit/controller_test.go | 4 - pkg/ratelimit/limiter_test.go | 3 - pkg/ratelimit/ratelimiter_test.go | 1 - pkg/slice/slice_test.go | 4 - pkg/utils/apiutil/apiutil_test.go | 3 - pkg/utils/assertutil/assertutil_test.go | 1 - pkg/utils/grpcutil/grpcutil_test.go | 1 - pkg/utils/jsonutil/jsonutil_test.go | 1 - pkg/utils/keyutil/util_test.go | 1 - pkg/utils/logutil/log_test.go | 2 - pkg/utils/metricutil/metricutil_test.go | 1 - pkg/utils/netutil/address_test.go | 2 - pkg/utils/reflectutil/tag_test.go | 3 - pkg/utils/requestutil/context_test.go | 2 - pkg/utils/typeutil/comparison_test.go | 5 - pkg/utils/typeutil/conversion_test.go | 3 - pkg/utils/typeutil/duration_test.go | 2 - pkg/utils/typeutil/size_test.go | 2 - pkg/utils/typeutil/string_slice_test.go | 2 - pkg/utils/typeutil/time_test.go | 3 - server/api/health_test.go | 2 +- server/api/member_test.go | 21 +- tools/go.mod | 1 + tools/go.sum | 4 + tools/pd-ut/README.md | 66 ++ tools/pd-ut/ut.go | 803 ++++++++++++++++++ tools/pd-ut/xprog.go | 119 +++ 48 files changed, 1010 insertions(+), 139 deletions(-) create mode 100644 tools/pd-ut/README.md create mode 100644 tools/pd-ut/ut.go create mode 100644 tools/pd-ut/xprog.go diff --git a/Makefile b/Makefile index 0d02189f508..d78ddcdd65e 100644 --- a/Makefile +++ b/Makefile @@ -80,7 +80,7 @@ BUILD_BIN_PATH := $(ROOT_PATH)/bin build: pd-server pd-ctl pd-recover -tools: pd-tso-bench pd-heartbeat-bench regions-dump stores-dump pd-api-bench +tools: pd-tso-bench pd-heartbeat-bench regions-dump stores-dump pd-api-bench pd-ut PD_SERVER_DEP := ifeq ($(SWAGGER), 1) @@ -108,7 +108,6 @@ pd-server-basic: .PHONY: pre-build build tools pd-server pd-server-basic # Tools - pd-ctl: cd tools && GOEXPERIMENT=$(BUILD_GOEXPERIMENT) CGO_ENABLED=$(BUILD_TOOL_CGO_ENABLED) go build -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/pd-ctl pd-ctl/main.go pd-tso-bench: @@ -127,8 +126,12 @@ regions-dump: cd tools && CGO_ENABLED=0 go build -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/regions-dump regions-dump/main.go stores-dump: cd tools && CGO_ENABLED=0 go build -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/stores-dump stores-dump/main.go +pd-ut: pd-xprog + cd tools && GOEXPERIMENT=$(BUILD_GOEXPERIMENT) CGO_ENABLED=$(BUILD_TOOL_CGO_ENABLED) go build -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/pd-ut pd-ut/ut.go +pd-xprog: + cd tools && GOEXPERIMENT=$(BUILD_GOEXPERIMENT) CGO_ENABLED=$(BUILD_TOOL_CGO_ENABLED) go build -tags xprog -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/xprog pd-ut/xprog.go -.PHONY: pd-ctl pd-tso-bench pd-recover pd-analysis pd-heartbeat-bench simulator regions-dump stores-dump pd-api-bench +.PHONY: pd-ctl pd-tso-bench pd-recover pd-analysis pd-heartbeat-bench simulator regions-dump stores-dump pd-api-bench pd-ut #### Docker image #### @@ -225,6 +228,12 @@ failpoint-disable: install-tools #### Test #### +ut: pd-ut + @$(FAILPOINT_ENABLE) + ./bin/pd-ut run --race + @$(CLEAN_UT_BINARY) + @$(FAILPOINT_DISABLE) + PACKAGE_DIRECTORIES := $(subst $(PD_PKG)/,,$(PACKAGES)) TEST_PKGS := $(filter $(shell find . -iname "*_test.go" -exec dirname {} \; | \ sort -u | sed -e "s/^\./github.com\/tikv\/pd/"),$(PACKAGES)) @@ -303,6 +312,8 @@ split: clean: failpoint-disable clean-test clean-build +CLEAN_UT_BINARY := find . -name '*.test.bin'| xargs rm -f + clean-test: # Cleaning test tmp... rm -rf /tmp/test_pd* @@ -310,6 +321,7 @@ clean-test: rm -rf /tmp/test_etcd* rm -f $(REAL_CLUSTER_TEST_PATH)/playground.log go clean -testcache + @$(CLEAN_UT_BINARY) clean-build: # Cleaning building files... diff --git a/client/resource_group/controller/util_test.go b/client/resource_group/controller/util_test.go index a89ea08b955..10fa7c345a5 100644 --- a/client/resource_group/controller/util_test.go +++ b/client/resource_group/controller/util_test.go @@ -27,7 +27,6 @@ type example struct { } func TestDurationJSON(t *testing.T) { - t.Parallel() re := require.New(t) example := &example{} @@ -41,7 +40,6 @@ func TestDurationJSON(t *testing.T) { } func TestDurationTOML(t *testing.T) { - t.Parallel() re := require.New(t) example := &example{} diff --git a/pkg/audit/audit_test.go b/pkg/audit/audit_test.go index 8098b36975e..9066d81ebe3 100644 --- a/pkg/audit/audit_test.go +++ b/pkg/audit/audit_test.go @@ -32,7 +32,6 @@ import ( ) func TestLabelMatcher(t *testing.T) { - t.Parallel() re := require.New(t) matcher := &LabelMatcher{"testSuccess"} labels1 := &BackendLabels{Labels: []string{"testFail", "testSuccess"}} @@ -42,7 +41,6 @@ func TestLabelMatcher(t *testing.T) { } func TestPrometheusHistogramBackend(t *testing.T) { - t.Parallel() re := require.New(t) serviceAuditHistogramTest := prometheus.NewHistogramVec( prometheus.HistogramOpts{ @@ -90,7 +88,6 @@ func TestPrometheusHistogramBackend(t *testing.T) { } func TestLocalLogBackendUsingFile(t *testing.T) { - t.Parallel() re := require.New(t) backend := NewLocalLogBackend(true) fname := testutil.InitTempFileLogger("info") diff --git a/pkg/autoscaling/calculation_test.go b/pkg/autoscaling/calculation_test.go index 05a348af59d..5d0c2ba1126 100644 --- a/pkg/autoscaling/calculation_test.go +++ b/pkg/autoscaling/calculation_test.go @@ -29,7 +29,6 @@ import ( ) func TestGetScaledTiKVGroups(t *testing.T) { - t.Parallel() re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -214,7 +213,6 @@ func (q *mockQuerier) Query(options *QueryOptions) (QueryResult, error) { } func TestGetTotalCPUUseTime(t *testing.T) { - t.Parallel() re := require.New(t) querier := &mockQuerier{} instances := []instance{ @@ -237,7 +235,6 @@ func TestGetTotalCPUUseTime(t *testing.T) { } func TestGetTotalCPUQuota(t *testing.T) { - t.Parallel() re := require.New(t) querier := &mockQuerier{} instances := []instance{ @@ -260,7 +257,6 @@ func TestGetTotalCPUQuota(t *testing.T) { } func TestScaleOutGroupLabel(t *testing.T) { - t.Parallel() re := require.New(t) var jsonStr = []byte(` { @@ -303,7 +299,6 @@ func TestScaleOutGroupLabel(t *testing.T) { } func TestStrategyChangeCount(t *testing.T) { - t.Parallel() re := require.New(t) var count uint64 = 2 strategy := &Strategy{ diff --git a/pkg/autoscaling/prometheus_test.go b/pkg/autoscaling/prometheus_test.go index b4cf9aefd91..3ee6cb94e37 100644 --- a/pkg/autoscaling/prometheus_test.go +++ b/pkg/autoscaling/prometheus_test.go @@ -180,7 +180,6 @@ func (c *normalClient) Do(_ context.Context, req *http.Request) (response *http. } func TestRetrieveCPUMetrics(t *testing.T) { - t.Parallel() re := require.New(t) client := &normalClient{ mockData: make(map[string]*response), @@ -225,7 +224,6 @@ func (c *emptyResponseClient) Do(_ context.Context, req *http.Request) (r *http. } func TestEmptyResponse(t *testing.T) { - t.Parallel() re := require.New(t) client := &emptyResponseClient{} querier := NewPrometheusQuerier(client) @@ -253,7 +251,6 @@ func (c *errorHTTPStatusClient) Do(_ context.Context, req *http.Request) (r *htt } func TestErrorHTTPStatus(t *testing.T) { - t.Parallel() re := require.New(t) client := &errorHTTPStatusClient{} querier := NewPrometheusQuerier(client) @@ -279,7 +276,6 @@ func (c *errorPrometheusStatusClient) Do(_ context.Context, req *http.Request) ( } func TestErrorPrometheusStatus(t *testing.T) { - t.Parallel() re := require.New(t) client := &errorPrometheusStatusClient{} querier := NewPrometheusQuerier(client) @@ -290,7 +286,6 @@ func TestErrorPrometheusStatus(t *testing.T) { } func TestGetInstanceNameFromAddress(t *testing.T) { - t.Parallel() re := require.New(t) testCases := []struct { address string @@ -328,7 +323,6 @@ func TestGetInstanceNameFromAddress(t *testing.T) { } func TestGetDurationExpression(t *testing.T) { - t.Parallel() re := require.New(t) testCases := []struct { duration time.Duration diff --git a/pkg/balancer/balancer_test.go b/pkg/balancer/balancer_test.go index 996b4f1da35..2c760c6220c 100644 --- a/pkg/balancer/balancer_test.go +++ b/pkg/balancer/balancer_test.go @@ -22,7 +22,6 @@ import ( ) func TestBalancerPutAndDelete(t *testing.T) { - t.Parallel() re := require.New(t) balancers := []Balancer[uint32]{ NewRoundRobin[uint32](), @@ -56,7 +55,6 @@ func TestBalancerPutAndDelete(t *testing.T) { } func TestBalancerDuplicate(t *testing.T) { - t.Parallel() re := require.New(t) balancers := []Balancer[uint32]{ NewRoundRobin[uint32](), @@ -77,7 +75,6 @@ func TestBalancerDuplicate(t *testing.T) { } func TestRoundRobin(t *testing.T) { - t.Parallel() re := require.New(t) balancer := NewRoundRobin[uint32]() for i := 0; i < 100; i++ { diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go index dbef41d2754..43e97dfa2b0 100644 --- a/pkg/cache/cache_test.go +++ b/pkg/cache/cache_test.go @@ -25,7 +25,6 @@ import ( ) func TestExpireRegionCache(t *testing.T) { - t.Parallel() re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -121,7 +120,6 @@ func sortIDs(ids []uint64) []uint64 { } func TestLRUCache(t *testing.T) { - t.Parallel() re := require.New(t) cache := newLRU(3) @@ -199,7 +197,6 @@ func TestLRUCache(t *testing.T) { } func TestFifoCache(t *testing.T) { - t.Parallel() re := require.New(t) cache := NewFIFO(3) cache.Put(1, "1") @@ -227,7 +224,6 @@ func TestFifoCache(t *testing.T) { } func TestFifoFromLastSameElems(t *testing.T) { - t.Parallel() re := require.New(t) type testStruct struct { value string @@ -260,7 +256,6 @@ func TestFifoFromLastSameElems(t *testing.T) { } func TestTwoQueueCache(t *testing.T) { - t.Parallel() re := require.New(t) cache := newTwoQueue(3) cache.Put(1, "1") @@ -345,7 +340,6 @@ func (pq PriorityQueueItemTest) ID() uint64 { } func TestPriorityQueue(t *testing.T) { - t.Parallel() re := require.New(t) testData := []PriorityQueueItemTest{0, 1, 2, 3, 4, 5} pq := NewPriorityQueue(0) diff --git a/pkg/codec/codec_test.go b/pkg/codec/codec_test.go index f734d2e528e..50bf552a60d 100644 --- a/pkg/codec/codec_test.go +++ b/pkg/codec/codec_test.go @@ -21,7 +21,6 @@ import ( ) func TestDecodeBytes(t *testing.T) { - t.Parallel() re := require.New(t) key := "abcdefghijklmnopqrstuvwxyz" for i := 0; i < len(key); i++ { @@ -32,7 +31,6 @@ func TestDecodeBytes(t *testing.T) { } func TestTableID(t *testing.T) { - t.Parallel() re := require.New(t) key := EncodeBytes([]byte("t\x80\x00\x00\x00\x00\x00\x00\xff")) re.Equal(int64(0xff), key.TableID()) diff --git a/pkg/core/rangetree/range_tree_test.go b/pkg/core/rangetree/range_tree_test.go index 0664a7bdbef..6955947cb1b 100644 --- a/pkg/core/rangetree/range_tree_test.go +++ b/pkg/core/rangetree/range_tree_test.go @@ -85,7 +85,6 @@ func bucketDebrisFactory(startKey, endKey []byte, item RangeItem) []RangeItem { } func TestRingPutItem(t *testing.T) { - t.Parallel() re := require.New(t) bucketTree := NewRangeTree(2, bucketDebrisFactory) bucketTree.Update(newSimpleBucketItem([]byte("002"), []byte("100"))) @@ -120,7 +119,6 @@ func TestRingPutItem(t *testing.T) { } func TestDebris(t *testing.T) { - t.Parallel() re := require.New(t) ringItem := newSimpleBucketItem([]byte("010"), []byte("090")) var overlaps []RangeItem diff --git a/pkg/core/storelimit/limit_test.go b/pkg/core/storelimit/limit_test.go index 75865330311..e11618767a1 100644 --- a/pkg/core/storelimit/limit_test.go +++ b/pkg/core/storelimit/limit_test.go @@ -45,7 +45,6 @@ func TestStoreLimit(t *testing.T) { } func TestSlidingWindow(t *testing.T) { - t.Parallel() re := require.New(t) capacity := int64(defaultWindowSize) s := NewSlidingWindows() @@ -92,7 +91,6 @@ func TestSlidingWindow(t *testing.T) { } func TestWindow(t *testing.T) { - t.Parallel() re := require.New(t) capacity := int64(100 * 10) s := newWindow(capacity) diff --git a/pkg/encryption/config_test.go b/pkg/encryption/config_test.go index 6f7e4a41b03..4134d46c2f3 100644 --- a/pkg/encryption/config_test.go +++ b/pkg/encryption/config_test.go @@ -23,7 +23,6 @@ import ( ) func TestAdjustDefaultValue(t *testing.T) { - t.Parallel() re := require.New(t) config := &Config{} err := config.Adjust() @@ -35,21 +34,18 @@ func TestAdjustDefaultValue(t *testing.T) { } func TestAdjustInvalidDataEncryptionMethod(t *testing.T) { - t.Parallel() re := require.New(t) config := &Config{DataEncryptionMethod: "unknown"} re.Error(config.Adjust()) } func TestAdjustNegativeRotationDuration(t *testing.T) { - t.Parallel() re := require.New(t) config := &Config{DataKeyRotationPeriod: typeutil.NewDuration(time.Duration(int64(-1)))} re.Error(config.Adjust()) } func TestAdjustInvalidMasterKeyType(t *testing.T) { - t.Parallel() re := require.New(t) config := &Config{MasterKey: MasterKeyConfig{Type: "unknown"}} re.Error(config.Adjust()) diff --git a/pkg/encryption/crypter_test.go b/pkg/encryption/crypter_test.go index 12a851d1563..9ac72bd7813 100644 --- a/pkg/encryption/crypter_test.go +++ b/pkg/encryption/crypter_test.go @@ -24,7 +24,6 @@ import ( ) func TestEncryptionMethodSupported(t *testing.T) { - t.Parallel() re := require.New(t) re.Error(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_PLAINTEXT)) re.Error(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_UNKNOWN)) @@ -34,7 +33,6 @@ func TestEncryptionMethodSupported(t *testing.T) { } func TestKeyLength(t *testing.T) { - t.Parallel() re := require.New(t) _, err := KeyLength(encryptionpb.EncryptionMethod_PLAINTEXT) re.Error(err) @@ -52,7 +50,6 @@ func TestKeyLength(t *testing.T) { } func TestNewIv(t *testing.T) { - t.Parallel() re := require.New(t) ivCtr, err := NewIvCTR() re.NoError(err) @@ -63,7 +60,6 @@ func TestNewIv(t *testing.T) { } func TestNewDataKey(t *testing.T) { - t.Parallel() re := require.New(t) for _, method := range []encryptionpb.EncryptionMethod{ encryptionpb.EncryptionMethod_AES128_CTR, @@ -82,7 +78,6 @@ func TestNewDataKey(t *testing.T) { } func TestAesGcmCrypter(t *testing.T) { - t.Parallel() re := require.New(t) key, err := hex.DecodeString("ed568fbd8c8018ed2d042a4e5d38d6341486922d401d2022fb81e47c900d3f07") re.NoError(err) diff --git a/pkg/encryption/master_key_test.go b/pkg/encryption/master_key_test.go index 4bc08dab7a5..31962e9e99d 100644 --- a/pkg/encryption/master_key_test.go +++ b/pkg/encryption/master_key_test.go @@ -24,7 +24,6 @@ import ( ) func TestPlaintextMasterKey(t *testing.T) { - t.Parallel() re := require.New(t) config := &encryptionpb.MasterKey{ Backend: &encryptionpb.MasterKey_Plaintext{ @@ -50,7 +49,6 @@ func TestPlaintextMasterKey(t *testing.T) { } func TestEncrypt(t *testing.T) { - t.Parallel() re := require.New(t) keyHex := "2f07ec61e5a50284f47f2b402a962ec672e500b26cb3aa568bb1531300c74806" // #nosec G101 key, err := hex.DecodeString(keyHex) @@ -66,7 +64,6 @@ func TestEncrypt(t *testing.T) { } func TestDecrypt(t *testing.T) { - t.Parallel() re := require.New(t) keyHex := "2f07ec61e5a50284f47f2b402a962ec672e500b26cb3aa568bb1531300c74806" // #nosec G101 key, err := hex.DecodeString(keyHex) @@ -83,7 +80,6 @@ func TestDecrypt(t *testing.T) { } func TestNewFileMasterKeyMissingPath(t *testing.T) { - t.Parallel() re := require.New(t) config := &encryptionpb.MasterKey{ Backend: &encryptionpb.MasterKey_File{ @@ -97,7 +93,6 @@ func TestNewFileMasterKeyMissingPath(t *testing.T) { } func TestNewFileMasterKeyMissingFile(t *testing.T) { - t.Parallel() re := require.New(t) dir := t.TempDir() path := dir + "/key" @@ -113,7 +108,6 @@ func TestNewFileMasterKeyMissingFile(t *testing.T) { } func TestNewFileMasterKeyNotHexString(t *testing.T) { - t.Parallel() re := require.New(t) dir := t.TempDir() path := dir + "/key" @@ -130,7 +124,6 @@ func TestNewFileMasterKeyNotHexString(t *testing.T) { } func TestNewFileMasterKeyLengthMismatch(t *testing.T) { - t.Parallel() re := require.New(t) dir := t.TempDir() path := dir + "/key" @@ -147,7 +140,6 @@ func TestNewFileMasterKeyLengthMismatch(t *testing.T) { } func TestNewFileMasterKey(t *testing.T) { - t.Parallel() re := require.New(t) key := "2f07ec61e5a50284f47f2b402a962ec672e500b26cb3aa568bb1531300c74806" // #nosec G101 dir := t.TempDir() diff --git a/pkg/encryption/region_crypter_test.go b/pkg/encryption/region_crypter_test.go index 5fd9778a8c0..b1ca558063c 100644 --- a/pkg/encryption/region_crypter_test.go +++ b/pkg/encryption/region_crypter_test.go @@ -70,7 +70,6 @@ func (m *testKeyManager) GetKey(keyID uint64) (*encryptionpb.DataKey, error) { } func TestNilRegion(t *testing.T) { - t.Parallel() re := require.New(t) m := newTestKeyManager() region, err := EncryptRegion(nil, m) @@ -81,7 +80,6 @@ func TestNilRegion(t *testing.T) { } func TestEncryptRegionWithoutKeyManager(t *testing.T) { - t.Parallel() re := require.New(t) region := &metapb.Region{ Id: 10, @@ -98,7 +96,6 @@ func TestEncryptRegionWithoutKeyManager(t *testing.T) { } func TestEncryptRegionWhileEncryptionDisabled(t *testing.T) { - t.Parallel() re := require.New(t) region := &metapb.Region{ Id: 10, @@ -117,7 +114,6 @@ func TestEncryptRegionWhileEncryptionDisabled(t *testing.T) { } func TestEncryptRegion(t *testing.T) { - t.Parallel() re := require.New(t) startKey := []byte("abc") endKey := []byte("xyz") @@ -152,7 +148,6 @@ func TestEncryptRegion(t *testing.T) { } func TestDecryptRegionNotEncrypted(t *testing.T) { - t.Parallel() re := require.New(t) region := &metapb.Region{ Id: 10, @@ -170,7 +165,6 @@ func TestDecryptRegionNotEncrypted(t *testing.T) { } func TestDecryptRegionWithoutKeyManager(t *testing.T) { - t.Parallel() re := require.New(t) region := &metapb.Region{ Id: 10, @@ -186,7 +180,6 @@ func TestDecryptRegionWithoutKeyManager(t *testing.T) { } func TestDecryptRegionWhileKeyMissing(t *testing.T) { - t.Parallel() re := require.New(t) keyID := uint64(3) m := newTestKeyManager() @@ -207,7 +200,6 @@ func TestDecryptRegionWhileKeyMissing(t *testing.T) { } func TestDecryptRegion(t *testing.T) { - t.Parallel() re := require.New(t) keyID := uint64(1) startKey := []byte("abc") diff --git a/pkg/errs/errs_test.go b/pkg/errs/errs_test.go index d76c02dc110..1dcabc32d9a 100644 --- a/pkg/errs/errs_test.go +++ b/pkg/errs/errs_test.go @@ -97,7 +97,6 @@ func TestError(t *testing.T) { } func TestErrorEqual(t *testing.T) { - t.Parallel() re := require.New(t) err1 := ErrSchedulerNotFound.FastGenByArgs() err2 := ErrSchedulerNotFound.FastGenByArgs() @@ -134,7 +133,6 @@ func TestZapError(t *testing.T) { } func TestErrorWithStack(t *testing.T) { - t.Parallel() re := require.New(t) conf := &log.Config{Level: "debug", File: log.FileLogConfig{}, DisableTimestamp: true} lg := newZapTestLogger(conf) diff --git a/pkg/mock/mockhbstream/mockhbstream_test.go b/pkg/mock/mockhbstream/mockhbstream_test.go index a8e88f61aee..aa1ca85279b 100644 --- a/pkg/mock/mockhbstream/mockhbstream_test.go +++ b/pkg/mock/mockhbstream/mockhbstream_test.go @@ -29,7 +29,6 @@ import ( ) func TestActivity(t *testing.T) { - t.Parallel() re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/pkg/movingaverage/avg_over_time_test.go b/pkg/movingaverage/avg_over_time_test.go index 43553d9d608..4a54e33d449 100644 --- a/pkg/movingaverage/avg_over_time_test.go +++ b/pkg/movingaverage/avg_over_time_test.go @@ -23,7 +23,6 @@ import ( ) func TestPulse(t *testing.T) { - t.Parallel() re := require.New(t) aot := NewAvgOverTime(5 * time.Second) // warm up @@ -43,7 +42,6 @@ func TestPulse(t *testing.T) { } func TestPulse2(t *testing.T) { - t.Parallel() re := require.New(t) dur := 5 * time.Second aot := NewAvgOverTime(dur) @@ -57,7 +55,6 @@ func TestPulse2(t *testing.T) { } func TestChange(t *testing.T) { - t.Parallel() re := require.New(t) aot := NewAvgOverTime(5 * time.Second) @@ -91,7 +88,6 @@ func TestChange(t *testing.T) { } func TestMinFilled(t *testing.T) { - t.Parallel() re := require.New(t) interval := 10 * time.Second rate := 1.0 @@ -108,7 +104,6 @@ func TestMinFilled(t *testing.T) { } func TestUnstableInterval(t *testing.T) { - t.Parallel() re := require.New(t) aot := NewAvgOverTime(5 * time.Second) re.Equal(0., aot.Get()) diff --git a/pkg/movingaverage/max_filter_test.go b/pkg/movingaverage/max_filter_test.go index bba770cecc2..7d3906ec93c 100644 --- a/pkg/movingaverage/max_filter_test.go +++ b/pkg/movingaverage/max_filter_test.go @@ -21,7 +21,6 @@ import ( ) func TestMaxFilter(t *testing.T) { - t.Parallel() re := require.New(t) var empty float64 = 0 data := []float64{2, 1, 3, 4, 1, 1, 3, 3, 2, 0, 5} diff --git a/pkg/movingaverage/moving_average_test.go b/pkg/movingaverage/moving_average_test.go index 49c20637c20..fd0a1a9fcf3 100644 --- a/pkg/movingaverage/moving_average_test.go +++ b/pkg/movingaverage/moving_average_test.go @@ -72,7 +72,6 @@ func checkInstantaneous(re *require.Assertions, ma MovingAvg) { } func TestMedianFilter(t *testing.T) { - t.Parallel() re := require.New(t) var empty float64 = 0 data := []float64{2, 4, 2, 800, 600, 6, 3} @@ -92,7 +91,6 @@ type testCase struct { } func TestMovingAvg(t *testing.T) { - t.Parallel() re := require.New(t) var empty float64 = 0 data := []float64{1, 1, 1, 1, 5, 1, 1, 1} diff --git a/pkg/movingaverage/weight_allocator_test.go b/pkg/movingaverage/weight_allocator_test.go index 631a71f10c9..405d8f72876 100644 --- a/pkg/movingaverage/weight_allocator_test.go +++ b/pkg/movingaverage/weight_allocator_test.go @@ -21,7 +21,6 @@ import ( ) func TestWeightAllocator(t *testing.T) { - t.Parallel() re := require.New(t) checkSumFunc := func(wa *WeightAllocator, length int) { diff --git a/pkg/ratelimit/concurrency_limiter_test.go b/pkg/ratelimit/concurrency_limiter_test.go index e77c79c8ebc..a397b6ac50f 100644 --- a/pkg/ratelimit/concurrency_limiter_test.go +++ b/pkg/ratelimit/concurrency_limiter_test.go @@ -26,7 +26,6 @@ import ( ) func TestConcurrencyLimiter(t *testing.T) { - t.Parallel() re := require.New(t) cl := NewConcurrencyLimiter(10) for i := 0; i < 10; i++ { diff --git a/pkg/ratelimit/controller_test.go b/pkg/ratelimit/controller_test.go index 48a5ee2054b..50be24540e4 100644 --- a/pkg/ratelimit/controller_test.go +++ b/pkg/ratelimit/controller_test.go @@ -78,7 +78,6 @@ func runMulitLabelLimiter(t *testing.T, limiter *Controller, testCase []labelCas } func TestControllerWithConcurrencyLimiter(t *testing.T) { - t.Parallel() re := require.New(t) limiter := NewController(context.Background(), "grpc", nil) defer limiter.Close() @@ -191,7 +190,6 @@ func TestControllerWithConcurrencyLimiter(t *testing.T) { } func TestBlockList(t *testing.T) { - t.Parallel() re := require.New(t) opts := []Option{AddLabelAllowList()} limiter := NewController(context.Background(), "grpc", nil) @@ -213,7 +211,6 @@ func TestBlockList(t *testing.T) { } func TestControllerWithQPSLimiter(t *testing.T) { - t.Parallel() re := require.New(t) limiter := NewController(context.Background(), "grpc", nil) defer limiter.Close() @@ -323,7 +320,6 @@ func TestControllerWithQPSLimiter(t *testing.T) { } func TestControllerWithTwoLimiters(t *testing.T) { - t.Parallel() re := require.New(t) limiter := NewController(context.Background(), "grpc", nil) defer limiter.Close() diff --git a/pkg/ratelimit/limiter_test.go b/pkg/ratelimit/limiter_test.go index fabb9d98917..36f339b47ac 100644 --- a/pkg/ratelimit/limiter_test.go +++ b/pkg/ratelimit/limiter_test.go @@ -40,7 +40,6 @@ func (r *releaseUtil) append(d DoneFunc) { } func TestWithConcurrencyLimiter(t *testing.T) { - t.Parallel() re := require.New(t) limiter := newLimiter() @@ -103,7 +102,6 @@ func TestWithConcurrencyLimiter(t *testing.T) { } func TestWithQPSLimiter(t *testing.T) { - t.Parallel() re := require.New(t) limiter := newLimiter() status := limiter.updateQPSConfig(float64(rate.Every(time.Second)), 1) @@ -177,7 +175,6 @@ func TestWithQPSLimiter(t *testing.T) { } func TestWithTwoLimiters(t *testing.T) { - t.Parallel() re := require.New(t) cfg := &DimensionConfig{ QPS: 100, diff --git a/pkg/ratelimit/ratelimiter_test.go b/pkg/ratelimit/ratelimiter_test.go index 35b355e7b21..f16bb6a83d2 100644 --- a/pkg/ratelimit/ratelimiter_test.go +++ b/pkg/ratelimit/ratelimiter_test.go @@ -22,7 +22,6 @@ import ( ) func TestRateLimiter(t *testing.T) { - t.Parallel() re := require.New(t) limiter := NewRateLimiter(100, 100) diff --git a/pkg/slice/slice_test.go b/pkg/slice/slice_test.go index 1fe3fe79dcf..019cd49c46a 100644 --- a/pkg/slice/slice_test.go +++ b/pkg/slice/slice_test.go @@ -22,7 +22,6 @@ import ( ) func TestSlice(t *testing.T) { - t.Parallel() re := require.New(t) testCases := []struct { a []int @@ -45,7 +44,6 @@ func TestSlice(t *testing.T) { } func TestSliceContains(t *testing.T) { - t.Parallel() re := require.New(t) ss := []string{"a", "b", "c"} re.True(slice.Contains(ss, "a")) @@ -61,7 +59,6 @@ func TestSliceContains(t *testing.T) { } func TestSliceRemoveGenericTypes(t *testing.T) { - t.Parallel() re := require.New(t) ss := []string{"a", "b", "c"} ss = slice.Remove(ss, "a") @@ -77,7 +74,6 @@ func TestSliceRemoveGenericTypes(t *testing.T) { } func TestSliceRemove(t *testing.T) { - t.Parallel() re := require.New(t) is := []int64{} diff --git a/pkg/utils/apiutil/apiutil_test.go b/pkg/utils/apiutil/apiutil_test.go index 106d3fb21cb..aee21621dd2 100644 --- a/pkg/utils/apiutil/apiutil_test.go +++ b/pkg/utils/apiutil/apiutil_test.go @@ -26,7 +26,6 @@ import ( ) func TestJsonRespondErrorOk(t *testing.T) { - t.Parallel() re := require.New(t) rd := render.New(render.Options{ IndentJSON: true, @@ -45,7 +44,6 @@ func TestJsonRespondErrorOk(t *testing.T) { } func TestJsonRespondErrorBadInput(t *testing.T) { - t.Parallel() re := require.New(t) rd := render.New(render.Options{ IndentJSON: true, @@ -71,7 +69,6 @@ func TestJsonRespondErrorBadInput(t *testing.T) { } func TestGetIPPortFromHTTPRequest(t *testing.T) { - t.Parallel() re := require.New(t) testCases := []struct { diff --git a/pkg/utils/assertutil/assertutil_test.go b/pkg/utils/assertutil/assertutil_test.go index 84bd21cef05..076cdd2ac93 100644 --- a/pkg/utils/assertutil/assertutil_test.go +++ b/pkg/utils/assertutil/assertutil_test.go @@ -22,7 +22,6 @@ import ( ) func TestNilFail(t *testing.T) { - t.Parallel() re := require.New(t) var failErr error checker := NewChecker() diff --git a/pkg/utils/grpcutil/grpcutil_test.go b/pkg/utils/grpcutil/grpcutil_test.go index 21b7e1a4acb..2cbff4f3ebc 100644 --- a/pkg/utils/grpcutil/grpcutil_test.go +++ b/pkg/utils/grpcutil/grpcutil_test.go @@ -37,7 +37,6 @@ func TestToTLSConfig(t *testing.T) { } }() - t.Parallel() re := require.New(t) tlsConfig := TLSConfig{ KeyPath: path.Join(certPath, "pd-server-key.pem"), diff --git a/pkg/utils/jsonutil/jsonutil_test.go b/pkg/utils/jsonutil/jsonutil_test.go index a046fbaf70a..1e8c21917ba 100644 --- a/pkg/utils/jsonutil/jsonutil_test.go +++ b/pkg/utils/jsonutil/jsonutil_test.go @@ -31,7 +31,6 @@ type testJSONStructLevel2 struct { } func TestJSONUtil(t *testing.T) { - t.Parallel() re := require.New(t) father := &testJSONStructLevel1{ Name: "father", diff --git a/pkg/utils/keyutil/util_test.go b/pkg/utils/keyutil/util_test.go index 374faa1f797..7bcb0a49c6f 100644 --- a/pkg/utils/keyutil/util_test.go +++ b/pkg/utils/keyutil/util_test.go @@ -21,7 +21,6 @@ import ( ) func TestKeyUtil(t *testing.T) { - t.Parallel() re := require.New(t) startKey := []byte("a") endKey := []byte("b") diff --git a/pkg/utils/logutil/log_test.go b/pkg/utils/logutil/log_test.go index 7d4be7a88bd..650ba62fe9d 100644 --- a/pkg/utils/logutil/log_test.go +++ b/pkg/utils/logutil/log_test.go @@ -23,7 +23,6 @@ import ( ) func TestStringToZapLogLevel(t *testing.T) { - t.Parallel() re := require.New(t) re.Equal(zapcore.FatalLevel, StringToZapLogLevel("fatal")) re.Equal(zapcore.ErrorLevel, StringToZapLogLevel("ERROR")) @@ -35,7 +34,6 @@ func TestStringToZapLogLevel(t *testing.T) { } func TestRedactLog(t *testing.T) { - t.Parallel() re := require.New(t) testCases := []struct { name string diff --git a/pkg/utils/metricutil/metricutil_test.go b/pkg/utils/metricutil/metricutil_test.go index b817eb0112d..acac9ce4d49 100644 --- a/pkg/utils/metricutil/metricutil_test.go +++ b/pkg/utils/metricutil/metricutil_test.go @@ -23,7 +23,6 @@ import ( ) func TestCamelCaseToSnakeCase(t *testing.T) { - t.Parallel() re := require.New(t) inputs := []struct { name string diff --git a/pkg/utils/netutil/address_test.go b/pkg/utils/netutil/address_test.go index faa3e2e1d04..127c9a6d0f7 100644 --- a/pkg/utils/netutil/address_test.go +++ b/pkg/utils/netutil/address_test.go @@ -22,7 +22,6 @@ import ( ) func TestResolveLoopBackAddr(t *testing.T) { - t.Parallel() re := require.New(t) nodes := []struct { address string @@ -40,7 +39,6 @@ func TestResolveLoopBackAddr(t *testing.T) { } func TestIsEnableHttps(t *testing.T) { - t.Parallel() re := require.New(t) re.False(IsEnableHTTPS(http.DefaultClient)) httpClient := &http.Client{ diff --git a/pkg/utils/reflectutil/tag_test.go b/pkg/utils/reflectutil/tag_test.go index f613f1f81b6..3e49e093912 100644 --- a/pkg/utils/reflectutil/tag_test.go +++ b/pkg/utils/reflectutil/tag_test.go @@ -35,7 +35,6 @@ type testStruct3 struct { } func TestFindJSONFullTagByChildTag(t *testing.T) { - t.Parallel() re := require.New(t) key := "enable" result := FindJSONFullTagByChildTag(reflect.TypeOf(testStruct1{}), key) @@ -51,7 +50,6 @@ func TestFindJSONFullTagByChildTag(t *testing.T) { } func TestFindSameFieldByJSON(t *testing.T) { - t.Parallel() re := require.New(t) input := map[string]any{ "name": "test2", @@ -65,7 +63,6 @@ func TestFindSameFieldByJSON(t *testing.T) { } func TestFindFieldByJSONTag(t *testing.T) { - t.Parallel() re := require.New(t) t1 := testStruct1{} t2 := testStruct2{} diff --git a/pkg/utils/requestutil/context_test.go b/pkg/utils/requestutil/context_test.go index 298fc1ff8a3..e6bdcd7be46 100644 --- a/pkg/utils/requestutil/context_test.go +++ b/pkg/utils/requestutil/context_test.go @@ -24,7 +24,6 @@ import ( ) func TestRequestInfo(t *testing.T) { - t.Parallel() re := require.New(t) ctx := context.Background() _, ok := RequestInfoFrom(ctx) @@ -53,7 +52,6 @@ func TestRequestInfo(t *testing.T) { } func TestEndTime(t *testing.T) { - t.Parallel() re := require.New(t) ctx := context.Background() _, ok := EndTimeFrom(ctx) diff --git a/pkg/utils/typeutil/comparison_test.go b/pkg/utils/typeutil/comparison_test.go index b296405b3d5..b53e961b4ee 100644 --- a/pkg/utils/typeutil/comparison_test.go +++ b/pkg/utils/typeutil/comparison_test.go @@ -23,7 +23,6 @@ import ( ) func TestMinUint64(t *testing.T) { - t.Parallel() re := require.New(t) re.Equal(uint64(1), MinUint64(1, 2)) re.Equal(uint64(1), MinUint64(2, 1)) @@ -31,7 +30,6 @@ func TestMinUint64(t *testing.T) { } func TestMaxUint64(t *testing.T) { - t.Parallel() re := require.New(t) re.Equal(uint64(2), MaxUint64(1, 2)) re.Equal(uint64(2), MaxUint64(2, 1)) @@ -39,7 +37,6 @@ func TestMaxUint64(t *testing.T) { } func TestMinDuration(t *testing.T) { - t.Parallel() re := require.New(t) re.Equal(time.Second, MinDuration(time.Minute, time.Second)) re.Equal(time.Second, MinDuration(time.Second, time.Minute)) @@ -47,7 +44,6 @@ func TestMinDuration(t *testing.T) { } func TestEqualFloat(t *testing.T) { - t.Parallel() re := require.New(t) f1 := rand.Float64() re.True(Float64Equal(f1, f1*1.000)) @@ -55,7 +51,6 @@ func TestEqualFloat(t *testing.T) { } func TestAreStringSlicesEquivalent(t *testing.T) { - t.Parallel() re := require.New(t) re.True(AreStringSlicesEquivalent(nil, nil)) re.True(AreStringSlicesEquivalent([]string{}, nil)) diff --git a/pkg/utils/typeutil/conversion_test.go b/pkg/utils/typeutil/conversion_test.go index 14b2f9dea52..7b17cfcbe2c 100644 --- a/pkg/utils/typeutil/conversion_test.go +++ b/pkg/utils/typeutil/conversion_test.go @@ -23,7 +23,6 @@ import ( ) func TestBytesToUint64(t *testing.T) { - t.Parallel() re := require.New(t) str := "\x00\x00\x00\x00\x00\x00\x03\xe8" a, err := BytesToUint64([]byte(str)) @@ -32,7 +31,6 @@ func TestBytesToUint64(t *testing.T) { } func TestUint64ToBytes(t *testing.T) { - t.Parallel() re := require.New(t) var a uint64 = 1000 b := Uint64ToBytes(a) @@ -41,7 +39,6 @@ func TestUint64ToBytes(t *testing.T) { } func TestJSONToUint64Slice(t *testing.T) { - t.Parallel() re := require.New(t) type testArray struct { Array []uint64 `json:"array"` diff --git a/pkg/utils/typeutil/duration_test.go b/pkg/utils/typeutil/duration_test.go index cff7c3cd66c..42db0be8cdc 100644 --- a/pkg/utils/typeutil/duration_test.go +++ b/pkg/utils/typeutil/duration_test.go @@ -27,7 +27,6 @@ type example struct { } func TestDurationJSON(t *testing.T) { - t.Parallel() re := require.New(t) example := &example{} @@ -41,7 +40,6 @@ func TestDurationJSON(t *testing.T) { } func TestDurationTOML(t *testing.T) { - t.Parallel() re := require.New(t) example := &example{} diff --git a/pkg/utils/typeutil/size_test.go b/pkg/utils/typeutil/size_test.go index 57c246953e4..f2772ff8c6d 100644 --- a/pkg/utils/typeutil/size_test.go +++ b/pkg/utils/typeutil/size_test.go @@ -23,7 +23,6 @@ import ( ) func TestSizeJSON(t *testing.T) { - t.Parallel() re := require.New(t) b := ByteSize(265421587) o, err := json.Marshal(b) @@ -40,7 +39,6 @@ func TestSizeJSON(t *testing.T) { } func TestParseMbFromText(t *testing.T) { - t.Parallel() re := require.New(t) testCases := []struct { body []string diff --git a/pkg/utils/typeutil/string_slice_test.go b/pkg/utils/typeutil/string_slice_test.go index 9a197eb68e4..9177cee0eb9 100644 --- a/pkg/utils/typeutil/string_slice_test.go +++ b/pkg/utils/typeutil/string_slice_test.go @@ -22,7 +22,6 @@ import ( ) func TestStringSliceJSON(t *testing.T) { - t.Parallel() re := require.New(t) b := StringSlice([]string{"zone", "rack"}) o, err := json.Marshal(b) @@ -36,7 +35,6 @@ func TestStringSliceJSON(t *testing.T) { } func TestEmpty(t *testing.T) { - t.Parallel() re := require.New(t) ss := StringSlice([]string{}) b, err := json.Marshal(ss) diff --git a/pkg/utils/typeutil/time_test.go b/pkg/utils/typeutil/time_test.go index 7a5baf55afa..b8078f63fa8 100644 --- a/pkg/utils/typeutil/time_test.go +++ b/pkg/utils/typeutil/time_test.go @@ -23,7 +23,6 @@ import ( ) func TestParseTimestamp(t *testing.T) { - t.Parallel() re := require.New(t) for i := 0; i < 3; i++ { t := time.Now().Add(time.Second * time.Duration(rand.Int31n(1000))) @@ -39,7 +38,6 @@ func TestParseTimestamp(t *testing.T) { } func TestSubTimeByWallClock(t *testing.T) { - t.Parallel() re := require.New(t) for i := 0; i < 100; i++ { r := rand.Int63n(1000) @@ -63,7 +61,6 @@ func TestSubTimeByWallClock(t *testing.T) { } func TestSmallTimeDifference(t *testing.T) { - t.Parallel() re := require.New(t) t1, err := time.Parse("2006-01-02 15:04:05.999", "2021-04-26 00:44:25.682") re.NoError(err) diff --git a/server/api/health_test.go b/server/api/health_test.go index 6d2caec12cd..89a9627bc37 100644 --- a/server/api/health_test.go +++ b/server/api/health_test.go @@ -26,7 +26,7 @@ import ( ) func checkSliceResponse(re *require.Assertions, body []byte, cfgs []*config.Config, unhealthy string) { - got := []Health{} + var got []Health re.NoError(json.Unmarshal(body, &got)) re.Len(cfgs, len(got)) diff --git a/server/api/member_test.go b/server/api/member_test.go index 65c0ff67360..5d692cda7d9 100644 --- a/server/api/member_test.go +++ b/server/api/member_test.go @@ -158,26 +158,7 @@ func (suite *memberTestSuite) changeLeaderPeerUrls(leader *pdpb.Member, id uint6 resp.Body.Close() } -type resignTestSuite struct { - suite.Suite - cfgs []*config.Config - servers []*server.Server - clean testutil.CleanupFunc -} - -func TestResignTestSuite(t *testing.T) { - suite.Run(t, new(resignTestSuite)) -} - -func (suite *resignTestSuite) SetupSuite() { - suite.cfgs, suite.servers, suite.clean = mustNewCluster(suite.Require(), 1) -} - -func (suite *resignTestSuite) TearDownSuite() { - suite.clean() -} - -func (suite *resignTestSuite) TestResignMyself() { +func (suite *memberTestSuite) TestResignMyself() { re := suite.Require() addr := suite.cfgs[0].ClientUrls + apiPrefix + "/api/v1/leader/resign" resp, err := testDialClient.Post(addr, "", nil) diff --git a/tools/go.mod b/tools/go.mod index 8287f834471..9d8728f7034 100644 --- a/tools/go.mod +++ b/tools/go.mod @@ -31,6 +31,7 @@ require ( github.com/tikv/pd v0.0.0-00010101000000-000000000000 github.com/tikv/pd/client v0.0.0-00010101000000-000000000000 go.etcd.io/etcd v0.5.0-alpha.5.0.20240320135013-950cd5fbe6ca + go.uber.org/automaxprocs v1.5.3 go.uber.org/goleak v1.3.0 go.uber.org/zap v1.27.0 golang.org/x/text v0.14.0 diff --git a/tools/go.sum b/tools/go.sum index ac6cc75903e..d7c7a4801b1 100644 --- a/tools/go.sum +++ b/tools/go.sum @@ -385,6 +385,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/power-devops/perfstat v0.0.0-20221212215047-62379fc7944b h1:0LFwY6Q3gMACTjAbMZBjXAqTOzOwFaj2Ld6cjeQ7Rig= github.com/power-devops/perfstat v0.0.0-20221212215047-62379fc7944b/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= +github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g= +github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U= github.com/prometheus/client_golang v1.19.0 h1:ygXvpU1AoN1MhdzckN+PyD9QJOSD4x7kmXYlnfbA6JU= github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= @@ -514,6 +516,8 @@ go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +go.uber.org/automaxprocs v1.5.3 h1:kWazyxZUrS3Gs4qUpbwo5kEIMGe/DAvi5Z4tl2NW4j8= +go.uber.org/automaxprocs v1.5.3/go.mod h1:eRbA25aqJrxAbsLO0xy5jVwPt7FQnRgjW+efnwa1WM0= go.uber.org/dig v1.9.0 h1:pJTDXKEhRqBI8W7rU7kwT5EgyRZuSMVSFcZolOvKK9U= go.uber.org/dig v1.9.0/go.mod h1:X34SnWGr8Fyla9zQNO2GSO2D+TIuqB14OS8JhYocIyw= go.uber.org/fx v1.12.0 h1:+1+3Cz9M0dFMPy9SW9XUIUHye8bnPUm7q7DroNGWYG4= diff --git a/tools/pd-ut/README.md b/tools/pd-ut/README.md new file mode 100644 index 00000000000..77b59bea4f7 --- /dev/null +++ b/tools/pd-ut/README.md @@ -0,0 +1,66 @@ +# pd-ut + +pd-ut is a tool to run unit tests for PD. + +## Build + +1. [Go](https://golang.org/) Version 1.21 or later +2. In the root directory of the [PD project](https://github.com/tikv/pd), use the `make pd-ut` command to compile and generate `bin/pd-ut` + +## Usage + +This section describes how to use the pd-ut tool. + +### brief run all tests +```shell +make ut +``` + + +### run by pd-ut + +- You should `make failpoint-enable` before running the tests. +- And after running the tests, you should `make failpoint-disable` and `make clean-test` to disable the failpoint and clean the environment. + +#### Flags description + +```shell +// run all tests +pd-ut + +// show usage +pd-ut -h + +// list all packages +pd-ut list + +// list test cases of a single package +pd-ut list $package + +// list test cases that match a pattern +pd-ut list $package 'r:$regex' + +// run all tests +pd-ut run + +// run test all cases of a single package +pd-ut run $package + +// run test cases of a single package +pd-ut run $package $test + +// run test cases that match a pattern +pd-ut run $package 'r:$regex' + +// build all test package +pd-ut build + +// build a test package +pd-ut build xxx + +// write the junitfile +pd-ut run --junitfile xxx + +// test with race flag +pd-ut run --race +``` diff --git a/tools/pd-ut/ut.go b/tools/pd-ut/ut.go new file mode 100644 index 00000000000..5d50bbd51a8 --- /dev/null +++ b/tools/pd-ut/ut.go @@ -0,0 +1,803 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bytes" + "encoding/xml" + "errors" + "fmt" + "io" + "log" + "math/rand" + "os" + "os/exec" + "path" + "regexp" + "runtime" + "strconv" + "strings" + "sync" + "time" + + // Set the correct value when it runs inside docker. + _ "go.uber.org/automaxprocs" +) + +func usage() bool { + msg := `// run all tests +pd-ut + +// show usage +pd-ut -h + +// list all packages +pd-ut list + +// list test cases of a single package +pd-ut list $package + +// list test cases that match a pattern +pd-ut list $package 'r:$regex' + +// run all tests +pd-ut run + +// run test all cases of a single package +pd-ut run $package + +// run test cases of a single package +pd-ut run $package $test + +// run test cases that match a pattern +pd-ut run $package 'r:$regex' + +// build all test package +pd-ut build + +// build a test package +pd-ut build xxx + +// write the junitfile +pd-ut run --junitfile xxx + +// test with race flag +pd-ut run --race` + + fmt.Println(msg) + return true +} + +const modulePath = "github.com/tikv/pd" + +var ( + // runtime + p int + buildParallel int + workDir string + // arguments + race bool + junitFile string +) + +func main() { + race = handleFlag("--race") + junitFile = stripFlag("--junitfile") + + // Get the correct count of CPU if it's in docker. + p = runtime.GOMAXPROCS(0) + // We use 2 * p for `go build` to make it faster. + buildParallel = p * 2 + var err error + workDir, err = os.Getwd() + if err != nil { + fmt.Println("os.Getwd() error", err) + } + + var isSucceed bool + // run all tests + if len(os.Args) == 1 { + isSucceed = cmdRun() + } + + if len(os.Args) >= 2 { + switch os.Args[1] { + case "list": + isSucceed = cmdList(os.Args[2:]...) + case "build": + isSucceed = cmdBuild(os.Args[2:]...) + case "run": + isSucceed = cmdRun(os.Args[2:]...) + default: + isSucceed = usage() + } + } + if !isSucceed { + os.Exit(1) + } +} + +func cmdList(args ...string) bool { + pkgs, err := listPackages() + if err != nil { + log.Println("list package error", err) + return false + } + + // list all packages + if len(args) == 0 { + for _, pkg := range pkgs { + fmt.Println(pkg) + } + return false + } + + // list test case of a single package + if len(args) == 1 || len(args) == 2 { + pkg := args[0] + pkgs = filter(pkgs, func(s string) bool { return s == pkg }) + if len(pkgs) != 1 { + fmt.Println("package not exist", pkg) + return false + } + + err := buildTestBinary(pkg) + if err != nil { + log.Println("build package error", pkg, err) + return false + } + exist, err := testBinaryExist(pkg) + if err != nil { + log.Println("check test binary existence error", err) + return false + } + if !exist { + fmt.Println("no test case in ", pkg) + return false + } + + res := listTestCases(pkg, nil) + + if len(args) == 2 { + res, err = filterTestCases(res, args[1]) + if err != nil { + log.Println("filter test cases error", err) + return false + } + } + + for _, x := range res { + fmt.Println(x.test) + } + } + return true +} + +func cmdBuild(args ...string) bool { + pkgs, err := listPackages() + if err != nil { + log.Println("list package error", err) + return false + } + + // build all packages + if len(args) == 0 { + err := buildTestBinaryMulti(pkgs) + if err != nil { + fmt.Println("build package error", pkgs, err) + return false + } + return true + } + + // build test binary of a single package + if len(args) >= 1 { + pkg := args[0] + err := buildTestBinary(pkg) + if err != nil { + log.Println("build package error", pkg, err) + return false + } + } + return true +} + +func cmdRun(args ...string) bool { + var err error + pkgs, err := listPackages() + if err != nil { + fmt.Println("list packages error", err) + return false + } + tasks := make([]task, 0, 5000) + start := time.Now() + // run all tests + if len(args) == 0 { + err := buildTestBinaryMulti(pkgs) + if err != nil { + fmt.Println("build package error", pkgs, err) + return false + } + + for _, pkg := range pkgs { + exist, err := testBinaryExist(pkg) + if err != nil { + fmt.Println("check test binary existence error", err) + return false + } + if !exist { + fmt.Println("no test case in ", pkg) + continue + } + + tasks = listTestCases(pkg, tasks) + } + } + + // run tests for a single package + if len(args) == 1 { + pkg := args[0] + err := buildTestBinary(pkg) + if err != nil { + log.Println("build package error", pkg, err) + return false + } + exist, err := testBinaryExist(pkg) + if err != nil { + log.Println("check test binary existence error", err) + return false + } + + if !exist { + fmt.Println("no test case in ", pkg) + return false + } + tasks = listTestCases(pkg, tasks) + } + + // run a single test + if len(args) == 2 { + pkg := args[0] + err := buildTestBinary(pkg) + if err != nil { + log.Println("build package error", pkg, err) + return false + } + exist, err := testBinaryExist(pkg) + if err != nil { + log.Println("check test binary existence error", err) + return false + } + if !exist { + fmt.Println("no test case in ", pkg) + return false + } + + tasks = listTestCases(pkg, tasks) + tasks, err = filterTestCases(tasks, args[1]) + if err != nil { + log.Println("filter test cases error", err) + return false + } + } + + fmt.Printf("building task finish, parallelism=%d, count=%d, takes=%v\n", buildParallel, len(tasks), time.Since(start)) + + taskCh := make(chan task, 100) + works := make([]numa, p) + var wg sync.WaitGroup + for i := 0; i < p; i++ { + wg.Add(1) + go works[i].worker(&wg, taskCh) + } + + shuffle(tasks) + + start = time.Now() + for _, task := range tasks { + taskCh <- task + } + close(taskCh) + wg.Wait() + fmt.Println("run all tasks takes", time.Since(start)) + + if junitFile != "" { + out := collectTestResults(works) + f, err := os.Create(junitFile) + if err != nil { + fmt.Println("create junit file fail:", err) + return false + } + if err := write(f, out); err != nil { + fmt.Println("write junit file error:", err) + return false + } + } + + for _, work := range works { + if work.Fail { + return false + } + } + return true +} + +// stripFlag strip the '--flag xxx' from the command line os.Args +// Example of the os.Args changes +// Before: ut run pkg TestXXX --junitfile yyy +// After: ut run pkg TestXXX +// The value of the flag is returned. +func stripFlag(flag string) string { + var res string + tmp := os.Args[:0] + // Iter to the flag + var i int + for ; i < len(os.Args); i++ { + if os.Args[i] == flag { + i++ + break + } + tmp = append(tmp, os.Args[i]) + } + // Handle the flag + if i < len(os.Args) { + res = os.Args[i] + i++ + } + // Iter the remain flags + for ; i < len(os.Args); i++ { + tmp = append(tmp, os.Args[i]) + } + + os.Args = tmp + return res +} + +func handleFlag(f string) (found bool) { + tmp := os.Args[:0] + for i := 0; i < len(os.Args); i++ { + if os.Args[i] == f { + found = true + continue + } + tmp = append(tmp, os.Args[i]) + } + os.Args = tmp + return +} + +type task struct { + pkg string + test string +} + +func (t *task) String() string { + return t.pkg + " " + t.test +} + +func listTestCases(pkg string, tasks []task) []task { + newCases := listNewTestCases(pkg) + for _, c := range newCases { + tasks = append(tasks, task{pkg, c}) + } + + return tasks +} + +func filterTestCases(tasks []task, arg1 string) ([]task, error) { + if strings.HasPrefix(arg1, "r:") { + r, err := regexp.Compile(arg1[2:]) + if err != nil { + return nil, err + } + tmp := tasks[:0] + for _, task := range tasks { + if r.MatchString(task.test) { + tmp = append(tmp, task) + } + } + return tmp, nil + } + tmp := tasks[:0] + for _, task := range tasks { + if strings.Contains(task.test, arg1) { + tmp = append(tmp, task) + } + } + return tmp, nil +} + +func listPackages() ([]string, error) { + cmd := exec.Command("go", "list", "./...") + ss, err := cmdToLines(cmd) + if err != nil { + return nil, withTrace(err) + } + + ret := ss[:0] + for _, s := range ss { + if !strings.HasPrefix(s, modulePath) { + continue + } + pkg := s[len(modulePath)+1:] + if skipDIR(pkg) { + continue + } + ret = append(ret, pkg) + } + return ret, nil +} + +type numa struct { + Fail bool + results []testResult +} + +func (n *numa) worker(wg *sync.WaitGroup, ch chan task) { + defer wg.Done() + for t := range ch { + res := n.runTestCase(t.pkg, t.test) + if res.Failure != nil { + fmt.Println("[FAIL] ", t.pkg, t.test) + fmt.Fprintf(os.Stderr, "err=%s\n%s", res.err.Error(), res.Failure.Contents) + n.Fail = true + } + n.results = append(n.results, res) + } +} + +type testResult struct { + JUnitTestCase + d time.Duration + err error +} + +func (n *numa) runTestCase(pkg string, fn string) testResult { + res := testResult{ + JUnitTestCase: JUnitTestCase{ + ClassName: path.Join(modulePath, pkg), + Name: fn, + }, + } + + var buf bytes.Buffer + var err error + var start time.Time + for i := 0; i < 3; i++ { + cmd := n.testCommand(pkg, fn) + cmd.Dir = path.Join(workDir, pkg) + // Combine the test case output, so the run result for failed cases can be displayed. + cmd.Stdout = &buf + cmd.Stderr = &buf + + start = time.Now() + err = cmd.Run() + if err != nil { + var exitError *exec.ExitError + if errors.As(err, &exitError) { + // Retry 3 times to get rid of the weird error: + switch err.Error() { + case "signal: segmentation fault (core dumped)": + buf.Reset() + continue + case "signal: trace/breakpoint trap (core dumped)": + buf.Reset() + continue + } + if strings.Contains(buf.String(), "panic during panic") { + buf.Reset() + continue + } + } + } + break + } + if err != nil { + res.Failure = &JUnitFailure{ + Message: "Failed", + Contents: buf.String(), + } + res.err = err + } + + res.d = time.Since(start) + res.Time = formatDurationAsSeconds(res.d) + return res +} + +func collectTestResults(workers []numa) JUnitTestSuites { + version := goVersion() + // pkg => test cases + pkgs := make(map[string][]JUnitTestCase) + durations := make(map[string]time.Duration) + + // The test result in workers are shuffled, so group by the packages here + for _, n := range workers { + for _, res := range n.results { + cases, ok := pkgs[res.ClassName] + if !ok { + cases = make([]JUnitTestCase, 0, 10) + } + cases = append(cases, res.JUnitTestCase) + pkgs[res.ClassName] = cases + durations[res.ClassName] += res.d + } + } + + suites := JUnitTestSuites{} + // Turn every package result to a suite. + for pkg, cases := range pkgs { + suite := JUnitTestSuite{ + Tests: len(cases), + Failures: failureCases(cases), + Time: formatDurationAsSeconds(durations[pkg]), + Name: pkg, + Properties: packageProperties(version), + TestCases: cases, + } + suites.Suites = append(suites.Suites, suite) + } + return suites +} + +func failureCases(input []JUnitTestCase) int { + sum := 0 + for _, v := range input { + if v.Failure != nil { + sum++ + } + } + return sum +} + +func (n *numa) testCommand(pkg string, fn string) *exec.Cmd { + args := make([]string, 0, 10) + exe := "./" + testFileName(pkg) + args = append(args, "-test.cpu", "1") + if !race { + args = append(args, []string{"-test.timeout", "2m"}...) + } else { + // it takes a longer when race is enabled. so it is set more timeout value. + args = append(args, []string{"-test.timeout", "5m"}...) + } + + // core.test -test.run TestClusteredPrefixColum + args = append(args, "-test.run", "^"+fn+"$") + + return exec.Command(exe, args...) +} + +func skipDIR(pkg string) bool { + skipDir := []string{"tests", "bin", "cmd", "tools"} + for _, ignore := range skipDir { + if strings.HasPrefix(pkg, ignore) { + return true + } + } + return false +} + +// buildTestBinaryMulti is much faster than build the test packages one by one. +func buildTestBinaryMulti(pkgs []string) error { + // go test --exec=xprog --tags=tso_function_test,deadlock -vet=off --count=0 $(pkgs) + xprogPath := path.Join(workDir, "bin/xprog") + packages := make([]string, 0, len(pkgs)) + for _, pkg := range pkgs { + packages = append(packages, path.Join(modulePath, pkg)) + } + + p := strconv.Itoa(buildParallel) + cmd := exec.Command("go", "test", "-p", p, "--exec", xprogPath, "-vet", "off", "--tags=tso_function_test,deadlock") + cmd.Args = append(cmd.Args, packages...) + cmd.Dir = workDir + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return withTrace(err) + } + return nil +} + +func buildTestBinary(pkg string) error { + //nolint:gosec + cmd := exec.Command("go", "test", "-c", "-vet", "off", "--tags=tso_function_test,deadlock", "-o", testFileName(pkg), "-v") + if race { + cmd.Args = append(cmd.Args, "-race") + } + cmd.Dir = path.Join(workDir, pkg) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return withTrace(err) + } + return nil +} + +func testBinaryExist(pkg string) (bool, error) { + _, err := os.Stat(testFileFullPath(pkg)) + if err != nil { + var pathError *os.PathError + if errors.As(err, &pathError) { + return false, nil + } + } + return true, withTrace(err) +} + +func testFileName(pkg string) string { + _, file := path.Split(pkg) + return file + ".test.bin" +} + +func testFileFullPath(pkg string) string { + return path.Join(workDir, pkg, testFileName(pkg)) +} + +func listNewTestCases(pkg string) []string { + exe := "./" + testFileName(pkg) + + // core.test -test.list Test + cmd := exec.Command(exe, "-test.list", "Test") + cmd.Dir = path.Join(workDir, pkg) + var buf bytes.Buffer + cmd.Stdout = &buf + err := cmd.Run() + res := strings.Split(buf.String(), "\n") + if err != nil && len(res) == 0 { + fmt.Println("err ==", err) + } + return filter(res, func(s string) bool { + return strings.HasPrefix(s, "Test") && s != "TestT" && s != "TestBenchDaily" + }) +} + +func cmdToLines(cmd *exec.Cmd) ([]string, error) { + res, err := cmd.Output() + if err != nil { + return nil, withTrace(err) + } + ss := bytes.Split(res, []byte{'\n'}) + ret := make([]string, len(ss)) + for i, s := range ss { + ret[i] = string(s) + } + return ret, nil +} + +func filter(input []string, f func(string) bool) []string { + ret := input[:0] + for _, s := range input { + if f(s) { + ret = append(ret, s) + } + } + return ret +} + +func shuffle(tasks []task) { + for i := 0; i < len(tasks); i++ { + pos := rand.Intn(len(tasks)) + tasks[i], tasks[pos] = tasks[pos], tasks[i] + } +} + +type errWithStack struct { + err error + buf []byte +} + +func (e *errWithStack) Error() string { + return e.err.Error() + "\n" + string(e.buf) +} + +func withTrace(err error) error { + if err == nil { + return err + } + var errStack *errWithStack + if errors.As(err, &errStack) { + return err + } + var stack [4096]byte + sz := runtime.Stack(stack[:], false) + return &errWithStack{err, stack[:sz]} +} + +func formatDurationAsSeconds(d time.Duration) string { + return fmt.Sprintf("%f", d.Seconds()) +} + +func packageProperties(goVersion string) []JUnitProperty { + return []JUnitProperty{ + {Name: "go.version", Value: goVersion}, + } +} + +// goVersion returns the version as reported by the go binary in PATH. This +// version will not be the same as runtime.Version, which is always the version +// of go used to build the gotestsum binary. +// +// To skip the os/exec call set the GOVERSION environment variable to the +// desired value. +func goVersion() string { + if version, ok := os.LookupEnv("GOVERSION"); ok { + return version + } + cmd := exec.Command("go", "version") + out, err := cmd.Output() + if err != nil { + return "unknown" + } + return strings.TrimPrefix(strings.TrimSpace(string(out)), "go version ") +} + +func write(out io.Writer, suites JUnitTestSuites) error { + doc, err := xml.MarshalIndent(suites, "", "\t") + if err != nil { + return err + } + _, err = out.Write([]byte(xml.Header)) + if err != nil { + return err + } + _, err = out.Write(doc) + return err +} + +// JUnitTestSuites is a collection of JUnit test suites. +type JUnitTestSuites struct { + XMLName xml.Name `xml:"testsuites"` + Suites []JUnitTestSuite +} + +// JUnitTestSuite is a single JUnit test suite which may contain many +// testcases. +type JUnitTestSuite struct { + XMLName xml.Name `xml:"testsuite"` + Tests int `xml:"tests,attr"` + Failures int `xml:"failures,attr"` + Time string `xml:"time,attr"` + Name string `xml:"name,attr"` + Properties []JUnitProperty `xml:"properties>property,omitempty"` + TestCases []JUnitTestCase +} + +// JUnitTestCase is a single test case with its result. +type JUnitTestCase struct { + XMLName xml.Name `xml:"testcase"` + ClassName string `xml:"classname,attr"` + Name string `xml:"name,attr"` + Time string `xml:"time,attr"` + SkipMessage *JUnitSkipMessage `xml:"skipped,omitempty"` + Failure *JUnitFailure `xml:"failure,omitempty"` +} + +// JUnitSkipMessage contains the reason why a testcase was skipped. +type JUnitSkipMessage struct { + Message string `xml:"message,attr"` +} + +// JUnitProperty represents a key/value pair used to define properties. +type JUnitProperty struct { + Name string `xml:"name,attr"` + Value string `xml:"value,attr"` +} + +// JUnitFailure contains data related to a failed test. +type JUnitFailure struct { + Message string `xml:"message,attr"` + Type string `xml:"type,attr"` + Contents string `xml:",chardata"` +} diff --git a/tools/pd-ut/xprog.go b/tools/pd-ut/xprog.go new file mode 100644 index 00000000000..cf3e9b295e2 --- /dev/null +++ b/tools/pd-ut/xprog.go @@ -0,0 +1,119 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//go:build xprog +// +build xprog + +package main + +import ( + "bufio" + "fmt" + "io" + "os" + "path/filepath" + "strings" +) + +func main() { + // See https://github.com/golang/go/issues/15513#issuecomment-773994959 + // go test --exec=xprog ./... + // Command line args looks like: + // '$CWD/xprog /tmp/go-build2662369829/b1382/aggfuncs.test -test.paniconexit0 -test.timeout=10m0s' + // This program moves the test binary /tmp/go-build2662369829/b1382/aggfuncs.test to someplace else for later use. + + // Extract the current work directory + cwd := os.Args[0] + cwd = cwd[:len(cwd)-len("bin/xprog")] + + testBinaryPath := os.Args[1] + dir, _ := filepath.Split(testBinaryPath) + + // Extract the package info from /tmp/go-build2662369829/b1382/importcfg.link + pkg := getPackageInfo(dir) + + const prefix = "github.com/tikv/pd/" + if !strings.HasPrefix(pkg, prefix) { + os.Exit(-3) + } + + // github.com/tikv/pd/server/api/api.test => server/api/api + pkg = pkg[len(prefix) : len(pkg)-len(".test")] + + _, file := filepath.Split(pkg) + + // The path of the destination file looks like $CWD/server/api/api.test.bin + newName := filepath.Join(cwd, pkg, file+".test.bin") + + if err1 := os.Rename(testBinaryPath, newName); err1 != nil { + // Rename fail, handle error like "invalid cross-device linkcd tools/check" + err1 = MoveFile(testBinaryPath, newName) + if err1 != nil { + os.Exit(-4) + } + } +} + +func getPackageInfo(dir string) string { + // Read the /tmp/go-build2662369829/b1382/importcfg.link file to get the package information + f, err := os.Open(filepath.Join(dir, "importcfg.link")) + if err != nil { + os.Exit(-1) + } + defer f.Close() + + r := bufio.NewReader(f) + line, _, err := r.ReadLine() + if err != nil { + os.Exit(-2) + } + start := strings.IndexByte(string(line), ' ') + end := strings.IndexByte(string(line), '=') + pkg := string(line[start+1 : end]) + return pkg +} + +func MoveFile(srcPath, dstPath string) error { + inputFile, err := os.Open(srcPath) + if err != nil { + return fmt.Errorf("couldn't open source file: %s", err) + } + outputFile, err := os.Create(dstPath) + if err != nil { + inputFile.Close() + return fmt.Errorf("couldn't open dst file: %s", err) + } + defer outputFile.Close() + _, err = io.Copy(outputFile, inputFile) + inputFile.Close() + if err != nil { + return fmt.Errorf("writing to output file failed: %s", err) + } + + // Handle the permissions + si, err := os.Stat(srcPath) + if err != nil { + return fmt.Errorf("stat error: %s", err) + } + err = os.Chmod(dstPath, si.Mode()) + if err != nil { + return fmt.Errorf("chmod error: %s", err) + } + + // The copy was successful, so now delete the original file + err = os.Remove(srcPath) + if err != nil { + return fmt.Errorf("failed removing original file: %s", err) + } + return nil +}