Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support upstream via unix socket #209

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions cmd/kube-rbac-proxy/app/kube-rbac-proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,11 @@ type completedProxyRunOptions struct {
secureListenAddress string
healthzPath string

upstreamURL *url.URL
upstreamForceH2C bool
upstreamCABundle *x509.CertPool
upstreamURL *url.URL
upstreamForceH2C bool
upstreamUnixSocket string
upstreamCABundle *x509.CertPool
upstreamClientCert *tls.Certificate

auth *proxy.Config
tls *options.TLSConfig
Expand Down Expand Up @@ -211,12 +213,18 @@ func Complete(o *options.ProxyRunOptions) (*completedProxyRunOptions, error) {
ignorePaths: o.IgnorePaths,
}

if o.Upstream == "" && o.UpstreamUnixSocket != "" {
o.Upstream = "http://localhost"
}

completed.upstreamURL, err = url.Parse(o.Upstream)
if err != nil {
return nil, fmt.Errorf("failed to parse upstream URL: %w", err)
}

if upstreamCAPath := o.UpstreamCAFile; len(upstreamCAPath) > 0 {
completed.upstreamUnixSocket = o.UpstreamUnixSocket

if upstreamCAPath := o.TLS.UpstreamCAFile; len(upstreamCAPath) > 0 {
upstreamCAPEM, err := os.ReadFile(upstreamCAPath)
if err != nil {
return nil, err
Expand All @@ -229,6 +237,14 @@ func Complete(o *options.ProxyRunOptions) (*completedProxyRunOptions, error) {
completed.upstreamCABundle = upstreamCACertPool
}

if len(o.TLS.UpstreamClientCertFile) > 0 {
certKeyPair, err := tls.LoadX509KeyPair(o.TLS.UpstreamClientCertFile, o.TLS.UpstreamClientKeyFile)
if err != nil {
return nil, fmt.Errorf("failed to read upstream client cert/key: %w", err)
}
completed.upstreamClientCert = &certKeyPair
}

completed.auth = o.Auth
completed.tls = o.TLS

Expand Down Expand Up @@ -296,15 +312,16 @@ func Run(cfg *completedProxyRunOptions) error {
sarAuthorizer,
)

upstreamTransport, err := initTransport(cfg.upstreamCABundle, cfg.tls.UpstreamClientCertFile, cfg.tls.UpstreamClientKeyFile)
upstreamTransport, err := initTransport(cfg.upstreamCABundle, cfg.upstreamClientCert, cfg.upstreamUnixSocket)
if err != nil {
return fmt.Errorf("failed to set up upstream TLS connection: %w", err)
return fmt.Errorf("failed to init upstream transport: %w", err)
}

proxy := httputil.NewSingleHostReverseProxy(cfg.upstreamURL)
proxy.Transport = upstreamTransport

if cfg.upstreamForceH2C {
upstreamUnixSocket := cfg.upstreamUnixSocket
// Force http/2 for connections to the upstream i.e. do not start with HTTP1.1 UPGRADE req to
// initialize http/2 session.
// See https://github.com/golang/go/issues/14141#issuecomment-219212895 for more context
Expand All @@ -314,6 +331,9 @@ func Run(cfg *completedProxyRunOptions) error {
// Do disable TLS.
// In combination with the schema check above. We could enforce h2c against the upstream server
DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) {
if upstreamUnixSocket != "" {
return net.Dial("unix", upstreamUnixSocket)
}
return net.Dial(netw, addr)
},
}
Expand Down
6 changes: 4 additions & 2 deletions cmd/kube-rbac-proxy/app/options/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type ProxyRunOptions struct {

Upstream string
UpstreamForceH2C bool
UpstreamCAFile string
UpstreamUnixSocket string
Auth *proxy.Config
TLS *TLSConfig
KubeconfigLocation string
Expand All @@ -50,6 +50,7 @@ type TLSConfig struct {
CipherSuites []string
ReloadInterval time.Duration

UpstreamCAFile string
UpstreamClientCertFile string
UpstreamClientKeyFile string
}
Expand Down Expand Up @@ -78,7 +79,7 @@ func (o *ProxyRunOptions) Flags() k8sapiflag.NamedFlagSets {
flagset.StringVar(&o.SecureListenAddress, "secure-listen-address", "", "The address the kube-rbac-proxy HTTPs server should listen on.")
flagset.StringVar(&o.Upstream, "upstream", "", "The upstream URL to proxy to once requests have successfully been authenticated and authorized.")
flagset.BoolVar(&o.UpstreamForceH2C, "upstream-force-h2c", false, "Force h2c to communiate with the upstream. This is required when the upstream speaks h2c(http/2 cleartext - insecure variant of http/2) only. For example, go-grpc server in the insecure mode, such as helm's tiller w/o TLS, speaks h2c only")
flagset.StringVar(&o.UpstreamCAFile, "upstream-ca-file", "", "The CA the upstream uses for TLS connection. This is required when the upstream uses TLS and its own CA certificate")
flagset.StringVar(&o.UpstreamUnixSocket, "upstream-unix-socket", "", "The upstream unix socket to proxy to once requests have successfully been authenticated and authorized.")
flagset.StringVar(&o.ConfigFileName, "config-file", "", "Configuration file to configure kube-rbac-proxy.")
flagset.StringSliceVar(&o.AllowPaths, "allow-paths", nil, "Comma-separated list of paths against which kube-rbac-proxy pattern-matches the incoming request. If the request doesn't match, kube-rbac-proxy responds with a 404 status code. If omitted, the incoming request path isn't checked. Cannot be used with --ignore-paths.")
flagset.StringSliceVar(&o.IgnorePaths, "ignore-paths", nil, "Comma-separated list of paths against which kube-rbac-proxy pattern-matches the incoming request. If the requst matches, it will proxy the request without performing an authentication or authorization check. Cannot be used with --allow-paths.")
Expand All @@ -90,6 +91,7 @@ func (o *ProxyRunOptions) Flags() k8sapiflag.NamedFlagSets {
flagset.StringVar(&o.TLS.MinVersion, "tls-min-version", "VersionTLS12", "Minimum TLS version supported. Value must match version names from https://golang.org/pkg/crypto/tls/#pkg-constants.")
flagset.StringSliceVar(&o.TLS.CipherSuites, "tls-cipher-suites", nil, "Comma-separated list of cipher suites for the server. Values are from tls package constants (https://golang.org/pkg/crypto/tls/#pkg-constants). If omitted, the default Go cipher suites will be used")
flagset.DurationVar(&o.TLS.ReloadInterval, "tls-reload-interval", time.Minute, "The interval at which to watch for TLS certificate changes, by default set to 1 minute.")
flagset.StringVar(&o.TLS.UpstreamCAFile, "upstream-ca-file", "", "The CA the upstream uses for TLS connection. This is required when the upstream uses TLS and its own CA certificate")
flagset.StringVar(&o.TLS.UpstreamClientCertFile, "upstream-client-cert-file", "", "If set, the client will be used to authenticate the proxy to upstream. Requires --upstream-client-key-file to be set, too.")
flagset.StringVar(&o.TLS.UpstreamClientKeyFile, "upstream-client-key-file", "", "The key matching the certificate from --upstream-client-cert-file. If set, requires --upstream-client-cert-file to be set, too.")

Expand Down
87 changes: 64 additions & 23 deletions cmd/kube-rbac-proxy/app/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,89 @@ limitations under the License.
package app

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"net/http"
"time"
)

func initTransport(upstreamCAPool *x509.CertPool, upstreamClientCertPath, upstreamClientKeyPath string) (http.RoundTripper, error) {
if upstreamCAPool == nil {
func initTransport(upstreamCAPool *x509.CertPool, upstreamClientCert *tls.Certificate, upstreamUnixSocket string) (http.RoundTripper, error) {
if upstreamCAPool == nil && upstreamClientCert == nil && upstreamUnixSocket == "" {
return http.DefaultTransport, nil
}

var certKeyPair tls.Certificate
if len(upstreamClientCertPath) > 0 {
var err error
certKeyPair, err = tls.LoadX509KeyPair(upstreamClientCertPath, upstreamClientKeyPath)
if err != nil {
return nil, fmt.Errorf("failed to read upstream client cert/key: %w", err)
}
builder := newHTTPTransportBuilder()

if upstreamCAPool != nil {
builder.withRootCAs(upstreamCAPool)
}

// http.Transport sourced from go 1.10.7
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
if upstreamClientCert != nil {
builder.withClientCerts(*upstreamClientCert)
}

if upstreamUnixSocket != "" {
builder.withUnixDialContext(upstreamUnixSocket)
}

return builder.build(), nil
}

type httpTransportBuilder struct {
tlsClientConfig *tls.Config
dialContext func(ctx context.Context, network, addr string) (net.Conn, error)
}

func newHTTPTransportBuilder() *httpTransportBuilder {
return &httpTransportBuilder{
dialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
}
}

func (b *httpTransportBuilder) withRootCAs(certs *x509.CertPool) *httpTransportBuilder {
if b.tlsClientConfig != nil {
b.tlsClientConfig.RootCAs = certs
} else {
b.tlsClientConfig = &tls.Config{RootCAs: certs}
}
return b
}

func (b *httpTransportBuilder) withClientCerts(certs ...tls.Certificate) *httpTransportBuilder {
if b.tlsClientConfig != nil {
b.tlsClientConfig.Certificates = certs
} else {
b.tlsClientConfig = &tls.Config{Certificates: certs}
}
return b
}

func (b *httpTransportBuilder) withUnixDialContext(socket string) *httpTransportBuilder {
b.dialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
return (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext(ctx, "unix", socket)
}
return b
}

func (b *httpTransportBuilder) build() *http.Transport {
// http.Transport sourced from go 1.10.7
return &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: b.dialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{
RootCAs: upstreamCAPool,
},
TLSClientConfig: b.tlsClientConfig,
}

if certKeyPair.Certificate != nil {
transport.TLSClientConfig.Certificates = []tls.Certificate{certKeyPair}
}

return transport, nil
}
72 changes: 67 additions & 5 deletions cmd/kube-rbac-proxy/app/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,12 @@ import (
"k8s.io/client-go/util/keyutil"
)

const (
okResponse = "ok"
)

func TestInitTransportWithDefault(t *testing.T) {
roundTripper, err := initTransport(nil, "", "")
roundTripper, err := initTransport(nil, nil, "")
if err != nil {
t.Errorf("want err to be nil, but got %v", err)
return
Expand All @@ -57,7 +61,7 @@ func TestInitTransportWithCustomCA(t *testing.T) {
upstreamCAPool := x509.NewCertPool()
upstreamCAPool.AppendCertsFromPEM(upstreamCAPEM)

roundTripper, err := initTransport(upstreamCAPool, "", "")
roundTripper, err := initTransport(upstreamCAPool, nil, "")
if err != nil {
t.Fatalf("want err to be nil, but got %v", err)
}
Expand All @@ -69,8 +73,7 @@ func TestInitTransportWithCustomCA(t *testing.T) {

func testHTTPHandler(w http.ResponseWriter, req *http.Request) {
if len(req.TLS.PeerCertificates) > 0 {
_, _ = w.Write([]byte("ok"))
return
okTestHTTPHandler(w, req)
} else {
reqDump, _ := httputil.DumpRequest(req, false)
resp := fmt.Sprintf("got request without client certificates:\n%s\n", reqDump)
Expand Down Expand Up @@ -128,10 +131,14 @@ func TestInitTransportWithClientCertAuth(t *testing.T) {
if err := keyutil.WriteKey(clientKeyPath, clientKey); err != nil {
t.Fatalf("failed to write client key: %v", err)
}
certKeyPair, err := tls.X509KeyPair(clientCert, clientKey)
if err != nil {
t.Fatalf("failed to read client cert/key: %v", err)
}

serverCA := x509.NewCertPool()
serverCA.AppendCertsFromPEM(cert)
roundTripper, err := initTransport(serverCA, clientCertPath, clientKeyPath)
roundTripper, err := initTransport(serverCA, &certKeyPair, "")
if err != nil {
t.Errorf("want err to be nil, but got %v", err)
return
Expand All @@ -158,6 +165,57 @@ func TestInitTransportWithClientCertAuth(t *testing.T) {
}
}

func TestInitTransportWithUnixSocket(t *testing.T) {
// start http server listening on unix socket
unixSocketPath := "/tmp/kube-rbac-proxy-test.sock"
l, err := net.Listen("unix", unixSocketPath)
if err != nil {
t.Fatalf("failed to listen on unix socket: %v", err)
}
defer l.Close()
go func() {
srv := &http.Server{
Handler: http.HandlerFunc(okTestHTTPHandler),
}
if err := srv.Serve(l); err != http.ErrServerClosed {
t.Logf("failed to run the test server: %v", err)
}
}()

// get http transport with unix socket
roundTripper, err := initTransport(nil, nil, unixSocketPath)
if err != nil {
t.Errorf("want err to be nil, but got %v", err)
return
}

// send test HTTP request to the server using the transport
httpReq, err := http.NewRequest(http.MethodPost, "http://localhost", nil)
if err != nil {
t.Fatalf("failed to create an HTTP request: %v", err)
}

resp, err := roundTripper.RoundTrip(httpReq)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()

respBody, err := io.ReadAll(resp.Body)
if err != nil {
t.Logf("failed to read response body: %v", err)
}

if resp.StatusCode != http.StatusOK {
t.Logf("response with failure logs:\n%s", respBody)
t.Errorf("expected the response code to be '%d', but it is '%d'", http.StatusOK, resp.StatusCode)
} else {
if string(respBody) != okResponse {
t.Errorf("expected the reponse body to be %q, but it is %q", okResponse, string(respBody))
}
}
}

func generateClientCert(t *testing.T) ([]byte, []byte, *x509.CertPool, error) {
t.Helper()

Expand Down Expand Up @@ -202,3 +260,7 @@ func generateClientCert(t *testing.T) ([]byte, []byte, *x509.CertPool, error) {

return certPEM, privKeyPEM, caPool, nil
}

func okTestHTTPHandler(w http.ResponseWriter, req *http.Request) {
_, _ = w.Write([]byte(okResponse))
}