Skip to content

Commit

Permalink
Support upstream via unix socket
Browse files Browse the repository at this point in the history
  • Loading branch information
alebedev87 committed Dec 2, 2022
1 parent a5bc10f commit a78a4de
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 36 deletions.
32 changes: 26 additions & 6 deletions cmd/kube-rbac-proxy/app/kube-rbac-proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,11 @@ type completedProxyRunOptions struct {
secureListenAddress string
healthzPath string

upstreamURL *url.URL
upstreamForceH2C bool
upstreamCABundle *x509.CertPool
upstreamURL *url.URL
upstreamForceH2C bool
upstreamUnixSocket string
upstreamCABundle *x509.CertPool
upstreamClientCert *tls.Certificate

auth *proxy.Config
tls *options.TLSConfig
Expand Down Expand Up @@ -211,12 +213,18 @@ func Complete(o *options.ProxyRunOptions) (*completedProxyRunOptions, error) {
ignorePaths: o.IgnorePaths,
}

if o.Upstream == "" && o.UpstreamUnixSocket != "" {
o.Upstream = "http://localhost"
}

completed.upstreamURL, err = url.Parse(o.Upstream)
if err != nil {
return nil, fmt.Errorf("failed to parse upstream URL: %w", err)
}

if upstreamCAPath := o.UpstreamCAFile; len(upstreamCAPath) > 0 {
completed.upstreamUnixSocket = o.UpstreamUnixSocket

if upstreamCAPath := o.TLS.UpstreamCAFile; len(upstreamCAPath) > 0 {
upstreamCAPEM, err := os.ReadFile(upstreamCAPath)
if err != nil {
return nil, err
Expand All @@ -229,6 +237,14 @@ func Complete(o *options.ProxyRunOptions) (*completedProxyRunOptions, error) {
completed.upstreamCABundle = upstreamCACertPool
}

if len(o.TLS.UpstreamClientCertFile) > 0 {
certKeyPair, err := tls.LoadX509KeyPair(o.TLS.UpstreamClientCertFile, o.TLS.UpstreamClientKeyFile)
if err != nil {
return nil, fmt.Errorf("failed to read upstream client cert/key: %w", err)
}
completed.upstreamClientCert = &certKeyPair
}

completed.auth = o.Auth
completed.tls = o.TLS

Expand Down Expand Up @@ -296,15 +312,16 @@ func Run(cfg *completedProxyRunOptions) error {
sarAuthorizer,
)

upstreamTransport, err := initTransport(cfg.upstreamCABundle, cfg.tls.UpstreamClientCertFile, cfg.tls.UpstreamClientKeyFile)
upstreamTransport, err := initTransport(cfg.upstreamCABundle, cfg.upstreamClientCert, cfg.upstreamUnixSocket)
if err != nil {
return fmt.Errorf("failed to set up upstream TLS connection: %w", err)
return fmt.Errorf("failed to init upstream transport: %w", err)
}

proxy := httputil.NewSingleHostReverseProxy(cfg.upstreamURL)
proxy.Transport = upstreamTransport

if cfg.upstreamForceH2C {
upstreamUnixSocket := cfg.upstreamUnixSocket
// Force http/2 for connections to the upstream i.e. do not start with HTTP1.1 UPGRADE req to
// initialize http/2 session.
// See https://github.com/golang/go/issues/14141#issuecomment-219212895 for more context
Expand All @@ -314,6 +331,9 @@ func Run(cfg *completedProxyRunOptions) error {
// Do disable TLS.
// In combination with the schema check above. We could enforce h2c against the upstream server
DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) {
if upstreamUnixSocket != "" {
return net.Dial("unix", upstreamUnixSocket)
}
return net.Dial(netw, addr)
},
}
Expand Down
6 changes: 4 additions & 2 deletions cmd/kube-rbac-proxy/app/options/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type ProxyRunOptions struct {

Upstream string
UpstreamForceH2C bool
UpstreamCAFile string
UpstreamUnixSocket string
Auth *proxy.Config
TLS *TLSConfig
KubeconfigLocation string
Expand All @@ -50,6 +50,7 @@ type TLSConfig struct {
CipherSuites []string
ReloadInterval time.Duration

UpstreamCAFile string
UpstreamClientCertFile string
UpstreamClientKeyFile string
}
Expand Down Expand Up @@ -78,7 +79,7 @@ func (o *ProxyRunOptions) Flags() k8sapiflag.NamedFlagSets {
flagset.StringVar(&o.SecureListenAddress, "secure-listen-address", "", "The address the kube-rbac-proxy HTTPs server should listen on.")
flagset.StringVar(&o.Upstream, "upstream", "", "The upstream URL to proxy to once requests have successfully been authenticated and authorized.")
flagset.BoolVar(&o.UpstreamForceH2C, "upstream-force-h2c", false, "Force h2c to communiate with the upstream. This is required when the upstream speaks h2c(http/2 cleartext - insecure variant of http/2) only. For example, go-grpc server in the insecure mode, such as helm's tiller w/o TLS, speaks h2c only")
flagset.StringVar(&o.UpstreamCAFile, "upstream-ca-file", "", "The CA the upstream uses for TLS connection. This is required when the upstream uses TLS and its own CA certificate")
flagset.StringVar(&o.UpstreamUnixSocket, "upstream-unix-socket", "", "The upstream unix socket to proxy to once requests have successfully been authenticated and authorized.")
flagset.StringVar(&o.ConfigFileName, "config-file", "", "Configuration file to configure kube-rbac-proxy.")
flagset.StringSliceVar(&o.AllowPaths, "allow-paths", nil, "Comma-separated list of paths against which kube-rbac-proxy pattern-matches the incoming request. If the request doesn't match, kube-rbac-proxy responds with a 404 status code. If omitted, the incoming request path isn't checked. Cannot be used with --ignore-paths.")
flagset.StringSliceVar(&o.IgnorePaths, "ignore-paths", nil, "Comma-separated list of paths against which kube-rbac-proxy pattern-matches the incoming request. If the requst matches, it will proxy the request without performing an authentication or authorization check. Cannot be used with --allow-paths.")
Expand All @@ -90,6 +91,7 @@ func (o *ProxyRunOptions) Flags() k8sapiflag.NamedFlagSets {
flagset.StringVar(&o.TLS.MinVersion, "tls-min-version", "VersionTLS12", "Minimum TLS version supported. Value must match version names from https://golang.org/pkg/crypto/tls/#pkg-constants.")
flagset.StringSliceVar(&o.TLS.CipherSuites, "tls-cipher-suites", nil, "Comma-separated list of cipher suites for the server. Values are from tls package constants (https://golang.org/pkg/crypto/tls/#pkg-constants). If omitted, the default Go cipher suites will be used")
flagset.DurationVar(&o.TLS.ReloadInterval, "tls-reload-interval", time.Minute, "The interval at which to watch for TLS certificate changes, by default set to 1 minute.")
flagset.StringVar(&o.TLS.UpstreamCAFile, "upstream-ca-file", "", "The CA the upstream uses for TLS connection. This is required when the upstream uses TLS and its own CA certificate")
flagset.StringVar(&o.TLS.UpstreamClientCertFile, "upstream-client-cert-file", "", "If set, the client will be used to authenticate the proxy to upstream. Requires --upstream-client-key-file to be set, too.")
flagset.StringVar(&o.TLS.UpstreamClientKeyFile, "upstream-client-key-file", "", "The key matching the certificate from --upstream-client-cert-file. If set, requires --upstream-client-cert-file to be set, too.")

Expand Down
87 changes: 64 additions & 23 deletions cmd/kube-rbac-proxy/app/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,89 @@ limitations under the License.
package app

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"net/http"
"time"
)

func initTransport(upstreamCAPool *x509.CertPool, upstreamClientCertPath, upstreamClientKeyPath string) (http.RoundTripper, error) {
if upstreamCAPool == nil {
func initTransport(upstreamCAPool *x509.CertPool, upstreamClientCert *tls.Certificate, upstreamUnixSocket string) (http.RoundTripper, error) {
if upstreamCAPool == nil && upstreamClientCert == nil && upstreamUnixSocket == "" {
return http.DefaultTransport, nil
}

var certKeyPair tls.Certificate
if len(upstreamClientCertPath) > 0 {
var err error
certKeyPair, err = tls.LoadX509KeyPair(upstreamClientCertPath, upstreamClientKeyPath)
if err != nil {
return nil, fmt.Errorf("failed to read upstream client cert/key: %w", err)
}
builder := newHTTPTransportBuilder()

if upstreamCAPool != nil {
builder.withRootCAs(upstreamCAPool)
}

// http.Transport sourced from go 1.10.7
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
if upstreamClientCert != nil {
builder.withClientCerts(*upstreamClientCert)
}

if upstreamUnixSocket != "" {
builder.withUnixDialContext(upstreamUnixSocket)
}

return builder.build(), nil
}

type httpTransportBuilder struct {
tlsClientConfig *tls.Config
dialContext func(ctx context.Context, network, addr string) (net.Conn, error)
}

func newHTTPTransportBuilder() *httpTransportBuilder {
return &httpTransportBuilder{
dialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
}
}

func (b *httpTransportBuilder) withRootCAs(certs *x509.CertPool) *httpTransportBuilder {
if b.tlsClientConfig != nil {
b.tlsClientConfig.RootCAs = certs
} else {
b.tlsClientConfig = &tls.Config{RootCAs: certs}
}
return b
}

func (b *httpTransportBuilder) withClientCerts(certs ...tls.Certificate) *httpTransportBuilder {
if b.tlsClientConfig != nil {
b.tlsClientConfig.Certificates = certs
} else {
b.tlsClientConfig = &tls.Config{Certificates: certs}
}
return b
}

func (b *httpTransportBuilder) withUnixDialContext(socket string) *httpTransportBuilder {
b.dialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
return (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext(ctx, "unix", socket)
}
return b
}

func (b *httpTransportBuilder) build() *http.Transport {
// http.Transport sourced from go 1.10.7
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: b.dialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{
RootCAs: upstreamCAPool,
},
TLSClientConfig: b.tlsClientConfig,
}

if certKeyPair.Certificate != nil {
transport.TLSClientConfig.Certificates = []tls.Certificate{certKeyPair}
}

return transport, nil
}
72 changes: 67 additions & 5 deletions cmd/kube-rbac-proxy/app/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,12 @@ import (
"k8s.io/client-go/util/keyutil"
)

const (
okResponse = "ok"
)

func TestInitTransportWithDefault(t *testing.T) {
roundTripper, err := initTransport(nil, "", "")
roundTripper, err := initTransport(nil, nil, "")
if err != nil {
t.Errorf("want err to be nil, but got %v", err)
return
Expand All @@ -57,7 +61,7 @@ func TestInitTransportWithCustomCA(t *testing.T) {
upstreamCAPool := x509.NewCertPool()
upstreamCAPool.AppendCertsFromPEM(upstreamCAPEM)

roundTripper, err := initTransport(upstreamCAPool, "", "")
roundTripper, err := initTransport(upstreamCAPool, nil, "")
if err != nil {
t.Fatalf("want err to be nil, but got %v", err)
}
Expand All @@ -69,8 +73,7 @@ func TestInitTransportWithCustomCA(t *testing.T) {

func testHTTPHandler(w http.ResponseWriter, req *http.Request) {
if len(req.TLS.PeerCertificates) > 0 {
_, _ = w.Write([]byte("ok"))
return
okTestHTTPHandler(w, req)
} else {
reqDump, _ := httputil.DumpRequest(req, false)
resp := fmt.Sprintf("got request without client certificates:\n%s\n", reqDump)
Expand Down Expand Up @@ -128,10 +131,14 @@ func TestInitTransportWithClientCertAuth(t *testing.T) {
if err := keyutil.WriteKey(clientKeyPath, clientKey); err != nil {
t.Fatalf("failed to write client key: %v", err)
}
certKeyPair, err := tls.X509KeyPair(clientCert, clientKey)
if err != nil {
t.Fatalf("failed to read client cert/key: %v", err)
}

serverCA := x509.NewCertPool()
serverCA.AppendCertsFromPEM(cert)
roundTripper, err := initTransport(serverCA, clientCertPath, clientKeyPath)
roundTripper, err := initTransport(serverCA, &certKeyPair, "")
if err != nil {
t.Errorf("want err to be nil, but got %v", err)
return
Expand All @@ -158,6 +165,57 @@ func TestInitTransportWithClientCertAuth(t *testing.T) {
}
}

func TestInitTransportWithUnixSocket(t *testing.T) {
// start http server listening on unix socket
unixSocketPath := "/tmp/kube-rbac-proxy-test.sock"
l, err := net.Listen("unix", unixSocketPath)
if err != nil {
t.Fatalf("failed to listen on unix socket: %v", err)
}
defer l.Close()
go func() {
srv := &http.Server{
Handler: http.HandlerFunc(okTestHTTPHandler),
}
if err := srv.Serve(l); err != http.ErrServerClosed {
t.Logf("failed to run the test server: %v", err)
}
}()

// get http transport with unix socket
roundTripper, err := initTransport(nil, nil, unixSocketPath)
if err != nil {
t.Errorf("want err to be nil, but got %v", err)
return
}

// send test HTTP request to the server using the transport
httpReq, err := http.NewRequest(http.MethodPost, "http://localhost", nil)
if err != nil {
t.Fatalf("failed to create an HTTP request: %v", err)
}

resp, err := roundTripper.RoundTrip(httpReq)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()

respBody, err := io.ReadAll(resp.Body)
if err != nil {
t.Logf("failed to read response body: %v", err)
}

if resp.StatusCode != http.StatusOK {
t.Logf("response with failure logs:\n%s", respBody)
t.Errorf("expected the response code to be '%d', but it is '%d'", http.StatusOK, resp.StatusCode)
} else {
if string(respBody) != okResponse {
t.Errorf("expected the reponse body to be %q, but it is %q", okResponse, string(respBody))
}
}
}

func generateClientCert(t *testing.T) ([]byte, []byte, *x509.CertPool, error) {
t.Helper()

Expand Down Expand Up @@ -202,3 +260,7 @@ func generateClientCert(t *testing.T) ([]byte, []byte, *x509.CertPool, error) {

return certPEM, privKeyPEM, caPool, nil
}

func okTestHTTPHandler(w http.ResponseWriter, req *http.Request) {
_, _ = w.Write([]byte(okResponse))
}

0 comments on commit a78a4de

Please sign in to comment.