diff --git a/go.mod b/go.mod index 291583e..7f43a54 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.20 require ( github.com/dop251/goja v0.0.0-20231027120936-b396bb4c349d github.com/mstoykov/k6-taskqueue-lib v0.1.0 + github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.8.4 go.k6.io/k6 v0.48.0 ) @@ -29,7 +30,6 @@ require ( github.com/onsi/gomega v1.20.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/serenize/snaker v0.0.0-20201027110005-a7ad2135616e // indirect - github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/afero v1.1.2 // indirect go.opentelemetry.io/otel v1.19.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 // indirect diff --git a/timers/timers.go b/timers/timers.go index 971024a..01c438a 100644 --- a/timers/timers.go +++ b/timers/timers.go @@ -5,6 +5,7 @@ import ( "time" "github.com/mstoykov/k6-taskqueue-lib/taskqueue" + "github.com/sirupsen/logrus" "github.com/dop251/goja" "go.k6.io/k6/js/modules" @@ -212,10 +213,21 @@ func (e *Timers) closeTaskQueue() { // so that we do not execute it twice e.taskQueueCh = nil - // wait for this to happen so we don't need to hit the event loop again - // instead this just closes the queue - ch <- struct{}{} - <-ch + select { + case ch <- struct{}{}: + // wait for this to happen so we don't need to hit the event loop again + // instead this just closes the queue + <-ch + case <-e.vu.Context().Done(): // still shortcircuit if the context is done as we might block otherwise + } +} + +// logger is helper to get a logger either from the state or the initenv +func (e *Timers) logger() logrus.FieldLogger { + if state := e.vu.State(); state != nil { + return state.Logger + } + return e.vu.InitEnv().Logger } func (e *Timers) setupTaskQueueCloserOnIterationEnd() { @@ -229,23 +241,28 @@ func (e *Timers) setupTaskQueueCloserOnIterationEnd() { // lets report timers won't be executed and clean the fields for the next execution // we need to do this on the event loop as we don't want to have a race q.Queue(func() error { - logger := e.vu.State().Logger + logger := e.logger() for _, timer := range e.queue.queue { - logger.Warnf("%s %d was stopped because the VU iteration was interrupted", timer.name, timer.id) + logger.Warnf("%s %d was stopped because the VU iteration was interrupted", + timer.name, timer.id) } // TODO: use `clear` when we only support go 1.21 and above e.timers = make(map[uint64]time.Time) + e.queue.stopTimer() e.queue = new(timerQueue) e.taskQueue = nil return nil }) + q.Close() case <-ch: + e.timers = make(map[uint64]time.Time) + e.queue.stopTimer() + e.queue = new(timerQueue) e.taskQueue = nil + q.Close() close(ch) } - e.queue.stopTimer() - q.Close() }() } diff --git a/timers/timers_test.go b/timers/timers_test.go index 1af7549..681270b 100644 --- a/timers/timers_test.go +++ b/timers/timers_test.go @@ -1,6 +1,7 @@ package timers import ( + "context" "testing" "time" @@ -138,3 +139,54 @@ func TestSetIntervalOrder(t *testing.T) { log = log[:0] } } + +func TestSetTimeoutContextCancel(t *testing.T) { + t.Parallel() + runtime := modulestest.NewRuntime(t) + err := runtime.SetupModuleSystem(map[string]any{"k6/x/timers": New()}, nil, nil) + require.NoError(t, err) + + rt := runtime.VU.Runtime() + var log []string + interruptChannel := make(chan struct{}) + require.NoError(t, rt.Set("print", func(s string) { log = append(log, s) })) + require.NoError(t, rt.Set("interrupt", func() { + select { + case interruptChannel <- struct{}{}: + default: + } + })) + + _, err = rt.RunString(`globalThis.setTimeout = require("k6/x/timers").setTimeout;`) + require.NoError(t, err) + + for i := 0; i < 2000; i++ { + ctx, cancel := context.WithCancel(context.Background()) + runtime.CancelContext = cancel + runtime.VU.CtxField = ctx + runtime.VU.RuntimeField.ClearInterrupt() + const interruptMsg = "definitely an interrupt" + go func() { + <-interruptChannel + time.Sleep(time.Millisecond) + runtime.CancelContext() + runtime.VU.RuntimeField.Interrupt(interruptMsg) + }() + _, err = runtime.RunOnEventLoop(` + (async () => { + let poll = async (resolve, reject) => { + await (async () => 5); + setTimeout(poll, 1, resolve, reject); + interrupt(); + } + setTimeout(async () => { + await new Promise(poll) + }, 0) + })() + `) + if err != nil { + require.ErrorContains(t, err, interruptMsg) + } + require.Empty(t, log) + } +}