Skip to content

Commit

Permalink
feat: allow hooks to be declared on embedded fields
Browse files Browse the repository at this point in the history
Specifically, on Go embedded fields, not on fields tagged with `embed`.

Fixes #90.
  • Loading branch information
alecthomas committed Dec 27, 2024
1 parent 565ae9b commit 840220c
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 53 deletions.
38 changes: 2 additions & 36 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ func main() {

## Hooks: BeforeReset(), BeforeResolve(), BeforeApply(), AfterApply() and the Bind() option

If a node in the grammar has a `BeforeReset(...)`, `BeforeResolve
(...)`, `BeforeApply(...) error` and/or `AfterApply(...) error` method, those
If a node in the CLI, or any of its embedded fields, has a `BeforeReset(...) error`, `BeforeResolve
(...) error`, `BeforeApply(...) error` and/or `AfterApply(...) error` method, those
methods will be called before values are reset, before validation/assignment,
and after validation/assignment, respectively.

Expand Down Expand Up @@ -341,40 +341,6 @@ func main() {
}
```

Another example of using hooks is load the env-file:

```go
package main

import (
"fmt"
"github.com/alecthomas/kong"
"github.com/joho/godotenv"
)

type EnvFlag string

// BeforeResolve loads env file.
func (c EnvFlag) BeforeReset(ctx *kong.Context, trace *kong.Path) error {
path := string(ctx.FlagValue(trace.Flag).(EnvFlag)) // nolint
path = kong.ExpandPath(path)
if err := godotenv.Load(path); err != nil {
return err
}
return nil
}

var CLI struct {
EnvFile EnvFlag
Flag `env:"FLAG"`
}

func main() {
_ = kong.Parse(&CLI)
fmt.Println(CLI.Flag)
}
```

## Flags

Any [mapped](#mapper---customising-how-the-command-line-is-mapped-to-go-values) field in the command structure _not_ tagged with `cmd` or `arg` will be a flag. Flags are optional by default.
Expand Down
27 changes: 27 additions & 0 deletions callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,33 @@ func getMethod(value reflect.Value, name string) reflect.Value {
return method
}

// Get methods from the given value and any embedded fields.
func getMethods(value reflect.Value, name string) []reflect.Value {
// Collect all possible receivers
receivers := []reflect.Value{value}
if value.Kind() == reflect.Ptr {
value = value.Elem()
}
if value.Kind() == reflect.Struct {
t := value.Type()
for i := 0; i < value.NumField(); i++ {
field := value.Field(i)
fieldType := t.Field(i)
if fieldType.IsExported() && fieldType.Anonymous {
receivers = append(receivers, field)
}
}
}
// Search all receivers for methods
var methods []reflect.Value
for _, receiver := range receivers {
if method := getMethod(receiver, name); method.IsValid() {
methods = append(methods, method)
}
}
return methods
}

func callFunction(f reflect.Value, bindings bindings) error {
if f.Kind() != reflect.Func {
return fmt.Errorf("expected function, got %s", f.Type())
Expand Down
30 changes: 13 additions & 17 deletions kong.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,16 +361,14 @@ func (k *Kong) applyHook(ctx *Context, name string) error {
default:
panic("unsupported Path")
}
method := getMethod(value, name)
if !method.IsValid() {
continue
}
binds := k.bindings.clone()
binds.add(ctx, trace)
binds.add(trace.Node().Vars().CloneWith(k.vars))
binds.merge(ctx.bindings)
if err := callFunction(method, binds); err != nil {
return err
for _, method := range getMethods(value, name) {
binds := k.bindings.clone()
binds.add(ctx, trace)
binds.add(trace.Node().Vars().CloneWith(k.vars))
binds.merge(ctx.bindings)
if err := callFunction(method, binds); err != nil {
return err
}
}
}
// Path[0] will always be the app root.
Expand All @@ -392,13 +390,11 @@ func (k *Kong) applyHookToDefaultFlags(ctx *Context, node *Node, name string) er
if !flag.HasDefault || ctx.values[flag.Value].IsValid() || !flag.Target.IsValid() {
continue
}
method := getMethod(flag.Target, name)
if !method.IsValid() {
continue
}
path := &Path{Flag: flag}
if err := callFunction(method, binds.clone().add(path)); err != nil {
return next(err)
for _, method := range getMethods(flag.Target, name) {
path := &Path{Flag: flag}
if err := callFunction(method, binds.clone().add(path)); err != nil {
return next(err)
}
}
}
return next(nil)
Expand Down
33 changes: 33 additions & 0 deletions kong_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2406,3 +2406,36 @@ func TestProviderMethods(t *testing.T) {
err = kctx.Run(t)
assert.NoError(t, err)
}

type EmbeddedCallback struct {
Embedded bool
}

func (e *EmbeddedCallback) AfterApply() error {
e.Embedded = true
return nil
}

type EmbeddedRoot struct {
EmbeddedCallback
Root bool
}

func (e *EmbeddedRoot) AfterApply() error {
e.Root = true
return nil
}

func TestEmbeddedCallbacks(t *testing.T) {
actual := &EmbeddedRoot{}
k := mustNew(t, actual)
_, err := k.Parse(nil)
assert.NoError(t, err)
expected := &EmbeddedRoot{
EmbeddedCallback: EmbeddedCallback{
Embedded: true,
},
Root: true,
}
assert.Equal(t, expected, actual)
}

0 comments on commit 840220c

Please sign in to comment.