diff --git a/worker/main.go b/worker/main.go index 2e7d391..bb8c65d 100644 --- a/worker/main.go +++ b/worker/main.go @@ -1,7 +1,6 @@ package main import ( - "context" "encoding/json" "errors" "flag" @@ -9,6 +8,7 @@ import ( "math/rand" "net" "os" + "sync" "time" "github.com/aws/aws-sdk-go/service/s3" @@ -118,24 +118,24 @@ func connectToServer(serverAddress string) error { // PerfTest runs a performance test as configured in testConfig func PerfTest(testConfig *common.TestCaseConfiguration, Workqueue *Workqueue, workerID string) time.Duration { workChannel := make(chan WorkItem, len(*Workqueue.Queue)) - doneChannel := make(chan bool) + notifyChan := make(chan struct{}) + wg := &sync.WaitGroup{} + wg.Add(testConfig.ParallelClients) startTime := time.Now().UTC() promTestStart.WithLabelValues(testConfig.Name).Set(float64(startTime.UnixNano() / int64(1000000))) // promTestGauge.WithLabelValues(testConfig.Name).Inc() for worker := 0; worker < testConfig.ParallelClients; worker++ { - go DoWork(workChannel, doneChannel) + go DoWork(workChannel, notifyChan, wg) } log.Infof("Started %d parallel clients", testConfig.ParallelClients) if testConfig.Runtime != 0 { - workUntilTimeout(Workqueue, workChannel, time.Duration(testConfig.Runtime)) + workUntilTimeout(Workqueue, workChannel, notifyChan, time.Duration(testConfig.Runtime)) } else { workUntilOps(Workqueue, workChannel, testConfig.OpsDeadline, testConfig.ParallelClients) } // Wait for all the goroutines to finish - for i := 0; i < testConfig.ParallelClients; i++ { - <-doneChannel - } + wg.Wait() log.Info("All clients finished") endTime := time.Now().UTC() promTestEnd.WithLabelValues(testConfig.Name).Set(float64(endTime.UnixNano() / int64(1000000))) @@ -161,15 +161,14 @@ func PerfTest(testConfig *common.TestCaseConfiguration, Workqueue *Workqueue, wo return endTime.Sub(startTime) } -func workUntilTimeout(Workqueue *Workqueue, workChannel chan WorkItem, runtime time.Duration) { - workContext, WorkCancel = context.WithCancel(context.Background()) +func workUntilTimeout(Workqueue *Workqueue, workChannel chan WorkItem, notifyChan chan<- struct{}, runtime time.Duration) { timer := time.NewTimer(runtime) for { for _, work := range *Workqueue.Queue { select { case <-timer.C: log.Debug("Reached Runtime end") - WorkCancel() + close(notifyChan) return case workChannel <- work: } diff --git a/worker/workItems.go b/worker/workItems.go index bced28e..4d5248d 100644 --- a/worker/workItems.go +++ b/worker/workItems.go @@ -2,10 +2,10 @@ package main import ( "bytes" - "context" "fmt" "math/rand" "sort" + "sync" "time" log "github.com/sirupsen/logrus" @@ -79,15 +79,6 @@ func GetNextOperation(Queue *Workqueue) string { return Queue.OperationValues[0].Key } -func init() { - workContext = context.Background() -} - -var workContext context.Context - -// WorkCancel is the function to stop the execution of jobs -var WorkCancel context.CancelFunc - // IncreaseOperationValue increases the given operation's value by the set amount func IncreaseOperationValue(operation string, value float64, Queue *Workqueue) error { for i := range Queue.OperationValues { @@ -229,18 +220,17 @@ func (op Stopper) Clean() error { // DoWork processes the workitems in the workChannel until // either the time runs out or a stopper is found -func DoWork(workChannel chan WorkItem, doneChannel chan bool) { +func DoWork(workChannel <-chan WorkItem, notifyChan <-chan struct{}, wg *sync.WaitGroup) { + defer wg.Done() for { select { - case <-workContext.Done(): + case <-notifyChan: log.Debugf("Runtime over - Got timeout from work context") - doneChannel <- true return case work := <-workChannel: switch work.(type) { case Stopper: log.Debug("Found the end of the work Queue - stopping") - doneChannel <- true return } err := work.Do()