Skip to content
This repository has been archived by the owner on Dec 9, 2022. It is now read-only.

Commit

Permalink
Merge pull request #37 from deliveroo/PRIC-1537_Allow_pipeline_comman…
Browse files Browse the repository at this point in the history
…d_to_accept_list_of_files

PRIC-1356 Allow paddle to download individual files
  • Loading branch information
Adriano Medeiros dos Santos authored Jul 15, 2019
2 parents aae53f5 + 428ac6a commit 9d6a841
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 38 deletions.
61 changes: 37 additions & 24 deletions cli/data/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ var (
getBranch string
getCommitPath string
getBucket string
getFiles []string
getKeys []string
getSubdir string
)

Expand All @@ -54,7 +54,7 @@ var getCmd = &cobra.Command{
Example:
$ paddle data get -b experimental --bucket roo-pipeline --subdir version1 trained-model/version1 dest/path
$ paddle data get -b experimental --bucket roo-pipeline --files file1.csv,file2.csv trained-model/version1 dest/path
$ paddle data get -b experimental --bucket roo-pipeline --keys file1.csv,file2.csv trained-model/version1 dest/path
`,
Run: func(cmd *cobra.Command, args []string) {
if getBucket == "" {
Expand All @@ -69,19 +69,19 @@ $ paddle data get -b experimental --bucket roo-pipeline --files file1.csv,file2.
path: fmt.Sprintf("%s/%s/%s", args[0], getBranch, getCommitPath),
}

copyPathToDestination(source, args[1], getFiles, getSubdir)
copyPathToDestination(source, args[1], getKeys, getSubdir)
},
}

func init() {
getCmd.Flags().StringVarP(&getBranch, "branch", "b", "master", "Branch to work on")
getCmd.Flags().StringVar(&getBucket, "bucket", "", "Bucket to use")
getCmd.Flags().StringVarP(&getCommitPath, "path", "p", "HEAD", "Path to fetch (instead of HEAD)")
getCmd.Flags().StringSliceVarP(&getFiles, "files", "f", []string{}, "A list of files to download separated by comma")
getCmd.Flags().StringSliceVarP(&getKeys, "keys", "k", []string{}, "A list of keys to download separated by comma")
getCmd.Flags().StringVarP(&getSubdir, "subdir", "d", "", "Custom subfolder name for export path")
}

func copyPathToDestination(source S3Path, destination string, files []string, subdir string) {
func copyPathToDestination(source S3Path, destination string, keys []string, subdir string) {
session := session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
}))
Expand All @@ -101,7 +101,7 @@ func copyPathToDestination(source S3Path, destination string, files []string, su
}

fmt.Println("Copying " + source.path + " to " + destination)
copy(session, source, destination, files)
copy(session, source, destination, keys)
}

func readHEAD(session *session.Session, source S3Path) string {
Expand All @@ -127,7 +127,7 @@ func parseDestination(destination string, subdir string) string {
return destination
}

func copy(session *session.Session, source S3Path, destination string, files []string) {
func copy(session *session.Session, source S3Path, destination string, keys []string) {
query := &s3.ListObjectsV2Input{
Bucket: aws.String(source.bucket),
Prefix: aws.String(source.path),
Expand All @@ -141,7 +141,7 @@ func copy(session *session.Session, source S3Path, destination string, files []s
return
}

copyToLocalFiles(svc, response.Contents, source, destination, files)
copyToLocalFiles(svc, response.Contents, source, destination, keys)

// Check if more results
query.ContinuationToken = response.NextContinuationToken
Expand All @@ -152,14 +152,18 @@ func copy(session *session.Session, source S3Path, destination string, files []s
}
}

func copyToLocalFiles(s3Client *s3.S3, objects []*s3.Object, source S3Path, destination string, files []string) {
func copyToLocalFiles(s3Client *s3.S3, objects []*s3.Object, source S3Path, destination string, keys []string) {
var (
wg = new(sync.WaitGroup)
sem = make(chan struct{}, s3ParallelGets)
downloadList = filterObjects(objects, files)
wg = new(sync.WaitGroup)
sem = make(chan struct{}, s3ParallelGets)
)

wg.Add(len(objects))
downloadList, err := filterObjects(source, objects, keys)
if err != nil {
exitErrorf("Error downloading keys: %v", err)
}

wg.Add(len(downloadList))

for _, key := range downloadList {
go process(s3Client, source, destination, *key.Key, sem, wg)
Expand All @@ -168,20 +172,29 @@ func copyToLocalFiles(s3Client *s3.S3, objects []*s3.Object, source S3Path, dest
wg.Wait()
}

func filterObjects(objects []*s3.Object, files []string) []*s3.Object {
var downloadList []*s3.Object
if len(files) == 0 {
return objects
func filterObjects(source S3Path, objects []*s3.Object, keys []string) ([]*s3.Object, error) {
var (
downloadList []*s3.Object
objsByKey = make(map[string]*s3.Object)
keysNotFound []string
)
if len(keys) == 0 {
return objects, nil
}
for _, obj := range objects {
objsByKey[*obj.Key] = obj
}
for _, key := range objects {
_, file := filepath.Split(*key.Key)
for _, value := range files {
if value == file {
downloadList = append(downloadList, key)
}
for _, key := range keys {
if obj, contains := objsByKey[source.path+key]; contains {
downloadList = append(downloadList, obj)
continue
}
keysNotFound = append(keysNotFound, key)
}
if len(keysNotFound) > 0 {
return nil, errors.New("couldn't find " + strings.Join(keysNotFound, ","))
}
return downloadList
return downloadList, nil
}

func process(s3Client *s3.S3, src S3Path, basePath string, filePath string, sem chan struct{}, wg *sync.WaitGroup) {
Expand Down
54 changes: 41 additions & 13 deletions cli/data/get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,58 @@ import (

func TestFilterObjects(t *testing.T) {
var (
file1 = "file1.csv"
file2 = "file2.csv"
obj1 = &s3.Object{Key: &file1}
obj2 = &s3.Object{Key: &file2}
files = []string{"file1.csv"}
key1 = "path/file1.csv"
key2 = "path/file2.csv"
key3 = "path/folder/file3.csv"
obj1 = &s3.Object{Key: &key1}
obj2 = &s3.Object{Key: &key2}
obj3 = &s3.Object{Key: &key3}
keys = []string{"file1.csv", "file2.csv", "folder/file3.csv"}
s3Path = S3Path{bucket: "bucket", path: "path/"}
)

result := filterObjects([]*s3.Object{obj1, obj2}, files)
result, err := filterObjects(s3Path, []*s3.Object{obj1, obj2, obj3}, keys)
if err != nil {
t.Errorf("It should filter objects properly, but %v", err)
}

if len(result) != 1 {
t.Errorf("Failed to filter files, got: %v, want: %v.", len(result), len(files))
if len(result) != 3 {
t.Errorf("Failed to filter keys got: %v, want: 3", len(result))
}
}

func TestFilterObjectsWithEmptyFiles(t *testing.T) {
func TestFilterObjectsWithNoKeys(t *testing.T) {
var (
file = "file.csv"
obj = &s3.Object{Key: &file}
key = "path/file.csv"
obj = &s3.Object{Key: &key}
s3Path = S3Path{bucket: "bucket", path: "path/"}
)

result := filterObjects([]*s3.Object{obj}, []string{})
result, err := filterObjects(s3Path, []*s3.Object{obj}, []string{})
if err != nil {
t.Errorf("It should filter objects properly, but %v", err)
}

length := len(result)
if length != 1 {
t.Errorf("Failed to filter files, got: %v, want: %v.", length, length)
t.Errorf("It should return all objects, but got: %v, want: 1.", length)
}
}

func TestFilterObjectsUsingNonExistentKeys(t *testing.T) {
var (
key = "path/f1.csv"
obj = &s3.Object{Key: &key}
s3Path = S3Path{bucket: "bucket", path: "path/"}
keys = []string{"f2.csv", "f3.csv"}
)

result, err := filterObjects(s3Path, []*s3.Object{obj}, keys)
if result != nil {
t.Error("It should not return a list of S3 objects")
}

if err == nil {
t.Error("It should return an error")
}
}
1 change: 1 addition & 0 deletions cli/pipeline/pipeline_definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type PipelineDefinitionStep struct {
Branch string `yaml:"branch" json:"branch"`
Path string `yaml:"path" json:"path"`
Bucket string `yaml:"bucket" json:"bucket"`
Keys []string `yaml:"keys" json:"keys"`
Subdir string `yaml:"subdir" json:"subdir"`
} `yaml:"inputs" json:"inputs"`
Commands []string `yaml:"commands" json:"commands"`
Expand Down
10 changes: 9 additions & 1 deletion cli/pipeline/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ spec:
- "-c"
- "mkdir -p $INPUT_PATH $OUTPUT_PATH &&
{{ range $index, $input := .Step.Inputs }}
paddle data get {{ $input.Step }}/{{ $input.Version }} $INPUT_PATH -b {{ $input.Branch | sanitizeName }} -p {{ $input.Path }} {{ $input.Bucket | bucketParam }} {{ $input.Subdir | subdirParam }} &&
paddle data get {{ $input.Step }}/{{ $input.Version }} $INPUT_PATH -b {{ $input.Branch | sanitizeName }} -p {{ $input.Path }} {{ $input.Bucket | bucketParam }} {{$input.Keys | keysParam}} {{ $input.Subdir | subdirParam }} &&
{{ end }}
touch /data/first-step.txt &&
echo first step finished &&
Expand Down Expand Up @@ -209,6 +209,7 @@ func (p PodDefinition) compile() *bytes.Buffer {
fmap := template.FuncMap{
"sanitizeName": sanitizeName,
"bucketParam": p.bucketParam,
"keysParam": p.keysParam,
"subdirParam": p.subdirParam,
}
tmpl := template.Must(template.New("podTemplate").Funcs(fmap).Parse(podTemplate))
Expand Down Expand Up @@ -270,6 +271,13 @@ func (p *PodDefinition) bucketParam(bucket string) string {
return ""
}

func (p *PodDefinition) keysParam(keys []string) string {
if len(keys) != 0 {
return "--keys " + strings.Join(keys, ",")
}
return ""
}

func (p *PodDefinition) subdirParam(subdir string) string {
if subdir != "" {
return "-d " + subdir
Expand Down
33 changes: 33 additions & 0 deletions cli/pipeline/template_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,36 @@ func TestEnv(t *testing.T) {
t.Errorf("Did not find env var")
}
}

func TestKeys(t *testing.T) {
data, err := ioutil.ReadFile("test/sample_keys.yml")
if err != nil {
panic(err.Error())
}

pipeline := ParsePipeline(data)
podDefinition := NewPodDefinition(pipeline, &pipeline.Steps[0])

keys := podDefinition.Step.Inputs[0].Keys
if len(keys) != 3 {
t.Errorf("Failed to parse keys, got: %v, want: 3.", len(keys))
}

stepPodBuffer := podDefinition.compile()

pod := &v1.Pod{}
yaml.NewYAMLOrJSONDecoder(stepPodBuffer, 4096).Decode(pod)

if pod.Name != "sample-keys-version1-step1-master" {
t.Errorf("Pod name is %s", pod.Name)
}

if pod.Spec.Containers[0].Image != pipeline.Steps[0].Image {
t.Errorf("First image is %s", pod.Spec.Containers[0].Image)
}

command := pod.Spec.Containers[1].Command[2]
if !strings.Contains(command, "--keys file1.json,file2.json,folder/file3.json") {
t.Errorf("Failed to build paddle get, keys flag is missing")
}
}
26 changes: 26 additions & 0 deletions cli/pipeline/test/sample_keys.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
pipeline: sample-keys
bucket: "{{ s3_bucket_name | default('canoe-sample-pipeline') }}"
namespace: modeltraining

steps:
-
step: step1
version: version1
inputs:
-
step: step1
version: version1
branch: master
path: HEAD
keys:
- file1.json
- file2.json
- folder/file3.json
image: 219541440308.dkr.ecr.eu-west-1.amazonaws.com/paddlecontainer:latest
branch: master
commands:
- echo executing sample-pipeline-data > ${OUTPUT_PATH}/sample-pipeline-data-model.txt
resources:
cpu: 2
memory: 2Gi
storage-mb: 1000

0 comments on commit 9d6a841

Please sign in to comment.