Skip to content

Commit

Permalink
Add WebSocket 0-RTT support (#374)
Browse files Browse the repository at this point in the history
  • Loading branch information
RPRX authored Mar 13, 2021
1 parent 9adce5a commit a10557c
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 14 deletions.
13 changes: 13 additions & 0 deletions infra/conf/transport_internet.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package conf
import (
"encoding/json"
"math"
"net/url"
"strconv"
"strings"

"github.com/golang/protobuf/proto"
Expand Down Expand Up @@ -155,9 +157,20 @@ func (c *WebSocketConfig) Build() (proto.Message, error) {
Value: value,
})
}
var ed uint32
if u, err := url.Parse(path); err == nil {
if q := u.Query(); q.Get("ed") != "" {
Ed, _ := strconv.Atoi(q.Get("ed"))
ed = uint32(Ed)
q.Del("ed")
u.RawQuery = q.Encode()
path = u.String()
}
}
config := &websocket.Config{
Path: path,
Header: header,
Ed: ed,
}
if c.AcceptProxyProtocol {
config.AcceptProxyProtocol = c.AcceptProxyProtocol
Expand Down
15 changes: 12 additions & 3 deletions transport/internet/websocket/config.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions transport/internet/websocket/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ message Config {
repeated Header header = 3;

bool accept_proxy_protocol = 4;

uint32 ed = 5;
}
3 changes: 2 additions & 1 deletion transport/internet/websocket/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ type connection struct {
remoteAddr net.Addr
}

func newConnection(conn *websocket.Conn, remoteAddr net.Addr) *connection {
func newConnection(conn *websocket.Conn, remoteAddr net.Addr, extraReader io.Reader) *connection {
return &connection{
conn: conn,
remoteAddr: remoteAddr,
reader: extraReader,
}
}

Expand Down
88 changes: 81 additions & 7 deletions transport/internet/websocket/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ package websocket

import (
"context"
"encoding/base64"
"io"
"time"

"github.com/gorilla/websocket"

"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/session"
Expand All @@ -15,10 +18,21 @@ import (
// Dial dials a WebSocket connection to the given destination.
func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx))

conn, err := dialWebsocket(ctx, dest, streamSettings)
if err != nil {
return nil, newError("failed to dial WebSocket").Base(err)
var conn net.Conn
if streamSettings.ProtocolSettings.(*Config).Ed > 0 {
ctx, cancel := context.WithCancel(ctx)
conn = &delayDialConn{
dialed: make(chan bool, 1),
cancel: cancel,
ctx: ctx,
dest: dest,
streamSettings: streamSettings,
}
} else {
var err error
if conn, err = dialWebSocket(ctx, dest, streamSettings, nil); err != nil {
return nil, newError("failed to dial WebSocket").Base(err)
}
}
return internet.Connection(conn), nil
}
Expand All @@ -27,7 +41,7 @@ func init() {
common.Must(internet.RegisterTransportDialer(protocolName, Dial))
}

func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (net.Conn, error) {
func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig, ed []byte) (net.Conn, error) {
wsSettings := streamSettings.ProtocolSettings.(*Config)

dialer := &websocket.Dialer{
Expand All @@ -52,7 +66,12 @@ func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *in
}
uri := protocol + "://" + host + wsSettings.GetNormalizedPath()

conn, resp, err := dialer.Dial(uri, wsSettings.GetRequestHeader())
header := wsSettings.GetRequestHeader()
if ed != nil {
header.Add("Sec-WebSocket-Protocol", base64.StdEncoding.EncodeToString(ed))
}

conn, resp, err := dialer.Dial(uri, header)
if err != nil {
var reason string
if resp != nil {
Expand All @@ -61,5 +80,60 @@ func dialWebsocket(ctx context.Context, dest net.Destination, streamSettings *in
return nil, newError("failed to dial to (", uri, "): ", reason).Base(err)
}

return newConnection(conn, conn.RemoteAddr()), nil
return newConnection(conn, conn.RemoteAddr(), nil), nil
}

type delayDialConn struct {
net.Conn
closed bool
dialed chan bool
cancel context.CancelFunc
ctx context.Context
dest net.Destination
streamSettings *internet.MemoryStreamConfig
}

func (d *delayDialConn) Write(b []byte) (int, error) {
if d.closed {
return 0, io.ErrClosedPipe
}
if d.Conn == nil {
ed := b
if len(ed) > int(d.streamSettings.ProtocolSettings.(*Config).Ed) {
ed = nil
}
var err error
if d.Conn, err = dialWebSocket(d.ctx, d.dest, d.streamSettings, ed); err != nil {
d.Close()
return 0, newError("failed to dial WebSocket").Base(err)
}
d.dialed <- true
if ed != nil {
return len(ed), nil
}
}
return d.Conn.Write(b)
}

func (d *delayDialConn) Read(b []byte) (int, error) {
if d.closed {
return 0, io.ErrClosedPipe
}
if d.Conn == nil {
select {
case <-d.ctx.Done():
return 0, io.ErrUnexpectedEOF
case <-d.dialed:
}
}
return d.Conn.Read(b)
}

func (d *delayDialConn) Close() error {
d.closed = true
d.cancel()
if d.Conn == nil {
return nil
}
return d.Conn.Close()
}
13 changes: 12 additions & 1 deletion transport/internet/websocket/hub.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package websocket

import (
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"io"
"net/http"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -51,7 +55,14 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
}
}

h.ln.addConn(newConnection(conn, remoteAddr))
var extraReader io.Reader
if len(request.Header["Sec-WebSocket-Protocol"]) > 0 {
decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(request.Header["Sec-WebSocket-Protocol"][0]))
if ed, err := io.ReadAll(decoder); err == nil && len(ed) > 0 {
extraReader = bytes.NewReader(ed)
}
}
h.ln.addConn(newConnection(conn, remoteAddr, extraReader))
}

type Listener struct {
Expand Down
4 changes: 2 additions & 2 deletions transport/internet/websocket/ws.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*Package websocket implements Websocket transport
/*Package websocket implements WebSocket transport
Websocket transport implements an HTTP(S) compliable, surveillance proof transport method with plausible deniability.
WebSocket transport implements an HTTP(S) compliable, surveillance proof transport method with plausible deniability.
*/
package websocket

Expand Down

0 comments on commit a10557c

Please sign in to comment.