Skip to content

Commit

Permalink
Merge pull request #42 from xataio/add-config-tls-support
Browse files Browse the repository at this point in the history
Add config tls support
  • Loading branch information
eminano authored Jun 18, 2024
2 parents e218205 + d166597 commit 72e3912
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 75 deletions.
20 changes: 12 additions & 8 deletions cmd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/spf13/viper"
"github.com/xataio/pgstream/internal/backoff"
"github.com/xataio/pgstream/internal/kafka"
"github.com/xataio/pgstream/internal/tls"
pgschemalog "github.com/xataio/pgstream/pkg/schemalog/postgres"
"github.com/xataio/pgstream/pkg/stream"
kafkacheckpoint "github.com/xataio/pgstream/pkg/wal/checkpointer/kafka"
Expand Down Expand Up @@ -92,10 +93,7 @@ func parseKafkaReaderConfig(kafkaServers []string, kafkaTopic, consumerGroupID s
Topic: kafka.TopicConfig{
Name: kafkaTopic,
},
TLS: &kafka.TLSConfig{
// TODO: add support for TLS configuration
Enabled: false,
},
TLS: parseTLSConfig("PGSTREAM_KAFKA"),
},
ConsumerGroupID: consumerGroupID,
ConsumerGroupStartOffset: viper.GetString("PGSTREAM_KAFKA_READER_CONSUMER_GROUP_START_OFFSET"),
Expand Down Expand Up @@ -143,10 +141,7 @@ func parseKafkaWriterConfig(kafkaServers []string, kafkaTopic string) *kafkaproc
ReplicationFactor: viper.GetInt("PGSTREAM_KAFKA_TOPIC_REPLICATION_FACTOR"),
AutoCreate: viper.GetBool("PGSTREAM_KAFKA_TOPIC_AUTO_CREATE"),
},
TLS: &kafka.TLSConfig{
// TODO: add support for TLS configuration
Enabled: false,
},
TLS: parseTLSConfig("PGSTREAM_KAFKA"),
},
BatchTimeout: viper.GetDuration("PGSTREAM_KAFKA_WRITER_BATCH_TIMEOUT"),
BatchBytes: viper.GetInt64("PGSTREAM_KAFKA_WRITER_BATCH_BYTES"),
Expand Down Expand Up @@ -220,3 +215,12 @@ func parseTranslatorConfig() *translator.Config {
},
}
}

func parseTLSConfig(prefix string) *tls.Config {
return &tls.Config{
Enabled: viper.GetBool(fmt.Sprintf("%s_TLS_ENABLED", prefix)),
CaCertFile: viper.GetString(fmt.Sprintf("%s_TLS_CA_CERT_FILE", prefix)),
ClientCertFile: viper.GetString(fmt.Sprintf("%s_TLS_CLIENT_CERT_FILE", prefix)),
ClientKeyFile: viper.GetString(fmt.Sprintf("%s_TLS_CLIENT_KEY_FILE", prefix)),
}
}
25 changes: 12 additions & 13 deletions internal/kafka/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ import (
"strconv"
"time"

tlslib "github.com/xataio/pgstream/internal/tls"

"github.com/segmentio/kafka-go"
)

type ConnConfig struct {
Servers []string
Topic TopicConfig
TLS *TLSConfig
TLS *tlslib.Config
}

type TopicConfig struct {
Expand Down Expand Up @@ -65,20 +67,17 @@ func withConnection(config *ConnConfig, kafkaOperation func(conn *kafka.Conn) er
return kafkaOperation(controllerConn)
}

func buildDialer(tlsConfig *TLSConfig) (*kafka.Dialer, error) {
func buildDialer(cfg *tlslib.Config) (*kafka.Dialer, error) {
timeout := 10 * time.Second

dialer := &kafka.Dialer{
Timeout: timeout,
DualStack: true,
}
if tlsConfig.Enabled {
var err error
dialer, err = buildTLSDialer(tlsConfig, timeout)
if err != nil {
return nil, fmt.Errorf("building dialer: %w", err)
}
tlsConfig, err := tlslib.NewConfig(cfg)
if err != nil {
return nil, fmt.Errorf("loading TLS configuration: %w", err)
}

return dialer, nil
return &kafka.Dialer{
Timeout: timeout,
DualStack: true,
TLS: tlsConfig,
}, nil
}
13 changes: 6 additions & 7 deletions internal/kafka/kafka_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"github.com/segmentio/kafka-go"
tlslib "github.com/xataio/pgstream/internal/tls"
loglib "github.com/xataio/pgstream/pkg/log"
)

Expand Down Expand Up @@ -101,16 +102,14 @@ func createTopic(cfg *ConnConfig) error {
})
}

func buildTransport(tlsConfig *TLSConfig) (kafka.RoundTripper, error) {
transport := kafka.DefaultTransport

if tlsConfig.Enabled {
tlsConfig, err := newTLSConfig(tlsConfig)
func buildTransport(cfg *tlslib.Config) (kafka.RoundTripper, error) {
if cfg.Enabled {
tlsConfig, err := tlslib.NewConfig(cfg)
if err != nil {
return nil, fmt.Errorf("building TLS config: %w", err)
}
transport = &kafka.Transport{TLS: tlsConfig}
return &kafka.Transport{TLS: tlsConfig}, nil
}

return transport, nil
return kafka.DefaultTransport, nil
}
47 changes: 0 additions & 47 deletions internal/kafka/tls.go

This file was deleted.

83 changes: 83 additions & 0 deletions internal/tls/tls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// SPDX-License-Identifier: Apache-2.0

package tls

import (
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"os"
)

type Config struct {
// Enabled determines if TLS should be used. Defaults to false.
Enabled bool
// File path to the CA PEM certificate to be used for TLS connection. If TLS is
// enabled and no CA cert file is provided, the system certificate pool is
// used as default.
CaCertFile string
// File path to the client PEM certificate
ClientCertFile string
// File path to the client PEM key
ClientKeyFile string
}

func NewConfig(cfg *Config) (*tls.Config, error) {
if !cfg.Enabled {
return nil, nil
}

certPool, err := getCertPool(cfg.CaCertFile)
if err != nil {
return nil, err
}

certificates, err := getCertificates(cfg.ClientCertFile, cfg.ClientKeyFile)
if err != nil {
return nil, err
}

return &tls.Config{
MinVersion: tls.VersionTLS12,
MaxVersion: 0,
Certificates: certificates,
RootCAs: certPool,
}, nil
}

func getCertPool(caCertFile string) (*x509.CertPool, error) {
if caCertFile != "" {
pemCertBytes, err := readFile(caCertFile)
if err != nil {
return nil, fmt.Errorf("reading CA certificate file: %w", err)
}
certPool := x509.NewCertPool()
certPool.AppendCertsFromPEM(pemCertBytes)
return certPool, nil
}

return x509.SystemCertPool()
}

func getCertificates(clientCertFile, clientKeyFile string) ([]tls.Certificate, error) {
if clientCertFile != "" && clientKeyFile != "" {
cert, err := tls.LoadX509KeyPair(clientCertFile, clientKeyFile)
if err != nil {
return nil, err
}
return []tls.Certificate{cert}, nil
}

return []tls.Certificate{}, nil
}

func readFile(path string) ([]byte, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
defer file.Close()

return io.ReadAll(file)
}

0 comments on commit 72e3912

Please sign in to comment.