Skip to content

Commit

Permalink
fixup! aws/credentials: Rework credential_process provider
Browse files Browse the repository at this point in the history
  • Loading branch information
YakDriver committed Oct 19, 2018
1 parent 52b550b commit acbf376
Showing 1 changed file with 52 additions and 38 deletions.
90 changes: 52 additions & 38 deletions aws/credentials/processcreds/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ package processcreds

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -99,7 +98,7 @@ const (
DefaultMaxBufSize = 512

// DefaultTimeout limits the time a process can run, in milliseconds.
DefaultTimeout = time.Duration(500)
DefaultTimeout = 500
)

// ProcessProvider satisfies the credentials.Provider interface, and is a client to
Expand Down Expand Up @@ -131,7 +130,7 @@ type ProcessProvider struct {
MaxBufSize int

// Timeout limits the time a process can run, in milliseconds.
Timeout time.Duration
Timeout int
}

// NewCredentials returns a pointer to a new Credentials object wrapping the
Expand Down Expand Up @@ -160,7 +159,7 @@ type credentialProcessResponse struct {

// Retrieve executes the 'credential_process' and returns the credentials.
func (p *ProcessProvider) Retrieve() (credentials.Value, error) {
if p.Duration == time.Duration(0) && p.Timeout == time.Duration(0) && p.MaxBufSize == 0 {
if p.Duration == time.Duration(0) && p.Timeout == 0 && p.MaxBufSize == 0 {
p.Duration = DefaultDuration
p.Timeout = DefaultTimeout
p.MaxBufSize = DefaultMaxBufSize
Expand Down Expand Up @@ -233,7 +232,7 @@ type LimitedBuffer struct {
// Write to the LimitedBuffer
func (b *LimitedBuffer) Write(p []byte) (int, error) {
if len(p)+b.buff.Len() > b.maxSize {
return -1, fmt.Errorf("buffer overflow")
return -1, fmt.Errorf("buffer size (%v) exceeded: %v", b.maxSize, len(p)+b.buff.Len())
}
b.buff.Write(p)
return len(p), nil
Expand All @@ -255,16 +254,11 @@ func NewBuffer(start, max int) *LimitedBuffer {
// executeCredentialProcess executes the `process` command on the OS and
// returns the results or an error.
func (p *ProcessProvider) executeCredentialProcess() ([]byte, error) {

command := strings.TrimSpace(p.Process)
if command == "" {
return nil, fmt.Errorf("process must be a non-empty string")
}

// Avoid very long or hung processes
ctx, cancel := context.WithTimeout(context.Background(), p.Timeout*time.Millisecond)
defer cancel()

var cmdargs []string
var env []string
if runtime.GOOS == "windows" {
Expand All @@ -276,6 +270,7 @@ func (p *ProcessProvider) executeCredentialProcess() ([]byte, error) {
if _, ok := os.LookupEnv("PATH"); !ok {
env = append(env, fmt.Sprintf("%s=%s", "PATH", "C:\\Windows\\system32"))
}

} else {
cmdargs = []string{"/bin/sh", "-c"}
if _, ok := os.LookupEnv("PATH"); !ok {
Expand All @@ -295,45 +290,64 @@ func (p *ProcessProvider) executeCredentialProcess() ([]byte, error) {
}

// Setup the command
cmd := exec.CommandContext(ctx, cmdargs[0], cmdargs[1:]...)
cmd := exec.Command(cmdargs[0], cmdargs[1:]...)
cmd.Stderr = pw
cmd.Stdout = pw
cmd.Env = cmdEnv

output := NewBuffer(DefaultInitialBufSize, p.MaxBufSize)

// Write everything we read from the pipe to the output buffer
tee := io.TeeReader(pr, output)

copyDoneCh := make(chan struct{})
go readOutput(tee, copyDoneCh)
var readErr error
read := make(chan bool, 1)
go func() {
// Write everything we read from the pipe to the output buffer
tee := io.TeeReader(pr, output)

// Start the command
err = cmd.Start()
if err == nil {
err = cmd.Wait()
}
cancel()
// blocks until the pipe is closed - pw.Close()
_, readErr = ioutil.ReadAll(tee)

// Close the write-end of the pipe
pw.Close()
if readErr != nil {
read <- false
return
}
read <- true
}()

var execErr error
exec := make(chan bool, 1)
go func() {
defer pw.Close() // important!

// Start the command
execErr = cmd.Start()
if execErr == nil {
execErr = cmd.Wait()
}

select {
case <-ctx.Done():
}
if execErr != nil {
exec <- false
return
}

if err != nil {
return nil, fmt.Errorf("error executing process '%s': %v. Output: %s",
command, err, output.Bytes())
exec <- true
}()

finished := false
for !finished {
select {
case readResult := <-read:
if !readResult {
return nil, readErr
}
finished = true
case execResult := <-exec:
if !execResult {
return nil, execErr
}
case <-time.After(time.Duration(p.Timeout) * time.Millisecond):
return nil, fmt.Errorf("process timed out")
}
}

return output.Bytes(), nil
}

func readOutput(r io.Reader, doneCh chan<- struct{}) {
defer close(doneCh)
_, err := ioutil.ReadAll(r)
if err != nil {
panic("error reading")
}
}

0 comments on commit acbf376

Please sign in to comment.