diff --git a/config/config.go b/config/config.go index 884fcbb42..d3fbd039a 100644 --- a/config/config.go +++ b/config/config.go @@ -31,7 +31,7 @@ type CertSource struct { type Listen struct { Addr string - Scheme string + Proto string ReadTimeout time.Duration WriteTimeout time.Duration CertSource CertSource diff --git a/config/load.go b/config/load.go index 29d8f6055..b68686db2 100644 --- a/config/load.go +++ b/config/load.go @@ -211,13 +211,19 @@ func parseListen(cfg string, cs map[string]CertSource, readTimeout, writeTimeout l = Listen{ Addr: opts[0], - Scheme: "http", + Proto: "http", ReadTimeout: readTimeout, WriteTimeout: writeTimeout, } + var csName string for k, v := range kvParse(cfg) { switch k { + case "proto": + l.Proto = v + if l.Proto != "http" && l.Proto != "https" && l.Proto != "tcp+sni" { + return Listen{}, fmt.Errorf("unknown protocol %q", v) + } case "rt": // read timeout d, err := time.ParseDuration(v) if err != nil { @@ -231,14 +237,23 @@ func parseListen(cfg string, cs map[string]CertSource, readTimeout, writeTimeout } l.WriteTimeout = d case "cs": // cert source + csName = v c, ok := cs[v] if !ok { - return Listen{}, fmt.Errorf("unknown certificate source %s", v) + return Listen{}, fmt.Errorf("unknown certificate source %q", v) } l.CertSource = c - l.Scheme = "https" + l.Proto = "https" } } + + if csName != "" && l.Proto != "https" { + return Listen{}, fmt.Errorf("cert source requires proto 'https'") + } + if csName == "" && l.Proto == "https" { + return Listen{}, fmt.Errorf("proto 'https' requires cert source") + } + return } @@ -247,13 +262,13 @@ func parseLegacyListen(cfg string, readTimeout, writeTimeout time.Duration) (l L l = Listen{ Addr: opts[0], - Scheme: "http", + Proto: "http", ReadTimeout: readTimeout, WriteTimeout: writeTimeout, } if len(opts) > 1 { - l.Scheme = "https" + l.Proto = "https" l.CertSource.Type = "file" l.CertSource.CertPath = opts[1] } diff --git a/config/load_test.go b/config/load_test.go index 525b05e9b..f58b104ae 100644 --- a/config/load_test.go +++ b/config/load_test.go @@ -13,7 +13,7 @@ import ( func TestFromProperties(t *testing.T) { in := ` proxy.cs = cs=name;type=path;cert=foo;clientca=bar;refresh=99s;hdr=a: b;caupgcn=furb -proxy.addr = :1234 +proxy.addr = :1234;proto=tcp+sni proxy.localip = 4.4.4.4 proxy.strategy = rr proxy.matcher = prefix @@ -55,7 +55,7 @@ ui.title = fabfab aws.apigw.cert.cn = furb ` out := &Config{ - ListenerValue: []string{":1234"}, + ListenerValue: []string{":1234;proto=tcp+sni"}, CertSourcesValue: []map[string]string{{"cs": "name", "type": "path", "cert": "foo", "clientca": "bar", "refresh": "99s", "hdr": "a: b", "caupgcn": "furb"}}, CertSources: map[string]CertSource{ "name": CertSource{ @@ -111,7 +111,7 @@ aws.apigw.cert.cn = furb Listen: []Listen{ { Addr: ":1234", - Scheme: "http", + Proto: "tcp+sni", ReadTimeout: 5 * time.Second, WriteTimeout: 10 * time.Second, }, @@ -171,7 +171,7 @@ func TestParseScheme(t *testing.T) { func TestParseListen(t *testing.T) { cs := map[string]CertSource{ - "name": CertSource{Type: "foo"}, + "name": CertSource{Name: "name", Type: "foo"}, } tests := []struct { @@ -186,19 +186,29 @@ func TestParseListen(t *testing.T) { }, { ":123", - Listen{Addr: ":123", Scheme: "http"}, + Listen{Addr: ":123", Proto: "http"}, + "", + }, + { + ":123;proto=http", + Listen{Addr: ":123", Proto: "http"}, + "", + }, + { + ":123;proto=tcp+sni", + Listen{Addr: ":123", Proto: "tcp+sni"}, "", }, { ":123;rt=5s;wt=5s", - Listen{Addr: ":123", Scheme: "http", ReadTimeout: 5 * time.Second, WriteTimeout: 5 * time.Second}, + Listen{Addr: ":123", Proto: "http", ReadTimeout: 5 * time.Second, WriteTimeout: 5 * time.Second}, "", }, { ":123;pathA;pathB;pathC", Listen{ - Addr: ":123", - Scheme: "https", + Addr: ":123", + Proto: "https", CertSource: CertSource{ Type: "file", CertPath: "pathA", @@ -211,14 +221,52 @@ func TestParseListen(t *testing.T) { { ":123;cs=name", Listen{ - Addr: ":123", - Scheme: "https", + Addr: ":123", + Proto: "https", CertSource: CertSource{ + Name: "name", Type: "foo", }, }, "", }, + { + ":123;cs=name;proto=https", + Listen{ + Addr: ":123", + Proto: "https", + CertSource: CertSource{ + Name: "name", + Type: "foo", + }, + }, + "", + }, + { + ":123;proto=https", + Listen{}, + "proto 'https' requires cert source", + }, + { + ":123;cs=name;proto=http", + Listen{}, + "cert source requires proto 'https'", + }, + { + ":123;cs=name;proto=tcp+sni", + Listen{}, + "cert source requires proto 'https'", + }, + { + ":123;proto=foo", + Listen{}, + "unknown protocol \"foo\"", + }, + { + ":123;cs=foo", + Listen{}, + "unknown certificate source \"foo\"", + }, } for i, tt := range tests { diff --git a/demo/server/server.go b/demo/server/server.go index 56d2122a6..dda645f6c 100644 --- a/demo/server/server.go +++ b/demo/server/server.go @@ -45,12 +45,15 @@ import ( func main() { var addr, consul, name, prefix, proto, token string + var certFile, keyFile string flag.StringVar(&addr, "addr", "127.0.0.1:5000", "host:port of the service") flag.StringVar(&consul, "consul", "127.0.0.1:8500", "host:port of the consul agent") flag.StringVar(&name, "name", filepath.Base(os.Args[0]), "name of the service") flag.StringVar(&prefix, "prefix", "", "comma-sep list of host/path prefixes to register") flag.StringVar(&proto, "proto", "http", "protocol for endpoints: http or ws") flag.StringVar(&token, "token", "", "consul ACL token") + flag.StringVar(&certFile, "cert", "", "path to cert file") + flag.StringVar(&keyFile, "key", "", "path to key file") flag.Parse() if prefix == "" { @@ -73,19 +76,26 @@ func main() { } } + // register consul health check endpoint + http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "OK") + }) + // start http server go func() { log.Printf("Listening on %s serving %s", addr, prefix) - if err := http.ListenAndServe(addr, nil); err != nil { + + var err error + if certFile != "" { + err = http.ListenAndServeTLS(addr, certFile, keyFile, nil) + } else { + err = http.ListenAndServe(addr, nil) + } + if err != nil { log.Fatal(err) } }() - // register consul health check endpoint - http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, "OK") - }) - // build urlprefix-host/path tag list // e.g. urlprefix-/foo, urlprefix-/bar, ... var tags []string @@ -103,6 +113,21 @@ func main() { log.Fatal(err) } + var check *api.AgentServiceCheck + if certFile != "" { + check = &api.AgentServiceCheck{ + TCP: addr, + Interval: "2s", + Timeout: "1s", + } + } else { + check = &api.AgentServiceCheck{ + HTTP: "http://" + addr + "/health", + Interval: "1s", + Timeout: "1s", + } + } + // register service with health check serviceID := name + "-" + addr service := &api.AgentServiceRegistration{ @@ -111,11 +136,7 @@ func main() { Port: port, Address: host, Tags: tags, - Check: &api.AgentServiceCheck{ - HTTP: "http://" + addr + "/health", - Interval: "1s", - Timeout: "1s", - }, + Check: check, } config := &api.Config{Address: consul, Scheme: "http", Token: token} diff --git a/fabio.properties b/fabio.properties index 336c8171c..ea5f73951 100644 --- a/fabio.properties +++ b/fabio.properties @@ -162,22 +162,38 @@ # proxy.cs = -# proxy.addr configures the HTTP and HTTPS listeners. +# proxy.addr configures listeners. # # Each listener is configured with and address and a # list of optional arguments in the form of # # [host]:port;opt=arg;opt[=arg];... # +# Each listener has a protocol which is configured +# with the 'proto' option for which it routes and +# forwards traffic. +# +# The supported protocols are: +# +# * http for HTTP based protocols +# * https for HTTPS based protocols +# * tcp+sni for an SNI aware TCP proxy +# +# If no 'proto' option is specified then the protocol +# is either 'http' or 'https' depending on whether a +# certificate source is configured via the 'cs' option +# which contains the name of the certificate source. +# +# The TCP+SNI proxy analyzes the ClientHello message +# of TLS connections to extract the server name +# extension and then forwards the encrypted traffic +# to the destination without decrypting the traffic. +# # General options: # # read timeout: rt= # write timeout: wt= # -# HTTPS listeners require a certificate source which is -# configured by setting the 'cs' option to the name of -# a certificate source. -# # Examples: # # # HTTP listener on port 9999 @@ -195,6 +211,9 @@ # # HTTPS listener on port 443 with certificate source # proxy.addr = :443;cs=some-name # +# # TCP listener on port 443 with SNI routing +# proxy.addr = :443;proto=tcp+sni +# # The default is # # proxy.addr = :9999 diff --git a/listen.go b/listen.go index 41fa93ba5..ba14e6258 100644 --- a/listen.go +++ b/listen.go @@ -22,9 +22,16 @@ func init() { } // startListeners runs one or more listeners for the handler -func startListeners(listen []config.Listen, wait time.Duration, h http.Handler) { +func startListeners(listen []config.Listen, wait time.Duration, h http.Handler, tcph proxy.TCPProxy) { for _, l := range listen { - go listenAndServe(l, h) + switch l.Proto { + case "tcp+sni": + go listenAndServeTCP(l, tcph) + case "http", "https": + go listenAndServeHTTP(l, h) + default: + panic("invalid protocol: " + l.Proto) + } } // wait for shutdown signal @@ -39,7 +46,35 @@ func startListeners(listen []config.Listen, wait time.Duration, h http.Handler) log.Print("[INFO] Down") } -func listenAndServe(l config.Listen, h http.Handler) { +func listenAndServeTCP(l config.Listen, h proxy.TCPProxy) { + log.Print("[INFO] TCP+SNI proxy listening on ", l.Addr) + ln, err := net.Listen("tcp", l.Addr) + if err != nil { + exit.Fatal("[FATAL] ", err) + } + defer ln.Close() + + // close the socket on exit to terminate the accept loop + go func() { + <-quit + ln.Close() + }() + + for { + conn, err := ln.Accept() + if err != nil { + select { + case <-quit: + return + default: + exit.Fatal("[FATAL] ", err) + } + } + go h.Serve(conn) + } +} + +func listenAndServeHTTP(l config.Listen, h http.Handler) { srv := &http.Server{ Handler: h, Addr: l.Addr, @@ -47,7 +82,7 @@ func listenAndServe(l config.Listen, h http.Handler) { WriteTimeout: l.WriteTimeout, } - if l.Scheme == "https" { + if l.Proto == "https" { src, err := cert.NewSource(l.CertSource) if err != nil { exit.Fatal("[FATAL] ", err) diff --git a/listen_test.go b/listen_test.go index 6ca76099c..fd12bc35b 100644 --- a/listen_test.go +++ b/listen_test.go @@ -39,11 +39,11 @@ func TestGracefulShutdown(t *testing.T) { // start proxy with graceful shutdown period long enough // to complete one more request. var wg sync.WaitGroup - l := config.Listen{Addr: "127.0.0.1:57777"} + l := config.Listen{Addr: "127.0.0.1:57777", Proto: "http"} wg.Add(1) go func() { defer wg.Done() - startListeners([]config.Listen{l}, 250*time.Millisecond, proxy.New(http.DefaultTransport, config.Proxy{})) + startListeners([]config.Listen{l}, 250*time.Millisecond, proxy.NewHTTPProxy(http.DefaultTransport, config.Proxy{}), nil) }() // trigger shutdown after some time diff --git a/main.go b/main.go index 482162c2a..f93ebb72d 100644 --- a/main.go +++ b/main.go @@ -53,16 +53,19 @@ func main() { registry.Default.Deregister() }) + httpProxy := newHTTPProxy(cfg) + tcpProxy := proxy.NewTCPSNIProxy(cfg.Proxy) + initRuntime(cfg) initMetrics(cfg) initBackend(cfg) go watchBackend() startAdmin(cfg) - startListeners(cfg.Listen, cfg.Proxy.ShutdownWait, newProxy(cfg)) + startListeners(cfg.Listen, cfg.Proxy.ShutdownWait, httpProxy, tcpProxy) exit.Wait() } -func newProxy(cfg *config.Config) *proxy.Proxy { +func newHTTPProxy(cfg *config.Config) http.Handler { if err := route.SetPickerStrategy(cfg.Proxy.Strategy); err != nil { exit.Fatal("[FATAL] ", err) } @@ -82,7 +85,7 @@ func newProxy(cfg *config.Config) *proxy.Proxy { }).Dial, } - return proxy.New(tr, cfg.Proxy) + return proxy.NewHTTPProxy(tr, cfg.Proxy) } func startAdmin(cfg *config.Config) { diff --git a/proxy/clienthello.go b/proxy/clienthello.go new file mode 100644 index 000000000..f132e95aa --- /dev/null +++ b/proxy/clienthello.go @@ -0,0 +1,313 @@ +package proxy + +// record types +const ( + handshakeRecord = 0x16 + clientHelloType = 0x01 +) + +// readServerName returns the server name from a TLS ClientHello message which +// has the server_name extension (SNI). ok is set to true if the ClientHello +// message was parsed successfully. If the server_name extension was not set +// and empty string is returned as serverName. +func readServerName(data []byte) (serverName string, ok bool) { + if m, ok := readClientHello(data); ok { + return m.serverName, true + } + return "", false +} + +// readClientHello +func readClientHello(data []byte) (m *clientHelloMsg, ok bool) { + if len(data) < 9 { + // println("buf too short") + return nil, false + } + + // TLS record header + // ----------------- + // byte 0: rec type (should be 0x16 == Handshake) + // byte 1-2: version (should be 0x3000 < v < 0x3003) + // byte 3-4: rec len + recType := data[0] + if recType != handshakeRecord { + // println("no handshake ") + return nil, false + } + + recLen := int(data[3])<<8 | int(data[4]) + if recLen == 0 || recLen > len(data)-5 { + // println("rec too short") + return nil, false + } + + // Handshake record header + // ----------------------- + // byte 5: hs msg type (should be 0x01 == client_hello) + // byte 6-8: hs msg len + hsType := data[5] + if hsType != clientHelloType { + // println("no client_hello") + return nil, false + } + + hsLen := int(data[6])<<16 | int(data[7])<<8 | int(data[8]) + if hsLen == 0 || hsLen > len(data)-9 { + // println("handshake rec too short") + return nil, false + } + + // byte 9- : client hello msg + // + // m.unmarshal parses the entire handshake message and + // not just the client hello. Therefore, we need to pass + // data from byte 5 instead of byte 9. (see comment below) + m = new(clientHelloMsg) + if !m.unmarshal(data[5:]) { + // println("client_hello unmarshal failed") + return nil, false + } + return m, true +} + +// The code below is a verbatim copy from go1.7/src/crypto/tls/handshake_messages.go +// with some parts commented out. It does enough work to parse a TLS client hello +// message and extract the server name extension since this is all we care about. +// +// Copyright (c) 2016 The Go Authors + +// TLS extension numbers +const ( + extensionServerName uint16 = 0 + // extensionStatusRequest uint16 = 5 + // extensionSupportedCurves uint16 = 10 + // extensionSupportedPoints uint16 = 11 + // extensionSignatureAlgorithms uint16 = 13 + // extensionALPN uint16 = 16 + // extensionSCT uint16 = 18 // https://tools.ietf.org/html/rfc6962#section-6 + // extensionSessionTicket uint16 = 35 + // extensionNextProtoNeg uint16 = 13172 // not IANA assigned + // extensionRenegotiationInfo uint16 = 0xff01 +) + +type clientHelloMsg struct { + raw []byte + vers uint16 + random []byte + sessionId []byte + cipherSuites []uint16 + compressionMethods []uint8 + nextProtoNeg bool + serverName string + ocspStapling bool + scts bool + // supportedCurves []CurveID + supportedPoints []uint8 + ticketSupported bool + sessionTicket []uint8 + //signatureAndHashes []signatureAndHash + secureRenegotiation []byte + secureRenegotiationSupported bool + alpnProtocols []string +} + +func (m *clientHelloMsg) unmarshal(data []byte) bool { + if len(data) < 42 { + return false + } + m.raw = data + m.vers = uint16(data[4])<<8 | uint16(data[5]) + m.random = data[6:38] + sessionIdLen := int(data[38]) + if sessionIdLen > 32 || len(data) < 39+sessionIdLen { + return false + } + m.sessionId = data[39 : 39+sessionIdLen] + data = data[39+sessionIdLen:] + if len(data) < 2 { + return false + } + // cipherSuiteLen is the number of bytes of cipher suite numbers. Since + // they are uint16s, the number must be even. + cipherSuiteLen := int(data[0])<<8 | int(data[1]) + if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen { + return false + } + // numCipherSuites := cipherSuiteLen / 2 + // m.cipherSuites = make([]uint16, numCipherSuites) + // for i := 0; i < numCipherSuites; i++ { + // m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i]) + // if m.cipherSuites[i] == scsvRenegotiation { + // m.secureRenegotiationSupported = true + // } + // } + data = data[2+cipherSuiteLen:] + if len(data) < 1 { + return false + } + compressionMethodsLen := int(data[0]) + if len(data) < 1+compressionMethodsLen { + return false + } + m.compressionMethods = data[1 : 1+compressionMethodsLen] + + data = data[1+compressionMethodsLen:] + + m.nextProtoNeg = false + m.serverName = "" + m.ocspStapling = false + m.ticketSupported = false + m.sessionTicket = nil + // m.signatureAndHashes = nil + m.alpnProtocols = nil + m.scts = false + + if len(data) == 0 { + // ClientHello is optionally followed by extension data + return true + } + if len(data) < 2 { + return false + } + + extensionsLength := int(data[0])<<8 | int(data[1]) + data = data[2:] + if extensionsLength != len(data) { + return false + } + + for len(data) != 0 { + if len(data) < 4 { + return false + } + extension := uint16(data[0])<<8 | uint16(data[1]) + length := int(data[2])<<8 | int(data[3]) + data = data[4:] + if len(data) < length { + return false + } + + switch extension { + case extensionServerName: + d := data[:length] + if len(d) < 2 { + return false + } + namesLen := int(d[0])<<8 | int(d[1]) + d = d[2:] + if len(d) != namesLen { + return false + } + for len(d) > 0 { + if len(d) < 3 { + return false + } + nameType := d[0] + nameLen := int(d[1])<<8 | int(d[2]) + d = d[3:] + if len(d) < nameLen { + return false + } + if nameType == 0 { + m.serverName = string(d[:nameLen]) + break + } + d = d[nameLen:] + } + // case extensionNextProtoNeg: + // if length > 0 { + // return false + // } + // m.nextProtoNeg = true + // case extensionStatusRequest: + // m.ocspStapling = length > 0 && data[0] == statusTypeOCSP + // case extensionSupportedCurves: + // // http://tools.ietf.org/html/rfc4492#section-5.5.1 + // if length < 2 { + // return false + // } + // l := int(data[0])<<8 | int(data[1]) + // if l%2 == 1 || length != l+2 { + // return false + // } + // numCurves := l / 2 + // m.supportedCurves = make([]CurveID, numCurves) + // d := data[2:] + // for i := 0; i < numCurves; i++ { + // m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1]) + // d = d[2:] + // } + // case extensionSupportedPoints: + // // http://tools.ietf.org/html/rfc4492#section-5.5.2 + // if length < 1 { + // return false + // } + // l := int(data[0]) + // if length != l+1 { + // return false + // } + // m.supportedPoints = make([]uint8, l) + // copy(m.supportedPoints, data[1:]) + // case extensionSessionTicket: + // // http://tools.ietf.org/html/rfc5077#section-3.2 + // m.ticketSupported = true + // m.sessionTicket = data[:length] + // case extensionSignatureAlgorithms: + // // https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1 + // if length < 2 || length&1 != 0 { + // return false + // } + // l := int(data[0])<<8 | int(data[1]) + // if l != length-2 { + // return false + // } + // n := l / 2 + // d := data[2:] + // m.signatureAndHashes = make([]signatureAndHash, n) + // for i := range m.signatureAndHashes { + // m.signatureAndHashes[i].hash = d[0] + // m.signatureAndHashes[i].signature = d[1] + // d = d[2:] + // } + // case extensionRenegotiationInfo: + // if length == 0 { + // return false + // } + // d := data[:length] + // l := int(d[0]) + // d = d[1:] + // if l != len(d) { + // return false + // } + + // m.secureRenegotiation = d + // m.secureRenegotiationSupported = true + // case extensionALPN: + // if length < 2 { + // return false + // } + // l := int(data[0])<<8 | int(data[1]) + // if l != length-2 { + // return false + // } + // d := data[2:length] + // for len(d) != 0 { + // stringLen := int(d[0]) + // d = d[1:] + // if stringLen == 0 || stringLen > len(d) { + // return false + // } + // m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen])) + // d = d[stringLen:] + // } + // case extensionSCT: + // m.scts = true + // if length != 0 { + // return false + // } + } + data = data[length:] + } + + return true +} diff --git a/proxy/proxy.go b/proxy/proxy.go index 201a173dd..ab0eae68b 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -9,22 +9,22 @@ import ( gometrics "github.com/rcrowley/go-metrics" ) -// Proxy is a dynamic reverse proxy. -type Proxy struct { +// httpProxy is a dynamic reverse proxy for HTTP and HTTPS protocols. +type httpProxy struct { tr http.RoundTripper cfg config.Proxy requests gometrics.Timer } -func New(tr http.RoundTripper, cfg config.Proxy) *Proxy { - return &Proxy{ +func NewHTTPProxy(tr http.RoundTripper, cfg config.Proxy) http.Handler { + return &httpProxy{ tr: tr, cfg: cfg, requests: gometrics.GetOrRegisterTimer("requests", gometrics.DefaultRegistry), } } -func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (p *httpProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { if ShuttingDown() { http.Error(w, "shutting down", http.StatusServiceUnavailable) return diff --git a/proxy/proxy_integration_test.go b/proxy/proxy_integration_test.go index 34c2185a0..47e94af89 100644 --- a/proxy/proxy_integration_test.go +++ b/proxy/proxy_integration_test.go @@ -23,7 +23,7 @@ func TestProxyProducesCorrectXffHeader(t *testing.T) { route.SetTable(table) tr := &http.Transport{Dial: (&net.Dialer{}).Dial} - proxy := New(tr, config.Proxy{LocalIP: "1.1.1.1", ClientIPHeader: "X-Forwarded-For"}) + proxy := NewHTTPProxy(tr, config.Proxy{LocalIP: "1.1.1.1", ClientIPHeader: "X-Forwarded-For"}) req := &http.Request{ RequestURI: "/", @@ -43,7 +43,7 @@ func TestProxyNoRouteStaus(t *testing.T) { route.SetTable(make(route.Table)) tr := &http.Transport{Dial: (&net.Dialer{}).Dial} cfg := config.Proxy{NoRouteStatus: 999} - proxy := New(tr, cfg) + proxy := NewHTTPProxy(tr, cfg) req := &http.Request{ RequestURI: "/", URL: &url.URL{}, diff --git a/proxy/tcp_sni_proxy.go b/proxy/tcp_sni_proxy.go new file mode 100644 index 000000000..6f5045297 --- /dev/null +++ b/proxy/tcp_sni_proxy.go @@ -0,0 +1,88 @@ +package proxy + +import ( + "fmt" + "io" + "log" + "net" + + "github.com/eBay/fabio/config" + "github.com/eBay/fabio/route" +) + +type TCPProxy interface { + Serve(conn net.Conn) +} + +func NewTCPSNIProxy(cfg config.Proxy) TCPProxy { + return &tcpSNIProxy{cfg: cfg} +} + +type tcpSNIProxy struct { + cfg config.Proxy +} + +func (p *tcpSNIProxy) Serve(in net.Conn) { + defer in.Close() + + if ShuttingDown() { + return + } + + // capture client hello + data := make([]byte, 1024) + n, err := in.Read(data) + if err != nil { + return + } + data = data[:n] + + serverName, ok := readServerName(data) + if !ok { + // println("handshake failed") + fmt.Fprintln(in, "handshake failed") + return + } + + if serverName == "" { + // println("server name missing") + fmt.Fprintln(in, "server_name missing") + return + } + // println(serverName) + + t := route.GetTable().LookupHost(serverName) + if t == nil { + log.Print("[WARN] No route for ", serverName) + return + } + // println(serverName + " -> " + t.URL.Host) + + out, err := net.DialTimeout("tcp", t.URL.Host, p.cfg.DialTimeout) + if err != nil { + log.Println("cannot connect upstream") + return + } + defer out.Close() + // TODO(fs): set timeouts + + // copy client hello + _, err = out.Write(data) + if err != nil { + log.Println("copy client hello failed") + return + } + + errc := make(chan error, 2) + cp := func(dst io.Writer, src io.Reader) { + _, err := io.Copy(dst, src) + errc <- err + } + + go cp(out, in) + go cp(in, out) + err = <-errc + if err != nil && err != io.EOF { + log.Println("error ", err) + } +} diff --git a/route/table.go b/route/table.go index c2744242e..02c053c66 100644 --- a/route/table.go +++ b/route/table.go @@ -247,9 +247,9 @@ func (t Table) Lookup(req *http.Request, trace string) *Target { log.Printf("[TRACE] %s Tracing %s%s", trace, req.Host, req.RequestURI) } - target := t.doLookup(normalizeHost(req), req.RequestURI, trace) + target := t.lookup(normalizeHost(req), req.RequestURI, trace) if target == nil { - target = t.doLookup("", req.RequestURI, trace) + target = t.lookup("", req.RequestURI, trace) } if target != nil && trace != "" { @@ -259,7 +259,11 @@ func (t Table) Lookup(req *http.Request, trace string) *Target { return target } -func (t Table) doLookup(host, path, trace string) *Target { +func (t Table) LookupHost(host string) *Target { + return t.lookup(host, "/", "") +} + +func (t Table) lookup(host, path, trace string) *Target { for _, r := range t[host] { if match(path, r) { n := len(r.Targets)