Skip to content

Commit

Permalink
Add support for tls in rabbitmq scaler (kedacore#967)
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishek Mohite <b518004@iiit-bh.ac.in>
  • Loading branch information
mohite-abhi committed Jan 16, 2023
1 parent c161148 commit 1848877
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ To learn more about active deprecations, we recommend checking [GitHub Discussio

Here is an overview of all **stable** additions:

- **General**: Add TLS support to RabbitMQ scler ([#967](https://github.com/kedacore/keda/issues/967))
- **General**: Introduce admission webhooks to automatically validate resource changes to prevent misconfiguration and enforce best practices. ([#3755](https://github.com/kedacore/keda/issues/3755))
- **General**: Introduce new ArangoDB Scaler ([#4000](https://github.com/kedacore/keda/issues/4000))
- **Prometheus Metrics**: Introduce scaler latency in Prometheus metrics. ([#4037](https://github.com/kedacore/keda/issues/4037))
Expand Down
58 changes: 55 additions & 3 deletions pkg/scalers/rabbitmq_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ package scalers

import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"io"
"strings"
"net/http"
"net/url"
"regexp"
Expand Down Expand Up @@ -75,6 +78,12 @@ type rabbitMQMetadata struct {
metricName string // custom metric name for trigger
timeout time.Duration // custom http timeout for a specific trigger
scalerIndex int // scaler index

// TLS
ca string //Certificate authority file for TLS client authentication. Optional. If authmode is sasl_ssl, this is required.
cert string //Certificate for client authentication. Optional. If authmode is sasl_ssl, this is required.
key string //Key for client authentication. Optional. If authmode is sasl_ssl, this is required.
enableTLS bool
}

type queueInfo struct {
Expand Down Expand Up @@ -129,7 +138,7 @@ func NewRabbitMQScaler(config *ScalerConfig) (Scaler, error) {
host = hostURI.String()
}

conn, ch, err := getConnectionAndChannel(host)
conn, ch, err := getConnectionAndChannel(host, meta.ca, meta.cert, meta.key, meta.enableTLS)
if err != nil {
return nil, fmt.Errorf("error establishing rabbitmq connection: %w", err)
}
Expand Down Expand Up @@ -167,6 +176,27 @@ func parseRabbitMQMetadata(config *ScalerConfig) (*rabbitMQMetadata, error) {
return nil, fmt.Errorf("no host setting given")
}

meta.enableTLS = false
if val, ok := config.AuthParams["tls"]; ok {
val = strings.TrimSpace(val)
if val == "enable" {
certGiven := config.AuthParams["cert"] != ""
keyGiven := config.AuthParams["key"] != ""
if certGiven && !keyGiven {
return nil, fmt.Errorf("key must be provided with cert")
}
if keyGiven && !certGiven {
return nil, fmt.Errorf("cert must be provided with key")
}
meta.ca = config.AuthParams["ca"]
meta.cert = config.AuthParams["cert"]
meta.key = config.AuthParams["key"]
meta.enableTLS = true
} else if val != "disable" {
return nil, fmt.Errorf("err incorrect value for TLS given: %s", val)
}
}

// If the protocol is auto, check the host scheme.
if meta.protocol == autoProtocol {
parsedURL, err := url.Parse(meta.host)
Expand Down Expand Up @@ -354,8 +384,30 @@ func parseTrigger(meta *rabbitMQMetadata, config *ScalerConfig) (*rabbitMQMetada
return meta, nil
}

func getConnectionAndChannel(host string) (*amqp.Connection, *amqp.Channel, error) {
conn, err := amqp.Dial(host)
func getConnectionAndChannel(host string, caFile string, certFile string, keyFile string, enableTLS bool) (*amqp.Connection, *amqp.Channel, error) {
var conn *amqp.Connection
var err error

if enableTLS {
config := &tls.Config{}

config.RootCAs = x509.NewCertPool()

if caFile != "" {
config.RootCAs = x509.NewCertPool()
config.RootCAs.AppendCertsFromPEM([]byte(caFile))
}

if certFile != "" && keyFile != "" {
if cert, err := tls.LoadX509KeyPair(certFile, keyFile); err == nil {
config.Certificates = append(config.Certificates, cert)
}
}
conn, err = amqp.DialTLS(host, config)
} else {
conn, err = amqp.Dial(host)
}

if err != nil {
return nil, nil, err
}
Expand Down

0 comments on commit 1848877

Please sign in to comment.