Skip to content

Commit

Permalink
Add Concurrent Download Support for artifacts (#11531)
Browse files Browse the repository at this point in the history
* add concurrent download support - resolves #11244

* format imports

* mark `wg.Done()` via `defer`

* added tests for successful and failure cases and resolved some goleak

* docs: add changelog for #11531

* test typo fixes and improvements

Co-authored-by: Michael Schurter <mschurter@hashicorp.com>
  • Loading branch information
gowthamgts and schmichael authored Apr 20, 2022
1 parent 8eb569f commit f601cc3
Show file tree
Hide file tree
Showing 3 changed files with 285 additions and 25 deletions.
3 changes: 3 additions & 0 deletions .changelog/11531.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:improvement
client: Download up to 3 artifacts concurrently
```
91 changes: 71 additions & 20 deletions client/allocrunner/taskrunner/artifact_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package taskrunner
import (
"context"
"fmt"
"sync"

log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
Expand All @@ -25,32 +26,19 @@ func newArtifactHook(e ti.EventEmitter, logger log.Logger) *artifactHook {
return h
}

func (*artifactHook) Name() string {
// Copied in client/state when upgrading from <0.9 schemas, so if you
// change it here you also must change it there.
return "artifacts"
}

func (h *artifactHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error {
if len(req.Task.Artifacts) == 0 {
resp.Done = true
return nil
}

// Initialize hook state to store download progress
resp.State = make(map[string]string, len(req.Task.Artifacts))

h.eventEmitter.EmitEvent(structs.NewTaskEvent(structs.TaskDownloadingArtifacts))

for _, artifact := range req.Task.Artifacts {
func (h *artifactHook) doWork(req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse, jobs chan *structs.TaskArtifact, errorChannel chan error, wg *sync.WaitGroup, responseStateMutex *sync.Mutex) {
defer wg.Done()
for artifact := range jobs {
aid := artifact.Hash()
if req.PreviousState[aid] != "" {
h.logger.Trace("skipping already downloaded artifact", "artifact", artifact.GetterSource)
responseStateMutex.Lock()
resp.State[aid] = req.PreviousState[aid]
responseStateMutex.Unlock()
continue
}

h.logger.Debug("downloading artifact", "artifact", artifact.GetterSource)
h.logger.Debug("downloading artifact", "artifact", artifact.GetterSource, "aid", aid)
//XXX add ctx to GetArtifact to allow cancelling long downloads
if err := getter.GetArtifact(req.TaskEnv, artifact); err != nil {

Expand All @@ -60,13 +48,76 @@ func (h *artifactHook) Prestart(ctx context.Context, req *interfaces.TaskPrestar
)
herr := NewHookError(wrapped, structs.NewTaskEvent(structs.TaskArtifactDownloadFailed).SetDownloadError(wrapped))

return herr
errorChannel <- herr
continue
}

// Mark artifact as downloaded to avoid re-downloading due to
// retries caused by subsequent artifacts failing. Any
// non-empty value works.
responseStateMutex.Lock()
resp.State[aid] = "1"
responseStateMutex.Unlock()
}
}

func (*artifactHook) Name() string {
// Copied in client/state when upgrading from <0.9 schemas, so if you
// change it here you also must change it there.
return "artifacts"
}

func (h *artifactHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error {
if len(req.Task.Artifacts) == 0 {
resp.Done = true
return nil
}

// Initialize hook state to store download progress
resp.State = make(map[string]string, len(req.Task.Artifacts))

// responseStateMutex is a lock used to guard against concurrent writes to the above resp.State map
responseStateMutex := &sync.Mutex{}

h.eventEmitter.EmitEvent(structs.NewTaskEvent(structs.TaskDownloadingArtifacts))

// maxConcurrency denotes the number of workers that will download artifacts in parallel
maxConcurrency := 3

// jobsChannel is a buffered channel which will have all the artifacts that needs to be processed
jobsChannel := make(chan *structs.TaskArtifact, maxConcurrency)

// errorChannel is also a buffered channel that will be used to signal errors
errorChannel := make(chan error, maxConcurrency)

// create workers and process artifacts
go func() {
defer close(errorChannel)
var wg sync.WaitGroup
for i := 0; i < maxConcurrency; i++ {
wg.Add(1)
go h.doWork(req, resp, jobsChannel, errorChannel, &wg, responseStateMutex)
}
wg.Wait()
}()

// Push all artifact requests to job channel
go func() {
defer close(jobsChannel)
for _, artifact := range req.Task.Artifacts {
jobsChannel <- artifact
}
}()

// Iterate over the errorChannel and if there is an error, store it to a variable for future return
var err error
for e := range errorChannel {
err = e
}

// once error channel is closed, we can check and return the error
if err != nil {
return err
}

resp.Done = true
Expand Down
216 changes: 211 additions & 5 deletions client/allocrunner/taskrunner/artifact_hook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package taskrunner

import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -73,11 +74,7 @@ func TestTaskRunner_ArtifactHook_PartialDone(t *testing.T) {
artifactHook := newArtifactHook(me, testlog.HCLogger(t))

// Create a source directory with 1 of the 2 artifacts
srcdir, err := ioutil.TempDir("", "nomadtest-src")
require.NoError(t, err)
defer func() {
require.NoError(t, os.RemoveAll(srcdir))
}()
srcdir := t.TempDir()

// Only create one of the 2 artifacts to cause an error on first run.
file1 := filepath.Join(srcdir, "foo.txt")
Expand Down Expand Up @@ -159,3 +156,212 @@ func TestTaskRunner_ArtifactHook_PartialDone(t *testing.T) {
require.True(t, resp.Done)
require.Len(t, resp.State, 2)
}

// TestTaskRunner_ArtifactHook_ConcurrentDownloadSuccess asserts that the artifact hook
// download multiple files concurrently. this is a successful test without any errors.
func TestTaskRunner_ArtifactHook_ConcurrentDownloadSuccess(t *testing.T) {
t.Parallel()

me := &mockEmitter{}
artifactHook := newArtifactHook(me, testlog.HCLogger(t))

// Create a source directory all 7 artifacts
srcdir := t.TempDir()

numOfFiles := 7
for i := 0; i < numOfFiles; i++ {
file := filepath.Join(srcdir, fmt.Sprintf("file%d.txt", i))
require.NoError(t, ioutil.WriteFile(file, []byte{byte(i)}, 0644))
}

// Test server to serve the artifacts
ts := httptest.NewServer(http.FileServer(http.Dir(srcdir)))
defer ts.Close()

// Create the target directory.
destdir, err := ioutil.TempDir("", "nomadtest-dest")
require.NoError(t, err)
defer func() {
require.NoError(t, os.RemoveAll(destdir))
}()

req := &interfaces.TaskPrestartRequest{
TaskEnv: taskenv.NewTaskEnv(nil, nil, nil, nil, destdir, ""),
TaskDir: &allocdir.TaskDir{Dir: destdir},
Task: &structs.Task{
Artifacts: []*structs.TaskArtifact{
{
GetterSource: ts.URL + "/file0.txt",
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file1.txt",
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file2.txt",
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file3.txt",
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file4.txt",
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file5.txt",
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file6.txt",
GetterMode: structs.GetterModeAny,
},
},
},
}

resp := interfaces.TaskPrestartResponse{}

// start the hook
err = artifactHook.Prestart(context.Background(), req, &resp)

require.NoError(t, err)
require.True(t, resp.Done)
require.Len(t, resp.State, 7)
require.Len(t, me.events, 1)
require.Equal(t, structs.TaskDownloadingArtifacts, me.events[0].Type)

// Assert all files downloaded properly
files, err := filepath.Glob(filepath.Join(destdir, "*.txt"))
require.NoError(t, err)
require.Len(t, files, 7)
sort.Strings(files)
require.Contains(t, files[0], "file0.txt")
require.Contains(t, files[1], "file1.txt")
require.Contains(t, files[2], "file2.txt")
require.Contains(t, files[3], "file3.txt")
require.Contains(t, files[4], "file4.txt")
require.Contains(t, files[5], "file5.txt")
require.Contains(t, files[6], "file6.txt")
}

// TestTaskRunner_ArtifactHook_ConcurrentDownloadFailure asserts that the artifact hook
// download multiple files concurrently. first iteration will result in failure and
// second iteration should succeed without downloading already downloaded files.
func TestTaskRunner_ArtifactHook_ConcurrentDownloadFailure(t *testing.T) {
t.Parallel()

me := &mockEmitter{}
artifactHook := newArtifactHook(me, testlog.HCLogger(t))

// Create a source directory with 3 of the 4 artifacts
srcdir := t.TempDir()

file1 := filepath.Join(srcdir, "file1.txt")
require.NoError(t, ioutil.WriteFile(file1, []byte{'1'}, 0644))

file2 := filepath.Join(srcdir, "file2.txt")
require.NoError(t, ioutil.WriteFile(file2, []byte{'2'}, 0644))

file3 := filepath.Join(srcdir, "file3.txt")
require.NoError(t, ioutil.WriteFile(file3, []byte{'3'}, 0644))

// Test server to serve the artifacts
ts := httptest.NewServer(http.FileServer(http.Dir(srcdir)))
defer ts.Close()

// Create the target directory.
destdir, err := ioutil.TempDir("", "nomadtest-dest")
require.NoError(t, err)
defer func() {
require.NoError(t, os.RemoveAll(destdir))
}()

req := &interfaces.TaskPrestartRequest{
TaskEnv: taskenv.NewTaskEnv(nil, nil, nil, nil, destdir, ""),
TaskDir: &allocdir.TaskDir{Dir: destdir},
Task: &structs.Task{
Artifacts: []*structs.TaskArtifact{
{
GetterSource: ts.URL + "/file0.txt", // this request will fail
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file1.txt",
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file2.txt",
GetterMode: structs.GetterModeAny,
},
{
GetterSource: ts.URL + "/file3.txt",
GetterMode: structs.GetterModeAny,
},
},
},
}

resp := interfaces.TaskPrestartResponse{}

// On first run all files will be downloaded except file0.txt
err = artifactHook.Prestart(context.Background(), req, &resp)

require.Error(t, err)
require.True(t, structs.IsRecoverable(err))
require.Len(t, resp.State, 3)
require.False(t, resp.Done)
require.Len(t, me.events, 1)
require.Equal(t, structs.TaskDownloadingArtifacts, me.events[0].Type)

// delete the downloaded files so that it'll error if it's downloaded again
require.NoError(t, os.Remove(file1))
require.NoError(t, os.Remove(file2))
require.NoError(t, os.Remove(file3))

// create the missing file
file0 := filepath.Join(srcdir, "file0.txt")
require.NoError(t, ioutil.WriteFile(file0, []byte{'0'}, 0644))

// Mock TaskRunner by copying state from resp to req and reset resp.
req.PreviousState = helper.CopyMapStringString(resp.State)

resp = interfaces.TaskPrestartResponse{}

// Retry the download and assert it succeeds
err = artifactHook.Prestart(context.Background(), req, &resp)
require.NoError(t, err)
require.True(t, resp.Done)
require.Len(t, resp.State, 4)

// Assert all files downloaded properly
files, err := filepath.Glob(filepath.Join(destdir, "*.txt"))
require.NoError(t, err)
sort.Strings(files)
require.Contains(t, files[0], "file0.txt")
require.Contains(t, files[1], "file1.txt")
require.Contains(t, files[2], "file2.txt")
require.Contains(t, files[3], "file3.txt")

// verify the file contents too, since files will also be created for failed downloads
data0, err := ioutil.ReadFile(files[0])
require.NoError(t, err)
require.Equal(t, data0, []byte{'0'})

data1, err := ioutil.ReadFile(files[1])
require.NoError(t, err)
require.Equal(t, data1, []byte{'1'})

data2, err := ioutil.ReadFile(files[2])
require.NoError(t, err)
require.Equal(t, data2, []byte{'2'})

data3, err := ioutil.ReadFile(files[3])
require.NoError(t, err)
require.Equal(t, data3, []byte{'3'})

require.True(t, resp.Done)
require.Len(t, resp.State, 4)
}

0 comments on commit f601cc3

Please sign in to comment.