Skip to content
This repository has been archived by the owner on Dec 9, 2022. It is now read-only.

Commit

Permalink
PRIC-1356 Allow paddle to download individual files
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianomedeirossantos committed Jul 3, 2019
1 parent f763e98 commit 36e3486
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 51 deletions.
63 changes: 38 additions & 25 deletions cli/data/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ var (
getBranch string
getCommitPath string
getBucket string
getFiles []string
getKeys []string
)

const (
Expand All @@ -53,7 +53,7 @@ var getCmd = &cobra.Command{
Example:
$ paddle data get -b experimental --bucket roo-pipeline trained-model/version1 dest/path
$ paddle data get -b experimental --bucket roo-pipeline --files file1.csv,file2.csv trained-model/version1 dest/path
$ paddle data get -b experimental --bucket roo-pipeline --keys file1.csv,file2.csv trained-model/version1 dest/path
`,
Run: func(cmd *cobra.Command, args []string) {
if getBucket == "" {
Expand All @@ -68,18 +68,18 @@ $ paddle data get -b experimental --bucket roo-pipeline --files file1.csv,file2.
path: fmt.Sprintf("%s/%s/%s", args[0], getBranch, getCommitPath),
}

copyPathToDestination(source, args[1], getFiles)
copyPathToDestination(source, args[1], getKeys)
},
}

func init() {
getCmd.Flags().StringVarP(&getBranch, "branch", "b", "master", "Branch to work on")
getCmd.Flags().StringVar(&getBucket, "bucket", "", "Bucket to use")
getCmd.Flags().StringVarP(&getCommitPath, "path", "p", "HEAD", "Path to fetch (instead of HEAD)")
getCmd.Flags().StringSliceVarP(&getFiles, "files", "f", []string{}, "A list of files to download separated by comma")
getCmd.Flags().StringSliceVarP(&getKeys, "keys", "k", []string{}, "A list of keys to download separated by comma")
}

func copyPathToDestination(source S3Path, destination string, files []string) {
func copyPathToDestination(source S3Path, destination string, keys []string) {
session := session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
}))
Expand All @@ -96,7 +96,7 @@ func copyPathToDestination(source S3Path, destination string, files []string) {
}

fmt.Println("Copying " + source.path + " to " + destination)
copy(session, source, destination, files)
copy(session, source, destination, keys)
}

func readHEAD(session *session.Session, source S3Path) string {
Expand All @@ -113,7 +113,7 @@ func readHEAD(session *session.Session, source S3Path) string {
return buf.String()
}

func copy(session *session.Session, source S3Path, destination string, files []string) {
func copy(session *session.Session, source S3Path, destination string, keys []string) {
query := &s3.ListObjectsV2Input{
Bucket: aws.String(source.bucket),
Prefix: aws.String(source.path),
Expand All @@ -127,7 +127,7 @@ func copy(session *session.Session, source S3Path, destination string, files []s
return
}

copyToLocalFiles(svc, response.Contents, source, destination, files)
copyToLocalFiles(svc, response.Contents, source, destination, keys)

// Check if more results
query.ContinuationToken = response.NextContinuationToken
Expand All @@ -138,14 +138,18 @@ func copy(session *session.Session, source S3Path, destination string, files []s
}
}

func copyToLocalFiles(s3Client *s3.S3, objects []*s3.Object, source S3Path, destination string, files []string) {
func copyToLocalFiles(s3Client *s3.S3, objects []*s3.Object, source S3Path, destination string, keys []string) {
var (
wg = new(sync.WaitGroup)
sem = make(chan struct{}, s3ParallelGets)
downloadList = filterObjects(objects, files)
wg = new(sync.WaitGroup)
sem = make(chan struct{}, s3ParallelGets)
)

wg.Add(len(objects))
downloadList, err := filterObjects(source, objects, keys)
if err != nil {
exitErrorf("Error downloading keys: %v", err)
}

wg.Add(len(downloadList))

for _, key := range downloadList {
go process(s3Client, source, destination, *key.Key, sem, wg)
Expand All @@ -154,20 +158,29 @@ func copyToLocalFiles(s3Client *s3.S3, objects []*s3.Object, source S3Path, dest
wg.Wait()
}

func filterObjects(objects []*s3.Object, files []string) []*s3.Object {
var downloadList []*s3.Object
if len(files) == 0 {
return objects
}
for _, key := range objects {
_, file := filepath.Split(*key.Key)
for _, value := range files {
if value == file {
downloadList = append(downloadList, key)
}
func filterObjects(source S3Path, objects []*s3.Object, keys []string) ([]*s3.Object, error) {
var (
downloadList []*s3.Object
objsByKey = make(map[string]*s3.Object)
keysNotFound []string
)
if len(keys) == 0 {
return objects, nil
}
for _, obj := range objects {
objsByKey[*obj.Key] = obj
}
for _, key := range keys {
if obj, contains := objsByKey[source.path+key]; contains {
downloadList = append(downloadList, obj)
continue
}
keysNotFound = append(keysNotFound, key)
}
if len(keysNotFound) > 0 {
return nil, errors.New("couldn't find " + strings.Join(keysNotFound, ","))
}
return downloadList
return downloadList, nil
}

func process(s3Client *s3.S3, src S3Path, basePath string, filePath string, sem chan struct{}, wg *sync.WaitGroup) {
Expand Down
54 changes: 41 additions & 13 deletions cli/data/get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,58 @@ import (

func TestFilterObjects(t *testing.T) {
var (
file1 = "file1.csv"
file2 = "file2.csv"
obj1 = &s3.Object{Key: &file1}
obj2 = &s3.Object{Key: &file2}
files = []string{"file1.csv"}
key1 = "path/file1.csv"
key2 = "path/file2.csv"
key3 = "path/folder/file3.csv"
obj1 = &s3.Object{Key: &key1}
obj2 = &s3.Object{Key: &key2}
obj3 = &s3.Object{Key: &key3}
keys = []string{"file1.csv", "file2.csv", "folder/file3.csv"}
s3Path = S3Path{bucket: "bucket", path: "path/"}
)

result := filterObjects([]*s3.Object{obj1, obj2}, files)
result, err := filterObjects(s3Path, []*s3.Object{obj1, obj2, obj3}, keys)
if err != nil {
t.Errorf("It should filter objects properly, but %v", err)
}

if len(result) != 1 {
t.Errorf("Failed to filter files, got: %v, want: %v.", len(result), len(files))
if len(result) != 3 {
t.Errorf("Failed to filter keys got: %v, want: 3", len(result))
}
}

func TestFilterObjectsWithEmptyFiles(t *testing.T) {
func TestFilterObjectsWithNoKeys(t *testing.T) {
var (
file = "file.csv"
obj = &s3.Object{Key: &file}
key = "path/file.csv"
obj = &s3.Object{Key: &key}
s3Path = S3Path{bucket: "bucket", path: "path/"}
)

result := filterObjects([]*s3.Object{obj}, []string{})
result, err := filterObjects(s3Path, []*s3.Object{obj}, []string{})
if err != nil {
t.Errorf("It should filter objects properly, but %v", err)
}

length := len(result)
if length != 1 {
t.Errorf("Failed to filter files, got: %v, want: %v.", length, length)
t.Errorf("It should return all objects, but got: %v, want: 1.", length)
}
}

func TestFilterObjectsUsingNonExistentKeys(t *testing.T) {
var (
key = "path/f1.csv"
obj = &s3.Object{Key: &key}
s3Path = S3Path{bucket: "bucket", path: "path/"}
keys = []string{"f2.csv", "f3.csv"}
)

result, err := filterObjects(s3Path, []*s3.Object{obj}, keys)
if result != nil {
t.Error("It should not return a list of S3 objects")
}

if err == nil {
t.Error("It should return an error")
}
}
11 changes: 6 additions & 5 deletions cli/pipeline/pipeline_definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ type PipelineDefinitionStep struct {
Branch string `yaml:"branch" json:"branch"`
Image string `yaml:"image" json:"image"`
Inputs []struct {
Step string `yaml:"step" json:"step"`
Version string `yaml:"version" json:"version"`
Branch string `yaml:"branch" json:"branch"`
Path string `yaml:"path" json:"path"`
Bucket string `yaml:"bucket" json:"bucket"`
Step string `yaml:"step" json:"step"`
Version string `yaml:"version" json:"version"`
Branch string `yaml:"branch" json:"branch"`
Path string `yaml:"path" json:"path"`
Bucket string `yaml:"bucket" json:"bucket"`
Keys []string `yaml:"keys" json:"keys"`
} `yaml:"inputs" json:"inputs"`
Commands []string `yaml:"commands" json:"commands"`
Resources struct {
Expand Down
8 changes: 4 additions & 4 deletions cli/pipeline/pipeline_definition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ func TestParsePipeline(t *testing.T) {
if err != nil {
panic(err.Error())
}
pipeline := parsePipeline(data)
pipeline := ParsePipeline(data)

if len(pipeline.Steps) != 2 {
t.Errorf("excepted two steps, got %d", len(pipeline.Steps))
Expand All @@ -26,7 +26,7 @@ func TestOverrideTag(t *testing.T) {
if err != nil {
panic(err.Error())
}
pipeline := parsePipeline(data)
pipeline := ParsePipeline(data)

pipeline.Steps[0].OverrideTag("")

Expand All @@ -46,7 +46,7 @@ func TestOverrideVersion(t *testing.T) {
if err != nil {
panic(err.Error())
}
pipeline := parsePipeline(data)
pipeline := ParsePipeline(data)

pipeline.Steps[0].OverrideVersion("", true)

Expand All @@ -70,7 +70,7 @@ func TestOverrideBranch(t *testing.T) {
if err != nil {
panic(err.Error())
}
pipeline := parsePipeline(data)
pipeline := ParsePipeline(data)

pipeline.Steps[0].OverrideBranch("", true)

Expand Down
10 changes: 9 additions & 1 deletion cli/pipeline/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ spec:
- "-c"
- "mkdir -p $INPUT_PATH $OUTPUT_PATH &&
{{ range $index, $input := .Step.Inputs }}
paddle data get {{ $input.Step }}/{{ $input.Version }} $INPUT_PATH -b {{ $input.Branch | sanitizeName }} -p {{ $input.Path }} {{ $input.Bucket | bucketParam }} &&
paddle data get {{ $input.Step }}/{{ $input.Version }} $INPUT_PATH -b {{ $input.Branch | sanitizeName }} -p {{ $input.Path }} {{ $input.Bucket | bucketParam }} {{$input.Keys | keysParam}} &&
{{ end }}
touch /data/first-step.txt &&
echo first step finished &&
Expand Down Expand Up @@ -209,6 +209,7 @@ func (p PodDefinition) compile() *bytes.Buffer {
fmap := template.FuncMap{
"sanitizeName": sanitizeName,
"bucketParam": p.bucketParam,
"keysParam": p.keysParam,
}
tmpl := template.Must(template.New("podTemplate").Funcs(fmap).Parse(podTemplate))
buffer := new(bytes.Buffer)
Expand Down Expand Up @@ -269,6 +270,13 @@ func (p *PodDefinition) bucketParam(bucket string) string {
return ""
}

func (p *PodDefinition) keysParam(keys []string) string {
if len(keys) != 0 {
return "--keys " + strings.Join(keys, ",")
}
return ""
}

func sanitizeName(name string) string {
str := strings.ToLower(name)
str = strings.Replace(str, "_", "-", -1)
Expand Down
40 changes: 37 additions & 3 deletions cli/pipeline/template_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pipeline

import (
"io/ioutil"
"strings"
"testing"

"k8s.io/api/core/v1"
Expand All @@ -13,7 +14,7 @@ func TestCompileTemplate(t *testing.T) {
if err != nil {
panic(err.Error())
}
pipeline := parsePipeline(data)
pipeline := ParsePipeline(data)

podDefinition := NewPodDefinition(pipeline, &pipeline.Steps[0])

Expand All @@ -36,7 +37,7 @@ func TestSecrets(t *testing.T) {
if err != nil {
panic(err.Error())
}
pipeline := parsePipeline(data)
pipeline := ParsePipeline(data)

podDefinition := NewPodDefinition(pipeline, &pipeline.Steps[0])
secrets := []string{"ENV_VAR:secret_store:key_name"}
Expand Down Expand Up @@ -68,7 +69,7 @@ func TestEnv(t *testing.T) {
if err != nil {
panic(err.Error())
}
pipeline := parsePipeline(data)
pipeline := ParsePipeline(data)

podDefinition := NewPodDefinition(pipeline, &pipeline.Steps[0])
env := []string{"ENV_VAR:env_value"}
Expand All @@ -91,3 +92,36 @@ func TestEnv(t *testing.T) {
t.Errorf("Did not find env var")
}
}

func TestKeys(t *testing.T) {
data, err := ioutil.ReadFile("test/sample_keys.yml")
if err != nil {
panic(err.Error())
}

pipeline := ParsePipeline(data)
podDefinition := NewPodDefinition(pipeline, &pipeline.Steps[0])

keys := podDefinition.Step.Inputs[0].Keys
if len(keys) != 2 {
t.Errorf("Failed to parse keys, got: %v, want: 2.", len(keys))
}

stepPodBuffer := podDefinition.compile()

pod := &v1.Pod{}
yaml.NewYAMLOrJSONDecoder(stepPodBuffer, 4096).Decode(pod)

if pod.Name != "sample-steps-passing-version1-step1-master" {
t.Errorf("Pod name is %s", pod.Name)
}

if pod.Spec.Containers[0].Image != pipeline.Steps[0].Image {
t.Errorf("First image is %s", pod.Spec.Containers[0].Image)
}

command := pod.Spec.Containers[1].Command[2]
if !strings.Contains(command, "--keys file1.json,file2.json") {
t.Errorf("Failed to build paddle get, keys flag is missing")
}
}
25 changes: 25 additions & 0 deletions cli/pipeline/test/sample_keys.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
pipeline: sample-keys
bucket: "{{ s3_bucket_name | default('canoe-sample-pipeline') }}"
namespace: modeltraining

steps:
-
step: step1
version: version1
inputs:
-
step: step1
version: version1
branch: master
path: HEAD
keys:
- file1.json
- file2.json
image: 219541440308.dkr.ecr.eu-west-1.amazonaws.com/paddlecontainer:latest
branch: master
commands:
- echo executing sample-pipeline-data > ${OUTPUT_PATH}/sample-pipeline-data-model.txt
resources:
cpu: 2
memory: 2Gi
storage-mb: 1000

0 comments on commit 36e3486

Please sign in to comment.