Skip to content

Commit

Permalink
Conditionally enable mtls for the allocator. (#1645)
Browse files Browse the repository at this point in the history
* Removed the need for having certificates with mTLS disabled.

Co-authored-by: Nikhil Athreya <nathreya@google.com>
  • Loading branch information
devloop0 and nathreya-google authored Jun 29, 2020
1 parent c203a81 commit 6998d88
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 64 deletions.
138 changes: 74 additions & 64 deletions cmd/allocator/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,64 +88,66 @@ func main() {
return err
})

h := newServiceHandler(kubeClient, agonesClient, health)

// creates a new file watcher for client certificate folder
watcher, err := fsnotify.NewWatcher()
if err != nil {
logger.WithError(err).Fatal("could not create watcher for client certs")
}
defer watcher.Close() // nolint: errcheck
if err := watcher.Add(certDir); err != nil {
logger.WithError(err).Fatalf("cannot watch folder %s for secret changes", certDir)
}

watcherTLS, err := fsnotify.NewWatcher()
if err != nil {
logger.WithError(err).Fatal("could not create watcher for tls certs")
}
defer watcherTLS.Close() // nolint: errcheck
if err := watcherTLS.Add(tlsDir); err != nil {
logger.WithError(err).Fatalf("cannot watch folder %s for secret changes", tlsDir)
}
h := newServiceHandler(kubeClient, agonesClient, health, conf.MTLSDisabled)

listener, err := net.Listen("tcp", fmt.Sprintf(":%s", sslPort))
if err != nil {
logger.WithError(err).Fatalf("failed to listen on TCP port %s", sslPort)
}

// Watching for the events in certificate directory for updating certificates, when there is a change
go func() {
for {
select {
// watch for events
case event := <-watcherTLS.Events:
tlsCert, err := readTLSCert()
if err != nil {
logger.WithError(err).Error("could not load TLS cert; keeping old one")
} else {
h.tlsMutex.Lock()
h.tlsCert = tlsCert
h.tlsMutex.Unlock()
}
logger.Infof("Tls directory change event %v", event)
case event := <-watcher.Events:
h.certMutex.Lock()
caCertPool, err := getCACertPool(certDir)
if err != nil {
logger.WithError(err).Error("could not load CA certs; keeping old ones")
} else {
h.caCertPool = caCertPool
}
logger.Infof("Certificate directory change event %v", event)
h.certMutex.Unlock()
if !h.mTLSDisabled {
// creates a new file watcher for client certificate folder
watcher, err := fsnotify.NewWatcher()
if err != nil {
logger.WithError(err).Fatal("could not create watcher for client certs")
}
defer watcher.Close() // nolint: errcheck
if err := watcher.Add(certDir); err != nil {
logger.WithError(err).Fatalf("cannot watch folder %s for secret changes", certDir)
}

// watch for errors
case err := <-watcher.Errors:
logger.WithError(err).Error("error watching for certificate directory")
}
watcherTLS, err := fsnotify.NewWatcher()
if err != nil {
logger.WithError(err).Fatal("could not create watcher for tls certs")
}
}()
defer watcherTLS.Close() // nolint: errcheck
if err := watcherTLS.Add(tlsDir); err != nil {
logger.WithError(err).Fatalf("cannot watch folder %s for secret changes", tlsDir)
}

// Watching for the events in certificate directory for updating certificates, when there is a change
go func() {
for {
select {
// watch for events
case event := <-watcherTLS.Events:
tlsCert, err := readTLSCert()
if err != nil {
logger.WithError(err).Error("could not load TLS cert; keeping old one")
} else {
h.tlsMutex.Lock()
h.tlsCert = tlsCert
h.tlsMutex.Unlock()
}
logger.Infof("Tls directory change event %v", event)
case event := <-watcher.Events:
h.certMutex.Lock()
caCertPool, err := getCACertPool(certDir)
if err != nil {
logger.WithError(err).Error("could not load CA certs; keeping old ones")
} else {
h.caCertPool = caCertPool
}
logger.Infof("Certificate directory change event %v", event)
h.certMutex.Unlock()

// watch for errors
case err := <-watcher.Errors:
logger.WithError(err).Error("error watching for certificate directory")
}
}
}()
}

opts := h.getServerOptions()

Expand All @@ -165,7 +167,7 @@ func main() {
logger.WithError(err).Fatal("allocation service crashed")
}

func newServiceHandler(kubeClient kubernetes.Interface, agonesClient versioned.Interface, health healthcheck.Handler) *serviceHandler {
func newServiceHandler(kubeClient kubernetes.Interface, agonesClient versioned.Interface, health healthcheck.Handler, mTLSDisabled bool) *serviceHandler {
defaultResync := 30 * time.Second
agonesInformerFactory := externalversions.NewSharedInformerFactory(agonesClient, defaultResync)
kubeInformerFactory := informers.NewSharedInformerFactory(kubeClient, defaultResync)
Expand All @@ -182,6 +184,7 @@ func newServiceHandler(kubeClient kubernetes.Interface, agonesClient versioned.I
allocationCallback: func(gsa *allocationv1.GameServerAllocation) (k8sruntime.Object, error) {
return allocator.Allocate(gsa, stop)
},
mTLSDisabled: mTLSDisabled,
}

kubeInformerFactory.Start(stop)
Expand All @@ -190,21 +193,23 @@ func newServiceHandler(kubeClient kubernetes.Interface, agonesClient versioned.I
logger.WithError(err).Fatal("starting allocator failed.")
}

caCertPool, err := getCACertPool(certDir)
if err != nil {
logger.WithError(err).Fatal("could not load CA certs.")
}
h.certMutex.Lock()
h.caCertPool = caCertPool
h.certMutex.Unlock()
if !h.mTLSDisabled {
caCertPool, err := getCACertPool(certDir)
if err != nil {
logger.WithError(err).Fatal("could not load CA certs.")
}
h.certMutex.Lock()
h.caCertPool = caCertPool
h.certMutex.Unlock()

tlsCert, err := readTLSCert()
if err != nil {
logger.WithError(err).Fatal("could not load TLS certs.")
tlsCert, err := readTLSCert()
if err != nil {
logger.WithError(err).Fatal("could not load TLS certs.")
}
h.tlsMutex.Lock()
h.tlsCert = tlsCert
h.tlsMutex.Unlock()
}
h.tlsMutex.Lock()
h.tlsCert = tlsCert
h.tlsMutex.Unlock()

return &h
}
Expand All @@ -220,6 +225,9 @@ func readTLSCert() (*tls.Certificate, error) {
// getServerOptions returns a list of GRPC server options.
// Current options are TLS certs and opencensus stats handler.
func (h *serviceHandler) getServerOptions() []grpc.ServerOption {
if h.mTLSDisabled {
return []grpc.ServerOption{grpc.StatsHandler(&ocgrpc.ServerHandler{})}
}

cfg := &tls.Config{
GetCertificate: h.getTLSCert,
Expand Down Expand Up @@ -323,6 +331,8 @@ type serviceHandler struct {

tlsMutex sync.RWMutex
tlsCert *tls.Certificate

mTLSDisabled bool
}

// Allocate implements the Allocate gRPC method definition
Expand Down
6 changes: 6 additions & 0 deletions cmd/allocator/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ const (
enablePrometheusMetricsFlag = "prometheus-exporter"
projectIDFlag = "gcp-project-id"
stackdriverLabels = "stackdriver-labels"
mTLSDisabledFlag = "disable-mtls"
)

func init() {
registerMetricViews()
}

type config struct {
MTLSDisabled bool
PrometheusMetrics bool
Stackdriver bool
GCPProjectID string
Expand All @@ -51,11 +53,13 @@ func parseEnvFlags() config {
viper.SetDefault(enableStackdriverMetricsFlag, false)
viper.SetDefault(projectIDFlag, "")
viper.SetDefault(stackdriverLabels, "")
viper.SetDefault(mTLSDisabledFlag, false)

pflag.Bool(enablePrometheusMetricsFlag, viper.GetBool(enablePrometheusMetricsFlag), "Flag to activate metrics of Agones. Can also use PROMETHEUS_EXPORTER env variable.")
pflag.Bool(enableStackdriverMetricsFlag, viper.GetBool(enableStackdriverMetricsFlag), "Flag to activate stackdriver monitoring metrics for Agones. Can also use STACKDRIVER_EXPORTER env variable.")
pflag.String(projectIDFlag, viper.GetString(projectIDFlag), "GCP ProjectID used for Stackdriver, if not specified ProjectID from Application Default Credentials would be used. Can also use GCP_PROJECT_ID env variable.")
pflag.String(stackdriverLabels, viper.GetString(stackdriverLabels), "A set of default labels to add to all stackdriver metrics generated. By default metadata are automatically added using Kubernetes API and GCP metadata enpoint.")
pflag.Bool(mTLSDisabledFlag, viper.GetBool(mTLSDisabledFlag), "Flag to enable/disable mTLS in the allocator.")
runtime.FeaturesBindFlags()
pflag.Parse()

Expand All @@ -64,6 +68,7 @@ func parseEnvFlags() config {
runtime.Must(viper.BindEnv(enableStackdriverMetricsFlag))
runtime.Must(viper.BindEnv(projectIDFlag))
runtime.Must(viper.BindEnv(stackdriverLabels))
runtime.Must(viper.BindEnv(mTLSDisabledFlag))
runtime.Must(viper.BindPFlags(pflag.CommandLine))
runtime.Must(runtime.FeaturesBindEnv())

Expand All @@ -74,6 +79,7 @@ func parseEnvFlags() config {
Stackdriver: viper.GetBool(enableStackdriverMetricsFlag),
GCPProjectID: viper.GetString(projectIDFlag),
StackdriverLabels: viper.GetString(stackdriverLabels),
MTLSDisabled: viper.GetBool(mTLSDisabledFlag),
}
}

Expand Down
2 changes: 2 additions & 0 deletions install/helm/agones/templates/service/allocation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ spec:
value: {{ .Values.agones.metrics.stackdriverProjectID | quote }}
- name: STACKDRIVER_LABELS
value: {{ .Values.agones.metrics.stackdriverLabels | quote }}
- name: DISABLE_MTLS
value: {{ .Values.agones.allocator.disableMTLS | quote }}
- name: POD_NAME
valueFrom:
fieldRef:
Expand Down
1 change: 1 addition & 0 deletions install/helm/agones/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ agones:
serviceType: LoadBalancer
annotations: {}
generateTLS: true
disableMTLS: false
image:
registry: gcr.io/agones-images
tag: 1.7.0
Expand Down
2 changes: 2 additions & 0 deletions install/yaml/install.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1574,6 +1574,8 @@ spec:
value: ""
- name: STACKDRIVER_LABELS
value: ""
- name: DISABLE_MTLS
value: "false"
- name: POD_NAME
valueFrom:
fieldRef:
Expand Down
1 change: 1 addition & 0 deletions pkg/gameserverallocations/allocator.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ func (c *Allocator) allocateFromRemoteCluster(gsa *allocationv1.GameServerAlloca

// createRemoteClusterDialOption creates a grpc client dial option with proper certs to make a remote call.
func (c *Allocator) createRemoteClusterDialOption(namespace string, connectionInfo *multiclusterv1.ClusterConnectionInfo) (grpc.DialOption, error) {
// TODO: disableMTLS works for a single cluster; still need to address how the flag interacts with multi-cluster authentication.
clientCert, clientKey, caCert, err := c.getClientCertificates(namespace, connectionInfo.SecretName)
if err != nil {
return nil, err
Expand Down

0 comments on commit 6998d88

Please sign in to comment.