diff --git a/cmd/sqlcmd/sqlcmd.go b/cmd/sqlcmd/sqlcmd.go index 368bc126..28e40306 100644 --- a/cmd/sqlcmd/sqlcmd.go +++ b/cmd/sqlcmd/sqlcmd.go @@ -811,11 +811,11 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { if err == nil && s.Exitcode == 0 { once := false - if args.InitialQuery != "" { - s.Query = args.InitialQuery - } else if args.Query != "" { + if args.Query != "" { once = true s.Query = args.Query + } else if args.InitialQuery != "" { + s.Query = args.InitialQuery } iactive := args.InputFile == nil && args.Query == "" if iactive || s.Query != "" { diff --git a/cmd/sqlcmd/sqlcmd_test.go b/cmd/sqlcmd/sqlcmd_test.go index b6cf5039..e3c8a1e9 100644 --- a/cmd/sqlcmd/sqlcmd_test.go +++ b/cmd/sqlcmd/sqlcmd_test.go @@ -361,6 +361,34 @@ func TestQueryAndExit(t *testing.T) { } } +func TestInitQueryAndQueryExecutesQuery(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Errorf("Panic occurred: %v", r) + } + }() + o, err := os.CreateTemp("", "sqlcmdmain") + assert.NoError(t, err, "os.CreateTemp") + defer os.Remove(o.Name()) + defer o.Close() + args = newArguments() + args.InitialQuery = "SELECT 1" + args.Query = "SELECT 2" + args.OutputFile = o.Name() + vars := sqlcmd.InitializeVariables(args.useEnvVars()) + vars.Set(sqlcmd.SQLCMDMAXVARTYPEWIDTH, "0") + + setVars(vars, &args) + + exitCode, err := run(vars, &args) + assert.NoError(t, err, "run") + assert.Equal(t, 0, exitCode, "exitCode") + bytes, err := os.ReadFile(o.Name()) + if assert.NoError(t, err, "os.ReadFile") { + assert.Equal(t, "2"+sqlcmd.SqlcmdEol+sqlcmd.SqlcmdEol+oneRowAffected+sqlcmd.SqlcmdEol, string(bytes), "Incorrect output from run") + } +} + // Test to verify fix for issue: https://github.com/microsoft/go-sqlcmd/issues/98 // 1. Verify when -b is passed in (ExitOnError), we don't always get an error (even when input is good) // 2, Verify when the input is actually bad, we do get an error