diff --git a/parallel.go b/parallel.go index 0027619b4c..0349023ba2 100644 --- a/parallel.go +++ b/parallel.go @@ -2,7 +2,6 @@ package testcontainers import ( "context" - "errors" "fmt" "sync" ) @@ -32,24 +31,27 @@ func (gpe ParallelContainersError) Error() string { return fmt.Sprintf("%v", gpe.Errors) } +// parallelContainersResult represents result. +type parallelContainersResult struct { + ParallelContainersRequestError + Container Container +} + func parallelContainersRunner( ctx context.Context, requests <-chan GenericContainerRequest, - errorsCh chan<- ParallelContainersRequestError, - containers chan<- Container, + results chan<- parallelContainersResult, wg *sync.WaitGroup, ) { defer wg.Done() for req := range requests { c, err := GenericContainer(ctx, req) + res := parallelContainersResult{Container: c} if err != nil { - errorsCh <- ParallelContainersRequestError{ - Request: req, - Error: errors.Join(err, TerminateContainer(c)), - } - continue + res.Request = req + res.Error = err } - containers <- c + results <- res } } @@ -65,41 +67,26 @@ func ParallelContainers(ctx context.Context, reqs ParallelContainerRequest, opt } tasksChan := make(chan GenericContainerRequest, tasksChanSize) - errsChan := make(chan ParallelContainersRequestError) - resChan := make(chan Container) - waitRes := make(chan struct{}) - - containers := make([]Container, 0) - errors := make([]ParallelContainersRequestError, 0) + resultsChan := make(chan parallelContainersResult, tasksChanSize) + done := make(chan struct{}) - wg := sync.WaitGroup{} + var wg sync.WaitGroup wg.Add(tasksChanSize) // run workers for i := 0; i < tasksChanSize; i++ { - go parallelContainersRunner(ctx, tasksChan, errsChan, resChan, &wg) + go parallelContainersRunner(ctx, tasksChan, resultsChan, &wg) } + var errs []ParallelContainersRequestError + containers := make([]Container, 0, len(reqs)) go func() { - for { - select { - case c, ok := <-resChan: - if !ok { - resChan = nil - } else { - containers = append(containers, c) - } - case e, ok := <-errsChan: - if !ok { - errsChan = nil - } else { - errors = append(errors, e) - } - } - - if resChan == nil && errsChan == nil { - waitRes <- struct{}{} - break + defer close(done) + for res := range resultsChan { + if res.Error != nil { + errs = append(errs, res.ParallelContainersRequestError) + } else { + containers = append(containers, res.Container) } } }() @@ -108,14 +95,15 @@ func ParallelContainers(ctx context.Context, reqs ParallelContainerRequest, opt tasksChan <- req } close(tasksChan) + wg.Wait() - close(resChan) - close(errsChan) - <-waitRes + close(resultsChan) + + <-done - if len(errors) != 0 { - return containers, ParallelContainersError{Errors: errors} + if len(errs) != 0 { + return containers, ParallelContainersError{Errors: errs} } return containers, nil diff --git a/parallel_test.go b/parallel_test.go index f937b1e56d..25f919e99d 100644 --- a/parallel_test.go +++ b/parallel_test.go @@ -2,7 +2,6 @@ package testcontainers import ( "context" - "errors" "fmt" "testing" "time" @@ -99,23 +98,18 @@ func TestParallelContainers(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { res, err := ParallelContainers(context.Background(), tc.reqs, ParallelContainersOptions{}) - if err != nil { - require.NotZero(t, tc.expErrors) - var e ParallelContainersError - errors.As(err, &e) - if len(e.Errors) != tc.expErrors { - t.Fatalf("expected errors: %d, got: %d\n", tc.expErrors, len(e.Errors)) - } - } - for _, c := range res { - c := c CleanupContainer(t, c) } - if len(res) != tc.resLen { - t.Fatalf("expected containers: %d, got: %d\n", tc.resLen, len(res)) + if tc.expErrors != 0 { + require.Error(t, err) + var errs ParallelContainersError + require.ErrorAs(t, err, &errs) + require.Len(t, errs.Errors, tc.expErrors) } + + require.Len(t, res, tc.resLen) }) } } @@ -157,11 +151,8 @@ func TestParallelContainersWithReuse(t *testing.T) { ctx := context.Background() res, err := ParallelContainers(ctx, parallelRequest, ParallelContainersOptions{}) - if err != nil { - var e ParallelContainersError - errors.As(err, &e) - t.Fatalf("expected errors: %d, got: %d\n", 0, len(e.Errors)) + for _, c := range res { + CleanupContainer(t, c) } - // Container is reused, only terminate first container - CleanupContainer(t, res[0]) + require.NoError(t, err) }