From 3632e83faa70f2bdd08507c54b251ecd22092e6c Mon Sep 17 00:00:00 2001 From: yuhan6665 <1588741+yuhan6665@users.noreply.github.com> Date: Wed, 25 Sep 2024 21:29:41 -0400 Subject: [PATCH] Transport: Add HTTP3 to HTTP (#3819) --- infra/conf/transport_internet.go | 2 +- transport/internet/http/dialer.go | 175 +++++++++++++++++--------- transport/internet/http/http_test.go | 78 ++++++++++++ transport/internet/http/hub.go | 181 ++++++++++++++++----------- transport/internet/splithttp/hub.go | 8 +- 5 files changed, 316 insertions(+), 128 deletions(-) diff --git a/infra/conf/transport_internet.go b/infra/conf/transport_internet.go index 21cffc0792e0..e89cef0db452 100644 --- a/infra/conf/transport_internet.go +++ b/infra/conf/transport_internet.go @@ -650,7 +650,7 @@ func (p TransportProtocol) Build() (string, error) { return "mkcp", nil case "ws", "websocket": return "websocket", nil - case "h2", "http": + case "h2", "h3", "http": return "http", nil case "grpc", "gun": return "grpc", nil diff --git a/transport/internet/http/dialer.go b/transport/internet/http/dialer.go index a1421d3f9f17..31ded010de62 100644 --- a/transport/internet/http/dialer.go +++ b/transport/internet/http/dialer.go @@ -9,6 +9,8 @@ import ( "sync" "time" + "github.com/quic-go/quic-go" + "github.com/quic-go/quic-go/http3" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" c "github.com/xtls/xray-core/common/ctx" @@ -24,6 +26,13 @@ import ( "golang.org/x/net/http2" ) +// defines the maximum time an idle TCP session can survive in the tunnel, so +// it should be consistent across HTTP versions and with other transports. +const connIdleTimeout = 300 * time.Second + +// consistent with quic-go +const h3KeepalivePeriod = 10 * time.Second + type dialerConf struct { net.Destination *internet.MemoryStreamConfig @@ -48,72 +57,129 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in if tlsConfigs == nil && realityConfigs == nil { return nil, errors.New("TLS or REALITY must be enabled for http transport.").AtWarning() } + isH3 := tlsConfigs != nil && (len(tlsConfigs.NextProtocol) == 1 && tlsConfigs.NextProtocol[0] == "h3") + if isH3 { + dest.Network = net.Network_UDP + } sockopt := streamSettings.SocketSettings if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found { return client, nil } - transport := &http2.Transport{ - DialTLSContext: func(hctx context.Context, string, addr string, tlsConfig *gotls.Config) (net.Conn, error) { - rawHost, rawPort, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - if len(rawPort) == 0 { - rawPort = "443" - } - port, err := net.PortFromString(rawPort) - if err != nil { - return nil, err - } - address := net.ParseAddress(rawHost) + var transport http.RoundTripper + if isH3 { + quicConfig := &quic.Config{ + MaxIdleTimeout: connIdleTimeout, - hctx = c.ContextWithID(hctx, c.IDFromContext(ctx)) - hctx = session.ContextWithOutbounds(hctx, session.OutboundsFromContext(ctx)) - hctx = session.ContextWithTimeoutOnly(hctx, true) + // these two are defaults of quic-go/http3. the default of quic-go (no + // http3) is different, so it is hardcoded here for clarity. + // https://github.com/quic-go/quic-go/blob/b8ea5c798155950fb5bbfdd06cad1939c9355878/http3/client.go#L36-L39 + MaxIncomingStreams: -1, + KeepAlivePeriod: h3KeepalivePeriod, + } + roundTripper := &http3.RoundTripper{ + QUICConfig: quicConfig, + TLSClientConfig: tlsConfigs.GetTLSConfig(tls.WithDestination(dest)), + Dial: func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings) + if err != nil { + return nil, err + } - pconn, err := internet.DialSystem(hctx, net.TCPDestination(address, port), sockopt) - if err != nil { - errors.LogErrorInner(ctx, err, "failed to dial to " + addr) - return nil, err - } + var udpConn net.PacketConn + var udpAddr *net.UDPAddr - if realityConfigs != nil { - return reality.UClient(pconn, realityConfigs, hctx, dest) - } + switch c := conn.(type) { + case *internet.PacketConnWrapper: + var ok bool + udpConn, ok = c.Conn.(*net.UDPConn) + if !ok { + return nil, errors.New("PacketConnWrapper does not contain a UDP connection") + } + udpAddr, err = net.ResolveUDPAddr("udp", c.Dest.String()) + if err != nil { + return nil, err + } + case *net.UDPConn: + udpConn = c + udpAddr, err = net.ResolveUDPAddr("udp", c.RemoteAddr().String()) + if err != nil { + return nil, err + } + default: + udpConn = &internet.FakePacketConn{c} + udpAddr, err = net.ResolveUDPAddr("udp", c.RemoteAddr().String()) + if err != nil { + return nil, err + } + } - var cn tls.Interface - if fingerprint := tls.GetFingerprint(tlsConfigs.Fingerprint); fingerprint != nil { - cn = tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn) - } else { - cn = tls.Client(pconn, tlsConfig).(*tls.Conn) - } - if err := cn.HandshakeContext(ctx); err != nil { - errors.LogErrorInner(ctx, err, "failed to dial to " + addr) - return nil, err - } - if !tlsConfig.InsecureSkipVerify { - if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil { + return quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, cfg) + }, + } + transport = roundTripper + } else { + transportH2 := &http2.Transport{ + DialTLSContext: func(hctx context.Context, string, addr string, tlsConfig *gotls.Config) (net.Conn, error) { + rawHost, rawPort, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + if len(rawPort) == 0 { + rawPort = "443" + } + port, err := net.PortFromString(rawPort) + if err != nil { + return nil, err + } + address := net.ParseAddress(rawHost) + + hctx = c.ContextWithID(hctx, c.IDFromContext(ctx)) + hctx = session.ContextWithOutbounds(hctx, session.OutboundsFromContext(ctx)) + hctx = session.ContextWithTimeoutOnly(hctx, true) + + pconn, err := internet.DialSystem(hctx, net.TCPDestination(address, port), sockopt) + if err != nil { errors.LogErrorInner(ctx, err, "failed to dial to " + addr) return nil, err } - } - negotiatedProtocol := cn.NegotiatedProtocol() - if negotiatedProtocol != http2.NextProtoTLS { - return nil, errors.New("http2: unexpected ALPN protocol " + negotiatedProtocol + "; want q" + http2.NextProtoTLS).AtError() - } - return cn, nil - }, - } - - if tlsConfigs != nil { - transport.TLSClientConfig = tlsConfigs.GetTLSConfig(tls.WithDestination(dest)) - } - - if httpSettings.IdleTimeout > 0 || httpSettings.HealthCheckTimeout > 0 { - transport.ReadIdleTimeout = time.Second * time.Duration(httpSettings.IdleTimeout) - transport.PingTimeout = time.Second * time.Duration(httpSettings.HealthCheckTimeout) + + if realityConfigs != nil { + return reality.UClient(pconn, realityConfigs, hctx, dest) + } + + var cn tls.Interface + if fingerprint := tls.GetFingerprint(tlsConfigs.Fingerprint); fingerprint != nil { + cn = tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn) + } else { + cn = tls.Client(pconn, tlsConfig).(*tls.Conn) + } + if err := cn.HandshakeContext(ctx); err != nil { + errors.LogErrorInner(ctx, err, "failed to dial to " + addr) + return nil, err + } + if !tlsConfig.InsecureSkipVerify { + if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil { + errors.LogErrorInner(ctx, err, "failed to dial to " + addr) + return nil, err + } + } + negotiatedProtocol := cn.NegotiatedProtocol() + if negotiatedProtocol != http2.NextProtoTLS { + return nil, errors.New("http2: unexpected ALPN protocol " + negotiatedProtocol + "; want q" + http2.NextProtoTLS).AtError() + } + return cn, nil + }, + } + if tlsConfigs != nil { + transportH2.TLSClientConfig = tlsConfigs.GetTLSConfig(tls.WithDestination(dest)) + } + if httpSettings.IdleTimeout > 0 || httpSettings.HealthCheckTimeout > 0 { + transportH2.ReadIdleTimeout = time.Second * time.Duration(httpSettings.IdleTimeout) + transportH2.PingTimeout = time.Second * time.Duration(httpSettings.HealthCheckTimeout) + } + transport = transportH2 } client := &http.Client{ @@ -158,9 +224,6 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me Host: dest.NetAddr(), Path: httpSettings.getNormalizedPath(), }, - Proto: "HTTP/2", - ProtoMajor: 2, - ProtoMinor: 0, Header: httpHeaders, } // Disable any compression method from server. diff --git a/transport/internet/http/http_test.go b/transport/internet/http/http_test.go index 3639eb846420..dd6c852daec8 100644 --- a/transport/internet/http/http_test.go +++ b/transport/internet/http/http_test.go @@ -12,6 +12,7 @@ import ( "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol/tls/cert" "github.com/xtls/xray-core/testing/servers/tcp" + "github.com/xtls/xray-core/testing/servers/udp" "github.com/xtls/xray-core/transport/internet" . "github.com/xtls/xray-core/transport/internet/http" "github.com/xtls/xray-core/transport/internet/stat" @@ -92,3 +93,80 @@ func TestHTTPConnection(t *testing.T) { t.Error(r) } } + +func TestH3Connection(t *testing.T) { + port := udp.PickPort() + + listener, err := Listen(context.Background(), net.LocalHostIP, port, &internet.MemoryStreamConfig{ + ProtocolName: "http", + ProtocolSettings: &Config{}, + SecurityType: "tls", + SecuritySettings: &tls.Config{ + NextProtocol: []string{"h3"}, + Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("www.example.com")))}, + }, + }, func(conn stat.Connection) { + go func() { + defer conn.Close() + + b := buf.New() + defer b.Release() + + for { + if _, err := b.ReadFrom(conn); err != nil { + return + } + _, err := conn.Write(b.Bytes()) + common.Must(err) + } + }() + }) + common.Must(err) + + defer listener.Close() + + time.Sleep(time.Second) + + dctx := context.Background() + conn, err := Dial(dctx, net.TCPDestination(net.LocalHostIP, port), &internet.MemoryStreamConfig{ + ProtocolName: "http", + ProtocolSettings: &Config{}, + SecurityType: "tls", + SecuritySettings: &tls.Config{ + NextProtocol: []string{"h3"}, + ServerName: "www.example.com", + AllowInsecure: true, + }, + }) + common.Must(err) + defer conn.Close() + + const N = 1024 + b1 := make([]byte, N) + common.Must2(rand.Read(b1)) + b2 := buf.New() + + nBytes, err := conn.Write(b1) + common.Must(err) + if nBytes != N { + t.Error("write: ", nBytes) + } + + b2.Clear() + common.Must2(b2.ReadFullFrom(conn, N)) + if r := cmp.Diff(b2.Bytes(), b1); r != "" { + t.Error(r) + } + + nBytes, err = conn.Write(b1) + common.Must(err) + if nBytes != N { + t.Error("write: ", nBytes) + } + + b2.Clear() + common.Must2(b2.ReadFullFrom(conn, N)) + if r := cmp.Diff(b2.Bytes(), b1); r != "" { + t.Error(r) + } +} diff --git a/transport/internet/http/hub.go b/transport/internet/http/hub.go index 3421ae649dd8..96fe8f629d43 100644 --- a/transport/internet/http/hub.go +++ b/transport/internet/http/hub.go @@ -2,11 +2,14 @@ package http import ( "context" + gotls "crypto/tls" "io" "net/http" "strings" "time" + "github.com/quic-go/quic-go" + "github.com/quic-go/quic-go/http3" goreality "github.com/xtls/reality" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" @@ -23,10 +26,12 @@ import ( ) type Listener struct { - server *http.Server - handler internet.ConnHandler - local net.Addr - config *Config + server *http.Server + h3server *http3.Server + handler internet.ConnHandler + local net.Addr + config *Config + isH3 bool } func (l *Listener) Addr() net.Addr { @@ -34,7 +39,14 @@ func (l *Listener) Addr() net.Addr { } func (l *Listener) Close() error { - return l.server.Close() + if l.h3server != nil { + if err := l.h3server.Close(); err != nil { + return err + } + } else if l.server != nil { + return l.server.Close() + } + return errors.New("listener does not have an HTTP/3 server or h2 server") } type flushWriter struct { @@ -119,43 +131,33 @@ func (l *Listener) ServeHTTP(writer http.ResponseWriter, request *http.Request) func Listen(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, handler internet.ConnHandler) (internet.Listener, error) { httpSettings := streamSettings.ProtocolSettings.(*Config) - var listener *Listener - if port == net.Port(0) { // unix - listener = &Listener{ - handler: handler, - local: &net.UnixAddr{ - Name: address.Domain(), - Net: "unix", - }, - config: httpSettings, - } - } else { // tcp - listener = &Listener{ - handler: handler, - local: &net.TCPAddr{ - IP: address.IP(), - Port: int(port), - }, - config: httpSettings, - } - } - - var server *http.Server config := tls.ConfigFromStreamSettings(streamSettings) + var tlsConfig *gotls.Config if config == nil { - h2s := &http2.Server{} - - server = &http.Server{ - Addr: serial.Concat(address, ":", port), - Handler: h2c.NewHandler(listener, h2s), - ReadHeaderTimeout: time.Second * 4, + tlsConfig = &gotls.Config{} + } else { + tlsConfig = config.GetTLSConfig() + } + isH3 := len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3" + listener := &Listener{ + handler: handler, + config: httpSettings, + isH3: isH3, + } + if port == net.Port(0) { // unix + listener.local = &net.UnixAddr{ + Name: address.Domain(), + Net: "unix", + } + } else if isH3 { // udp + listener.local = &net.UDPAddr{ + IP: address.IP(), + Port: int(port), } } else { - server = &http.Server{ - Addr: serial.Concat(address, ":", port), - TLSConfig: config.GetTLSConfig(tls.WithNextProto("h2")), - Handler: listener, - ReadHeaderTimeout: time.Second * 4, + listener.local = &net.TCPAddr{ + IP: address.IP(), + Port: int(port), } } @@ -163,45 +165,84 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti errors.LogWarning(ctx, "accepting PROXY protocol") } - listener.server = server - go func() { - var streamListener net.Listener - var err error - if port == net.Port(0) { // unix - streamListener, err = internet.ListenSystem(ctx, &net.UnixAddr{ - Name: address.Domain(), - Net: "unix", - }, streamSettings.SocketSettings) - if err != nil { - errors.LogErrorInner(ctx, err, "failed to listen on ", address) - return - } - } else { // tcp - streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{ - IP: address.IP(), - Port: int(port), - }, streamSettings.SocketSettings) - if err != nil { - errors.LogErrorInner(ctx, err, "failed to listen on ", address, ":", port) - return - } + if isH3 { + Conn, err := internet.ListenSystemPacket(context.Background(), listener.local, streamSettings.SocketSettings) + if err != nil { + return nil, errors.New("failed to listen UDP(for SH3) on ", address, ":", port).Base(err) } + h3listener, err := quic.ListenEarly(Conn, tlsConfig, nil) + if err != nil { + return nil, errors.New("failed to listen QUIC(for SH3) on ", address, ":", port).Base(err) + } + errors.LogInfo(ctx, "listening QUIC(for SH3) on ", address, ":", port) - if config == nil { - if config := reality.ConfigFromStreamSettings(streamSettings); config != nil { - streamListener = goreality.NewListener(streamListener, config.GetREALITYConfig()) + listener.h3server = &http3.Server{ + Handler: listener, + } + go func() { + if err := listener.h3server.ServeListener(h3listener); err != nil { + errors.LogWarningInner(ctx, err, "failed to serve http3 for splithttp") } - err = server.Serve(streamListener) - if err != nil { - errors.LogInfoInner(ctx, err, "stopping serving H2C or REALITY H2") + }() + } else { + var server *http.Server + if config == nil { + h2s := &http2.Server{} + + server = &http.Server{ + Addr: serial.Concat(address, ":", port), + Handler: h2c.NewHandler(listener, h2s), + ReadHeaderTimeout: time.Second * 4, } } else { - err = server.ServeTLS(streamListener, "", "") - if err != nil { - errors.LogInfoInner(ctx, err, "stopping serving TLS H2") + server = &http.Server{ + Addr: serial.Concat(address, ":", port), + TLSConfig: config.GetTLSConfig(tls.WithNextProto("h2")), + Handler: listener, + ReadHeaderTimeout: time.Second * 4, } } - }() + + listener.server = server + go func() { + var streamListener net.Listener + var err error + if port == net.Port(0) { // unix + streamListener, err = internet.ListenSystem(ctx, &net.UnixAddr{ + Name: address.Domain(), + Net: "unix", + }, streamSettings.SocketSettings) + if err != nil { + errors.LogErrorInner(ctx, err, "failed to listen on ", address) + return + } + } else { // tcp + streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{ + IP: address.IP(), + Port: int(port), + }, streamSettings.SocketSettings) + if err != nil { + errors.LogErrorInner(ctx, err, "failed to listen on ", address, ":", port) + return + } + } + + if config == nil { + if config := reality.ConfigFromStreamSettings(streamSettings); config != nil { + streamListener = goreality.NewListener(streamListener, config.GetREALITYConfig()) + } + err = server.Serve(streamListener) + if err != nil { + errors.LogInfoInner(ctx, err, "stopping serving H2C or REALITY H2") + } + } else { + err = server.ServeTLS(streamListener, "", "") + if err != nil { + errors.LogInfoInner(ctx, err, "stopping serving TLS H2") + } + } + }() + } return listener, nil } diff --git a/transport/internet/splithttp/hub.go b/transport/internet/splithttp/hub.go index 423bf6e37fae..4545db644a2f 100644 --- a/transport/internet/splithttp/hub.go +++ b/transport/internet/splithttp/hub.go @@ -365,7 +365,13 @@ func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSet // Addr implements net.Listener.Addr(). func (ln *Listener) Addr() net.Addr { - return ln.listener.Addr() + if ln.h3listener != nil { + return ln.h3listener.Addr() + } + if ln.listener != nil { + return ln.listener.Addr() + } + return nil } // Close implements net.Listener.Close().