diff --git a/cmd/talosctl/pkg/talos/action/tracker.go b/cmd/talosctl/pkg/talos/action/tracker.go index eb1bbae648..a944f02799 100644 --- a/cmd/talosctl/pkg/talos/action/tracker.go +++ b/cmd/talosctl/pkg/talos/action/tracker.go @@ -26,7 +26,6 @@ import ( "google.golang.org/grpc/status" "github.com/siderolabs/talos/cmd/talosctl/cmd/common" - "github.com/siderolabs/talos/cmd/talosctl/pkg/talos/global" "github.com/siderolabs/talos/cmd/talosctl/pkg/talos/helpers" machineapi "github.com/siderolabs/talos/pkg/machinery/api/machine" "github.com/siderolabs/talos/pkg/machinery/client" @@ -92,7 +91,7 @@ type Tracker struct { timeout time.Duration isTerminal bool debug bool - cliContext *global.Args + clientExecutor ClientExecutor } // TrackerOption is the functional option for the Tracker. @@ -119,9 +118,16 @@ func WithDebug(debug bool) TrackerOption { } } +// WithTerminalOverride sets the terminal override. +func WithTerminalOverride(isTerminal bool) TrackerOption { + return func(t *Tracker) { + t.isTerminal = isTerminal + } +} + // NewTracker creates a new Tracker. func NewTracker( - cliContext *global.Args, + clientExecutor ClientExecutor, expectedEventFn func(event client.EventResult) bool, actionFn func(ctx context.Context, c *client.Client) (string, error), opts ...TrackerOption, @@ -129,11 +135,11 @@ func NewTracker( tracker := Tracker{ expectedEventFn: expectedEventFn, actionFn: actionFn, - nodeToLatestStatusUpdate: make(map[string]reporter.Update, len(cliContext.Nodes)), + nodeToLatestStatusUpdate: make(map[string]reporter.Update, len(clientExecutor.NodeList())), reporter: reporter.New(), reportCh: make(chan nodeUpdate), isTerminal: isatty.IsTerminal(os.Stderr.Fd()), - cliContext: cliContext, + clientExecutor: clientExecutor, } for _, option := range opts { @@ -143,6 +149,12 @@ func NewTracker( return &tracker } +// ClientExecutor is the interface for the client executor. +type ClientExecutor interface { + WithClient(action func(context.Context, *client.Client) error, dialOptions ...grpc.DialOption) error + NodeList() []string +} + // Run executes the action on nodes and tracks its progress by watching events with retries. // After receiving the expected event, if provided, it tracks the progress by running the post check with retries. // @@ -152,7 +164,7 @@ func (a *Tracker) Run() error { var eg errgroup.Group - err := a.cliContext.WithClient(func(ctx context.Context, c *client.Client) error { + err := a.clientExecutor.WithClient(func(ctx context.Context, c *client.Client) error { ctx, cancel := context.WithTimeout(ctx, a.timeout) defer cancel() @@ -170,7 +182,7 @@ func (a *Tracker) Run() error { var trackEg errgroup.Group - for _, node := range a.cliContext.Nodes { + for _, node := range a.clientExecutor.NodeList() { var ( dmesg *circular.Buffer err error diff --git a/cmd/talosctl/pkg/talos/global/client.go b/cmd/talosctl/pkg/talos/global/client.go index d6412202a5..07006480e9 100644 --- a/cmd/talosctl/pkg/talos/global/client.go +++ b/cmd/talosctl/pkg/talos/global/client.go @@ -28,6 +28,11 @@ type Args struct { Endpoints []string } +// NodeList returns the list of nodes to run the command against. +func (c *Args) NodeList() []string { + return c.Nodes +} + // WithClientNoNodes wraps common code to initialize Talos client and provide cancellable context. // // WithClientNoNodes doesn't set any node information on the request context.