diff --git a/pkg/cli/cert.go b/pkg/cli/cert.go index 771385a8a153..e10b919f40ee 100644 --- a/pkg/cli/cert.go +++ b/pkg/cli/cert.go @@ -203,9 +203,6 @@ List certificates and keys found in the certificate directory. // runListCerts loads and lists all certs. func runListCerts(cmd *cobra.Command, args []string) error { - if err := security.SetCertPrincipalMap(certCtx.certPrincipalMap); err != nil { - return err - } cm, err := security.NewCertificateManager(baseCfg.SSLCertsDir) if err != nil { return errors.Wrap(err, "cannot load certificates") diff --git a/pkg/cli/context.go b/pkg/cli/context.go index 49b9fc27cdb8..d111dc3f929b 100644 --- a/pkg/cli/context.go +++ b/pkg/cli/context.go @@ -77,6 +77,7 @@ func initCLIDefaults() { cliCtx.cmdTimeout = 0 // no timeout cliCtx.clientConnHost = "" cliCtx.clientConnPort = base.DefaultPort + cliCtx.certPrincipalMap = nil cliCtx.sqlConnURL = "" cliCtx.sqlConnUser = "" cliCtx.sqlConnPasswd = "" @@ -173,8 +174,6 @@ func initCLIDefaults() { authCtx.validityPeriod = 1 * time.Hour - certCtx.certPrincipalMap = nil - initPreFlagsDefaults() // Clear the "Changed" state of all the registered command-line flags. @@ -217,6 +216,9 @@ type cliContext struct { // clientConnPort is the port name/number to use to connect to a server. clientConnPort string + // certPrincipalMap is the cert-principal:db-principal map. + certPrincipalMap []string + // for CLI commands that use the SQL interface, these parameters // determine how to connect to the server. sqlConnURL, sqlConnUser, sqlConnDBName string @@ -405,9 +407,3 @@ var demoCtx struct { insecure bool geoLibsDir string } - -// certCtx captures the command-line parameters of the `cert` command. -// Defaults set by InitCLIDefaults() above. -var certCtx struct { - certPrincipalMap []string -} diff --git a/pkg/cli/flags.go b/pkg/cli/flags.go index 5538d8037b23..f84c09db155e 100644 --- a/pkg/cli/flags.go +++ b/pkg/cli/flags.go @@ -240,7 +240,9 @@ func init() { // Every command but start will inherit the following setting. AddPersistentPreRunE(cockroachCmd, func(cmd *cobra.Command, _ []string) error { - extraClientFlagInit() + if err := extraClientFlagInit(); err != nil { + return err + } return setDefaultStderrVerbosity(cmd, log.Severity_WARNING) }) @@ -441,12 +443,11 @@ func init() { f := cmd.Flags() // All certs commands need the certificate directory. StringFlag(f, &baseCfg.SSLCertsDir, cliflags.CertsDir, baseCfg.SSLCertsDir) + // All certs commands get the certificate principal map. + StringSlice(f, &cliCtx.certPrincipalMap, + cliflags.CertPrincipalMap, cliCtx.certPrincipalMap) } - // The list certs command needs the certificate principal map. - StringSlice(listCertsCmd.Flags(), &certCtx.certPrincipalMap, - cliflags.CertPrincipalMap, certCtx.certPrincipalMap) - for _, cmd := range []*cobra.Command{createCACertCmd, createClientCACertCmd} { f := cmd.Flags() // CA certificates have a longer expiration time. @@ -495,6 +496,9 @@ func init() { // Certificate flags. StringFlag(f, &baseCfg.SSLCertsDir, cliflags.CertsDir, baseCfg.SSLCertsDir) + // Certificate principal map. + StringSlice(f, &cliCtx.certPrincipalMap, + cliflags.CertPrincipalMap, cliCtx.certPrincipalMap) } // Auth commands. @@ -878,7 +882,10 @@ func extraServerFlagInit(cmd *cobra.Command) error { return nil } -func extraClientFlagInit() { +func extraClientFlagInit() error { + if err := security.SetCertPrincipalMap(cliCtx.certPrincipalMap); err != nil { + return err + } serverCfg.Addr = net.JoinHostPort(cliCtx.clientConnHost, cliCtx.clientConnPort) serverCfg.AdvertiseAddr = serverCfg.Addr serverCfg.SQLAddr = net.JoinHostPort(cliCtx.clientConnHost, cliCtx.clientConnPort) @@ -894,6 +901,7 @@ func extraClientFlagInit() { if sqlCtx.debugMode { sqlCtx.echo = true } + return nil } func setDefaultStderrVerbosity(cmd *cobra.Command, defaultSeverity log.Severity) error {