Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use methods when there is state. Removed dead code. Default to not export. #16

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type TableWriter struct {
}

// NewTable create a new table writer
func NewTable(wr io.Writer) *TableWriter {
func (u Unicreds) NewTable(wr io.Writer) *TableWriter {
return &TableWriter{wr: wr}
}

Expand Down
3 changes: 2 additions & 1 deletion cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

func TestRender(t *testing.T) {

u := Unicreds{}
tt := []struct {
tableFormat int
output string
Expand Down Expand Up @@ -38,7 +39,7 @@ func TestRender(t *testing.T) {
for _, tv := range tt {
var b bytes.Buffer

table := NewTable(&b)
table := u.NewTable(&b)
table.SetHeaders(tv.headers)
table.SetFormat(tv.tableFormat)
table.BulkWrite(tv.rows)
Expand Down
88 changes: 22 additions & 66 deletions cmd/unicreds/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,17 @@ import (
"fmt"
"io/ioutil"
"os"
"net/http"

"github.com/apex/log"
"github.com/apex/log/handlers/cli"

"github.com/alecthomas/kingpin"
"github.com/aws/aws-sdk-go/aws"
"github.com/versent/unicreds"
)

const (
zoneURL = "http://169.254.169.254/latest/meta-data/placement/availability-zone"
)

var (
app = kingpin.New("unicreds", "A credential/secret storage command line tool.")
debug = app.Flag("debug", "Enable debug mode.").Short('d').Bool()
csv = app.Flag("csv", "Enable csv output for table data.").Short('c').Bool()
app = kingpin.New("unicreds", "A credential/secret storage command line tool.")
csv = app.Flag("csv", "Enable csv output for table data.").Short('c').Bool()

region = app.Flag("region", "Configure the AWS region").Short('r').String()

Expand Down Expand Up @@ -61,45 +54,34 @@ func main() {

command := kingpin.MustParse(app.Parse(os.Args[1:]))

if *region != "" {
// update the aws config overrides if present
setRegion(region)
} else {
// or try to get our region based on instance metadata
r, err := getRegion()
if err != nil {
printFatalError(err)
}
u := unicreds.Unicreds{}

setRegion(r)
err := u.SetRegion(region)
if err != nil {
printFatalError(err)
}

switch command {
case cmdSetup.FullCommand():
err := unicreds.Setup()
if err != nil {
if err := u.Setup(); err != nil {
printFatalError(err)
}
case cmdGet.FullCommand():
cred, err := unicreds.GetSecret(*cmdGetName)
if err != nil {
if err := u.GetSecret(*cmdGetName); err != nil {
printFatalError(err)
}
fmt.Println(cred.Secret)
fmt.Println(u.DecryptedCredentials)
case cmdPut.FullCommand():
version, err := unicreds.ResolveVersion(*cmdPutName, *cmdPutVersion)
if err != nil {
if err := u.ResolveVersion(*cmdPutName, *cmdPutVersion); err != nil {
printFatalError(err)
}

err = unicreds.PutSecret(*alias, *cmdPutName, *cmdPutSecret, version)
if err != nil {
if err := unicreds.PutSecret(*alias, *cmdPutName, *cmdPutSecret, u.Version); err != nil {
printFatalError(err)
}
log.WithFields(log.Fields{"name": *cmdPutName, "version": version}).Info("stored")
log.WithFields(log.Fields{"name": *cmdPutName, "version": u.Version}).Info("stored")
case cmdPutFile.FullCommand():
version, err := unicreds.ResolveVersion(*cmdPutFileName, *cmdPutFileVersion)
if err != nil {
if err := u.ResolveVersion(*cmdPutFileName, *cmdPutFileVersion); err != nil {
printFatalError(err)
}

Expand All @@ -108,76 +90,50 @@ func main() {
printFatalError(err)
}

err = unicreds.PutSecret(*alias, *cmdPutFileName, string(data), version)
if err != nil {
if err = unicreds.PutSecret(*alias, *cmdPutFileName, string(data), u.Version); err != nil {
printFatalError(err)
}
log.WithFields(log.Fields{"name": *cmdPutName, "version": version}).Info("stored")
log.WithFields(log.Fields{"name": *cmdPutName, "version": u.Version}).Info("stored")
case cmdList.FullCommand():
creds, err := unicreds.ListSecrets(*cmdListAll)
if err != nil {
if err := u.ListSecrets(*cmdListAll); err != nil {
printFatalError(err)
}

table := unicreds.NewTable(os.Stdout)
table := u.NewTable(os.Stdout)
table.SetHeaders([]string{"Name", "Version", "Created-At"})

if *csv {
table.SetFormat(unicreds.TableFormatCSV)
}

for _, cred := range creds {
for _, cred := range u.Credentials {
table.Write([]string{cred.Name, cred.Version, cred.CreatedAtDate()})
}
table.Render()
case cmdGetAll.FullCommand():
creds, err := unicreds.GetAllSecrets(true)
if err != nil {
if err := u.GetAllSecrets(true); err != nil {
printFatalError(err)
}

table := unicreds.NewTable(os.Stdout)
table := u.NewTable(os.Stdout)
table.SetHeaders([]string{"Name", "Secret"})

if *csv {
table.SetFormat(unicreds.TableFormatCSV)
}

for _, cred := range creds {
for _, cred := range u.DecryptedCredentials {
table.Write([]string{cred.Name, cred.Secret})
}
table.Render()
case cmdDelete.FullCommand():
err := unicreds.DeleteSecret(*cmdDeleteName)
err := u.DeleteSecret(*cmdDeleteName)
if err != nil {
printFatalError(err)
}
}
}

func getRegion() (*string, error) {
// Use meta-data to get our region
response, err := http.Get(zoneURL)
if err != nil {
return nil, err
}

defer response.Body.Close()
contents, err := ioutil.ReadAll(response.Body)
if err != nil {
return nil, err
}

// Strip last char
r := string(contents[0:len(string(contents))-1])
return &r, nil
}

func setRegion(region *string) {
unicreds.SetDynamoDBConfig(&aws.Config{Region: region})
unicreds.SetKMSConfig(&aws.Config{Region: region})
}

func printFatalError(err error) {
log.WithError(err).Error("failed")
os.Exit(1)
Expand Down
63 changes: 23 additions & 40 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,60 +9,44 @@ import (
"github.com/aws/aws-sdk-go/service/dynamodb"
)

// Decode decode the supplied struct from the dynamodb result map
//
// NOTE: this function needs a lot more validation and refinement.
func Decode(data map[string]*dynamodb.AttributeValue, rawVal interface{}) error {
val := reflect.ValueOf(rawVal)
if val.Kind() != reflect.Ptr {
return errors.New("result must be a pointer")
// decode decode the supplied struct from the dynamodb result map
func decode(name string, data map[string]*dynamodb.AttributeValue, val interface{}) (err error) {
if data == nil {
// If the data is nil, then we don't set anything.
return nil
}

val = val.Elem()
if !val.CanAddr() {
return errors.New("result must be addressable (a pointer)")
fields := make(map[*reflect.StructField]reflect.Value)

v := reflect.ValueOf(val)
if v.Kind() != reflect.Ptr {
return errors.New("result must be a pointer")
}
return decode("ds", data, val)
}

func decode(name string, data map[string]*dynamodb.AttributeValue, val reflect.Value) error {
if data == nil {
// If the data is nil, then we don't set anything.
return nil
v = v.Elem()
if !v.CanAddr() {
return errors.New("result must be addressable (a pointer)")
}

dataVal := reflect.ValueOf(data)
if !dataVal.IsValid() {
d := reflect.ValueOf(data)
if !d.IsValid() {
// If the data value is invalid, then we just set the value
// to be the zero value.
val.Set(reflect.Zero(val.Type()))
v.Set(reflect.Zero(v.Type()))
return nil
}

var err error
dataKind := getKind(val)
switch dataKind {
case reflect.Struct:
err = decodeStruct(name, data, val)
default:
return fmt.Errorf("%s: unsupported type: %s", name, dataKind)
if getKind(v) != reflect.Struct {
return fmt.Errorf("%s: unsupported type: %s", name, getKind(v))
}

return err
}

func decodeStruct(name string, data map[string]*dynamodb.AttributeValue, val reflect.Value) (err error) {

fields := make(map[*reflect.StructField]reflect.Value)

structVal := val
structType := structVal.Type()
structType := v.Type()

for i := 0; i < structType.NumField(); i++ {
fieldType := structType.Field(i)

// Normal struct field, store it away
fields[&fieldType] = structVal.Field(i)
fields[&fieldType] = v.Field(i)
}

for fieldType, field := range fields {
Expand All @@ -73,16 +57,15 @@ func decodeStruct(name string, data map[string]*dynamodb.AttributeValue, val ref
fieldName = tagValue
}

keyVal := data[fieldName]
if keyVal == nil {
if k := data[fieldName]; k == nil {
continue
}

switch getKind(field) {
case reflect.String:
err = decodeString(fieldName, keyVal, field)
err = decodeString(fieldName, data[fieldName], field)
case reflect.Int:
err = decodeInt(fieldName, keyVal, field)
err = decodeInt(fieldName, data[fieldName], field)
default:
return fmt.Errorf("%s: unsupported type: %s", fieldName, getKind(field))
}
Expand Down
2 changes: 1 addition & 1 deletion decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestDecode(t *testing.T) {
},
}

err := Decode(data, &cred)
err := decode("ds", data, &cred)
if err != nil {
fmt.Printf("%+v\n", err)
}
Expand Down
Loading