diff --git a/client.go b/client.go index 4651ba5..18d507c 100644 --- a/client.go +++ b/client.go @@ -2,6 +2,7 @@ package main import ( "context" + "sync" "time" "golang.org/x/sync/errgroup" @@ -66,27 +67,25 @@ func (client *Client) Serve(ctx context.Context) error { defer close(respCh) g, ctx := errgroup.WithContext(ctx) - ctx, finalized := context.WithCancel(ctx) - defer finalized() - - g.Go(func() error { - for { - select { - case resp := <-respCh: - client.Stats.Count(resp.Err, resp.Elapsed) - if client.ResponseHandler != nil { - client.ResponseHandler(resp) - } - case <-ctx.Done(): - return ctx.Err() + + respWait := sync.WaitGroup{} + + go func() { + for resp := range respCh { + client.Stats.Count(resp.Err, resp.Elapsed) + if client.ResponseHandler != nil { + client.ResponseHandler(resp) } + respWait.Done() } - }) + }() if client.Concurrency < 1 { client.Concurrency = 1 } + logger.Debug().Str("endpoint", client.Endpoint).Int("concurrency", client.Concurrency).Msg("starting client") + for i := 0; i < client.Concurrency; i++ { g.Go(func() error { // Consume requests @@ -97,21 +96,24 @@ func (client *Client) Serve(ctx context.Context) error { for { select { case <-ctx.Done(): - logger.Debug().Str("endpoint", client.Endpoint).Msg("shutting down client") + logger.Debug().Str("endpoint", client.Endpoint).Msg("aborting client") return nil case req := <-client.In: if req.ID == -1 { // Final request received, shutdown - finalized() + logger.Debug().Str("endpoint", client.Endpoint).Msg("received final request, shutting down") return nil } + respWait.Add(1) respCh <- req.Do(t) } } }) } - return g.Wait() + err := g.Wait() + respWait.Wait() + return err } var id requestID @@ -122,8 +124,11 @@ type Clients []*Client // serving will end cleanly. func (c Clients) Finalize() { for _, client := range c { - client.In <- Request{ - ID: -1, + for i := 0; i < client.Concurrency; i++ { + // Signal each client instance to shut down + client.In <- Request{ + ID: -1, + } } } } diff --git a/main.go b/main.go index 1b50051..84f011d 100644 --- a/main.go +++ b/main.go @@ -123,9 +123,6 @@ func run(ctx context.Context, options Options) error { timeout = d } - ctx, done := context.WithCancel(ctx) - defer done() - g, ctx := errgroup.WithContext(ctx) // Launch clients @@ -145,7 +142,7 @@ func run(ctx context.Context, options Options) error { if len(options.Verbose) > 0 { r.MismatchedResponse = func(resps []Response) { - logger.Info().Msgf("mismatched responses: %v", resps) + logger.Info().Msgf("mismatched responses: %s", Responses(resps).String()) } } diff --git a/request.go b/request.go index ebc2aab..80a61c9 100644 --- a/request.go +++ b/request.go @@ -1,9 +1,6 @@ package main import ( - "bytes" - "fmt" - "strings" "time" ) @@ -31,46 +28,3 @@ func (req *Request) Do(t Transport) Response { Elapsed: time.Now().Sub(timeStarted), } } - -type Response struct { - client *Client - Request *Request - - ID requestID - Body []byte - Err error - - Elapsed time.Duration -} - -func (r *Response) Equal(other Response) bool { - if r.Err == nil && other.Err == nil { - return bytes.Equal(r.Body, other.Body) - } - if r.Err != nil && other.Err != nil { - return r.Err.Error() == other.Err.Error() && bytes.Equal(r.Body, other.Body) - } - return false -} - -type Responses []*Response - -func (r Responses) String() string { - var buf strings.Builder - - // TODO: Sort before printing - last := r[0] - for i, resp := range r { - fmt.Fprintf(&buf, "\t%s", resp.Elapsed) - - if resp.Err.Error() != last.Err.Error() { - fmt.Fprintf(&buf, "[%d: error mismatch: %s != %s]", i, resp.Err, last.Err) - } - - if !bytes.Equal(resp.Body, last.Body) { - fmt.Fprintf(&buf, "[%d: body mismatch: %s != %s]", i, resp.Body[:20], last.Body[:20]) - } - } - - return buf.String() -} diff --git a/response.go b/response.go new file mode 100644 index 0000000..b9a82a2 --- /dev/null +++ b/response.go @@ -0,0 +1,54 @@ +package main + +import ( + "bytes" + "fmt" + "strings" + "time" +) + +type Response struct { + client *Client + Request *Request + + ID requestID + Body []byte + Err error + + Elapsed time.Duration +} + +func (r *Response) Equal(other Response) bool { + if r.Err == nil && other.Err == nil { + return bytes.Equal(r.Body, other.Body) + } + if r.Err != nil && other.Err != nil { + return r.Err.Error() == other.Err.Error() && bytes.Equal(r.Body, other.Body) + } + return false +} + +type Responses []Response + +func (resps Responses) String() string { + var buf strings.Builder + + // TODO: Sort before printing + last := resps[0] + for i, resp := range resps { + fmt.Fprintf(&buf, "\t%s", resp.Elapsed) + + if resp.Err == nil && last.Err == nil { + if !bytes.Equal(resp.Body, last.Body) { + fmt.Fprintf(&buf, "[%d: body mismatch: %s != %s]", i, resp.Body[:50], last.Body[:50]) + } + } else if resp.Err != nil && last.Err != nil && resp.Err.Error() != last.Err.Error() { + fmt.Fprintf(&buf, "[%d: error mismatch: %s != %s]", i, resp.Err, last.Err) + } else { + fmt.Fprintf(&buf, "[%d: error mismatch: %s != %s]", i, resp.Err, last.Err) + } + + } + + return buf.String() +}