diff --git a/.changelog/11531.txt b/.changelog/11531.txt new file mode 100644 index 000000000000..f84c986ec0b9 --- /dev/null +++ b/.changelog/11531.txt @@ -0,0 +1,3 @@ +```release-note:improvement +client: Download up to 3 artifacts concurrently +``` diff --git a/client/allocrunner/taskrunner/artifact_hook.go b/client/allocrunner/taskrunner/artifact_hook.go index 481c098e2155..627ee6e42749 100644 --- a/client/allocrunner/taskrunner/artifact_hook.go +++ b/client/allocrunner/taskrunner/artifact_hook.go @@ -3,6 +3,7 @@ package taskrunner import ( "context" "fmt" + "sync" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/nomad/client/allocrunner/interfaces" @@ -25,32 +26,19 @@ func newArtifactHook(e ti.EventEmitter, logger log.Logger) *artifactHook { return h } -func (*artifactHook) Name() string { - // Copied in client/state when upgrading from <0.9 schemas, so if you - // change it here you also must change it there. - return "artifacts" -} - -func (h *artifactHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error { - if len(req.Task.Artifacts) == 0 { - resp.Done = true - return nil - } - - // Initialize hook state to store download progress - resp.State = make(map[string]string, len(req.Task.Artifacts)) - - h.eventEmitter.EmitEvent(structs.NewTaskEvent(structs.TaskDownloadingArtifacts)) - - for _, artifact := range req.Task.Artifacts { +func (h *artifactHook) doWork(req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse, jobs chan *structs.TaskArtifact, errorChannel chan error, wg *sync.WaitGroup, responseStateMutex *sync.Mutex) { + defer wg.Done() + for artifact := range jobs { aid := artifact.Hash() if req.PreviousState[aid] != "" { h.logger.Trace("skipping already downloaded artifact", "artifact", artifact.GetterSource) + responseStateMutex.Lock() resp.State[aid] = req.PreviousState[aid] + responseStateMutex.Unlock() continue } - h.logger.Debug("downloading artifact", "artifact", artifact.GetterSource) + h.logger.Debug("downloading artifact", "artifact", artifact.GetterSource, "aid", aid) //XXX add ctx to GetArtifact to allow cancelling long downloads if err := getter.GetArtifact(req.TaskEnv, artifact); err != nil { @@ -60,13 +48,76 @@ func (h *artifactHook) Prestart(ctx context.Context, req *interfaces.TaskPrestar ) herr := NewHookError(wrapped, structs.NewTaskEvent(structs.TaskArtifactDownloadFailed).SetDownloadError(wrapped)) - return herr + errorChannel <- herr + continue } // Mark artifact as downloaded to avoid re-downloading due to // retries caused by subsequent artifacts failing. Any // non-empty value works. + responseStateMutex.Lock() resp.State[aid] = "1" + responseStateMutex.Unlock() + } +} + +func (*artifactHook) Name() string { + // Copied in client/state when upgrading from <0.9 schemas, so if you + // change it here you also must change it there. + return "artifacts" +} + +func (h *artifactHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error { + if len(req.Task.Artifacts) == 0 { + resp.Done = true + return nil + } + + // Initialize hook state to store download progress + resp.State = make(map[string]string, len(req.Task.Artifacts)) + + // responseStateMutex is a lock used to guard against concurrent writes to the above resp.State map + responseStateMutex := &sync.Mutex{} + + h.eventEmitter.EmitEvent(structs.NewTaskEvent(structs.TaskDownloadingArtifacts)) + + // maxConcurrency denotes the number of workers that will download artifacts in parallel + maxConcurrency := 3 + + // jobsChannel is a buffered channel which will have all the artifacts that needs to be processed + jobsChannel := make(chan *structs.TaskArtifact, maxConcurrency) + + // errorChannel is also a buffered channel that will be used to signal errors + errorChannel := make(chan error, maxConcurrency) + + // create workers and process artifacts + go func() { + defer close(errorChannel) + var wg sync.WaitGroup + for i := 0; i < maxConcurrency; i++ { + wg.Add(1) + go h.doWork(req, resp, jobsChannel, errorChannel, &wg, responseStateMutex) + } + wg.Wait() + }() + + // Push all artifact requests to job channel + go func() { + defer close(jobsChannel) + for _, artifact := range req.Task.Artifacts { + jobsChannel <- artifact + } + }() + + // Iterate over the errorChannel and if there is an error, store it to a variable for future return + var err error + for e := range errorChannel { + err = e + } + + // once error channel is closed, we can check and return the error + if err != nil { + return err } resp.Done = true diff --git a/client/allocrunner/taskrunner/artifact_hook_test.go b/client/allocrunner/taskrunner/artifact_hook_test.go index 0a3f21e50c37..c135b5cb4141 100644 --- a/client/allocrunner/taskrunner/artifact_hook_test.go +++ b/client/allocrunner/taskrunner/artifact_hook_test.go @@ -2,6 +2,7 @@ package taskrunner import ( "context" + "fmt" "io/ioutil" "net/http" "net/http/httptest" @@ -73,11 +74,7 @@ func TestTaskRunner_ArtifactHook_PartialDone(t *testing.T) { artifactHook := newArtifactHook(me, testlog.HCLogger(t)) // Create a source directory with 1 of the 2 artifacts - srcdir, err := ioutil.TempDir("", "nomadtest-src") - require.NoError(t, err) - defer func() { - require.NoError(t, os.RemoveAll(srcdir)) - }() + srcdir := t.TempDir() // Only create one of the 2 artifacts to cause an error on first run. file1 := filepath.Join(srcdir, "foo.txt") @@ -159,3 +156,212 @@ func TestTaskRunner_ArtifactHook_PartialDone(t *testing.T) { require.True(t, resp.Done) require.Len(t, resp.State, 2) } + +// TestTaskRunner_ArtifactHook_ConcurrentDownloadSuccess asserts that the artifact hook +// download multiple files concurrently. this is a successful test without any errors. +func TestTaskRunner_ArtifactHook_ConcurrentDownloadSuccess(t *testing.T) { + t.Parallel() + + me := &mockEmitter{} + artifactHook := newArtifactHook(me, testlog.HCLogger(t)) + + // Create a source directory all 7 artifacts + srcdir := t.TempDir() + + numOfFiles := 7 + for i := 0; i < numOfFiles; i++ { + file := filepath.Join(srcdir, fmt.Sprintf("file%d.txt", i)) + require.NoError(t, ioutil.WriteFile(file, []byte{byte(i)}, 0644)) + } + + // Test server to serve the artifacts + ts := httptest.NewServer(http.FileServer(http.Dir(srcdir))) + defer ts.Close() + + // Create the target directory. + destdir, err := ioutil.TempDir("", "nomadtest-dest") + require.NoError(t, err) + defer func() { + require.NoError(t, os.RemoveAll(destdir)) + }() + + req := &interfaces.TaskPrestartRequest{ + TaskEnv: taskenv.NewTaskEnv(nil, nil, nil, nil, destdir, ""), + TaskDir: &allocdir.TaskDir{Dir: destdir}, + Task: &structs.Task{ + Artifacts: []*structs.TaskArtifact{ + { + GetterSource: ts.URL + "/file0.txt", + GetterMode: structs.GetterModeAny, + }, + { + GetterSource: ts.URL + "/file1.txt", + GetterMode: structs.GetterModeAny, + }, + { + GetterSource: ts.URL + "/file2.txt", + GetterMode: structs.GetterModeAny, + }, + { + GetterSource: ts.URL + "/file3.txt", + GetterMode: structs.GetterModeAny, + }, + { + GetterSource: ts.URL + "/file4.txt", + GetterMode: structs.GetterModeAny, + }, + { + GetterSource: ts.URL + "/file5.txt", + GetterMode: structs.GetterModeAny, + }, + { + GetterSource: ts.URL + "/file6.txt", + GetterMode: structs.GetterModeAny, + }, + }, + }, + } + + resp := interfaces.TaskPrestartResponse{} + + // start the hook + err = artifactHook.Prestart(context.Background(), req, &resp) + + require.NoError(t, err) + require.True(t, resp.Done) + require.Len(t, resp.State, 7) + require.Len(t, me.events, 1) + require.Equal(t, structs.TaskDownloadingArtifacts, me.events[0].Type) + + // Assert all files downloaded properly + files, err := filepath.Glob(filepath.Join(destdir, "*.txt")) + require.NoError(t, err) + require.Len(t, files, 7) + sort.Strings(files) + require.Contains(t, files[0], "file0.txt") + require.Contains(t, files[1], "file1.txt") + require.Contains(t, files[2], "file2.txt") + require.Contains(t, files[3], "file3.txt") + require.Contains(t, files[4], "file4.txt") + require.Contains(t, files[5], "file5.txt") + require.Contains(t, files[6], "file6.txt") +} + +// TestTaskRunner_ArtifactHook_ConcurrentDownloadFailure asserts that the artifact hook +// download multiple files concurrently. first iteration will result in failure and +// second iteration should succeed without downloading already downloaded files. +func TestTaskRunner_ArtifactHook_ConcurrentDownloadFailure(t *testing.T) { + t.Parallel() + + me := &mockEmitter{} + artifactHook := newArtifactHook(me, testlog.HCLogger(t)) + + // Create a source directory with 3 of the 4 artifacts + srcdir := t.TempDir() + + file1 := filepath.Join(srcdir, "file1.txt") + require.NoError(t, ioutil.WriteFile(file1, []byte{'1'}, 0644)) + + file2 := filepath.Join(srcdir, "file2.txt") + require.NoError(t, ioutil.WriteFile(file2, []byte{'2'}, 0644)) + + file3 := filepath.Join(srcdir, "file3.txt") + require.NoError(t, ioutil.WriteFile(file3, []byte{'3'}, 0644)) + + // Test server to serve the artifacts + ts := httptest.NewServer(http.FileServer(http.Dir(srcdir))) + defer ts.Close() + + // Create the target directory. + destdir, err := ioutil.TempDir("", "nomadtest-dest") + require.NoError(t, err) + defer func() { + require.NoError(t, os.RemoveAll(destdir)) + }() + + req := &interfaces.TaskPrestartRequest{ + TaskEnv: taskenv.NewTaskEnv(nil, nil, nil, nil, destdir, ""), + TaskDir: &allocdir.TaskDir{Dir: destdir}, + Task: &structs.Task{ + Artifacts: []*structs.TaskArtifact{ + { + GetterSource: ts.URL + "/file0.txt", // this request will fail + GetterMode: structs.GetterModeAny, + }, + { + GetterSource: ts.URL + "/file1.txt", + GetterMode: structs.GetterModeAny, + }, + { + GetterSource: ts.URL + "/file2.txt", + GetterMode: structs.GetterModeAny, + }, + { + GetterSource: ts.URL + "/file3.txt", + GetterMode: structs.GetterModeAny, + }, + }, + }, + } + + resp := interfaces.TaskPrestartResponse{} + + // On first run all files will be downloaded except file0.txt + err = artifactHook.Prestart(context.Background(), req, &resp) + + require.Error(t, err) + require.True(t, structs.IsRecoverable(err)) + require.Len(t, resp.State, 3) + require.False(t, resp.Done) + require.Len(t, me.events, 1) + require.Equal(t, structs.TaskDownloadingArtifacts, me.events[0].Type) + + // delete the downloaded files so that it'll error if it's downloaded again + require.NoError(t, os.Remove(file1)) + require.NoError(t, os.Remove(file2)) + require.NoError(t, os.Remove(file3)) + + // create the missing file + file0 := filepath.Join(srcdir, "file0.txt") + require.NoError(t, ioutil.WriteFile(file0, []byte{'0'}, 0644)) + + // Mock TaskRunner by copying state from resp to req and reset resp. + req.PreviousState = helper.CopyMapStringString(resp.State) + + resp = interfaces.TaskPrestartResponse{} + + // Retry the download and assert it succeeds + err = artifactHook.Prestart(context.Background(), req, &resp) + require.NoError(t, err) + require.True(t, resp.Done) + require.Len(t, resp.State, 4) + + // Assert all files downloaded properly + files, err := filepath.Glob(filepath.Join(destdir, "*.txt")) + require.NoError(t, err) + sort.Strings(files) + require.Contains(t, files[0], "file0.txt") + require.Contains(t, files[1], "file1.txt") + require.Contains(t, files[2], "file2.txt") + require.Contains(t, files[3], "file3.txt") + + // verify the file contents too, since files will also be created for failed downloads + data0, err := ioutil.ReadFile(files[0]) + require.NoError(t, err) + require.Equal(t, data0, []byte{'0'}) + + data1, err := ioutil.ReadFile(files[1]) + require.NoError(t, err) + require.Equal(t, data1, []byte{'1'}) + + data2, err := ioutil.ReadFile(files[2]) + require.NoError(t, err) + require.Equal(t, data2, []byte{'2'}) + + data3, err := ioutil.ReadFile(files[3]) + require.NoError(t, err) + require.Equal(t, data3, []byte{'3'}) + + require.True(t, resp.Done) + require.Len(t, resp.State, 4) +}