Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WebSocket 0-RTT #374

Merged
merged 1 commit into from
Mar 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions infra/conf/transport_internet.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package conf
import (
"encoding/json"
"math"
"net/url"
"strconv"
"strings"

"github.com/golang/protobuf/proto"
Expand Down Expand Up @@ -155,9 +157,20 @@ func (c *WebSocketConfig) Build() (proto.Message, error) {
Value: value,
})
}
var ed uint32
if u, err := url.Parse(path); err == nil {
if q := u.Query(); q.Get("ed") != "" {
Ed, _ := strconv.Atoi(q.Get("ed"))
ed = uint32(Ed)
q.Del("ed")
u.RawQuery = q.Encode()
path = u.String()
}
}
config := &websocket.Config{
Path: path,
Header: header,
Ed: ed,
}
if c.AcceptProxyProtocol {
config.AcceptProxyProtocol = c.AcceptProxyProtocol
Expand Down
15 changes: 12 additions & 3 deletions transport/internet/websocket/config.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions transport/internet/websocket/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ message Config {
repeated Header header = 3;

bool accept_proxy_protocol = 4;

uint32 ed = 5;
}
3 changes: 2 additions & 1 deletion transport/internet/websocket/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ type connection struct {
remoteAddr net.Addr
}

func newConnection(conn *websocket.Conn, remoteAddr net.Addr) *connection {
func newConnection(conn *websocket.Conn, remoteAddr net.Addr, extraReader io.Reader) *connection {
return &connection{
conn: conn,
remoteAddr: remoteAddr,
reader: extraReader,
}
}

Expand Down
88 changes: 81 additions & 7 deletions transport/internet/websocket/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ package websocket

import (
"context"
"encoding/base64"
"io"
"time"

"github.com/gorilla/websocket"

"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/session"
Expand All @@ -15,10 +18,21 @@ import (
// Dial dials a WebSocket connection to the given destination.
func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx))

conn, err := dialWebsocket(ctx, dest, streamSettings)
if err != nil {
return nil, newError("failed to dial WebSocket").Base(err)
var conn net.Conn
if streamSettings.ProtocolSettings.(*Config).Ed > 0 {
ctx, cancel := context.WithCancel(ctx)
conn = &delayDialConn{
dialed: make(chan bool, 1),
cancel: cancel,
ctx: ctx,
dest: dest,
streamSettings: streamSettings,
}
} else {
var err error
if conn, err = dialWebSocket(ctx, dest, streamSettings, nil); err != nil {
return nil, newError("failed to dial WebSocket").Base(err)
}
}
return internet.Connection(conn), nil
}
Expand All @@ -27,7 +41,7 @@ func init() {
common.Must(internet.RegisterTransportDialer(protocolName, Dial))
}

func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (net.Conn, error) {
func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig, ed []byte) (net.Conn, error) {
wsSettings := streamSettings.ProtocolSettings.(*Config)

dialer := &websocket.Dialer{
Expand All @@ -52,7 +66,12 @@ func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *in
}
uri := protocol + "://" + host + wsSettings.GetNormalizedPath()

conn, resp, err := dialer.Dial(uri, wsSettings.GetRequestHeader())
header := wsSettings.GetRequestHeader()
if ed != nil {
header.Add("Sec-WebSocket-Protocol", base64.StdEncoding.EncodeToString(ed))
}

conn, resp, err := dialer.Dial(uri, header)
if err != nil {
var reason string
if resp != nil {
Expand All @@ -61,5 +80,60 @@ func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *in
return nil, newError("failed to dial to (", uri, "): ", reason).Base(err)
}

return newConnection(conn, conn.RemoteAddr()), nil
return newConnection(conn, conn.RemoteAddr(), nil), nil
}

type delayDialConn struct {
net.Conn
closed bool
dialed chan bool
cancel context.CancelFunc
ctx context.Context
dest net.Destination
streamSettings *internet.MemoryStreamConfig
}

func (d *delayDialConn) Write(b []byte) (int, error) {
if d.closed {
return 0, io.ErrClosedPipe
}
if d.Conn == nil {
ed := b
if len(ed) > int(d.streamSettings.ProtocolSettings.(*Config).Ed) {
ed = nil
}
var err error
if d.Conn, err = dialWebSocket(d.ctx, d.dest, d.streamSettings, ed); err != nil {
d.Close()
return 0, newError("failed to dial WebSocket").Base(err)
}
d.dialed <- true
if ed != nil {
return len(ed), nil
}
}
return d.Conn.Write(b)
}

func (d *delayDialConn) Read(b []byte) (int, error) {
if d.closed {
return 0, io.ErrClosedPipe
}
if d.Conn == nil {
select {
case <-d.ctx.Done():
return 0, io.ErrUnexpectedEOF
case <-d.dialed:
}
}
return d.Conn.Read(b)
}

func (d *delayDialConn) Close() error {
d.closed = true
d.cancel()
if d.Conn == nil {
return nil
}
return d.Conn.Close()
}
13 changes: 12 additions & 1 deletion transport/internet/websocket/hub.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package websocket

import (
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"io"
"net/http"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -51,7 +55,14 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
}
}

h.ln.addConn(newConnection(conn, remoteAddr))
var extraReader io.Reader
if len(request.Header["Sec-WebSocket-Protocol"]) > 0 {
decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(request.Header["Sec-WebSocket-Protocol"][0]))
if ed, err := io.ReadAll(decoder); err == nil && len(ed) > 0 {
extraReader = bytes.NewReader(ed)
}
}
h.ln.addConn(newConnection(conn, remoteAddr, extraReader))
}

type Listener struct {
Expand Down
4 changes: 2 additions & 2 deletions transport/internet/websocket/ws.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*Package websocket implements Websocket transport
/*Package websocket implements WebSocket transport

Websocket transport implements an HTTP(S) compliable, surveillance proof transport method with plausible deniability.
WebSocket transport implements an HTTP(S) compliable, surveillance proof transport method with plausible deniability.
*/
package websocket

Expand Down