diff --git a/internal/cmd/shell/shell.go b/internal/cmd/shell/shell.go index 7293a908..e7d27ffa 100644 --- a/internal/cmd/shell/shell.go +++ b/internal/cmd/shell/shell.go @@ -172,8 +172,9 @@ second argument: "-P", port, } + styledBranch := formatMySQLBranch(database, branch) m := &mysql{} - err = m.Run(ctx, mysqlArgs...) + err = m.Run(ctx, styledBranch, mysqlArgs...) return err }, @@ -189,6 +190,16 @@ second argument: return cmd } +func formatMySQLBranch(database, branch string) string { + branchStyled := printer.BoldBlue(branch) + if branch == "main" { + branchStyled = printer.BoldRed(branch) + } + + return printer.Bold(fmt.Sprintf("%s/%s> ", database, branchStyled)) + +} + // createLoginFile creates a temporary file to store the username and password, so we don't have to // pass them as `mysql` command-line arguments. func createLoginFile(username, password string) (string, error) { @@ -210,12 +221,15 @@ type mysql struct { } // Run runs the `mysql` client with the given arguments. -func (m *mysql) Run(ctx context.Context, args ...string) error { +func (m *mysql) Run(ctx context.Context, styledBranch string, args ...string) error { c := exec.CommandContext(ctx, "mysql", args...) if m.Dir != "" { c.Dir = m.Dir } + c.Env = append(os.Environ(), + fmt.Sprintf("MYSQL_PS1=%s", styledBranch)) + c.Stdout = os.Stdout c.Stderr = os.Stderr c.Stdin = os.Stdin diff --git a/internal/printer/printer.go b/internal/printer/printer.go index 28794ab2..0bd9a554 100644 --- a/internal/printer/printer.go +++ b/internal/printer/printer.go @@ -229,6 +229,11 @@ func BoldBlue(msg string) string { return color.New(color.FgBlue).Add(color.Bold).Sprint(msg) } +// BoldRed returns a string formatted with red and bold. +func BoldRed(msg string) string { + return color.New(color.FgRed).Add(color.Bold).Sprint(msg) +} + // Bold returns a string formatted with bold. func Bold(msg string) string { // the 'color' package already handles IsTTY gracefully