Skip to content

Commit

Permalink
Treat enablement of TLS separately for server and client config (#2501)
Browse files Browse the repository at this point in the history
  • Loading branch information
sergeybykov authored and meiliang86 committed Feb 23, 2022
1 parent 1a1aa2f commit e7852f7
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 14 deletions.
9 changes: 6 additions & 3 deletions common/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,12 @@ func (c *Config) String() string {
return maskedYaml
}

func (r *GroupTLS) IsEnabled() bool {
return r.Server.KeyFile != "" || r.Server.KeyData != "" ||
len(r.Client.RootCAFiles) > 0 || len(r.Client.RootCAData) > 0 ||
func (r *GroupTLS) IsServerEnabled() bool {
return r.Server.KeyFile != "" || r.Server.KeyData != ""
}

func (r *GroupTLS) IsClientEnabled() bool {
return len(r.Client.RootCAFiles) > 0 || len(r.Client.RootCAData) > 0 ||
r.Client.ForceTLS
}

Expand Down
8 changes: 4 additions & 4 deletions common/rpc/encryption/localStoreTlsProvider.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func (s *localStoreTlsProvider) GetInternodeClientConfig() (*tls.Config, error)
return newClientTLSConfig(s.internodeClientCertProvider, client.ServerName,
s.settings.Internode.Server.RequireClientAuth, false, !client.DisableHostVerification)
},
s.settings.Internode.IsEnabled(),
s.settings.Internode.IsClientEnabled(),
)
}

Expand All @@ -143,7 +143,7 @@ func (s *localStoreTlsProvider) GetFrontendClientConfig() (*tls.Config, error) {
useTLS = true
} else {
client = &s.settings.Frontend.Client
useTLS = s.settings.Frontend.IsEnabled()
useTLS = s.settings.Frontend.IsClientEnabled()
}
return s.getOrCreateConfig(
&s.cachedFrontendClientConfig,
Expand All @@ -161,7 +161,7 @@ func (s *localStoreTlsProvider) GetFrontendServerConfig() (*tls.Config, error) {
func() (*tls.Config, error) {
return newServerTLSConfig(s.frontendCertProvider, s.frontendPerHostCertProviderMap, &s.settings.Frontend, s.logger)
},
s.settings.Frontend.IsEnabled())
s.settings.Frontend.IsServerEnabled())
}

func (s *localStoreTlsProvider) GetInternodeServerConfig() (*tls.Config, error) {
Expand All @@ -170,7 +170,7 @@ func (s *localStoreTlsProvider) GetInternodeServerConfig() (*tls.Config, error)
func() (*tls.Config, error) {
return newServerTLSConfig(s.internodeCertProvider, nil, &s.settings.Internode, s.logger)
},
s.settings.Internode.IsEnabled())
s.settings.Internode.IsServerEnabled())
}

func (s *localStoreTlsProvider) GetExpiringCerts(timeWindow time.Duration,
Expand Down
22 changes: 15 additions & 7 deletions common/rpc/encryption/tls_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,27 @@ func (s *tlsConfigTest) SetupTest() {
func (s *tlsConfigTest) TestIsEnabled() {

emptyCfg := config.GroupTLS{}
s.False(emptyCfg.IsEnabled())
s.False(emptyCfg.IsServerEnabled())
s.False(emptyCfg.IsClientEnabled())
cfg := config.GroupTLS{Server: config.ServerTLS{KeyFile: "foo"}}
s.True(cfg.IsEnabled())
s.True(cfg.IsServerEnabled())
s.False(cfg.IsClientEnabled())
cfg = config.GroupTLS{Server: config.ServerTLS{KeyData: "foo"}}
s.True(cfg.IsEnabled())
s.True(cfg.IsServerEnabled())
s.False(cfg.IsClientEnabled())
cfg = config.GroupTLS{Client: config.ClientTLS{RootCAFiles: []string{"bar"}}}
s.True(cfg.IsEnabled())
s.False(cfg.IsServerEnabled())
s.True(cfg.IsClientEnabled())
cfg = config.GroupTLS{Client: config.ClientTLS{RootCAData: []string{"bar"}}}
s.True(cfg.IsEnabled())
s.False(cfg.IsServerEnabled())
s.True(cfg.IsClientEnabled())
cfg = config.GroupTLS{Client: config.ClientTLS{ForceTLS: true}}
s.True(cfg.IsEnabled())
s.False(cfg.IsServerEnabled())
s.True(cfg.IsClientEnabled())
cfg = config.GroupTLS{Client: config.ClientTLS{ForceTLS: false}}
s.False(cfg.IsEnabled())
s.False(cfg.IsServerEnabled())
s.False(cfg.IsClientEnabled())

}

func (s *tlsConfigTest) TestIsSystemWorker() {
Expand Down
24 changes: 24 additions & 0 deletions common/rpc/test/rpc_localstore_tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ type localStoreRPCSuite struct {
internodeDynamicTLSFactory *TestFactory
internodeMutualTLSRPCRefreshFactory *TestFactory
frontendMutualTLSRPCRefreshFactory *TestFactory
frontendConfigRootCAForceTLSFactory *TestFactory

internodeCertDir string
frontendCertDir string
Expand All @@ -101,6 +102,7 @@ type localStoreRPCSuite struct {
frontendConfigMutualTLS config.GroupTLS
frontendConfigPerHostOverrides config.GroupTLS
frontendConfigRootCAOnly config.GroupTLS
frontendConfigRootCAForceTLS config.GroupTLS
frontendConfigAltRootCAOnly config.GroupTLS
frontendConfigSystemWorker config.WorkerTLS
frontendConfigMutualTLSRefresh config.GroupTLS
Expand Down Expand Up @@ -201,6 +203,9 @@ func (s *localStoreRPCSuite) SetupSuite() {
RootCAData: []string{convertFileToBase64(s.frontendChain.CaPubFile)},
},
}
s.frontendConfigRootCAForceTLS = s.frontendConfigRootCAOnly
s.frontendConfigRootCAForceTLS.Client.ForceTLS = true

s.frontendConfigAltRootCAOnly = config.GroupTLS{
Server: config.ServerTLS{
RequireClientAuth: true,
Expand Down Expand Up @@ -319,6 +324,13 @@ func (s *localStoreRPCSuite) setupFrontend() {
},
}

localStoreRootCAForceTLS := &config.Global{
Membership: s.membershipConfig,
TLS: config.RootTLS{
Frontend: s.frontendConfigRootCAForceTLS,
},
}

provider, err := encryption.NewTLSConfigProviderFromConfig(localStoreMutualTLS.TLS, nil, s.logger, nil)
s.NoError(err)
frontendMutualTLSFactory := rpc.NewFactory(rpcTestCfgDefault, "tester", s.logger, provider, dynamicconfig.NewNoopCollection())
Expand Down Expand Up @@ -355,6 +367,12 @@ func (s *localStoreRPCSuite) setupFrontend() {
s.internodeDynamicTLSFactory = i(dynamicServerTLSFactory)

s.frontendMutualTLSRPCRefreshFactory = f(frontendMutualTLSRefreshFactory)

provider, err = encryption.NewTLSConfigProviderFromConfig(localStoreRootCAForceTLS.TLS, nil, s.logger, nil)
s.NoError(err)
frontendRootCAForceTLSFactory := rpc.NewFactory(rpcTestCfgDefault, "tester", s.logger, provider, dynamicconfig.NewNoopCollection())
s.NotNil(frontendServerTLSFactory)
s.frontendConfigRootCAForceTLSFactory = f(frontendRootCAForceTLSFactory)
}

func (s *localStoreRPCSuite) setupInternode() {
Expand Down Expand Up @@ -789,3 +807,9 @@ func runRingpopTLSTest(s suite.Suite, logger log.Logger, serverA *TestFactory, s
s.NoError(err)
}
}

func (s *localStoreRPCSuite) TestClientForceTLS() {
options, err := s.frontendConfigRootCAForceTLSFactory.RPCFactory.GetFrontendGRPCServerOptions()
s.NoError(err)
s.Nil(options)
}

0 comments on commit e7852f7

Please sign in to comment.