diff --git a/command/command_test.go b/command/command_test.go new file mode 100644 index 0000000..ebd43f3 --- /dev/null +++ b/command/command_test.go @@ -0,0 +1,39 @@ +package command + +import ( + "flag" + "testing" + + "gopkg.in/urfave/cli.v1" +) + +func TestGetCommand(t *testing.T) { + flags := flag.FlagSet{} + flags.Bool("follow", false, "") + commandText := "tail -f -n 10 logs/rest.log" + flags.Parse([]string{"--", "--follow", "dev", commandText}) + app := cli.NewApp() + context := cli.NewContext(app, &flags, nil) + command, extraArgs := getCommand(context, true) + if !command.IsAdHoc() { + t.Errorf("Command name should be ad-hoc, but got %s", command.name) + } + if command.command != commandText { + t.Errorf("Command should be %s, but got %s", commandText, command.command) + } + if extraArgs != "" { + t.Errorf("Extra args should be empty, but got %s", extraArgs) + } +} + +func TestContainsFollow(t *testing.T) { + flags := flag.FlagSet{} + flags.Bool("follow", false, "") + flags.Parse([]string{"--", "--follow", "dev", "tail -f -n 10 logs/rest.log"}) + app := cli.NewApp() + context := cli.NewContext(app, &flags, nil) + follow := ContainsFollow(context) + if !follow { + t.Error("Should contain follow") + } +} diff --git a/command/init.go b/command/init.go index c94009f..4b3245f 100644 --- a/command/init.go +++ b/command/init.go @@ -20,6 +20,7 @@ const ( defaultSectionName = "DEFAULT" commandNameConfirmationSuffix = "!" + adHocCommandName = "ad-hoc" ) type Command struct { @@ -38,6 +39,10 @@ var commandNames []string var hostNames []string var defaults map[string]map[string]string +func (c Command) IsAdHoc() bool { + return c.name == adHocCommandName +} + func Init(profile string) { cfg := config.DefaultConfig commands = readCommands(cfg, profile) diff --git a/command/run.go b/command/run.go index 53fa9b5..90ebabb 100644 --- a/command/run.go +++ b/command/run.go @@ -15,17 +15,20 @@ import ( ) const ( + optionalFollowArgsIndex = 0 hostsGroupArgsIndex = 0 commandNameArgsIndex = 1 hostsGroupChildrenSuffix = ":children" allHostsGroup = "all" + FollowArgName = "follow" ) // CmdRun runs custom command func CmdRun(c *cli.Context) { CheckUpdate(c) - command, extraArgs := getCommand(c) - hosts := getHosts(c) + follow := ContainsFollow(c) + command, extraArgs := getCommand(c, follow) + hosts := getHosts(c, follow) var confirmation string if command.requiresConfirmation { fmt.Printf("Confirm to run \"%s\" command on %v - yes/no or y/n: ", command.name, hosts) @@ -54,7 +57,7 @@ func CmdRun(c *cli.Context) { cmd = fmt.Sprintf("cd %s && %s", workingDir, cmd) } } - go SSH(sshConfig, host, cmd, ch) + go SSH(sshConfig, host, cmd, follow, ch) } outputs := make([]ExecOutput, 0, len(hosts)) for i := 0; i < len(hosts); i++ { @@ -67,26 +70,43 @@ func CmdRun(c *cli.Context) { fmt.Println(output.output) } } -func getCommand(c *cli.Context) (Command, string) { +func getCommand(c *cli.Context, follow bool) (Command, string) { args := c.Args() - commandName := strings.TrimSpace(args.Get(commandNameArgsIndex)) + actualCommandNameArgsIndex := getActualArgsIndex(commandNameArgsIndex, follow) + commandName := strings.TrimSpace(args.Get(actualCommandNameArgsIndex)) command, ok := commands[commandName] if !ok { - adhocCommand := strings.Join(c.Args().Tail(), " ") + adhocCommand := strings.Join(args[actualCommandNameArgsIndex:], " ") fmt.Printf("%s: custom command \"%s\" is not defined, interpret it as the ad-hoc command: %s\n", c.App.Name, commandName, adhocCommand) - command = Command{"ad-hoc", adhocCommand, "", false} + command = Command{adHocCommandName, adhocCommand, "", false} } extraArgs := "" if ok && c.NArg() > 2 { - extraArgs = strings.Join(c.Args().Tail()[1:], " ") + extraArgsIndex := 1 + if follow { + extraArgsIndex = 2 + } + extraArgs = strings.Join(args.Tail()[extraArgsIndex:], " ") } return command, extraArgs } +func getActualArgsIndex(argsIndex int, follow bool) int { + actualArgsIndex := argsIndex + if follow { + actualArgsIndex = argsIndex + 1 + } + return actualArgsIndex +} +func ContainsFollow(c *cli.Context) bool { + follow := c.Args().Get(optionalFollowArgsIndex) + return follow == "-f" || follow == fmt.Sprintf("--%s", FollowArgName) +} -func getHosts(c *cli.Context) []string { +func getHosts(c *cli.Context, follow bool) []string { args := c.Args() - hostsGroup := strings.TrimSpace(args.Get(hostsGroupArgsIndex)) + actualHostsGroupArgsIndex := getActualArgsIndex(hostsGroupArgsIndex, follow) + hostsGroup := strings.TrimSpace(args.Get(actualHostsGroupArgsIndex)) hosts := getHostsByGroup(c, hostsGroup) sort.Strings(hosts) return hosts diff --git a/command/ssh.go b/command/ssh.go index a863bef..3f6ae30 100644 --- a/command/ssh.go +++ b/command/ssh.go @@ -20,7 +20,7 @@ const ( ) // SSH implements scp connection to the remote instance -func SSH(sshConfig *ssh_config.Config, host, command string, ch chan ExecOutput) { +func SSH(sshConfig *ssh_config.Config, host, command string, follow bool, ch chan ExecOutput) { fmt.Printf("Running command: %s on host %s\n", command, host) user, _ := sshConfig.Get(host, SshConfigUserKey) hostname, _ := sshConfig.Get(host, SshConfigHostnameKey) @@ -49,7 +49,11 @@ func SSH(sshConfig *ssh_config.Config, host, command string, ch chan ExecOutput) } defer session.Close() var output strings.Builder - session.Stdout = &output + if follow { + session.Stdout = os.Stdout + } else { + session.Stdout = &output + } session.Stderr = os.Stderr session.Stdin = os.Stdin if err := session.Run(command); err != nil { diff --git a/commands.go b/commands.go index 0da743b..07735d8 100644 --- a/commands.go +++ b/commands.go @@ -19,7 +19,7 @@ var GlobalFlags = []cli.Flag{ } func getProfile(c *cli.Context) string { - profile := c.String(profileArgName) + profile := c.GlobalString(profileArgName) if profile == "" { profile = os.Getenv("VMX_DEFAULT_PROFILE") } @@ -33,12 +33,17 @@ var Commands = []cli.Command{ Usage: "Run custom command", Description: `Example of usage is below: run logs => run logs command defined in the ~/.vmx/commands`, - Action: command.CmdRun, - Flags: []cli.Flag{}, + Action: command.CmdRun, + Flags: []cli.Flag{ + cli.BoolFlag{ + Name: fmt.Sprintf("%s, f", command.FollowArgName), + Usage: "flag indicates that the provided command will not exit, but will follow the output instead", + }, + }, SkipFlagParsing: true, BashComplete: func(c *cli.Context) { var names []string - if c.NArg() == 0 { + if c.NArg() == 0 || (c.NArg() == 1 && command.ContainsFollow(c)) { names = command.GetHostNames() } else { names = command.GetCommandNames()