Skip to content

Commit

Permalink
[ADDED] Setting TLS config with callbacks in Connect
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Piotrowski <piotr@synadia.com>
  • Loading branch information
piotrpio committed Feb 16, 2024
1 parent a91a735 commit a7632e5
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 49 deletions.
133 changes: 84 additions & 49 deletions nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,55 +90,56 @@ const (

// Errors
var (
ErrConnectionClosed = errors.New("nats: connection closed")
ErrConnectionDraining = errors.New("nats: connection draining")
ErrDrainTimeout = errors.New("nats: draining connection timed out")
ErrConnectionReconnecting = errors.New("nats: connection reconnecting")
ErrSecureConnRequired = errors.New("nats: secure connection required")
ErrSecureConnWanted = errors.New("nats: secure connection not available")
ErrBadSubscription = errors.New("nats: invalid subscription")
ErrTypeSubscription = errors.New("nats: invalid subscription type")
ErrBadSubject = errors.New("nats: invalid subject")
ErrBadQueueName = errors.New("nats: invalid queue name")
ErrSlowConsumer = errors.New("nats: slow consumer, messages dropped")
ErrTimeout = errors.New("nats: timeout")
ErrBadTimeout = errors.New("nats: timeout invalid")
ErrAuthorization = errors.New("nats: authorization violation")
ErrAuthExpired = errors.New("nats: authentication expired")
ErrAuthRevoked = errors.New("nats: authentication revoked")
ErrAccountAuthExpired = errors.New("nats: account authentication expired")
ErrNoServers = errors.New("nats: no servers available for connection")
ErrJsonParse = errors.New("nats: connect message, json parse error")
ErrChanArg = errors.New("nats: argument needs to be a channel type")
ErrMaxPayload = errors.New("nats: maximum payload exceeded")
ErrMaxMessages = errors.New("nats: maximum messages delivered")
ErrSyncSubRequired = errors.New("nats: illegal call on an async subscription")
ErrMultipleTLSConfigs = errors.New("nats: multiple tls.Configs not allowed")
ErrNoInfoReceived = errors.New("nats: protocol exception, INFO not received")
ErrReconnectBufExceeded = errors.New("nats: outbound buffer limit exceeded")
ErrInvalidConnection = errors.New("nats: invalid connection")
ErrInvalidMsg = errors.New("nats: invalid message or message nil")
ErrInvalidArg = errors.New("nats: invalid argument")
ErrInvalidContext = errors.New("nats: invalid context")
ErrNoDeadlineContext = errors.New("nats: context requires a deadline")
ErrNoEchoNotSupported = errors.New("nats: no echo option not supported by this server")
ErrClientIDNotSupported = errors.New("nats: client ID not supported by this server")
ErrUserButNoSigCB = errors.New("nats: user callback defined without a signature handler")
ErrNkeyButNoSigCB = errors.New("nats: nkey defined without a signature handler")
ErrNoUserCB = errors.New("nats: user callback not defined")
ErrNkeyAndUser = errors.New("nats: user callback and nkey defined")
ErrNkeysNotSupported = errors.New("nats: nkeys not supported by the server")
ErrStaleConnection = errors.New("nats: " + STALE_CONNECTION)
ErrTokenAlreadySet = errors.New("nats: token and token handler both set")
ErrMsgNotBound = errors.New("nats: message is not bound to subscription/connection")
ErrMsgNoReply = errors.New("nats: message does not have a reply")
ErrClientIPNotSupported = errors.New("nats: client IP not supported by this server")
ErrDisconnected = errors.New("nats: server is disconnected")
ErrHeadersNotSupported = errors.New("nats: headers not supported by this server")
ErrBadHeaderMsg = errors.New("nats: message could not decode headers")
ErrNoResponders = errors.New("nats: no responders available for request")
ErrMaxConnectionsExceeded = errors.New("nats: server maximum connections exceeded")
ErrConnectionNotTLS = errors.New("nats: connection is not tls")
ErrConnectionClosed = errors.New("nats: connection closed")
ErrConnectionDraining = errors.New("nats: connection draining")
ErrDrainTimeout = errors.New("nats: draining connection timed out")
ErrConnectionReconnecting = errors.New("nats: connection reconnecting")
ErrSecureConnRequired = errors.New("nats: secure connection required")
ErrSecureConnWanted = errors.New("nats: secure connection not available")
ErrBadSubscription = errors.New("nats: invalid subscription")
ErrTypeSubscription = errors.New("nats: invalid subscription type")
ErrBadSubject = errors.New("nats: invalid subject")
ErrBadQueueName = errors.New("nats: invalid queue name")
ErrSlowConsumer = errors.New("nats: slow consumer, messages dropped")
ErrTimeout = errors.New("nats: timeout")
ErrBadTimeout = errors.New("nats: timeout invalid")
ErrAuthorization = errors.New("nats: authorization violation")
ErrAuthExpired = errors.New("nats: authentication expired")
ErrAuthRevoked = errors.New("nats: authentication revoked")
ErrAccountAuthExpired = errors.New("nats: account authentication expired")
ErrNoServers = errors.New("nats: no servers available for connection")
ErrJsonParse = errors.New("nats: connect message, json parse error")
ErrChanArg = errors.New("nats: argument needs to be a channel type")
ErrMaxPayload = errors.New("nats: maximum payload exceeded")
ErrMaxMessages = errors.New("nats: maximum messages delivered")
ErrSyncSubRequired = errors.New("nats: illegal call on an async subscription")
ErrMultipleTLSConfigs = errors.New("nats: multiple tls.Configs not allowed")
ErrClientCertOrRootCAsRequired = errors.New("nats: at least one of certCB or rootCAsCB must be set")
ErrNoInfoReceived = errors.New("nats: protocol exception, INFO not received")
ErrReconnectBufExceeded = errors.New("nats: outbound buffer limit exceeded")
ErrInvalidConnection = errors.New("nats: invalid connection")
ErrInvalidMsg = errors.New("nats: invalid message or message nil")
ErrInvalidArg = errors.New("nats: invalid argument")
ErrInvalidContext = errors.New("nats: invalid context")
ErrNoDeadlineContext = errors.New("nats: context requires a deadline")
ErrNoEchoNotSupported = errors.New("nats: no echo option not supported by this server")
ErrClientIDNotSupported = errors.New("nats: client ID not supported by this server")
ErrUserButNoSigCB = errors.New("nats: user callback defined without a signature handler")
ErrNkeyButNoSigCB = errors.New("nats: nkey defined without a signature handler")
ErrNoUserCB = errors.New("nats: user callback not defined")
ErrNkeyAndUser = errors.New("nats: user callback and nkey defined")
ErrNkeysNotSupported = errors.New("nats: nkeys not supported by the server")
ErrStaleConnection = errors.New("nats: " + STALE_CONNECTION)
ErrTokenAlreadySet = errors.New("nats: token and token handler both set")
ErrMsgNotBound = errors.New("nats: message is not bound to subscription/connection")
ErrMsgNoReply = errors.New("nats: message does not have a reply")
ErrClientIPNotSupported = errors.New("nats: client IP not supported by this server")
ErrDisconnected = errors.New("nats: server is disconnected")
ErrHeadersNotSupported = errors.New("nats: headers not supported by this server")
ErrBadHeaderMsg = errors.New("nats: message could not decode headers")
ErrNoResponders = errors.New("nats: no responders available for request")
ErrMaxConnectionsExceeded = errors.New("nats: server maximum connections exceeded")
ErrConnectionNotTLS = errors.New("nats: connection is not tls")
)

// GetDefaultOptions returns default configuration options for the client.
Expand Down Expand Up @@ -864,6 +865,40 @@ func Secure(tls ...*tls.Config) Option {
}
}

// ClientTLSConfig is an Option to set the TLS configuration for secure
// connections. It can be used to e.g. set TLS config with cert and root CAs
// from memory. For simple use case of loading cert and CAs from file,
// ClientCert and RootCAs options are more convenient.
// If Secure is not already set this will set it as well.
func ClientTLSConfig(certCB TLSCertHandler, rootCAsCB RootCAsHandler) Option {
return func(o *Options) error {
o.Secure = true

if certCB == nil && rootCAsCB == nil {
return ErrClientCertOrRootCAsRequired
}

// Smoke test the callbacks to fail early
// if they are not valid.
if certCB != nil {
if _, err := certCB(); err != nil {
return err
}
}
if rootCAsCB != nil {
if _, err := rootCAsCB(); err != nil {
return err
}
}
if o.TLSConfig == nil {
o.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
}
o.TLSCertCB = certCB
o.RootCAsCB = rootCAsCB
return nil
}
}

// RootCAs is a helper option to provide the RootCAs pool from a list of filenames.
// If Secure is not already set this will set it as well.
func RootCAs(file ...string) Option {
Expand Down
98 changes: 98 additions & 0 deletions test/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"bytes"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"os"
Expand Down Expand Up @@ -234,6 +235,103 @@ func TestServerSecureConnections(t *testing.T) {
}
}

func TestClientTLSConfig(t *testing.T) {
s, opts := RunServerWithConfig("./configs/tlsverify.conf")
defer s.Shutdown()

endpoint := fmt.Sprintf("%s:%d", opts.Host, opts.Port)
secureURL := fmt.Sprintf("nats://%s", endpoint)

// Make sure this fails
nc, err := nats.Connect(secureURL, nats.Secure())
if err == nil {
nc.Close()
t.Fatal("Should have failed (TLS) connection without client certificate")
}
cert, err := os.ReadFile("./configs/certs/client-cert.pem")
if err != nil {
t.Fatal("Failed to read client certificate")
}
key, err := os.ReadFile("./configs/certs/client-key.pem")
if err != nil {
t.Fatal("Failed to read client key")
}
rootCAs, err := os.ReadFile("./configs/certs/ca.pem")
if err != nil {
t.Fatal("Failed to read root CAs")
}

certCB := func() (tls.Certificate, error) {
cert, err := tls.X509KeyPair(cert, key)
if err != nil {
return tls.Certificate{}, fmt.Errorf("nats: error loading client certificate: %w", err)
}
cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return tls.Certificate{}, fmt.Errorf("nats: error parsing client certificate: %w", err)
}
return cert, nil
}

caCB := func() (*x509.CertPool, error) {
pool := x509.NewCertPool()
ok := pool.AppendCertsFromPEM(rootCAs)
if !ok {
return nil, fmt.Errorf("nats: failed to parse root certificate from")
}
return pool, nil
}

// Check parameters validity
_, err = nats.Connect(secureURL, nats.ClientTLSConfig(nil, nil))
if !errors.Is(err, nats.ErrClientCertOrRootCAsRequired) {
t.Fatalf("Expected error %q, got %q", nats.ErrClientCertOrRootCAsRequired, err)
}

certErr := &tls.CertificateVerificationError{}
// Should fail because of missing CA
_, err = nats.Connect(secureURL,
nats.ClientCert("./configs/certs/client-cert.pem", "./configs/certs/client-key.pem"))
if ok := errors.As(err, &certErr); !ok {
t.Fatalf("Expected error %q, got %q", nats.ErrClientCertOrRootCAsRequired, err)
}

// Should fail because of missing certificate
_, err = nats.Connect(secureURL,
nats.ClientTLSConfig(nil, caCB))
if !strings.Contains(err.Error(), "bad certificate") && !strings.Contains(err.Error(), "certificate required") {
t.Fatalf("Expected missing certificate error; got: %s", err)
}

nc, err = nats.Connect(secureURL,
nats.ClientTLSConfig(certCB, caCB))
if err != nil {
t.Fatalf("Failed to create (TLS) connection: %v", err)
}
defer nc.Close()

omsg := []byte("Hello!")
checkRecv := make(chan bool)

received := 0
nc.Subscribe("foo", func(m *nats.Msg) {
received++
if !bytes.Equal(m.Data, omsg) {
t.Fatal("Message received does not match")
}
checkRecv <- true
})
err = nc.Publish("foo", omsg)
if err != nil {
t.Fatalf("Failed to publish on secure (TLS) connection: %v", err)
}
nc.Flush()

if err := Wait(checkRecv); err != nil {
t.Fatal("Failed to receive message")
}
}

func TestClientCertificate(t *testing.T) {
s, opts := RunServerWithConfig("./configs/tlsverify.conf")
defer s.Shutdown()
Expand Down

0 comments on commit a7632e5

Please sign in to comment.