diff --git a/go/pkg/client/batch_retries_test.go b/go/pkg/client/batch_retries_test.go index 2ee93cc1..48f00eb3 100644 --- a/go/pkg/client/batch_retries_test.go +++ b/go/pkg/client/batch_retries_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "sort" + "sync" "testing" "time" @@ -278,6 +279,9 @@ type sleepyBatchServer struct { numErrors int // A counter of DEADLINE_EXCEEDED errors the server has returned thus far. updateRequests int readRequests int + // These are required to pass thread sanitizer tests. + mu sync.Mutex + wg sync.WaitGroup } func (s *sleepyBatchServer) FindMissingBlobs(ctx context.Context, req *repb.FindMissingBlobsRequest) (*repb.FindMissingBlobsResponse, error) { @@ -289,24 +293,32 @@ func (s *sleepyBatchServer) GetTree(req *repb.GetTreeRequest, stream regrpc.Cont } func (s *sleepyBatchServer) BatchReadBlobs(ctx context.Context, req *repb.BatchReadBlobsRequest) (*repb.BatchReadBlobsResponse, error) { + defer s.wg.Done() + s.mu.Lock() s.readRequests++ s.numErrors++ if s.numErrors < 4 { + s.mu.Unlock() time.Sleep(s.timeout) return &repb.BatchReadBlobsResponse{}, nil } // Should not be reached. + s.mu.Unlock() return nil, status.Error(codes.Unimplemented, "") } func (s *sleepyBatchServer) BatchUpdateBlobs(ctx context.Context, req *repb.BatchUpdateBlobsRequest) (*repb.BatchUpdateBlobsResponse, error) { + defer s.wg.Done() + s.mu.Lock() s.updateRequests++ s.numErrors++ if s.numErrors < 4 { + s.mu.Unlock() time.Sleep(s.timeout) return &repb.BatchUpdateBlobsResponse{}, nil } // Should not be reached. + s.mu.Unlock() return nil, status.Error(codes.Unimplemented, "") } @@ -322,6 +334,7 @@ func TestBatchReadBlobsDeadlineExceededRetries(t *testing.T) { ctx := context.Background() retrier := client.RetryTransient() retrier.Backoff = retry.Immediately(retry.Attempts(3)) + fake.wg.Add(3) client, err := client.NewClient(ctx, instance, client.DialParams{ Service: listener.Addr().String(), NoSecurity: true, @@ -335,6 +348,7 @@ func TestBatchReadBlobsDeadlineExceededRetries(t *testing.T) { digests := []digest.Digest{digest.TestNew("a", 1)} _, err = client.BatchDownloadBlobs(ctx, digests) + fake.wg.Wait() if err == nil { t.Errorf("client.BatchDownloadBlobs(ctx, digests) = nil; expected DeadlineExceeded error got nil") } else if s, ok := status.FromError(err); ok && s.Code() != codes.DeadlineExceeded { @@ -358,6 +372,7 @@ func TestBatchUpdateBlobsDeadlineExceededRetries(t *testing.T) { ctx := context.Background() retrier := client.RetryTransient() retrier.Backoff = retry.Immediately(retry.Attempts(3)) + fake.wg.Add(3) client, err := client.NewClient(ctx, instance, client.DialParams{ Service: listener.Addr().String(), NoSecurity: true, @@ -371,6 +386,7 @@ func TestBatchUpdateBlobsDeadlineExceededRetries(t *testing.T) { blobs := map[digest.Digest][]byte{digest.TestNew("a", 1): []byte{1}} err = client.BatchWriteBlobs(ctx, blobs) + fake.wg.Wait() if err == nil { t.Errorf("client.BatchWriteBlobs(ctx, blobs) = nil; expected DeadlineExceeded error got nil") } else if s, ok := status.FromError(err); ok && s.Code() != codes.DeadlineExceeded {