Skip to content

Commit

Permalink
Merge pull request #80 from heetch/rog-005-fix-flags
Browse files Browse the repository at this point in the history
backend/flags: avoiding setting values that aren't specified
  • Loading branch information
sixstone-qq authored and philippgille committed Mar 22, 2021
2 parents bfc1de6 + 0782bac commit 08f6fa0
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 148 deletions.
69 changes: 38 additions & 31 deletions backend/flags/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"flag"
"fmt"
"os"
"reflect"
"time"

Expand All @@ -12,11 +13,15 @@ import (
)

// Backend that loads configuration from the command line flags.
type Backend struct{}
type Backend struct {
flags *flag.FlagSet
}

// NewBackend creates a flags backend.
func NewBackend() *Backend {
return new(Backend)
return &Backend{
flags: flag.CommandLine,
}
}

// LoadStruct takes a struct config, define flags based on it and parse the command line args.
Expand All @@ -34,80 +39,90 @@ func (b *Backend) LoadStruct(ctx context.Context, cfg *confita.StructConfig) err
switch {
case f.Value.Type().String() == "time.Duration":
var val time.Duration
flag.DurationVar(&val, f.Key, time.Duration(f.Default.Int()), f.Description)
b.flags.DurationVar(&val, f.Key, time.Duration(f.Default.Int()), f.Description)
if f.Short != "" {
flag.DurationVar(&val, f.Short, time.Duration(f.Default.Int()), shortDesc(f.Description))
b.flags.DurationVar(&val, f.Short, time.Duration(f.Default.Int()), shortDesc(f.Description))
}
// this function must be executed after the flag.Parse call.
defer func() {
// if the user has set the flag, save the value in the field.
if isFlagSet(f) {
if b.isFlagSet(f) {
f.Value.SetInt(int64(val))
}
}()
case k == reflect.Bool:
var val bool
flag.BoolVar(&val, f.Key, f.Default.Bool(), f.Description)
b.flags.BoolVar(&val, f.Key, f.Default.Bool(), f.Description)
if f.Short != "" {
flag.BoolVar(&val, f.Short, f.Default.Bool(), shortDesc(f.Description))
b.flags.BoolVar(&val, f.Short, f.Default.Bool(), shortDesc(f.Description))
}
defer func() {
if isFlagSet(f) {
if b.isFlagSet(f) {
f.Value.SetBool(val)
}
}()
case k >= reflect.Int && k <= reflect.Int64:
var val int
flag.IntVar(&val, f.Key, int(f.Default.Int()), f.Description)
b.flags.IntVar(&val, f.Key, int(f.Default.Int()), f.Description)
if f.Short != "" {
flag.IntVar(&val, f.Short, int(f.Default.Int()), shortDesc(f.Description))
b.flags.IntVar(&val, f.Short, int(f.Default.Int()), shortDesc(f.Description))
}
defer func() {
if isFlagSet(f) {
if b.isFlagSet(f) {
f.Value.SetInt(int64(val))
}
}()
case k >= reflect.Uint && k <= reflect.Uint64:
var val uint64
flag.Uint64Var(&val, f.Key, f.Default.Uint(), f.Description)
b.flags.Uint64Var(&val, f.Key, f.Default.Uint(), f.Description)
if f.Short != "" {
flag.Uint64Var(&val, f.Short, f.Default.Uint(), shortDesc(f.Description))
b.flags.Uint64Var(&val, f.Short, f.Default.Uint(), shortDesc(f.Description))
}
defer func() {
if isFlagSet(f) {
if b.isFlagSet(f) {
f.Value.SetUint(val)
}
}()
case k >= reflect.Float32 && k <= reflect.Float64:
var val float64
flag.Float64Var(&val, f.Key, f.Default.Float(), f.Description)
b.flags.Float64Var(&val, f.Key, f.Default.Float(), f.Description)
if f.Short != "" {
flag.Float64Var(&val, f.Short, f.Default.Float(), shortDesc(f.Description))
b.flags.Float64Var(&val, f.Short, f.Default.Float(), shortDesc(f.Description))
}
defer func() {
if isFlagSet(f) {
if b.isFlagSet(f) {
f.Value.SetFloat(val)
}
}()
case k == reflect.String:
var val string
flag.StringVar(&val, f.Key, f.Default.String(), f.Description)
b.flags.StringVar(&val, f.Key, f.Default.String(), f.Description)
if f.Short != "" {
flag.StringVar(&val, f.Short, f.Default.String(), shortDesc(f.Description))
b.flags.StringVar(&val, f.Short, f.Default.String(), shortDesc(f.Description))
}
defer func() {
if isFlagSet(f) {
if b.isFlagSet(f) {
f.Value.SetString(val)
}
}()
default:
flag.Var(&flagValue{f}, f.Key, f.Description)
b.flags.Var(&flagValue{f}, f.Key, f.Description)
}
}

flag.Parse()
// Note: in the usual case, when b.flags is flag.CommandLine, this will exit
// rather than returning an error.
return b.flags.Parse(os.Args[1:])
}

return nil
func (b *Backend) isFlagSet(config *confita.FieldConfig) bool {
ok := false
b.flags.Visit(func(f *flag.Flag) {
if f.Name == config.Key || f.Name == config.Short {
ok = true
}
})
return ok
}

type flagValue struct {
Expand Down Expand Up @@ -139,11 +154,3 @@ func (b *Backend) Name() string {
func shortDesc(description string) string {
return fmt.Sprintf("%s (short)", description)
}

func isFlagSet(config *confita.FieldConfig) bool {
flagset := make(map[*confita.FieldConfig]bool)
flag.Visit(func(f *flag.Flag) { flagset[config] = true })

_, ok := flagset[config]
return ok
}
Loading

0 comments on commit 08f6fa0

Please sign in to comment.