Skip to content

Commit

Permalink
feat: Change Password
Browse files Browse the repository at this point in the history
  • Loading branch information
shueybubbles committed Sep 14, 2023
1 parent 0a4cd19 commit d3b508f
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 23 deletions.
38 changes: 32 additions & 6 deletions cmd/sqlcmd/sqlcmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"strconv"
"strings"

mssql "github.com/microsoft/go-mssqldb"
"github.com/microsoft/go-mssqldb/azuread"
"github.com/microsoft/go-sqlcmd/internal/localizer"
"github.com/microsoft/go-sqlcmd/pkg/console"
Expand Down Expand Up @@ -70,7 +71,9 @@ type SQLCmdArguments struct {
RemoveControlCharacters *int
EchoInput bool
QueryTimeout int
EnableColumnEnryption bool
EnableColumnEncryption bool
ChangePassword string
ChangePasswordAndExit string
// Keep Help at the end of the list
Help bool
}
Expand Down Expand Up @@ -423,7 +426,9 @@ func setFlags(rootCmd *cobra.Command, args *SQLCmdArguments) {
_ = rootCmd.Flags().IntP(removeControlCharacters, "k", 0, localizer.Sprintf("%s Remove control characters from output. Pass 1 to substitute a space per character, 2 for a space per consecutive characters", "-k [1|2]"))
rootCmd.Flags().BoolVarP(&args.EchoInput, "echo-input", "e", false, localizer.Sprintf("Echo input"))
rootCmd.Flags().IntVarP(&args.QueryTimeout, "query-timeout", "t", 0, "Query timeout")
rootCmd.Flags().BoolVarP(&args.EnableColumnEnryption, "enable-column-encryption", "g", false, localizer.Sprintf("Enable column encryption"))
rootCmd.Flags().BoolVarP(&args.EnableColumnEncryption, "enable-column-encryption", "g", false, localizer.Sprintf("Enable column encryption"))
rootCmd.Flags().StringVarP(&args.ChangePassword, "change-password", "z", "", localizer.Sprintf("New password"))
rootCmd.Flags().StringVarP(&args.ChangePasswordAndExit, "change-password-exit", "Z", "", localizer.Sprintf("New password and exit"))
}

func setScriptVariable(v string) string {
Expand Down Expand Up @@ -682,11 +687,17 @@ func setConnect(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments, vars *sq
connect.ExitOnError = args.ExitOnError
connect.ErrorSeverityLevel = args.ErrorSeverityLevel
connect.DedicatedAdminConnection = args.DedicatedAdminConnection
connect.EnableColumnEnryption = args.EnableColumnEnryption
connect.EnableColumnEncryption = args.EnableColumnEncryption
if len(args.ChangePassword) > 0 {
connect.ChangePassword = args.ChangePassword
}
if len(args.ChangePasswordAndExit) > 0 {
connect.ChangePassword = args.ChangePasswordAndExit
}
}

func isConsoleInitializationRequired(connect *sqlcmd.ConnectSettings, args *SQLCmdArguments) bool {
iactive := args.InputFile == nil && args.Query == ""
iactive := args.InputFile == nil && args.Query == "" && len(args.ChangePasswordAndExit) == 0
return iactive || connect.RequiresPassword()
}

Expand Down Expand Up @@ -749,8 +760,23 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) {
// connect using no overrides
err = s.ConnectDb(nil, line == nil)
if err != nil {
s.WriteError(s.GetError(), err)
return 1, err
switch e := err.(type) {
// 18488 == password must be changed on connection
case mssql.Error:
if e.Number == 18488 && line != nil && len(args.Password) == 0 && len(args.ChangePassword) == 0 && len(args.ChangePasswordAndExit) == 0 {
b, _ := line.ReadPassword(localizer.Sprintf("Enter new password:"))
s.Connect.ChangePassword = string(b)
err = s.ConnectDb(nil, true)
}
}
if err != nil {
s.WriteError(s.GetError(), err)
return 1, err
}
}

if len(args.ChangePasswordAndExit) > 0 {
return 0, nil
}

script := vars.StartupScriptFile()
Expand Down
4 changes: 2 additions & 2 deletions cmd/sqlcmd/sqlcmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ func TestValidCommandLineToArgsConversion(t *testing.T) {
{[]string{"-i", `"comma,text.sql"`}, func(args SQLCmdArguments) bool {
return args.InputFile[0] == "comma,text.sql"
}},
{[]string{"-k", "-X", "-r"}, func(args SQLCmdArguments) bool {
return args.warnOnBlockedCmd() && !args.useEnvVars() && args.getControlCharacterBehavior() == sqlcmd.ControlRemove && *args.ErrorsToStderr == 0
{[]string{"-k", "-X", "-r", "-z", "something"}, func(args SQLCmdArguments) bool {
return args.warnOnBlockedCmd() && !args.useEnvVars() && args.getControlCharacterBehavior() == sqlcmd.ControlRemove && *args.ErrorsToStderr == 0 && args.ChangePassword == "something"
}},
}

Expand Down
34 changes: 20 additions & 14 deletions pkg/sqlcmd/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"strings"

"github.com/microsoft/go-mssqldb/azuread"
"github.com/microsoft/go-mssqldb/msdsn"
)

// ConnectSettings specifies the settings for SQL connections and queries
Expand Down Expand Up @@ -54,7 +55,9 @@ type ConnectSettings struct {
// DedicatedAdminConnection forces the connection to occur over tcp on the dedicated admin port. Requires Browser service access
DedicatedAdminConnection bool
// EnableColumnEncryption enables support for transparent column encryption
EnableColumnEnryption bool
EnableColumnEncryption bool
// ChangePassword is the new password for the user to set during login
ChangePassword string
}

func (c ConnectSettings) authenticationMethod() string {
Expand Down Expand Up @@ -113,50 +116,53 @@ func (connect ConnectSettings) ConnectionString() (connectionString string, err
return "", &InvalidServerName
}
serverName = pipeParts[0]
query.Add("pipe", pipeParts[2])
query.Add(msdsn.Pipe, pipeParts[2])
}
if port > 0 {
connectionURL.Host = fmt.Sprintf("%s:%d", serverName, port)
} else {
connectionURL.Host = serverName
}
if connect.Database != "" {
query.Add("database", connect.Database)
query.Add(msdsn.Database, connect.Database)
}

if connect.TrustServerCertificate {
query.Add("trustservercertificate", "true")
query.Add(msdsn.TrustServerCertificate, "true")
}
if connect.ApplicationIntent != "" && connect.ApplicationIntent != "default" {
query.Add("applicationintent", connect.ApplicationIntent)
query.Add(msdsn.ApplicationIntent, connect.ApplicationIntent)
}
if connect.LoginTimeoutSeconds > 0 {
query.Add("dial timeout", fmt.Sprint(connect.LoginTimeoutSeconds))
query.Add(msdsn.DialTimeout, fmt.Sprint(connect.LoginTimeoutSeconds))
}
if connect.PacketSize > 0 {
query.Add("packet size", fmt.Sprint(connect.PacketSize))
query.Add(msdsn.PacketSize, fmt.Sprint(connect.PacketSize))
}
if connect.WorkstationName != "" {
query.Add("workstation id", connect.WorkstationName)
query.Add(msdsn.WorkstationID, connect.WorkstationName)
}
if connect.Encrypt != "" && connect.Encrypt != "default" {
query.Add("encrypt", connect.Encrypt)
query.Add(msdsn.Encrypt, connect.Encrypt)
}
if connect.LogLevel > 0 {
query.Add("log", fmt.Sprint(connect.LogLevel))
query.Add(msdsn.LogParam, fmt.Sprint(connect.LogLevel))
}
if protocol != "" {
query.Add("protocol", protocol)
query.Add(msdsn.Protocol, protocol)
}
if connect.ApplicationName != "" {
query.Add(`app name`, connect.ApplicationName)
query.Add(msdsn.AppName, connect.ApplicationName)
}
if connect.DedicatedAdminConnection {
query.Set("protocol", "admin")
query.Set(msdsn.Protocol, "admin")
}
if connect.EnableColumnEnryption {
if connect.EnableColumnEncryption {
query.Set("columnencryption", "true")
}
if connect.ChangePassword != "" {
query.Set(msdsn.ChangePassword, connect.ChangePassword)
}
connectionURL.RawQuery = query.Encode()
return connectionURL.String(), nil
}
3 changes: 2 additions & 1 deletion pkg/sqlcmd/sqlcmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
_ "github.com/microsoft/go-mssqldb/aecmk/localcert"
"github.com/microsoft/go-mssqldb/msdsn"
"github.com/microsoft/go-sqlcmd/internal/color"
"github.com/microsoft/go-sqlcmd/internal/localizer"
"golang.org/x/text/encoding/unicode"
"golang.org/x/text/transform"
)
Expand Down Expand Up @@ -312,7 +313,7 @@ func (s *Sqlcmd) promptPassword() (string, error) {
if s.lineIo == nil {
return "", nil
}
pwd, err := s.lineIo.ReadPassword("Password:")
pwd, err := s.lineIo.ReadPassword(localizer.Sprintf("Password:"))
if err != nil {
return "", err
}
Expand Down

0 comments on commit d3b508f

Please sign in to comment.